diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp index 7b962bd507..2ee60904dc 100644 --- a/src/NetworkConnection.cpp +++ b/src/NetworkConnection.cpp @@ -20,6 +20,7 @@ NetworkConnection::NetworkConnection(std::unique_ptr stream, m_get_entry_type(get_entry_type) { m_active = false; m_proto_rev = 0x0300; + m_state = static_cast(kCreated); } NetworkConnection::~NetworkConnection() { Stop(); } @@ -27,18 +28,38 @@ NetworkConnection::~NetworkConnection() { Stop(); } void NetworkConnection::Start() { if (m_active) return; m_active = true; + m_state = static_cast(kInit); + // clear queues + while (!m_incoming.empty()) m_incoming.pop(); + while (!m_outgoing.empty()) m_outgoing.pop(); + // start threads m_write_thread = std::thread(&NetworkConnection::WriteThreadMain, this); m_read_thread = std::thread(&NetworkConnection::ReadThreadMain, this); } void NetworkConnection::Stop() { + m_state = static_cast(kDead); m_active = false; // closing the stream so the read thread terminates if (m_stream) m_stream->close(); - // send a dummy outgoing message so the write thread terminates - m_outgoing.push(Outgoing{Message::KeepAlive()}); + // send an empty outgoing message set so the write thread terminates + m_outgoing.push(Outgoing()); + // wait for threads to terminate if (m_write_thread.joinable()) m_write_thread.join(); if (m_read_thread.joinable()) m_read_thread.join(); + // clear queues + while (!m_incoming.empty()) m_incoming.pop(); + while (!m_outgoing.empty()) m_outgoing.pop(); +} + +std::string NetworkConnection::remote_id() const { + std::lock_guard lock(m_remote_id_mutex); + return m_remote_id; +} + +void NetworkConnection::set_remote_id(StringRef remote_id) { + std::lock_guard lock(m_remote_id_mutex); + m_remote_id = remote_id; } void NetworkConnection::ReadThreadMain() { @@ -59,6 +80,7 @@ void NetworkConnection::ReadThreadMain() { m_incoming.emplace(std::move(msg)); } m_incoming.emplace(nullptr); // notify anyone waiting that we disconnected + m_state = static_cast(kDead); m_active = false; } @@ -67,7 +89,7 @@ void NetworkConnection::WriteThreadMain() { while (m_active) { auto msgs = m_outgoing.pop(); - if (!m_active) break; + if (msgs.empty()) break; encoder.set_proto_rev(m_proto_rev); encoder.Reset(); for (auto& msg : msgs) msg->Write(encoder); @@ -75,5 +97,6 @@ void NetworkConnection::WriteThreadMain() { if (!m_stream) break; if (m_stream->send(encoder.data(), encoder.size(), &err) == 0) break; } + m_state = static_cast(kDead); m_active = false; } diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h index 977b25561e..3ab9742ad8 100644 --- a/src/NetworkConnection.h +++ b/src/NetworkConnection.h @@ -21,6 +21,8 @@ namespace nt { class NetworkConnection { public: + enum State { kCreated, kInit, kHandshake, kActive, kDead }; + typedef std::shared_ptr Incoming; typedef ConcurrentQueue IncomingQueue; typedef std::vector> Outgoing; @@ -37,8 +39,16 @@ class NetworkConnection { TCPStream& stream() { return *m_stream; } OutgoingQueue& outgoing() { return m_outgoing; } IncomingQueue& incoming() { return m_incoming; } + + unsigned int proto_rev() const { return m_proto_rev; } void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; } + State state() const { return static_cast(m_state.load()); } + void set_state(State state) { m_state = static_cast(state); } + + std::string remote_id() const; + void set_remote_id(StringRef remote_id); + NetworkConnection(const NetworkConnection&) = delete; NetworkConnection& operator=(const NetworkConnection&) = delete; @@ -54,6 +64,9 @@ class NetworkConnection { std::thread m_write_thread; std::atomic_bool m_active; std::atomic_uint m_proto_rev; + std::atomic_int m_state; + mutable std::mutex m_remote_id_mutex; + std::string m_remote_id; }; } // namespace nt