diff --git a/wpiutil/src/main/native/cpp/SafeThread.cpp b/wpiutil/src/main/native/cpp/SafeThread.cpp index 5dd4d44e8d..ba1eba3131 100644 --- a/wpiutil/src/main/native/cpp/SafeThread.cpp +++ b/wpiutil/src/main/native/cpp/SafeThread.cpp @@ -4,8 +4,38 @@ #include "wpi/SafeThread.h" +#include + using namespace wpi; +// thread start/stop notifications for bindings that need to set up +// per-thread state + +static void* DefaultOnThreadStart() { + return nullptr; +} +static void DefaultOnThreadEnd(void*) {} + +using OnThreadStartFn = void* (*)(); +using OnThreadEndFn = void (*)(void*); +static std::atomic gSafeThreadRefcount; +static std::atomic gOnSafeThreadStart{DefaultOnThreadStart}; +static std::atomic gOnSafeThreadEnd{DefaultOnThreadEnd}; + +namespace wpi::impl { +void SetSafeThreadNotifiers(OnThreadStartFn OnStart, OnThreadEndFn OnEnd) { + if (gSafeThreadRefcount != 0) { + throw std::runtime_error( + "cannot set notifier while safe threads are running"); + } + // Note: there's a race here, but if you're not calling this function on + // the main thread before you start anything else, you're using this function + // incorrectly + gOnSafeThreadStart = OnStart ? OnStart : DefaultOnThreadStart; + gOnSafeThreadEnd = OnEnd ? OnEnd : DefaultOnThreadEnd; +} +} // namespace wpi::impl + void SafeThread::Stop() { m_active = false; m_cond.notify_all(); @@ -43,7 +73,13 @@ void detail::SafeThreadOwnerBase::Start(std::shared_ptr thr) { if (auto thr = m_thread.lock()) { return; } - m_stdThread = std::thread([=] { thr->Main(); }); + m_stdThread = std::thread([=] { + gSafeThreadRefcount++; + void* opaque = (gOnSafeThreadStart.load())(); + thr->Main(); + (gOnSafeThreadEnd.load())(opaque); + gSafeThreadRefcount--; + }); thr->m_threadId = m_stdThread.get_id(); m_thread = thr; }