// Copyright (c) FIRST and other WPILib contributors. // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. #ifndef WPIUTIL_WPI_CALLBACKMANAGER_H_ #define WPIUTIL_WPI_CALLBACKMANAGER_H_ #include #include #include #include #include #include #include #include "wpi/SafeThread.h" #include "wpi/UidVector.h" #include "wpi/condition_variable.h" #include "wpi/mutex.h" #include "wpi/raw_ostream.h" namespace wpi { template class CallbackListenerData { public: CallbackListenerData() = default; explicit CallbackListenerData(Callback callback_) : callback(callback_) {} explicit CallbackListenerData(unsigned int poller_uid_) : poller_uid(poller_uid_) {} explicit operator bool() const { return callback || poller_uid != UINT_MAX; } Callback callback; unsigned int poller_uid = UINT_MAX; }; // CRTP callback manager thread // @tparam Derived derived class // @tparam NotifierData data buffered for each callback // @tparam ListenerData data stored for each listener // Derived must define the following functions: // bool Matches(const ListenerData& listener, const NotifierData& data); // void SetListener(NotifierData* data, unsigned int listener_uid); // void DoCallback(Callback callback, const NotifierData& data); template >, typename TNotifierData = TUserInfo> class CallbackThread : public wpi::SafeThread { public: using UserInfo = TUserInfo; using NotifierData = TNotifierData; using ListenerData = TListenerData; ~CallbackThread() override { // Wake up any blocked pollers for (size_t i = 0; i < m_pollers.size(); ++i) { if (auto poller = m_pollers[i]) { poller->Terminate(); } } } void Main() override; wpi::UidVector m_listeners; std::queue> m_queue; wpi::condition_variable m_queue_empty; struct Poller { void Terminate() { { std::scoped_lock lock(poll_mutex); terminating = true; } poll_cond.notify_all(); } std::queue poll_queue; wpi::mutex poll_mutex; wpi::condition_variable poll_cond; bool terminating = false; bool canceling = false; }; wpi::UidVector, 64> m_pollers; // Must be called with m_mutex held template void SendPoller(unsigned int poller_uid, Args&&... args) { if (poller_uid > m_pollers.size()) { return; } auto poller = m_pollers[poller_uid]; if (!poller) { return; } { std::scoped_lock lock(poller->poll_mutex); poller->poll_queue.emplace(std::forward(args)...); } poller->poll_cond.notify_one(); } }; template void CallbackThread::Main() { std::unique_lock lock(m_mutex); while (m_active) { while (m_queue.empty()) { m_cond.wait(lock); if (!m_active) { return; } } while (!m_queue.empty()) { if (!m_active) { return; } auto item = std::move(m_queue.front()); if (item.first != UINT_MAX) { if (item.first < m_listeners.size()) { auto& listener = m_listeners[item.first]; if (listener && static_cast(this)->Matches(listener, item.second)) { static_cast(this)->SetListener(&item.second, item.first); if (listener.callback) { lock.unlock(); static_cast(this)->DoCallback(listener.callback, item.second); lock.lock(); } else if (listener.poller_uid != UINT_MAX) { SendPoller(listener.poller_uid, std::move(item.second)); } } } } else { // Use index because iterator might get invalidated. for (size_t i = 0; i < m_listeners.size(); ++i) { auto& listener = m_listeners[i]; if (!listener) { continue; } if (!static_cast(this)->Matches(listener, item.second)) { continue; } static_cast(this)->SetListener(&item.second, static_cast(i)); if (listener.callback) { lock.unlock(); static_cast(this)->DoCallback(listener.callback, item.second); lock.lock(); } else if (listener.poller_uid != UINT_MAX) { SendPoller(listener.poller_uid, item.second); } } } m_queue.pop(); } m_queue_empty.notify_all(); } } // CRTP callback manager // @tparam Derived derived class // @tparam Thread custom thread (must be derived from impl::CallbackThread) // // Derived must define the following functions: // void Start(); template class CallbackManager { friend class RpcServerTest; public: void Stop() { m_owner.Stop(); } void Remove(unsigned int listener_uid) { auto thr = m_owner.GetThread(); if (!thr) { return; } thr->m_listeners.erase(listener_uid); } unsigned int CreatePoller() { static_cast(this)->Start(); auto thr = m_owner.GetThread(); return thr->m_pollers.emplace_back( std::make_shared()); } void RemovePoller(unsigned int poller_uid) { auto thr = m_owner.GetThread(); if (!thr) { return; } // Remove any listeners that are associated with this poller for (size_t i = 0; i < thr->m_listeners.size(); ++i) { if (thr->m_listeners[i].poller_uid == poller_uid) { thr->m_listeners.erase(i); } } // Wake up any blocked pollers if (poller_uid >= thr->m_pollers.size()) { return; } auto poller = thr->m_pollers[poller_uid]; if (!poller) { return; } poller->Terminate(); thr->m_pollers.erase(poller_uid); } bool WaitForQueue(double timeout) { auto thr = m_owner.GetThread(); if (!thr) { return true; } auto& lock = thr.GetLock(); auto timeout_time = std::chrono::steady_clock::now() + std::chrono::duration(timeout); while (!thr->m_queue.empty()) { if (!thr->m_active) { return true; } if (timeout == 0) { return false; } if (timeout < 0) { thr->m_queue_empty.wait(lock); } else { auto cond_timed_out = thr->m_queue_empty.wait_until(lock, timeout_time); if (cond_timed_out == std::cv_status::timeout) { return false; } } } return true; } std::vector Poll(unsigned int poller_uid) { bool timed_out = false; return Poll(poller_uid, -1, &timed_out); } std::vector Poll(unsigned int poller_uid, double timeout, bool* timed_out) { std::vector infos; std::shared_ptr poller; { auto thr = m_owner.GetThread(); if (!thr) { return infos; } if (poller_uid > thr->m_pollers.size()) { return infos; } poller = thr->m_pollers[poller_uid]; if (!poller) { return infos; } } std::unique_lock lock(poller->poll_mutex); auto timeout_time = std::chrono::steady_clock::now() + std::chrono::duration(timeout); *timed_out = false; while (poller->poll_queue.empty()) { if (poller->terminating) { return infos; } if (poller->canceling) { // Note: this only works if there's a single thread calling this // function for any particular poller, but that's the intended use. poller->canceling = false; return infos; } if (timeout == 0) { *timed_out = true; return infos; } if (timeout < 0) { poller->poll_cond.wait(lock); } else { auto cond_timed_out = poller->poll_cond.wait_until(lock, timeout_time); if (cond_timed_out == std::cv_status::timeout) { *timed_out = true; return infos; } } } while (!poller->poll_queue.empty()) { infos.emplace_back(std::move(poller->poll_queue.front())); poller->poll_queue.pop(); } return infos; } void CancelPoll(unsigned int poller_uid) { std::shared_ptr poller; { auto thr = m_owner.GetThread(); if (!thr) { return; } if (poller_uid > thr->m_pollers.size()) { return; } poller = thr->m_pollers[poller_uid]; if (!poller) { return; } } { std::scoped_lock lock(poller->poll_mutex); poller->canceling = true; } poller->poll_cond.notify_one(); } protected: template void DoStart(Args&&... args) { m_owner.Start(std::forward(args)...); } template unsigned int DoAdd(Args&&... args) { static_cast(this)->Start(); auto thr = m_owner.GetThread(); return thr->m_listeners.emplace_back(std::forward(args)...); } template void Send(unsigned int only_listener, Args&&... args) { auto thr = m_owner.GetThread(); if (!thr || thr->m_listeners.empty()) { return; } thr->m_queue.emplace(std::piecewise_construct, std::make_tuple(only_listener), std::forward_as_tuple(std::forward(args)...)); thr->m_cond.notify_one(); } typename wpi::SafeThreadOwner::Proxy GetThread() const { return m_owner.GetThread(); } private: wpi::SafeThreadOwner m_owner; }; } // namespace wpi #endif // WPIUTIL_WPI_CALLBACKMANAGER_H_