From fef8f933d994a0cfdde32f7931339f20a2ba62c4 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Mon, 28 Dec 2015 08:28:24 -0800 Subject: [PATCH] Add SafeThread to fix thread JNI shutdown races. During JVM shutdown, some JNI calls may not return, so it's not possible to reliably perform a join() during static variable destruction (which occurs as the JVM unloads the JNI module). Also, due to static variable destruction, it's not safe to use any members of a static class instance from a separate thread of execution. SafeThread is a templated thread class and a related owner class that's designed for safe operation and shutdown of threads in the presence of callbacks that may not return. It also passes ownership of variables from the static instance to the thread, so the thread can safely operate until it exits (the last operation of the thread being to destroy its instance). Notifiers, RpcServer, and Logger now use SafeThread to ensure race-free destruction in both C++ and Java. All Java callback threads are now marked as Java daemon threads so they don't keep the JVM running after main() terminates. All Java callback threads are now named so their purpose is more easily identified in a debugger. Add SetRpcServerOnStart and SetRpcServerOnExit (similar to Listener). --- include/ntcore_c.h | 3 + include/ntcore_cpp.h | 3 + java/lib/NetworkTablesJNI.cpp | 153 ++++++++++++---------------------- ntcore.def | 4 + src/Notifier.cpp | 149 +++++++++++++++++++-------------- src/Notifier.h | 60 +------------ src/RpcServer.cpp | 78 +++++++---------- src/RpcServer.h | 29 ++++--- src/SafeThread.cpp | 31 +++++++ src/SafeThread.h | 93 +++++++++++++++++++++ src/Storage.cpp | 2 +- src/ntcore_c.cpp | 8 ++ src/ntcore_cpp.cpp | 17 ++-- 13 files changed, 343 insertions(+), 287 deletions(-) create mode 100644 src/SafeThread.cpp create mode 100644 src/SafeThread.h diff --git a/include/ntcore_c.h b/include/ntcore_c.h index 59c57d77d9..f3dee735fc 100644 --- a/include/ntcore_c.h +++ b/include/ntcore_c.h @@ -293,6 +293,9 @@ int NT_NotifierDestroyed(); * Remote Procedure Call Functions */ +void NT_SetRpcServerOnStart(void (*on_start)(void *data), void *data); +void NT_SetRpcServerOnExit(void (*on_exit)(void *data), void *data); + typedef char *(*NT_RpcCallback)(void *data, const char *name, size_t name_len, const char *params, size_t params_len, size_t *results_len); diff --git a/include/ntcore_cpp.h b/include/ntcore_cpp.h index 5ebc4098de..5fdff76b45 100644 --- a/include/ntcore_cpp.h +++ b/include/ntcore_cpp.h @@ -204,6 +204,9 @@ bool NotifierDestroyed(); * Remote Procedure Call Functions */ +void SetRpcServerOnStart(std::function on_start); +void SetRpcServerOnExit(std::function on_exit); + typedef std::function RpcCallback; diff --git a/java/lib/NetworkTablesJNI.cpp b/java/lib/NetworkTablesJNI.cpp index d235fd34ff..8faaa2b433 100644 --- a/java/lib/NetworkTablesJNI.cpp +++ b/java/lib/NetworkTablesJNI.cpp @@ -10,6 +10,7 @@ #include "edu_wpi_first_wpilibj_networktables_NetworkTablesJNI.h" #include "ntcore.h" #include "atomic_static.h" +#include "SafeThread.h" // // Globals and load/unload @@ -30,8 +31,12 @@ static JNIEnv *listenerEnv = nullptr; static void ListenerOnStart() { if (!jvm) return; JNIEnv *env; - if (jvm->AttachCurrentThread(reinterpret_cast(&env), - nullptr) != JNI_OK) + JavaVMAttachArgs args; + args.version = JNI_VERSION_1_2; + args.name = const_cast("NTListener"); + args.group = nullptr; + if (jvm->AttachCurrentThreadAsDaemon(reinterpret_cast(&env), + &args) != JNI_OK) return; if (!env || !env->functions) return; listenerEnv = env; @@ -1281,28 +1286,10 @@ JNIEXPORT jlong JNICALL Java_edu_wpi_first_wpilibj_networktables_NetworkTablesJN // Instead, this class attaches just once. When a hardware notification // occurs, a condition variable wakes up this thread and this thread actually // makes the call into Java. -class LoggerThreadJNI { +class LoggerThreadJNI : public nt::SafeThread { public: - static LoggerThreadJNI& GetInstance() { - ATOMIC_STATIC(LoggerThreadJNI, instance); - return instance; - } - LoggerThreadJNI(); - ~LoggerThreadJNI(); - void SetFunc(JNIEnv* env, jobject func, jmethodID mid); - void Start(); - void Stop(); + void Main(); - void Log(unsigned int level, const char* file, unsigned int line, - const char* msg); - - private: - void ThreadMain(); - - std::thread m_thread; - std::mutex m_mutex; - std::condition_variable m_cond; - std::atomic_bool m_active; struct LogMessage { LogMessage(unsigned int level_, const char* file_, unsigned int line_, const char* msg_) @@ -1313,84 +1300,56 @@ class LoggerThreadJNI { std::string msg; }; std::queue m_queue; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; jobject m_func = nullptr; jmethodID m_mid; - - ATOMIC_STATIC_DECL(LoggerThreadJNI) }; -ATOMIC_STATIC_INIT(LoggerThreadJNI) +class LoggerJNI : public nt::SafeThreadOwner { + public: + static LoggerJNI& GetInstance() { + ATOMIC_STATIC(LoggerJNI, instance); + return instance; + } + void SetFunc(JNIEnv* env, jobject func, jmethodID mid); + void Log(unsigned int level, const char* file, unsigned int line, + const char* msg); -LoggerThreadJNI::LoggerThreadJNI() { - m_active = false; -} + private: + ATOMIC_STATIC_DECL(LoggerJNI) +}; -LoggerThreadJNI::~LoggerThreadJNI() { - Stop(); -} +ATOMIC_STATIC_INIT(LoggerJNI) -void LoggerThreadJNI::SetFunc(JNIEnv* env, jobject func, jmethodID mid) { - std::lock_guard lock(m_mutex); +void LoggerJNI::SetFunc(JNIEnv* env, jobject func, jmethodID mid) { + auto thr = GetThread(); + if (!thr) return; // free global reference - if (m_func) env->DeleteGlobalRef(m_func); + if (thr->m_func) env->DeleteGlobalRef(thr->m_func); // create global reference - m_func = env->NewGlobalRef(func); - m_mid = mid; + thr->m_func = env->NewGlobalRef(func); + thr->m_mid = mid; } -void LoggerThreadJNI::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&LoggerThreadJNI::ThreadMain, this); +void LoggerJNI::Log(unsigned int level, const char *file, unsigned int line, + const char *msg) { + auto thr = GetThread(); + if (!thr) return; + thr->m_queue.emplace(level, file, line, msg); + thr->m_cond.notify_one(); } -void LoggerThreadJNI::Stop() { - { - std::lock_guard lock(m_mutex); - if (!m_active) return; - m_active = false; - } - m_cond.notify_one(); // wake up thread - - // join threads, with timeout - if (m_thread.joinable()) { - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} - -void LoggerThreadJNI::Log(unsigned int level, const char *file, - unsigned int line, const char *msg) { - std::lock_guard lock(m_mutex); - if (!m_active) return; - m_queue.emplace(level, file, line, msg); - m_cond.notify_one(); -} - -void LoggerThreadJNI::ThreadMain() { +void LoggerThreadJNI::Main() { JNIEnv *env; - jint rs = jvm->AttachCurrentThread((void**)&env, NULL); + JavaVMAttachArgs args; + args.version = JNI_VERSION_1_2; + args.name = const_cast("NTLogger"); + args.group = nullptr; + jint rs = jvm->AttachCurrentThreadAsDaemon((void**)&env, &args); if (rs != JNI_OK) return; std::unique_lock lock(m_mutex); while (m_active) { - m_cond.wait(lock, [&] { return !m_active || !m_queue.empty(); }); + m_cond.wait(lock, [&] { return !(m_active && m_queue.empty()); }); if (!m_active) break; while (!m_queue.empty()) { if (!m_active) break; @@ -1399,25 +1358,21 @@ void LoggerThreadJNI::ThreadMain() { auto func = m_func; auto mid = m_mid; lock.unlock(); // don't hold mutex during callback execution - env->CallVoidMethod(func, mid, (jint)item.level, - ToJavaString(env, item.file), (jint)item.line, - ToJavaString(env, item.msg)); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); + { + JavaLocal file(env, ToJavaString(env, item.file)); + JavaLocal msg(env, ToJavaString(env, item.msg)); + env->CallVoidMethod(func, mid, (jint)item.level, file.obj(), + (jint)item.line, msg.obj()); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } } lock.lock(); } } if (jvm) jvm->DetachCurrentThread(); - - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } } extern "C" { @@ -1439,14 +1394,14 @@ JNIEXPORT void JNICALL Java_edu_wpi_first_wpilibj_networktables_NetworkTablesJNI cls, "apply", "(ILjava/lang/String;ILjava/lang/String;)V"); if (!mid) return; - auto& thread = LoggerThreadJNI::GetInstance(); - thread.SetFunc(env, func, mid); - thread.Start(); + auto& logger = LoggerJNI::GetInstance(); + logger.Start(); + logger.SetFunc(env, func, mid); nt::SetLogger( [](unsigned int level, const char *file, unsigned int line, const char *msg) { - LoggerThreadJNI::GetInstance().Log(level, file, line, msg); + LoggerJNI::GetInstance().Log(level, file, line, msg); }, minLevel); } diff --git a/ntcore.def b/ntcore.def index d731352a2f..5ddea50198 100644 --- a/ntcore.def +++ b/ntcore.def @@ -76,6 +76,10 @@ NT_FreeCharArray @78 NT_NotifierDestroyed @79 NT_StopRpcServer @80 NT_StopNotifier @81 +NT_SetListenerOnStart @82 +NT_SetListenerOnExit @83 +NT_SetRpcServerOnStart @84 +NT_SetRpcServerOnExit @85 ; JNI functions JNI_OnLoad diff --git a/src/Notifier.cpp b/src/Notifier.cpp index fb2e8dcbc7..00beb9f0fb 100644 --- a/src/Notifier.cpp +++ b/src/Notifier.cpp @@ -7,53 +7,75 @@ #include "Notifier.h" +#include +#include + using namespace nt; ATOMIC_STATIC_INIT(Notifier) bool Notifier::s_destroyed = false; +class Notifier::Thread : public SafeThread { + public: + Thread(std::function on_start, std::function on_exit) + : m_on_start(on_start), m_on_exit(on_exit) {} + + void Main(); + + struct EntryListener { + EntryListener(StringRef prefix_, EntryListenerCallback callback_, + unsigned int flags_) + : prefix(prefix_), callback(callback_), flags(flags_) {} + + std::string prefix; + EntryListenerCallback callback; + unsigned int flags; + }; + std::vector m_entry_listeners; + std::vector m_conn_listeners; + + struct EntryNotification { + EntryNotification(StringRef name_, std::shared_ptr value_, + unsigned int flags_, EntryListenerCallback only_) + : name(name_), + value(value_), + flags(flags_), + only(only_) {} + + std::string name; + std::shared_ptr value; + unsigned int flags; + EntryListenerCallback only; + }; + std::queue m_entry_notifications; + + struct ConnectionNotification { + ConnectionNotification(bool connected_, const ConnectionInfo& conn_info_, + ConnectionListenerCallback only_) + : connected(connected_), conn_info(conn_info_), only(only_) {} + + bool connected; + ConnectionInfo conn_info; + ConnectionListenerCallback only; + }; + std::queue m_conn_notifications; + + std::function m_on_start; + std::function m_on_exit; +}; + Notifier::Notifier() { - m_active = false; m_local_notifiers = false; s_destroyed = false; } -Notifier::~Notifier() { - s_destroyed = true; - Stop(); -} +Notifier::~Notifier() { s_destroyed = true; } -void Notifier::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&Notifier::ThreadMain, this); -} +void Notifier::Start() { m_owner.Start(new Thread(m_on_start, m_on_exit)); } -void Notifier::Stop() { - m_active = false; - // send notification so the thread terminates - m_cond.notify_one(); - if (m_thread.joinable()) { - // join with timeout - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} +void Notifier::Stop() { m_owner.Stop(); } -void Notifier::ThreadMain() { +void Notifier::Thread::Main() { if (m_on_start) m_on_start(); std::unique_lock lock(m_mutex); @@ -138,65 +160,66 @@ void Notifier::ThreadMain() { done: if (m_on_exit) m_on_exit(); - - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } } unsigned int Notifier::AddEntryListener(StringRef prefix, EntryListenerCallback callback, unsigned int flags) { - std::lock_guard lock(m_mutex); - unsigned int uid = m_entry_listeners.size(); - m_entry_listeners.emplace_back(prefix, callback, flags); + auto thr = m_owner.GetThread(); + if (!thr) { + Start(); + thr = m_owner.GetThread(); + } + unsigned int uid = thr->m_entry_listeners.size(); + thr->m_entry_listeners.emplace_back(prefix, callback, flags); if ((flags & NT_NOTIFY_LOCAL) != 0) m_local_notifiers = true; return uid + 1; } void Notifier::RemoveEntryListener(unsigned int entry_listener_uid) { + auto thr = m_owner.GetThread(); + if (!thr) return; --entry_listener_uid; - std::lock_guard lock(m_mutex); - if (entry_listener_uid < m_entry_listeners.size()) - m_entry_listeners[entry_listener_uid].callback = nullptr; + if (entry_listener_uid < thr->m_entry_listeners.size()) + thr->m_entry_listeners[entry_listener_uid].callback = nullptr; } void Notifier::NotifyEntry(StringRef name, std::shared_ptr value, unsigned int flags, EntryListenerCallback only) { - if (!m_active) return; // optimization: don't generate needless local queue entries if we have // no local listeners (as this is a common case on the server side) if ((flags & NT_NOTIFY_LOCAL) != 0 && !m_local_notifiers) return; - std::unique_lock lock(m_mutex); - m_entry_notifications.emplace(name, value, flags, only); - lock.unlock(); - m_cond.notify_one(); + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_entry_notifications.emplace(name, value, flags, only); + thr->m_cond.notify_one(); } unsigned int Notifier::AddConnectionListener( ConnectionListenerCallback callback) { - std::lock_guard lock(m_mutex); - unsigned int uid = m_entry_listeners.size(); - m_conn_listeners.emplace_back(callback); + auto thr = m_owner.GetThread(); + if (!thr) { + Start(); + thr = m_owner.GetThread(); + } + unsigned int uid = thr->m_entry_listeners.size(); + thr->m_conn_listeners.emplace_back(callback); return uid + 1; } void Notifier::RemoveConnectionListener(unsigned int conn_listener_uid) { + auto thr = m_owner.GetThread(); + if (!thr) return; --conn_listener_uid; - std::lock_guard lock(m_mutex); - if (conn_listener_uid < m_conn_listeners.size()) - m_conn_listeners[conn_listener_uid] = nullptr; + if (conn_listener_uid < thr->m_conn_listeners.size()) + thr->m_conn_listeners[conn_listener_uid] = nullptr; } void Notifier::NotifyConnection(bool connected, const ConnectionInfo& conn_info, ConnectionListenerCallback only) { - if (!m_active) return; - std::unique_lock lock(m_mutex); - m_conn_notifications.emplace(connected, conn_info, only); - lock.unlock(); - m_cond.notify_one(); + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_conn_notifications.emplace(connected, conn_info, only); + thr->m_cond.notify_one(); } diff --git a/src/Notifier.h b/src/Notifier.h index d10054cf95..4a878ad718 100644 --- a/src/Notifier.h +++ b/src/Notifier.h @@ -8,16 +8,11 @@ #ifndef NT_NOTIFIER_H_ #define NT_NOTIFIER_H_ -#include -#include -#include -#include -#include -#include -#include +#include #include "atomic_static.h" #include "ntcore_cpp.h" +#include "SafeThread.h" namespace nt { @@ -33,7 +28,6 @@ class Notifier { void Start(); void Stop(); - bool active() const { return m_active; } bool local_notifiers() const { return m_local_notifiers; } static bool destroyed() { return s_destroyed; } @@ -57,57 +51,11 @@ class Notifier { private: Notifier(); - void ThreadMain(); + class Thread; + SafeThreadOwner m_owner; - std::atomic_bool m_active; std::atomic_bool m_local_notifiers; - std::mutex m_mutex; - std::condition_variable m_cond; - - struct EntryListener { - EntryListener(StringRef prefix_, EntryListenerCallback callback_, - unsigned int flags_) - : prefix(prefix_), callback(callback_), flags(flags_) {} - - std::string prefix; - EntryListenerCallback callback; - unsigned int flags; - }; - std::vector m_entry_listeners; - std::vector m_conn_listeners; - - struct EntryNotification { - EntryNotification(StringRef name_, std::shared_ptr value_, - unsigned int flags_, EntryListenerCallback only_) - : name(name_), - value(value_), - flags(flags_), - only(only_) {} - - std::string name; - std::shared_ptr value; - unsigned int flags; - EntryListenerCallback only; - }; - std::queue m_entry_notifications; - - struct ConnectionNotification { - ConnectionNotification(bool connected_, const ConnectionInfo& conn_info_, - ConnectionListenerCallback only_) - : connected(connected_), conn_info(conn_info_), only(only_) {} - - bool connected; - ConnectionInfo conn_info; - ConnectionListenerCallback only; - }; - std::queue m_conn_notifications; - - std::thread m_thread; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; - std::function m_on_start; std::function m_on_exit; diff --git a/src/RpcServer.cpp b/src/RpcServer.cpp index 43d37de1ae..669931f0c8 100644 --- a/src/RpcServer.cpp +++ b/src/RpcServer.cpp @@ -7,70 +7,57 @@ #include "RpcServer.h" +#include + #include "Log.h" using namespace nt; ATOMIC_STATIC_INIT(RpcServer) +class RpcServer::Thread : public SafeThread { + public: + Thread(std::function on_start, std::function on_exit) + : m_on_start(on_start), m_on_exit(on_exit) {} + + void Main(); + + std::queue m_call_queue; + + std::function m_on_start; + std::function m_on_exit; +}; + RpcServer::RpcServer() { - m_active = false; m_terminating = false; } RpcServer::~RpcServer() { Logger::GetInstance().SetLogger(nullptr); - Stop(); m_terminating = true; m_poll_cond.notify_all(); } void RpcServer::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&RpcServer::ThreadMain, this); + auto thr = m_owner.GetThread(); + if (!thr) m_owner.Start(new Thread(m_on_start, m_on_exit)); } -void RpcServer::Stop() { - m_active = false; - if (m_thread.joinable()) { - // send notification so the thread terminates - m_call_cond.notify_one(); - // join with timeout - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} +void RpcServer::Stop() { m_owner.Stop(); } void RpcServer::ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, SendMsgFunc send_response) { - std::unique_lock lock(m_mutex); - - if (func) - m_call_queue.emplace(name, msg, func, conn_id, send_response); - else + if (func) { + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_call_queue.emplace(name, msg, func, conn_id, send_response); + thr->m_cond.notify_one(); + } else { + std::lock_guard lock(m_mutex); m_poll_queue.emplace(name, msg, func, conn_id, send_response); - - lock.unlock(); - - if (func) - m_call_cond.notify_one(); - else m_poll_cond.notify_one(); + } } bool RpcServer::PollRpc(bool blocking, RpcCallInfo* call_info) { @@ -103,12 +90,14 @@ void RpcServer::PostRpcResponse(unsigned int rpc_id, unsigned int call_uid, m_response_map.erase(i); } -void RpcServer::ThreadMain() { +void RpcServer::Thread::Main() { + if (m_on_start) m_on_start(); + std::unique_lock lock(m_mutex); std::string tmp; while (m_active) { while (m_call_queue.empty()) { - m_call_cond.wait(lock); + m_cond.wait(lock); if (!m_active) goto done; } @@ -132,10 +121,5 @@ void RpcServer::ThreadMain() { } done: - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } + if (m_on_exit) m_on_exit(); } diff --git a/src/RpcServer.h b/src/RpcServer.h index 726034d9b0..8bae64bba6 100644 --- a/src/RpcServer.h +++ b/src/RpcServer.h @@ -12,14 +12,13 @@ #include #include #include -#include #include -#include #include "llvm/DenseMap.h" #include "atomic_static.h" #include "Message.h" #include "ntcore_cpp.h" +#include "SafeThread.h" namespace nt { @@ -37,7 +36,8 @@ class RpcServer { void Start(); void Stop(); - bool active() const { return m_active; } + void SetOnStart(std::function on_start) { m_on_start = on_start; } + void SetOnExit(std::function on_exit) { m_on_exit = on_exit; } void ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, @@ -50,13 +50,8 @@ class RpcServer { private: RpcServer(); - void ThreadMain(); - - std::atomic_bool m_active; - std::atomic_bool m_terminating; - - std::mutex m_mutex; - std::condition_variable m_call_cond, m_poll_cond; + class Thread; + SafeThreadOwner m_owner; struct RpcCall { RpcCall(StringRef name_, std::shared_ptr msg_, RpcCallback func_, @@ -73,15 +68,19 @@ class RpcServer { unsigned int conn_id; SendMsgFunc send_response; }; - std::queue m_call_queue, m_poll_queue; + std::mutex m_mutex; + + std::queue m_poll_queue; llvm::DenseMap, SendMsgFunc> m_response_map; - std::thread m_thread; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; + std::condition_variable m_poll_cond; + + std::atomic_bool m_terminating; + + std::function m_on_start; + std::function m_on_exit; ATOMIC_STATIC_DECL(RpcServer) }; diff --git a/src/SafeThread.cpp b/src/SafeThread.cpp new file mode 100644 index 0000000000..3ef6375bf1 --- /dev/null +++ b/src/SafeThread.cpp @@ -0,0 +1,31 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) FIRST 2015. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "SafeThread.h" + +using namespace nt; + +void detail::SafeThreadOwnerBase::Start(SafeThread* thr) { + SafeThread* curthr = nullptr; + SafeThread* newthr = thr; + if (!m_thread.compare_exchange_strong(curthr, newthr)) { + delete newthr; + return; + } + std::thread([=]() { + newthr->Main(); + delete newthr; + }).detach(); +} + +void detail::SafeThreadOwnerBase::Stop() { + SafeThread* thr = m_thread.exchange(nullptr); + if (!thr) return; + std::lock_guard lock(thr->m_mutex); + thr->m_active = false; + thr->m_cond.notify_one(); +} diff --git a/src/SafeThread.h b/src/SafeThread.h new file mode 100644 index 0000000000..a4973bbc83 --- /dev/null +++ b/src/SafeThread.h @@ -0,0 +1,93 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) FIRST 2015. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#ifndef NT_SAFETHREAD_H_ +#define NT_SAFETHREAD_H_ + +#include +#include +#include +#include + +namespace nt { + +// Base class for SafeThreadOwner threads. +class SafeThread { + public: + virtual ~SafeThread() = default; + virtual void Main() = 0; + + std::mutex m_mutex; + bool m_active = true; + std::condition_variable m_cond; +}; + +namespace detail { + +// Non-template proxy base class for common proxy code. +class SafeThreadProxyBase { + public: + SafeThreadProxyBase(SafeThread* thr) : m_thread(thr) { + if (!m_thread) return; + std::unique_lock(m_thread->m_mutex).swap(m_lock); + if (!m_thread->m_active) { + m_lock.unlock(); + m_thread = nullptr; + return; + } + } + explicit operator bool() const { return m_thread != nullptr; } + std::unique_lock& GetLock() { return m_lock; } + + protected: + SafeThread* m_thread; + std::unique_lock m_lock; +}; + +// A proxy for SafeThread. +// Also serves as a scoped lock on SafeThread::m_mutex. +template +class SafeThreadProxy : public SafeThreadProxyBase { + public: + SafeThreadProxy(SafeThread* thr) : SafeThreadProxyBase(thr) {} + T& operator*() const { return *static_cast(m_thread); } + T* operator->() const { return static_cast(m_thread); } +}; + +// Non-template owner base class for common owner code. +class SafeThreadOwnerBase { + public: + void Stop(); + + protected: + SafeThreadOwnerBase() { m_thread = nullptr; } + SafeThreadOwnerBase(const SafeThreadOwnerBase&) = delete; + SafeThreadOwnerBase& operator=(const SafeThreadOwnerBase&) = delete; + ~SafeThreadOwnerBase() { Stop(); } + + void Start(SafeThread* thr); + SafeThread* GetThread() { return m_thread.load(); } + + private: + std::atomic m_thread; +}; + +} // namespace detail + +template +class SafeThreadOwner : public detail::SafeThreadOwnerBase { + public: + void Start() { Start(new T); } + void Start(T* thr) { detail::SafeThreadOwnerBase::Start(thr); } + + using Proxy = typename detail::SafeThreadProxy; + Proxy GetThread() { return Proxy(detail::SafeThreadOwnerBase::GetThread()); } +}; + +} // namespace nt + +#endif // NT_SAFETHREAD_H_ diff --git a/src/Storage.cpp b/src/Storage.cpp index e2903e0a5e..9acad48cbc 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -1262,7 +1262,7 @@ void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) { entry->rpc_callback = callback; // start the RPC server - if (!m_rpc_server.active()) m_rpc_server.Start(); + m_rpc_server.Start(); if (old_value && *old_value == *value) return; diff --git a/src/ntcore_c.cpp b/src/ntcore_c.cpp index d7f545687f..05a3aa1999 100644 --- a/src/ntcore_c.cpp +++ b/src/ntcore_c.cpp @@ -215,6 +215,14 @@ int NT_NotifierDestroyed() { return nt::NotifierDestroyed(); } * Remote Procedure Call Functions */ +void NT_SetRpcServerOnStart(void (*on_start)(void *data), void *data) { + nt::SetRpcServerOnStart([=]() { on_start(data); }); +} + +void NT_SetRpcServerOnExit(void (*on_exit)(void *data), void *data) { + nt::SetRpcServerOnExit([=]() { on_exit(data); }); +} + void NT_CreateRpc(const char *name, size_t name_len, const char *def, size_t def_len, void *data, NT_RpcCallback callback) { nt::CreateRpc( diff --git a/src/ntcore_cpp.cpp b/src/ntcore_cpp.cpp index c50b3d4b64..0803a3aaff 100644 --- a/src/ntcore_cpp.cpp +++ b/src/ntcore_cpp.cpp @@ -75,9 +75,8 @@ void SetListenerOnExit(std::function on_exit) { unsigned int AddEntryListener(StringRef prefix, EntryListenerCallback callback, unsigned int flags) { - Notifier& notifier = Notifier::GetInstance(); - unsigned int uid = notifier.AddEntryListener(prefix, callback, flags); - notifier.Start(); + unsigned int uid = + Notifier::GetInstance().AddEntryListener(prefix, callback, flags); if ((flags & NT_NOTIFY_IMMEDIATE) != 0) Storage::GetInstance().NotifyEntries(prefix, callback); return uid; @@ -89,9 +88,7 @@ void RemoveEntryListener(unsigned int entry_listener_uid) { unsigned int AddConnectionListener(ConnectionListenerCallback callback, bool immediate_notify) { - Notifier& notifier = Notifier::GetInstance(); - unsigned int uid = notifier.AddConnectionListener(callback); - Notifier::GetInstance().Start(); + unsigned int uid = Notifier::GetInstance().AddConnectionListener(callback); if (immediate_notify) Dispatcher::GetInstance().NotifyConnections(callback); return uid; } @@ -106,6 +103,14 @@ bool NotifierDestroyed() { return Notifier::destroyed(); } * Remote Procedure Call Functions */ +void SetRpcServerOnStart(std::function on_start) { + RpcServer::GetInstance().SetOnStart(on_start); +} + +void SetRpcServerOnExit(std::function on_exit) { + RpcServer::GetInstance().SetOnExit(on_exit); +} + void CreateRpc(StringRef name, StringRef def, RpcCallback callback) { Storage::GetInstance().CreateRpc(name, def, callback); }