diff --git a/wpiutil/src/main/native/cpp/SafeThread.cpp b/wpiutil/src/main/native/cpp/SafeThread.cpp index bbecc5c232..5dd4d44e8d 100644 --- a/wpiutil/src/main/native/cpp/SafeThread.cpp +++ b/wpiutil/src/main/native/cpp/SafeThread.cpp @@ -6,8 +6,18 @@ using namespace wpi; +void SafeThread::Stop() { + m_active = false; + m_cond.notify_all(); +} + +void SafeThreadEvent::Stop() { + m_active = false; + m_stopEvent.Set(); +} + detail::SafeThreadProxyBase::SafeThreadProxyBase( - std::shared_ptr thr) + std::shared_ptr thr) : m_thread(std::move(thr)) { if (!m_thread) { return; @@ -28,7 +38,7 @@ detail::SafeThreadOwnerBase::~SafeThreadOwnerBase() { } } -void detail::SafeThreadOwnerBase::Start(std::shared_ptr thr) { +void detail::SafeThreadOwnerBase::Start(std::shared_ptr thr) { std::scoped_lock lock(m_mutex); if (auto thr = m_thread.lock()) { return; @@ -41,8 +51,7 @@ void detail::SafeThreadOwnerBase::Start(std::shared_ptr thr) { void detail::SafeThreadOwnerBase::Stop() { std::scoped_lock lock(m_mutex); if (auto thr = m_thread.lock()) { - thr->m_active = false; - thr->m_cond.notify_all(); + thr->Stop(); m_thread.reset(); } if (m_stdThread.joinable()) { @@ -56,8 +65,7 @@ void detail::SafeThreadOwnerBase::Join() { auto stdThread = std::move(m_stdThread); m_thread.reset(); lock.unlock(); - thr->m_active = false; - thr->m_cond.notify_all(); + thr->Stop(); stdThread.join(); } else if (m_stdThread.joinable()) { m_stdThread.detach(); @@ -85,8 +93,8 @@ detail::SafeThreadOwnerBase::GetNativeThreadHandle() { return m_stdThread.native_handle(); } -std::shared_ptr detail::SafeThreadOwnerBase::GetThreadSharedPtr() - const { +std::shared_ptr +detail::SafeThreadOwnerBase::GetThreadSharedPtr() const { std::scoped_lock lock(m_mutex); return m_thread.lock(); } diff --git a/wpiutil/src/main/native/include/wpi/SafeThread.h b/wpiutil/src/main/native/include/wpi/SafeThread.h index bf3773243c..753dc6c877 100644 --- a/wpiutil/src/main/native/include/wpi/SafeThread.h +++ b/wpiutil/src/main/native/include/wpi/SafeThread.h @@ -10,6 +10,7 @@ #include #include +#include "wpi/Synchronization.h" #include "wpi/condition_variable.h" #include "wpi/mutex.h" @@ -18,17 +19,33 @@ namespace wpi { /** * Base class for SafeThreadOwner threads. */ -class SafeThread { +class SafeThreadBase { public: - virtual ~SafeThread() = default; + virtual ~SafeThreadBase() = default; virtual void Main() = 0; + virtual void Stop() = 0; mutable wpi::mutex m_mutex; std::atomic_bool m_active{true}; - wpi::condition_variable m_cond; std::thread::id m_threadId; }; +class SafeThread : public SafeThreadBase { + public: + void Stop() override; + + wpi::condition_variable m_cond; +}; + +class SafeThreadEvent : public SafeThreadBase { + public: + SafeThreadEvent() : m_stopEvent{true} {} + + void Stop() override; + + Event m_stopEvent; +}; + namespace detail { /** @@ -36,12 +53,12 @@ namespace detail { */ class SafeThreadProxyBase { public: - explicit SafeThreadProxyBase(std::shared_ptr thr); + explicit SafeThreadProxyBase(std::shared_ptr thr); explicit operator bool() const { return m_thread != nullptr; } std::unique_lock& GetLock() { return m_lock; } protected: - std::shared_ptr m_thread; + std::shared_ptr m_thread; std::unique_lock m_lock; }; @@ -53,7 +70,7 @@ class SafeThreadProxyBase { template class SafeThreadProxy : public SafeThreadProxyBase { public: - explicit SafeThreadProxy(std::shared_ptr thr) + explicit SafeThreadProxy(std::shared_ptr thr) : SafeThreadProxyBase(std::move(thr)) {} T& operator*() const { return *static_cast(m_thread.get()); } T* operator->() const { return static_cast(m_thread.get()); } @@ -89,13 +106,13 @@ class SafeThreadOwnerBase { void SetJoinAtExit(bool joinAtExit) { m_joinAtExit = joinAtExit; } protected: - void Start(std::shared_ptr thr); - std::shared_ptr GetThreadSharedPtr() const; + void Start(std::shared_ptr thr); + std::shared_ptr GetThreadSharedPtr() const; private: mutable wpi::mutex m_mutex; std::thread m_stdThread; - std::weak_ptr m_thread; + std::weak_ptr m_thread; std::atomic_bool m_joinAtExit{true}; };