From 86c43df8d125a35a4aaa5029403dcae0ed7a6f84 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Fri, 21 Oct 2016 19:40:56 -0700 Subject: [PATCH] Fix connection notification races. (#130) Use a mutex on the connection state and one-shot all change notifications. Fixes #127. --- src/Dispatcher.cpp | 5 +---- src/NetworkConnection.cpp | 44 ++++++++++++++++++++++++++++----------- src/NetworkConnection.h | 8 ++++--- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp index e987ff38d6..672a87eb0f 100644 --- a/src/Dispatcher.cpp +++ b/src/Dispatcher.cpp @@ -193,10 +193,7 @@ std::vector DispatcherBase::GetConnections() const { void DispatcherBase::NotifyConnections( ConnectionListenerCallback callback) const { std::lock_guard lock(m_user_mutex); - for (auto& conn : m_connections) { - if (conn->state() != NetworkConnection::kActive) continue; - m_notifier.NotifyConnection(true, conn->info(), callback); - } + for (const auto& conn : m_connections) conn->NotifyIfActive(callback); } void DispatcherBase::DispatchThreadMain() { diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp index ddf882ed25..794f5fe558 100644 --- a/src/NetworkConnection.cpp +++ b/src/NetworkConnection.cpp @@ -27,10 +27,10 @@ NetworkConnection::NetworkConnection(std::unique_ptr stream, m_stream(std::move(stream)), m_notifier(notifier), m_handshake(handshake), - m_get_entry_type(get_entry_type) { + m_get_entry_type(get_entry_type), + m_state(kCreated) { m_active = false; m_proto_rev = 0x0300; - m_state = static_cast(kCreated); m_last_update = 0; // turn off Nagle algorithm; we bundle packets for transmission @@ -42,7 +42,7 @@ NetworkConnection::~NetworkConnection() { Stop(); } void NetworkConnection::Start() { if (m_active) return; m_active = true; - m_state = static_cast(kInit); + set_state(kInit); // clear queue while (!m_outgoing.empty()) m_outgoing.pop(); // reset shutdown flags @@ -58,7 +58,7 @@ void NetworkConnection::Start() { void NetworkConnection::Stop() { DEBUG2("NetworkConnection stopping (" << this << ")"); - m_state = static_cast(kDead); + set_state(kDead); m_active = false; // closing the stream so the read thread terminates if (m_stream) m_stream->close(); @@ -95,6 +95,29 @@ ConnectionInfo NetworkConnection::info() const { m_last_update, m_proto_rev}; } +NetworkConnection::State NetworkConnection::state() const { + std::lock_guard lock(m_state_mutex); + return m_state; +} + +void NetworkConnection::set_state(State state) { + std::lock_guard lock(m_state_mutex); + // Don't update state any more once we've died + if (m_state == kDead) return; + // One-shot notify state changes + if (m_state != kActive && state == kActive) + m_notifier.NotifyConnection(true, info()); + if (m_state != kDead && state == kDead) + m_notifier.NotifyConnection(false, info()); + m_state = state; +} + +void NetworkConnection::NotifyIfActive( + ConnectionListenerCallback callback) const { + std::lock_guard lock(m_state_mutex); + if (m_state == kActive) m_notifier.NotifyConnection(true, info(), callback); +} + std::string NetworkConnection::remote_id() const { std::lock_guard lock(m_remote_id_mutex); return m_remote_id; @@ -109,7 +132,7 @@ void NetworkConnection::ReadThreadMain() { wpi::raw_socket_istream is(*m_stream); WireDecoder decoder(is, m_proto_rev); - m_state = static_cast(kHandshake); + set_state(kHandshake); if (!m_handshake(*this, [&] { decoder.set_proto_rev(m_proto_rev); @@ -121,13 +144,12 @@ void NetworkConnection::ReadThreadMain() { [&](llvm::ArrayRef> msgs) { m_outgoing.emplace(msgs); })) { - m_state = static_cast(kDead); + set_state(kDead); m_active = false; goto done; } - m_state = static_cast(kActive); - m_notifier.NotifyConnection(true, info()); + set_state(kActive); while (m_active) { if (!m_stream) break; @@ -147,8 +169,7 @@ void NetworkConnection::ReadThreadMain() { m_process_incoming(std::move(msg), this); } DEBUG2("read thread died (" << this << ")"); - if (m_state != kDead) m_notifier.NotifyConnection(false, info()); - m_state = static_cast(kDead); + set_state(kDead); m_active = false; m_outgoing.push(Outgoing()); // also kill write thread @@ -186,8 +207,7 @@ void NetworkConnection::WriteThreadMain() { DEBUG4("sent " << encoder.size() << " bytes"); } DEBUG2("write thread died (" << this << ")"); - if (m_state != kDead) m_notifier.NotifyConnection(false, info()); - m_state = static_cast(kDead); + set_state(kDead); m_active = false; if (m_stream) m_stream->close(); // also kill read thread diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h index 80465cf73e..14d282d91e 100644 --- a/src/NetworkConnection.h +++ b/src/NetworkConnection.h @@ -60,14 +60,15 @@ class NetworkConnection { void QueueOutgoing(std::shared_ptr msg); void PostOutgoing(bool keep_alive); + void NotifyIfActive(ConnectionListenerCallback callback) const; unsigned int uid() const { return m_uid; } unsigned int proto_rev() const { return m_proto_rev; } void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; } - State state() const { return static_cast(m_state.load()); } - void set_state(State state) { m_state = static_cast(state); } + State state() const; + void set_state(State state); std::string remote_id() const; void set_remote_id(StringRef remote_id); @@ -94,7 +95,8 @@ class NetworkConnection { std::thread m_write_thread; std::atomic_bool m_active; std::atomic_uint m_proto_rev; - std::atomic_int m_state; + mutable std::mutex m_state_mutex; + State m_state; mutable std::mutex m_remote_id_mutex; std::string m_remote_id; std::atomic_ullong m_last_update;