From 9fe982bea1ad43cad506e2cc57631c3e2d09adc2 Mon Sep 17 00:00:00 2001
From: Christoph Hellwig <hch@cloudius-systems.com>
Date: Tue, 5 Feb 2013 08:42:41 +0100
Subject: [PATCH] provide a spinlock/mutex implementation that is SMP safe and
 callable from C code

---
 core/mutex.cc       |  78 +++++++++++++++++++----------
 include/mutex.hh    |  10 ++--
 include/osv/prex.h  |   1 -
 include/osv/vnode.h |   1 +
 scripts/loader.py   | 116 --------------------------------------------
 5 files changed, 57 insertions(+), 149 deletions(-)

diff --git a/core/mutex.cc b/core/mutex.cc
index 2882e37e3..68f7991c2 100644
--- a/core/mutex.cc
+++ b/core/mutex.cc
@@ -1,47 +1,75 @@
+
 #include "mutex.hh"
-#include "sched.hh"
+#include <sched.hh>
 
-void mutex::lock()
+struct waiter {
+    struct waiter*	prev;
+    struct waiter*	next;
+    sched::thread*	thread;
+};
+
+extern "C" void mutex_wait(mutex_t *mutex)
 {
-    // FIXME: use atomics
-    if (!_locked) {
-        _locked = true;
-        return;
+    struct waiter w;
+
+    w.thread = sched::thread::current();
+
+    spin_lock(&mutex->_wait_lock);
+    if (!mutex->_wait_list.first) {
+        mutex->_wait_list.first = &w;
     } else {
-        auto me = sched::thread::current();
-        _waiters.push_back(me);
-        sched::thread::wait_until([=] {
-            return !_locked && _waiters.front() == me;
-        });
-        _waiters.pop_front();
+        mutex->_wait_list.last->next = &w;
+        w.prev = mutex->_wait_list.last;
     }
+    mutex->_wait_list.last = &w;
+    spin_unlock(&mutex->_wait_lock);
+
+    sched::thread::wait_until([=] {
+        return !mutex->_locked && mutex->_wait_list.first == &w;
+    });
+
+    spin_lock(&mutex->_wait_lock);
+    if (mutex->_wait_list.first == &w)
+        mutex->_wait_list.first = w.next;
+    else
+        w.prev->next = w.next;
+
+    if (mutex->_wait_list.last == &w)
+        mutex->_wait_list.last = w.prev;
+    spin_unlock(&mutex->_wait_lock);
+}
+
+extern "C" void mutex_unlock(mutex_t *mutex)
+{
+    __sync_lock_release(&mutex->_locked, 0);
+
+    spin_lock(&mutex->_wait_lock);
+    if (mutex->_wait_list.first)
+        mutex->_wait_list.first->thread->wake();
+    spin_unlock(&mutex->_wait_lock);
+}
+
+void mutex::lock()
+{
+    mutex_lock(&_mutex);
 }
 
 bool mutex::try_lock()
 {
-    if (_locked) {
-        return false;
-    } else {
-        _locked = true;
-        return true;
-    }
+    return mutex_trylock(&_mutex);
 }
 
 void mutex::unlock()
 {
-    _locked = false;
-    if (!_waiters.empty()) {
-        _waiters.front()->wake();
-    }
+    mutex_unlock(&_mutex);
 }
 
 void spinlock::lock()
 {
-    while (__sync_lock_test_and_set(&_locked, 1))
-        ;
+    spin_lock(&_lock);
 }
 
 void spinlock::unlock()
 {
-    __sync_lock_release(&_locked, 0);
+    spin_unlock(&_lock);
 }
diff --git a/include/mutex.hh b/include/mutex.hh
index 289d10504..58d0fdff0 100644
--- a/include/mutex.hh
+++ b/include/mutex.hh
@@ -3,10 +3,7 @@
 
 #include <mutex>
 #include <list>
-
-namespace sched {
-class thread;
-}
+#include <osv/mutex.h>
 
 class mutex {
 public:
@@ -14,8 +11,7 @@ public:
     bool try_lock();
     void unlock();
 private:
-    bool _locked;
-    std::list<sched::thread*> _waiters;
+    mutex_t _mutex;
 };
 
 // Use mutex instead, except where impossible
@@ -24,7 +20,7 @@ public:
     void lock();
     void unlock();
 private:
-    bool _locked;
+    spinlock_t _lock;
 };
 
 template <class Lock, class Func>
diff --git a/include/osv/prex.h b/include/osv/prex.h
index dc1745900..6bebac8b3 100644
--- a/include/osv/prex.h
+++ b/include/osv/prex.h
@@ -29,7 +29,6 @@ __BEGIN_DECLS
 typedef unsigned long   object_t;
 typedef unsigned long   task_t;
 typedef unsigned long   thread_t;
-typedef unsigned long   mutex_t;
 typedef unsigned long   cond_t;
 typedef unsigned long   sem_t;
 typedef unsigned long   device_t;
diff --git a/include/osv/vnode.h b/include/osv/vnode.h
index 18aedf9ea..8f48eaadd 100755
--- a/include/osv/vnode.h
+++ b/include/osv/vnode.h
@@ -34,6 +34,7 @@
 #include <sys/stat.h>
 #include <osv/prex.h>
 #include <osv/uio.h>
+#include <osv/mutex.h>
 #include "file.h"
 #include <osv/list.h>
 #include "dirent.h"
diff --git a/scripts/loader.py b/scripts/loader.py
index 6bb20af50..3596d92f2 100644
--- a/scripts/loader.py
+++ b/scripts/loader.py
@@ -83,119 +83,3 @@ def ulong(x):
     return x
 
 ulong_type = gdb.lookup_type('unsigned long')
-
-class osv_syms(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv syms',
-                             gdb.COMMAND_USER, gdb.COMPLETE_NONE)
-    def invoke(self, arg, from_tty):
-        p = gdb.lookup_global_symbol('elf::program::s_objs').value()
-        p = p.dereference().address
-        while long(p.dereference()):
-            obj = p.dereference().dereference()
-            base = long(obj['_base'])
-            path = obj['_pathname']['_M_dataplus']['_M_p'].string()
-            path = translate(path)
-            print path, hex(base)
-            load_elf(path, base)
-            p += 1
-
-class osv_info(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv info', gdb.COMMAND_USER,
-                             gdb.COMPLETE_COMMAND, True)
-
-def thread_list():
-    ret = []
-    thread_list = gdb.lookup_global_symbol('sched::thread_list').value()
-    root = thread_list['data_']['root_plus_size_']['root_']
-    node = root['next_']
-    thread_type = gdb.lookup_type('sched::thread')
-    void_ptr = gdb.lookup_type('void').pointer()
-    for f in thread_type.fields():
-        if f.name == '_thread_list_link':
-            link_offset = f.bitpos / 8
-    while node != root.address:
-        t = node.cast(void_ptr) - link_offset
-        t = t.cast(thread_type.pointer())
-        ret.append(t.dereference())
-        node = node['next_']
-    return ret
-
-class thread_context(object):
-    def __init__(self, thread):
-        self.old_frame = gdb.selected_frame()
-        self.new_frame = gdb.newest_frame()
-        self.new_frame.select()
-        self.running = (not long(thread['_on_runqueue'])
-                        and not long(thread['_waiting']))
-        self.old_rsp = ulong(gdb.parse_and_eval('$rsp').cast(ulong_type))
-        self.old_rip = ulong(gdb.parse_and_eval('$rip').cast(ulong_type))
-        self.old_rbp = ulong(gdb.parse_and_eval('$rbp').cast(ulong_type))
-        if not self.running:
-            self.old_frame.select()
-            self.new_rsp = thread['_state']['rsp'].cast(ulong_type)
-    def __enter__(self):
-        self.new_frame.select()
-        if not self.running:
-            gdb.execute('set $rsp = %s' % (self.new_rsp + 16))
-            inf = gdb.selected_inferior()
-            stack = inf.read_memory(self.new_rsp, 16)
-            (new_rip, new_rbp) = struct.unpack('qq', stack)
-            gdb.execute('set $rip = %s' % (new_rip + 1))
-            gdb.execute('set $rbp = %s' % new_rbp)
-    def __exit__(self, *_):
-        if not self.running:
-            gdb.execute('set $rsp = %s' % self.old_rsp)
-            gdb.execute('set $rip = %s' % self.old_rip)
-            gdb.execute('set $rbp = %s' % self.old_rbp)
-        self.old_frame.select()
-
-class osv_info_threads(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv info threads',
-                             gdb.COMMAND_USER, gdb.COMPLETE_NONE)
-    def invoke(self, arg, for_tty):
-        for t in thread_list():
-            with thread_context(t):
-                fr = gdb.selected_frame()
-                sal = fr.find_sal()
-                gdb.write('%s %s at %s:%s\n' % (ulong(t.address),
-                                                fr.function().name,
-                                                sal.symtab.filename,
-                                                sal.line))
-
-class osv_thread(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv thread', gdb.COMMAND_USER,
-                             gdb.COMPLETE_COMMAND, True)
-    #def invoke(self, arg, for_tty):
-    #    for t in thread_list():
-    #        if t.address.cast(ulong_type) == long(arg, 0):
-    #            print 'match'
-
-class osv_thread_apply(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv thread apply', gdb.COMMAND_USER,
-                             gdb.COMPLETE_COMMAND, True)
-
-class osv_thread_apply_all(gdb.Command):
-    def __init__(self):
-        gdb.Command.__init__(self, 'osv thread apply all', gdb.COMMAND_USER,
-                             gdb.COMPLETE_NONE)
-    def invoke(self, arg, from_tty):
-        for t in thread_list():
-            gdb.write('thread %s\n\n' % t.address)
-            with thread_context(t):
-                gdb.execute(arg, from_tty)
-            gdb.write('\n')
-
-
-osv()
-osv_heap()
-osv_syms()
-osv_info()
-osv_info_threads()
-osv_thread()
-osv_thread_apply()
-osv_thread_apply_all()
\ No newline at end of file
-- 
GitLab