From ec1d09c018eead8ca2c958163a72f44260b91e9a Mon Sep 17 00:00:00 2001
From: Avi Kivity <avi@cloudius-systems.com>
Date: Mon, 25 Feb 2013 17:25:33 +0200
Subject: [PATCH] mutex: allow recusrive locking

prex code depends on this.

TODO: make it optional
---
 core/mutex.cc       | 19 +++++++++++++------
 include/osv/mutex.h |  2 ++
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/core/mutex.cc b/core/mutex.cc
index e1a52d349..ff5b449e5 100644
--- a/core/mutex.cc
+++ b/core/mutex.cc
@@ -14,8 +14,9 @@ extern "C" void mutex_lock(mutex_t *mutex)
     w.thread = sched::thread::current();
 
     spin_lock(&mutex->_wait_lock);
-    if (!mutex->_owner) {
+    if (!mutex->_owner || mutex->_owner == w.thread) {
         mutex->_owner = w.thread;
+        ++mutex->_depth;
         spin_unlock(&mutex->_wait_lock);
         return;
     }
@@ -43,8 +44,9 @@ extern "C" bool mutex_trylock(mutex_t *mutex)
 {
     bool ret = false;
     spin_lock(&mutex->_wait_lock);
-    if (!mutex->_owner) {
+    if (!mutex->_owner || mutex->_owner == sched::thread::current()) {
         mutex->_owner = sched::thread::current();
+        ++mutex->_depth;
         ret = true;
     }
     spin_unlock(&mutex->_wait_lock);
@@ -54,11 +56,16 @@ extern "C" bool mutex_trylock(mutex_t *mutex)
 extern "C" void mutex_unlock(mutex_t *mutex)
 {
     spin_lock(&mutex->_wait_lock);
-    if (mutex->_wait_list.first) {
-        mutex->_owner = mutex->_wait_list.first->thread;
-        mutex->_wait_list.first->thread->wake();
+    if (mutex->_depth == 1) {
+        if (mutex->_wait_list.first) {
+            mutex->_owner = mutex->_wait_list.first->thread;
+            mutex->_wait_list.first->thread->wake();
+        } else {
+            mutex->_owner = nullptr;
+            --mutex->_depth;
+        }
     } else {
-        mutex->_owner = nullptr;
+        --mutex->_depth;
     }
     spin_unlock(&mutex->_wait_lock);
 }
diff --git a/include/osv/mutex.h b/include/osv/mutex.h
index 63ba016af..8c9c4f34a 100644
--- a/include/osv/mutex.h
+++ b/include/osv/mutex.h
@@ -21,6 +21,7 @@ static inline void spinlock_init(spinlock_t *sl)
 
 struct cmutex {
     spinlock_t _wait_lock;
+    unsigned _depth;
     void *_owner;
     struct wait_list {
         struct waiter *first;
@@ -38,6 +39,7 @@ void mutex_unlock(mutex_t* m);
 
 static __always_inline void mutex_init(mutex_t* m)
 {
+    m->_depth = 0;
     m->_owner = 0;
     m->_wait_list.first = 0;
     m->_wait_list.last = 0;
-- 
GitLab