diff --git a/include/ntcore_c.h b/include/ntcore_c.h index 59c57d77d9..f3dee735fc 100644 --- a/include/ntcore_c.h +++ b/include/ntcore_c.h @@ -293,6 +293,9 @@ int NT_NotifierDestroyed(); * Remote Procedure Call Functions */ +void NT_SetRpcServerOnStart(void (*on_start)(void *data), void *data); +void NT_SetRpcServerOnExit(void (*on_exit)(void *data), void *data); + typedef char *(*NT_RpcCallback)(void *data, const char *name, size_t name_len, const char *params, size_t params_len, size_t *results_len); diff --git a/include/ntcore_cpp.h b/include/ntcore_cpp.h index 5ebc4098de..5fdff76b45 100644 --- a/include/ntcore_cpp.h +++ b/include/ntcore_cpp.h @@ -204,6 +204,9 @@ bool NotifierDestroyed(); * Remote Procedure Call Functions */ +void SetRpcServerOnStart(std::function on_start); +void SetRpcServerOnExit(std::function on_exit); + typedef std::function RpcCallback; diff --git a/java/lib/NetworkTablesJNI.cpp b/java/lib/NetworkTablesJNI.cpp index d235fd34ff..8faaa2b433 100644 --- a/java/lib/NetworkTablesJNI.cpp +++ b/java/lib/NetworkTablesJNI.cpp @@ -10,6 +10,7 @@ #include "edu_wpi_first_wpilibj_networktables_NetworkTablesJNI.h" #include "ntcore.h" #include "atomic_static.h" +#include "SafeThread.h" // // Globals and load/unload @@ -30,8 +31,12 @@ static JNIEnv *listenerEnv = nullptr; static void ListenerOnStart() { if (!jvm) return; JNIEnv *env; - if (jvm->AttachCurrentThread(reinterpret_cast(&env), - nullptr) != JNI_OK) + JavaVMAttachArgs args; + args.version = JNI_VERSION_1_2; + args.name = const_cast("NTListener"); + args.group = nullptr; + if (jvm->AttachCurrentThreadAsDaemon(reinterpret_cast(&env), + &args) != JNI_OK) return; if (!env || !env->functions) return; listenerEnv = env; @@ -1281,28 +1286,10 @@ JNIEXPORT jlong JNICALL Java_edu_wpi_first_wpilibj_networktables_NetworkTablesJN // Instead, this class attaches just once. When a hardware notification // occurs, a condition variable wakes up this thread and this thread actually // makes the call into Java. -class LoggerThreadJNI { +class LoggerThreadJNI : public nt::SafeThread { public: - static LoggerThreadJNI& GetInstance() { - ATOMIC_STATIC(LoggerThreadJNI, instance); - return instance; - } - LoggerThreadJNI(); - ~LoggerThreadJNI(); - void SetFunc(JNIEnv* env, jobject func, jmethodID mid); - void Start(); - void Stop(); + void Main(); - void Log(unsigned int level, const char* file, unsigned int line, - const char* msg); - - private: - void ThreadMain(); - - std::thread m_thread; - std::mutex m_mutex; - std::condition_variable m_cond; - std::atomic_bool m_active; struct LogMessage { LogMessage(unsigned int level_, const char* file_, unsigned int line_, const char* msg_) @@ -1313,84 +1300,56 @@ class LoggerThreadJNI { std::string msg; }; std::queue m_queue; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; jobject m_func = nullptr; jmethodID m_mid; - - ATOMIC_STATIC_DECL(LoggerThreadJNI) }; -ATOMIC_STATIC_INIT(LoggerThreadJNI) +class LoggerJNI : public nt::SafeThreadOwner { + public: + static LoggerJNI& GetInstance() { + ATOMIC_STATIC(LoggerJNI, instance); + return instance; + } + void SetFunc(JNIEnv* env, jobject func, jmethodID mid); + void Log(unsigned int level, const char* file, unsigned int line, + const char* msg); -LoggerThreadJNI::LoggerThreadJNI() { - m_active = false; -} + private: + ATOMIC_STATIC_DECL(LoggerJNI) +}; -LoggerThreadJNI::~LoggerThreadJNI() { - Stop(); -} +ATOMIC_STATIC_INIT(LoggerJNI) -void LoggerThreadJNI::SetFunc(JNIEnv* env, jobject func, jmethodID mid) { - std::lock_guard lock(m_mutex); +void LoggerJNI::SetFunc(JNIEnv* env, jobject func, jmethodID mid) { + auto thr = GetThread(); + if (!thr) return; // free global reference - if (m_func) env->DeleteGlobalRef(m_func); + if (thr->m_func) env->DeleteGlobalRef(thr->m_func); // create global reference - m_func = env->NewGlobalRef(func); - m_mid = mid; + thr->m_func = env->NewGlobalRef(func); + thr->m_mid = mid; } -void LoggerThreadJNI::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&LoggerThreadJNI::ThreadMain, this); +void LoggerJNI::Log(unsigned int level, const char *file, unsigned int line, + const char *msg) { + auto thr = GetThread(); + if (!thr) return; + thr->m_queue.emplace(level, file, line, msg); + thr->m_cond.notify_one(); } -void LoggerThreadJNI::Stop() { - { - std::lock_guard lock(m_mutex); - if (!m_active) return; - m_active = false; - } - m_cond.notify_one(); // wake up thread - - // join threads, with timeout - if (m_thread.joinable()) { - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} - -void LoggerThreadJNI::Log(unsigned int level, const char *file, - unsigned int line, const char *msg) { - std::lock_guard lock(m_mutex); - if (!m_active) return; - m_queue.emplace(level, file, line, msg); - m_cond.notify_one(); -} - -void LoggerThreadJNI::ThreadMain() { +void LoggerThreadJNI::Main() { JNIEnv *env; - jint rs = jvm->AttachCurrentThread((void**)&env, NULL); + JavaVMAttachArgs args; + args.version = JNI_VERSION_1_2; + args.name = const_cast("NTLogger"); + args.group = nullptr; + jint rs = jvm->AttachCurrentThreadAsDaemon((void**)&env, &args); if (rs != JNI_OK) return; std::unique_lock lock(m_mutex); while (m_active) { - m_cond.wait(lock, [&] { return !m_active || !m_queue.empty(); }); + m_cond.wait(lock, [&] { return !(m_active && m_queue.empty()); }); if (!m_active) break; while (!m_queue.empty()) { if (!m_active) break; @@ -1399,25 +1358,21 @@ void LoggerThreadJNI::ThreadMain() { auto func = m_func; auto mid = m_mid; lock.unlock(); // don't hold mutex during callback execution - env->CallVoidMethod(func, mid, (jint)item.level, - ToJavaString(env, item.file), (jint)item.line, - ToJavaString(env, item.msg)); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); + { + JavaLocal file(env, ToJavaString(env, item.file)); + JavaLocal msg(env, ToJavaString(env, item.msg)); + env->CallVoidMethod(func, mid, (jint)item.level, file.obj(), + (jint)item.line, msg.obj()); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } } lock.lock(); } } if (jvm) jvm->DetachCurrentThread(); - - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } } extern "C" { @@ -1439,14 +1394,14 @@ JNIEXPORT void JNICALL Java_edu_wpi_first_wpilibj_networktables_NetworkTablesJNI cls, "apply", "(ILjava/lang/String;ILjava/lang/String;)V"); if (!mid) return; - auto& thread = LoggerThreadJNI::GetInstance(); - thread.SetFunc(env, func, mid); - thread.Start(); + auto& logger = LoggerJNI::GetInstance(); + logger.Start(); + logger.SetFunc(env, func, mid); nt::SetLogger( [](unsigned int level, const char *file, unsigned int line, const char *msg) { - LoggerThreadJNI::GetInstance().Log(level, file, line, msg); + LoggerJNI::GetInstance().Log(level, file, line, msg); }, minLevel); } diff --git a/ntcore.def b/ntcore.def index d731352a2f..5ddea50198 100644 --- a/ntcore.def +++ b/ntcore.def @@ -76,6 +76,10 @@ NT_FreeCharArray @78 NT_NotifierDestroyed @79 NT_StopRpcServer @80 NT_StopNotifier @81 +NT_SetListenerOnStart @82 +NT_SetListenerOnExit @83 +NT_SetRpcServerOnStart @84 +NT_SetRpcServerOnExit @85 ; JNI functions JNI_OnLoad diff --git a/src/Notifier.cpp b/src/Notifier.cpp index fb2e8dcbc7..00beb9f0fb 100644 --- a/src/Notifier.cpp +++ b/src/Notifier.cpp @@ -7,53 +7,75 @@ #include "Notifier.h" +#include +#include + using namespace nt; ATOMIC_STATIC_INIT(Notifier) bool Notifier::s_destroyed = false; +class Notifier::Thread : public SafeThread { + public: + Thread(std::function on_start, std::function on_exit) + : m_on_start(on_start), m_on_exit(on_exit) {} + + void Main(); + + struct EntryListener { + EntryListener(StringRef prefix_, EntryListenerCallback callback_, + unsigned int flags_) + : prefix(prefix_), callback(callback_), flags(flags_) {} + + std::string prefix; + EntryListenerCallback callback; + unsigned int flags; + }; + std::vector m_entry_listeners; + std::vector m_conn_listeners; + + struct EntryNotification { + EntryNotification(StringRef name_, std::shared_ptr value_, + unsigned int flags_, EntryListenerCallback only_) + : name(name_), + value(value_), + flags(flags_), + only(only_) {} + + std::string name; + std::shared_ptr value; + unsigned int flags; + EntryListenerCallback only; + }; + std::queue m_entry_notifications; + + struct ConnectionNotification { + ConnectionNotification(bool connected_, const ConnectionInfo& conn_info_, + ConnectionListenerCallback only_) + : connected(connected_), conn_info(conn_info_), only(only_) {} + + bool connected; + ConnectionInfo conn_info; + ConnectionListenerCallback only; + }; + std::queue m_conn_notifications; + + std::function m_on_start; + std::function m_on_exit; +}; + Notifier::Notifier() { - m_active = false; m_local_notifiers = false; s_destroyed = false; } -Notifier::~Notifier() { - s_destroyed = true; - Stop(); -} +Notifier::~Notifier() { s_destroyed = true; } -void Notifier::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&Notifier::ThreadMain, this); -} +void Notifier::Start() { m_owner.Start(new Thread(m_on_start, m_on_exit)); } -void Notifier::Stop() { - m_active = false; - // send notification so the thread terminates - m_cond.notify_one(); - if (m_thread.joinable()) { - // join with timeout - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} +void Notifier::Stop() { m_owner.Stop(); } -void Notifier::ThreadMain() { +void Notifier::Thread::Main() { if (m_on_start) m_on_start(); std::unique_lock lock(m_mutex); @@ -138,65 +160,66 @@ void Notifier::ThreadMain() { done: if (m_on_exit) m_on_exit(); - - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } } unsigned int Notifier::AddEntryListener(StringRef prefix, EntryListenerCallback callback, unsigned int flags) { - std::lock_guard lock(m_mutex); - unsigned int uid = m_entry_listeners.size(); - m_entry_listeners.emplace_back(prefix, callback, flags); + auto thr = m_owner.GetThread(); + if (!thr) { + Start(); + thr = m_owner.GetThread(); + } + unsigned int uid = thr->m_entry_listeners.size(); + thr->m_entry_listeners.emplace_back(prefix, callback, flags); if ((flags & NT_NOTIFY_LOCAL) != 0) m_local_notifiers = true; return uid + 1; } void Notifier::RemoveEntryListener(unsigned int entry_listener_uid) { + auto thr = m_owner.GetThread(); + if (!thr) return; --entry_listener_uid; - std::lock_guard lock(m_mutex); - if (entry_listener_uid < m_entry_listeners.size()) - m_entry_listeners[entry_listener_uid].callback = nullptr; + if (entry_listener_uid < thr->m_entry_listeners.size()) + thr->m_entry_listeners[entry_listener_uid].callback = nullptr; } void Notifier::NotifyEntry(StringRef name, std::shared_ptr value, unsigned int flags, EntryListenerCallback only) { - if (!m_active) return; // optimization: don't generate needless local queue entries if we have // no local listeners (as this is a common case on the server side) if ((flags & NT_NOTIFY_LOCAL) != 0 && !m_local_notifiers) return; - std::unique_lock lock(m_mutex); - m_entry_notifications.emplace(name, value, flags, only); - lock.unlock(); - m_cond.notify_one(); + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_entry_notifications.emplace(name, value, flags, only); + thr->m_cond.notify_one(); } unsigned int Notifier::AddConnectionListener( ConnectionListenerCallback callback) { - std::lock_guard lock(m_mutex); - unsigned int uid = m_entry_listeners.size(); - m_conn_listeners.emplace_back(callback); + auto thr = m_owner.GetThread(); + if (!thr) { + Start(); + thr = m_owner.GetThread(); + } + unsigned int uid = thr->m_entry_listeners.size(); + thr->m_conn_listeners.emplace_back(callback); return uid + 1; } void Notifier::RemoveConnectionListener(unsigned int conn_listener_uid) { + auto thr = m_owner.GetThread(); + if (!thr) return; --conn_listener_uid; - std::lock_guard lock(m_mutex); - if (conn_listener_uid < m_conn_listeners.size()) - m_conn_listeners[conn_listener_uid] = nullptr; + if (conn_listener_uid < thr->m_conn_listeners.size()) + thr->m_conn_listeners[conn_listener_uid] = nullptr; } void Notifier::NotifyConnection(bool connected, const ConnectionInfo& conn_info, ConnectionListenerCallback only) { - if (!m_active) return; - std::unique_lock lock(m_mutex); - m_conn_notifications.emplace(connected, conn_info, only); - lock.unlock(); - m_cond.notify_one(); + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_conn_notifications.emplace(connected, conn_info, only); + thr->m_cond.notify_one(); } diff --git a/src/Notifier.h b/src/Notifier.h index d10054cf95..4a878ad718 100644 --- a/src/Notifier.h +++ b/src/Notifier.h @@ -8,16 +8,11 @@ #ifndef NT_NOTIFIER_H_ #define NT_NOTIFIER_H_ -#include -#include -#include -#include -#include -#include -#include +#include #include "atomic_static.h" #include "ntcore_cpp.h" +#include "SafeThread.h" namespace nt { @@ -33,7 +28,6 @@ class Notifier { void Start(); void Stop(); - bool active() const { return m_active; } bool local_notifiers() const { return m_local_notifiers; } static bool destroyed() { return s_destroyed; } @@ -57,57 +51,11 @@ class Notifier { private: Notifier(); - void ThreadMain(); + class Thread; + SafeThreadOwner m_owner; - std::atomic_bool m_active; std::atomic_bool m_local_notifiers; - std::mutex m_mutex; - std::condition_variable m_cond; - - struct EntryListener { - EntryListener(StringRef prefix_, EntryListenerCallback callback_, - unsigned int flags_) - : prefix(prefix_), callback(callback_), flags(flags_) {} - - std::string prefix; - EntryListenerCallback callback; - unsigned int flags; - }; - std::vector m_entry_listeners; - std::vector m_conn_listeners; - - struct EntryNotification { - EntryNotification(StringRef name_, std::shared_ptr value_, - unsigned int flags_, EntryListenerCallback only_) - : name(name_), - value(value_), - flags(flags_), - only(only_) {} - - std::string name; - std::shared_ptr value; - unsigned int flags; - EntryListenerCallback only; - }; - std::queue m_entry_notifications; - - struct ConnectionNotification { - ConnectionNotification(bool connected_, const ConnectionInfo& conn_info_, - ConnectionListenerCallback only_) - : connected(connected_), conn_info(conn_info_), only(only_) {} - - bool connected; - ConnectionInfo conn_info; - ConnectionListenerCallback only; - }; - std::queue m_conn_notifications; - - std::thread m_thread; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; - std::function m_on_start; std::function m_on_exit; diff --git a/src/RpcServer.cpp b/src/RpcServer.cpp index 43d37de1ae..669931f0c8 100644 --- a/src/RpcServer.cpp +++ b/src/RpcServer.cpp @@ -7,70 +7,57 @@ #include "RpcServer.h" +#include + #include "Log.h" using namespace nt; ATOMIC_STATIC_INIT(RpcServer) +class RpcServer::Thread : public SafeThread { + public: + Thread(std::function on_start, std::function on_exit) + : m_on_start(on_start), m_on_exit(on_exit) {} + + void Main(); + + std::queue m_call_queue; + + std::function m_on_start; + std::function m_on_exit; +}; + RpcServer::RpcServer() { - m_active = false; m_terminating = false; } RpcServer::~RpcServer() { Logger::GetInstance().SetLogger(nullptr); - Stop(); m_terminating = true; m_poll_cond.notify_all(); } void RpcServer::Start() { - { - std::lock_guard lock(m_mutex); - if (m_active) return; - m_active = true; - } - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = false; - } - m_thread = std::thread(&RpcServer::ThreadMain, this); + auto thr = m_owner.GetThread(); + if (!thr) m_owner.Start(new Thread(m_on_start, m_on_exit)); } -void RpcServer::Stop() { - m_active = false; - if (m_thread.joinable()) { - // send notification so the thread terminates - m_call_cond.notify_one(); - // join with timeout - std::unique_lock lock(m_shutdown_mutex); - auto timeout_time = - std::chrono::steady_clock::now() + std::chrono::seconds(1); - if (m_shutdown_cv.wait_until(lock, timeout_time, - [&] { return m_shutdown; })) - m_thread.join(); - else - m_thread.detach(); // timed out, detach it - } -} +void RpcServer::Stop() { m_owner.Stop(); } void RpcServer::ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, SendMsgFunc send_response) { - std::unique_lock lock(m_mutex); - - if (func) - m_call_queue.emplace(name, msg, func, conn_id, send_response); - else + if (func) { + auto thr = m_owner.GetThread(); + if (!thr) return; + thr->m_call_queue.emplace(name, msg, func, conn_id, send_response); + thr->m_cond.notify_one(); + } else { + std::lock_guard lock(m_mutex); m_poll_queue.emplace(name, msg, func, conn_id, send_response); - - lock.unlock(); - - if (func) - m_call_cond.notify_one(); - else m_poll_cond.notify_one(); + } } bool RpcServer::PollRpc(bool blocking, RpcCallInfo* call_info) { @@ -103,12 +90,14 @@ void RpcServer::PostRpcResponse(unsigned int rpc_id, unsigned int call_uid, m_response_map.erase(i); } -void RpcServer::ThreadMain() { +void RpcServer::Thread::Main() { + if (m_on_start) m_on_start(); + std::unique_lock lock(m_mutex); std::string tmp; while (m_active) { while (m_call_queue.empty()) { - m_call_cond.wait(lock); + m_cond.wait(lock); if (!m_active) goto done; } @@ -132,10 +121,5 @@ void RpcServer::ThreadMain() { } done: - // use condition variable to signal thread shutdown - { - std::lock_guard lock(m_shutdown_mutex); - m_shutdown = true; - m_shutdown_cv.notify_one(); - } + if (m_on_exit) m_on_exit(); } diff --git a/src/RpcServer.h b/src/RpcServer.h index 726034d9b0..8bae64bba6 100644 --- a/src/RpcServer.h +++ b/src/RpcServer.h @@ -12,14 +12,13 @@ #include #include #include -#include #include -#include #include "llvm/DenseMap.h" #include "atomic_static.h" #include "Message.h" #include "ntcore_cpp.h" +#include "SafeThread.h" namespace nt { @@ -37,7 +36,8 @@ class RpcServer { void Start(); void Stop(); - bool active() const { return m_active; } + void SetOnStart(std::function on_start) { m_on_start = on_start; } + void SetOnExit(std::function on_exit) { m_on_exit = on_exit; } void ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, @@ -50,13 +50,8 @@ class RpcServer { private: RpcServer(); - void ThreadMain(); - - std::atomic_bool m_active; - std::atomic_bool m_terminating; - - std::mutex m_mutex; - std::condition_variable m_call_cond, m_poll_cond; + class Thread; + SafeThreadOwner m_owner; struct RpcCall { RpcCall(StringRef name_, std::shared_ptr msg_, RpcCallback func_, @@ -73,15 +68,19 @@ class RpcServer { unsigned int conn_id; SendMsgFunc send_response; }; - std::queue m_call_queue, m_poll_queue; + std::mutex m_mutex; + + std::queue m_poll_queue; llvm::DenseMap, SendMsgFunc> m_response_map; - std::thread m_thread; - std::mutex m_shutdown_mutex; - std::condition_variable m_shutdown_cv; - bool m_shutdown = false; + std::condition_variable m_poll_cond; + + std::atomic_bool m_terminating; + + std::function m_on_start; + std::function m_on_exit; ATOMIC_STATIC_DECL(RpcServer) }; diff --git a/src/SafeThread.cpp b/src/SafeThread.cpp new file mode 100644 index 0000000000..3ef6375bf1 --- /dev/null +++ b/src/SafeThread.cpp @@ -0,0 +1,31 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) FIRST 2015. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "SafeThread.h" + +using namespace nt; + +void detail::SafeThreadOwnerBase::Start(SafeThread* thr) { + SafeThread* curthr = nullptr; + SafeThread* newthr = thr; + if (!m_thread.compare_exchange_strong(curthr, newthr)) { + delete newthr; + return; + } + std::thread([=]() { + newthr->Main(); + delete newthr; + }).detach(); +} + +void detail::SafeThreadOwnerBase::Stop() { + SafeThread* thr = m_thread.exchange(nullptr); + if (!thr) return; + std::lock_guard lock(thr->m_mutex); + thr->m_active = false; + thr->m_cond.notify_one(); +} diff --git a/src/SafeThread.h b/src/SafeThread.h new file mode 100644 index 0000000000..a4973bbc83 --- /dev/null +++ b/src/SafeThread.h @@ -0,0 +1,93 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) FIRST 2015. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#ifndef NT_SAFETHREAD_H_ +#define NT_SAFETHREAD_H_ + +#include +#include +#include +#include + +namespace nt { + +// Base class for SafeThreadOwner threads. +class SafeThread { + public: + virtual ~SafeThread() = default; + virtual void Main() = 0; + + std::mutex m_mutex; + bool m_active = true; + std::condition_variable m_cond; +}; + +namespace detail { + +// Non-template proxy base class for common proxy code. +class SafeThreadProxyBase { + public: + SafeThreadProxyBase(SafeThread* thr) : m_thread(thr) { + if (!m_thread) return; + std::unique_lock(m_thread->m_mutex).swap(m_lock); + if (!m_thread->m_active) { + m_lock.unlock(); + m_thread = nullptr; + return; + } + } + explicit operator bool() const { return m_thread != nullptr; } + std::unique_lock& GetLock() { return m_lock; } + + protected: + SafeThread* m_thread; + std::unique_lock m_lock; +}; + +// A proxy for SafeThread. +// Also serves as a scoped lock on SafeThread::m_mutex. +template +class SafeThreadProxy : public SafeThreadProxyBase { + public: + SafeThreadProxy(SafeThread* thr) : SafeThreadProxyBase(thr) {} + T& operator*() const { return *static_cast(m_thread); } + T* operator->() const { return static_cast(m_thread); } +}; + +// Non-template owner base class for common owner code. +class SafeThreadOwnerBase { + public: + void Stop(); + + protected: + SafeThreadOwnerBase() { m_thread = nullptr; } + SafeThreadOwnerBase(const SafeThreadOwnerBase&) = delete; + SafeThreadOwnerBase& operator=(const SafeThreadOwnerBase&) = delete; + ~SafeThreadOwnerBase() { Stop(); } + + void Start(SafeThread* thr); + SafeThread* GetThread() { return m_thread.load(); } + + private: + std::atomic m_thread; +}; + +} // namespace detail + +template +class SafeThreadOwner : public detail::SafeThreadOwnerBase { + public: + void Start() { Start(new T); } + void Start(T* thr) { detail::SafeThreadOwnerBase::Start(thr); } + + using Proxy = typename detail::SafeThreadProxy; + Proxy GetThread() { return Proxy(detail::SafeThreadOwnerBase::GetThread()); } +}; + +} // namespace nt + +#endif // NT_SAFETHREAD_H_ diff --git a/src/Storage.cpp b/src/Storage.cpp index e2903e0a5e..9acad48cbc 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -1262,7 +1262,7 @@ void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) { entry->rpc_callback = callback; // start the RPC server - if (!m_rpc_server.active()) m_rpc_server.Start(); + m_rpc_server.Start(); if (old_value && *old_value == *value) return; diff --git a/src/ntcore_c.cpp b/src/ntcore_c.cpp index d7f545687f..05a3aa1999 100644 --- a/src/ntcore_c.cpp +++ b/src/ntcore_c.cpp @@ -215,6 +215,14 @@ int NT_NotifierDestroyed() { return nt::NotifierDestroyed(); } * Remote Procedure Call Functions */ +void NT_SetRpcServerOnStart(void (*on_start)(void *data), void *data) { + nt::SetRpcServerOnStart([=]() { on_start(data); }); +} + +void NT_SetRpcServerOnExit(void (*on_exit)(void *data), void *data) { + nt::SetRpcServerOnExit([=]() { on_exit(data); }); +} + void NT_CreateRpc(const char *name, size_t name_len, const char *def, size_t def_len, void *data, NT_RpcCallback callback) { nt::CreateRpc( diff --git a/src/ntcore_cpp.cpp b/src/ntcore_cpp.cpp index c50b3d4b64..0803a3aaff 100644 --- a/src/ntcore_cpp.cpp +++ b/src/ntcore_cpp.cpp @@ -75,9 +75,8 @@ void SetListenerOnExit(std::function on_exit) { unsigned int AddEntryListener(StringRef prefix, EntryListenerCallback callback, unsigned int flags) { - Notifier& notifier = Notifier::GetInstance(); - unsigned int uid = notifier.AddEntryListener(prefix, callback, flags); - notifier.Start(); + unsigned int uid = + Notifier::GetInstance().AddEntryListener(prefix, callback, flags); if ((flags & NT_NOTIFY_IMMEDIATE) != 0) Storage::GetInstance().NotifyEntries(prefix, callback); return uid; @@ -89,9 +88,7 @@ void RemoveEntryListener(unsigned int entry_listener_uid) { unsigned int AddConnectionListener(ConnectionListenerCallback callback, bool immediate_notify) { - Notifier& notifier = Notifier::GetInstance(); - unsigned int uid = notifier.AddConnectionListener(callback); - Notifier::GetInstance().Start(); + unsigned int uid = Notifier::GetInstance().AddConnectionListener(callback); if (immediate_notify) Dispatcher::GetInstance().NotifyConnections(callback); return uid; } @@ -106,6 +103,14 @@ bool NotifierDestroyed() { return Notifier::destroyed(); } * Remote Procedure Call Functions */ +void SetRpcServerOnStart(std::function on_start) { + RpcServer::GetInstance().SetOnStart(on_start); +} + +void SetRpcServerOnExit(std::function on_exit) { + RpcServer::GetInstance().SetOnExit(on_exit); +} + void CreateRpc(StringRef name, StringRef def, RpcCallback callback) { Storage::GetInstance().CreateRpc(name, def, callback); }