diff --git a/wpiutil/src/main/native/cpp/SafeThread.cpp b/wpiutil/src/main/native/cpp/SafeThread.cpp index 284e75d376..13ba1a4908 100644 --- a/wpiutil/src/main/native/cpp/SafeThread.cpp +++ b/wpiutil/src/main/native/cpp/SafeThread.cpp @@ -21,20 +21,39 @@ detail::SafeThreadProxyBase::SafeThreadProxyBase( } } +detail::SafeThreadOwnerBase::~SafeThreadOwnerBase() { + if (m_joinAtExit) + Join(); + else + Stop(); +} + void detail::SafeThreadOwnerBase::Start(std::shared_ptr thr) { std::lock_guard lock(m_mutex); if (auto thr = m_thread.lock()) return; - std::thread stdThread([=] { thr->Main(); }); + m_stdThread = std::thread([=] { thr->Main(); }); m_thread = thr; - m_nativeHandle = stdThread.native_handle(); - stdThread.detach(); } void detail::SafeThreadOwnerBase::Stop() { std::lock_guard lock(m_mutex); if (auto thr = m_thread.lock()) { thr->m_active = false; - thr->m_cond.notify_one(); + thr->m_cond.notify_all(); + m_stdThread.detach(); + m_thread.reset(); + } +} + +void detail::SafeThreadOwnerBase::Join() { + std::unique_lock lock(m_mutex); + if (auto thr = m_thread.lock()) { + auto stdThread = std::move(m_stdThread); + m_thread.reset(); + lock.unlock(); + thr->m_active = false; + thr->m_cond.notify_all(); + stdThread.join(); } } @@ -44,8 +63,8 @@ void detail::swap(SafeThreadOwnerBase& lhs, SafeThreadOwnerBase& rhs) noexcept { std::lock(lhs.m_mutex, rhs.m_mutex); std::lock_guard lock_lhs(lhs.m_mutex, std::adopt_lock); std::lock_guard lock_rhs(rhs.m_mutex, std::adopt_lock); + std::swap(lhs.m_stdThread, rhs.m_stdThread); std::swap(lhs.m_thread, rhs.m_thread); - std::swap(lhs.m_nativeHandle, rhs.m_nativeHandle); } detail::SafeThreadOwnerBase::operator bool() const { @@ -54,9 +73,9 @@ detail::SafeThreadOwnerBase::operator bool() const { } std::thread::native_handle_type -detail::SafeThreadOwnerBase::GetNativeThreadHandle() const { +detail::SafeThreadOwnerBase::GetNativeThreadHandle() { std::lock_guard lock(m_mutex); - return m_nativeHandle; + return m_stdThread.native_handle(); } std::shared_ptr detail::SafeThreadOwnerBase::GetThread() const { diff --git a/wpiutil/src/main/native/include/wpi/SafeThread.h b/wpiutil/src/main/native/include/wpi/SafeThread.h index af6d90db89..ec4e2e15af 100644 --- a/wpiutil/src/main/native/include/wpi/SafeThread.h +++ b/wpiutil/src/main/native/include/wpi/SafeThread.h @@ -21,12 +21,11 @@ namespace wpi { // Base class for SafeThreadOwner threads. class SafeThread { public: - SafeThread() { m_active = true; } virtual ~SafeThread() = default; virtual void Main() = 0; mutable wpi::mutex m_mutex; - std::atomic_bool m_active; + std::atomic_bool m_active{true}; wpi::condition_variable m_cond; }; @@ -59,6 +58,7 @@ class SafeThreadProxy : public SafeThreadProxyBase { class SafeThreadOwnerBase { public: void Stop(); + void Join(); SafeThreadOwnerBase() noexcept = default; SafeThreadOwnerBase(const SafeThreadOwnerBase&) = delete; @@ -71,13 +71,15 @@ class SafeThreadOwnerBase { swap(*this, other); return *this; } - ~SafeThreadOwnerBase() { Stop(); } + ~SafeThreadOwnerBase(); friend void swap(SafeThreadOwnerBase& lhs, SafeThreadOwnerBase& rhs) noexcept; explicit operator bool() const; - std::thread::native_handle_type GetNativeThreadHandle() const; + std::thread::native_handle_type GetNativeThreadHandle(); + + void SetJoinAtExit(bool joinAtExit) { m_joinAtExit = joinAtExit; } protected: void Start(std::shared_ptr thr); @@ -85,8 +87,9 @@ class SafeThreadOwnerBase { private: mutable wpi::mutex m_mutex; + std::thread m_stdThread; std::weak_ptr m_thread; - std::thread::native_handle_type m_nativeHandle; + std::atomic_bool m_joinAtExit{true}; }; void swap(SafeThreadOwnerBase& lhs, SafeThreadOwnerBase& rhs) noexcept; diff --git a/wpiutil/src/main/native/include/wpi/jni_util.h b/wpiutil/src/main/native/include/wpi/jni_util.h index 9b20ad2e2a..7cb6c657b4 100644 --- a/wpiutil/src/main/native/include/wpi/jni_util.h +++ b/wpiutil/src/main/native/include/wpi/jni_util.h @@ -479,6 +479,7 @@ class JCallbackThread : public SafeThread { template class JCallbackManager : public SafeThreadOwner> { public: + JCallbackManager() { this->SetJoinAtExit(false); } void SetFunc(JNIEnv* env, jobject func, jmethodID mid); template