From 84ff80710ca67cdd4391b0d2339ffef64fa741ad Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Sun, 2 Aug 2015 00:33:41 -0700 Subject: [PATCH] Refactor TCP-specific pieces of Dispatcher. --- src/Dispatcher.cpp | 57 +++++++++++++++++++-------------- src/Dispatcher.h | 49 ++++++++++++++++++---------- src/tcpsockets/TCPConnector.cpp | 32 +++++++----------- src/tcpsockets/TCPConnector.h | 7 ++-- 4 files changed, 80 insertions(+), 65 deletions(-) diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp index 4c44342628..7fe48e50df 100644 --- a/src/Dispatcher.cpp +++ b/src/Dispatcher.cpp @@ -18,7 +18,17 @@ using namespace nt; ATOMIC_STATIC_INIT(Dispatcher) -Dispatcher::Dispatcher(Storage& storage) +void Dispatcher::StartServer(const char* listen_address, unsigned int port) { + DispatcherBase::StartServer(std::unique_ptr( + new TCPAcceptor(static_cast(port), listen_address))); +} + +void Dispatcher::StartClient(const char* server_name, unsigned int port) { + DispatcherBase::StartClient(std::bind(&TCPConnector::connect, server_name, + static_cast(port), 1)); +} + +DispatcherBase::DispatcherBase(Storage& storage) : m_storage(storage), m_server(false), m_do_flush(false), @@ -28,28 +38,29 @@ Dispatcher::Dispatcher(Storage& storage) m_update_rate = 100; } -Dispatcher::~Dispatcher() { +DispatcherBase::~DispatcherBase() { Stop(); } -void Dispatcher::StartServer(const char* listen_address, unsigned int port) { +void DispatcherBase::StartServer(std::unique_ptr acceptor) { { std::lock_guard lock(m_user_mutex); if (m_active) return; m_active = true; } m_server = true; + m_server_acceptor = std::move(acceptor); using namespace std::placeholders; m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3), m_server); m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this); - m_clientserver_thread = - std::thread(&Dispatcher::ServerThreadMain, this, listen_address, port); + m_clientserver_thread = std::thread(&Dispatcher::ServerThreadMain, this); } -void Dispatcher::StartClient(const char* server_name, unsigned int port) { +void DispatcherBase::StartClient( + std::function()> connect) { { std::lock_guard lock(m_user_mutex); if (m_active) return; @@ -63,10 +74,10 @@ void Dispatcher::StartClient(const char* server_name, unsigned int port) { m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this); m_clientserver_thread = - std::thread(&Dispatcher::ClientThreadMain, this, server_name, port); + std::thread(&Dispatcher::ClientThreadMain, this, connect); } -void Dispatcher::Stop() { +void DispatcherBase::Stop() { m_active = false; // wake up dispatch thread with a flush @@ -92,19 +103,19 @@ void Dispatcher::Stop() { conns.resize(0); } -void Dispatcher::SetUpdateRate(double interval) { +void DispatcherBase::SetUpdateRate(double interval) { // don't allow update rates faster than 100 ms if (interval < 0.1) interval = 0.1; m_update_rate = static_cast(interval * 1000); } -void Dispatcher::SetIdentity(llvm::StringRef name) { +void DispatcherBase::SetIdentity(llvm::StringRef name) { std::lock_guard lock(m_user_mutex); m_identity = name; } -void Dispatcher::Flush() { +void DispatcherBase::Flush() { auto now = std::chrono::steady_clock::now(); { std::lock_guard lock(m_flush_mutex); @@ -117,7 +128,7 @@ void Dispatcher::Flush() { m_flush_cv.notify_one(); } -void Dispatcher::DispatchThreadMain() { +void DispatcherBase::DispatchThreadMain() { // local copy of active m_connections struct ConnectionRef { NetworkConnection* net; @@ -177,9 +188,9 @@ void Dispatcher::DispatchThreadMain() { } } -void Dispatcher::QueueOutgoing(std::shared_ptr msg, - NetworkConnection* only, - NetworkConnection* except) { +void DispatcherBase::QueueOutgoing(std::shared_ptr msg, + NetworkConnection* only, + NetworkConnection* except) { std::lock_guard user_lock(m_user_mutex); for (auto& conn : m_connections) { if (conn.net.get() == except) continue; @@ -191,10 +202,7 @@ void Dispatcher::QueueOutgoing(std::shared_ptr msg, } } -void Dispatcher::ServerThreadMain(const char* listen_address, - unsigned int port) { - m_server_acceptor.reset( - new TCPAcceptor(static_cast(port), listen_address)); +void DispatcherBase::ServerThreadMain() { if (m_server_acceptor->start() != 0) { m_active = false; return; @@ -224,14 +232,15 @@ void Dispatcher::ServerThreadMain(const char* listen_address, } } -void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) { +void DispatcherBase::ClientThreadMain( + std::function()> connect) { while (m_active) { // sleep between retries std::this_thread::sleep_for(std::chrono::milliseconds(500)); // try to connect (with timeout) DEBUG("client trying to connect"); - auto stream = TCPConnector::connect(server_name, static_cast(port), 1); + auto stream = connect(); if (!stream) continue; // keep retrying DEBUG("client connected"); @@ -254,7 +263,7 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) { } } -bool Dispatcher::ClientHandshake( +bool DispatcherBase::ClientHandshake( NetworkConnection& conn, std::function()> get_msg, std::function>)> send_msgs) { @@ -326,7 +335,7 @@ bool Dispatcher::ClientHandshake( return true; } -bool Dispatcher::ServerHandshake( +bool DispatcherBase::ServerHandshake( NetworkConnection& conn, std::function()> get_msg, std::function>)> send_msgs) { @@ -405,7 +414,7 @@ bool Dispatcher::ServerHandshake( return true; } -void Dispatcher::ClientReconnect(unsigned int proto_rev) { +void DispatcherBase::ClientReconnect(unsigned int proto_rev) { if (m_server) return; { std::lock_guard lock(m_user_mutex); diff --git a/src/Dispatcher.h b/src/Dispatcher.h index d702205aed..15fcfe8c22 100644 --- a/src/Dispatcher.h +++ b/src/Dispatcher.h @@ -23,21 +23,18 @@ #include "NetworkConnection.h" #include "Storage.h" -class TCPAcceptor; +class NetworkAcceptor; +class NetworkStream; namespace nt { -class Dispatcher { +class DispatcherBase { friend class DispatcherTest; public: - static Dispatcher& GetInstance() { - ATOMIC_STATIC(Dispatcher, instance); - return instance; - } - ~Dispatcher(); + virtual ~DispatcherBase(); - void StartServer(const char* listen_address, unsigned int port); - void StartClient(const char* server_name, unsigned int port); + void StartServer(std::unique_ptr acceptor); + void StartClient(std::function()> connect); void Stop(); void SetUpdateRate(double interval); void SetIdentity(llvm::StringRef name); @@ -45,16 +42,17 @@ class Dispatcher { bool active() const { return m_active; } - Dispatcher(const Dispatcher&) = delete; - Dispatcher& operator=(const Dispatcher&) = delete; + DispatcherBase(const DispatcherBase&) = delete; + DispatcherBase& operator=(const DispatcherBase&) = delete; + + protected: + DispatcherBase(Storage& storage); private: - Dispatcher() : Dispatcher(Storage::GetInstance()) {} - Dispatcher(Storage& storage); - void DispatchThreadMain(); - void ServerThreadMain(const char* listen_address, unsigned int port); - void ClientThreadMain(const char* server_name, unsigned int port); + void ServerThreadMain(); + void ClientThreadMain( + std::function()> connect); bool ClientHandshake( NetworkConnection& conn, @@ -76,7 +74,7 @@ class Dispatcher { std::thread m_clientserver_thread; std::thread m_notifier_thread; - std::unique_ptr m_server_acceptor; + std::unique_ptr m_server_acceptor; // Mutex for user-accessible items std::mutex m_user_mutex; @@ -103,10 +101,27 @@ class Dispatcher { std::condition_variable m_reconnect_cv; unsigned int m_reconnect_proto_rev; bool m_do_reconnect; +}; + +class Dispatcher : public DispatcherBase { + friend class DispatcherTest; + public: + static Dispatcher& GetInstance() { + ATOMIC_STATIC(Dispatcher, instance); + return instance; + } + + void StartServer(const char* listen_address, unsigned int port); + void StartClient(const char* server_name, unsigned int port); + + private: + Dispatcher() : Dispatcher(Storage::GetInstance()) {} + Dispatcher(Storage& storage) : DispatcherBase(storage) {} ATOMIC_STATIC_DECL(Dispatcher) }; + } // namespace nt #endif // NT_DISPATCHER_H_ diff --git a/src/tcpsockets/TCPConnector.cpp b/src/tcpsockets/TCPConnector.cpp index 6f08eec3dd..3f58a2e558 100644 --- a/src/tcpsockets/TCPConnector.cpp +++ b/src/tcpsockets/TCPConnector.cpp @@ -31,6 +31,8 @@ #include #include +#include "TCPStream.h" + #include "Log.h" static int ResolveHostName(const char* hostname, struct in_addr* addr) { @@ -45,7 +47,8 @@ static int ResolveHostName(const char* hostname, struct in_addr* addr) { return result; } -std::unique_ptr TCPConnector::connect(const char* server, int port) { +std::unique_ptr TCPConnector::connect(const char* server, + int port, int timeout) { struct sockaddr_in address; std::memset(&address, 0, sizeof(address)); @@ -54,25 +57,14 @@ std::unique_ptr TCPConnector::connect(const char* server, int port) { if (ResolveHostName(server, &(address.sin_addr)) != 0) { inet_pton(PF_INET, server, &(address.sin_addr)); } - int sd = socket(AF_INET, SOCK_STREAM, 0); - if (::connect(sd, (struct sockaddr*)&address, sizeof(address)) != 0) { - DEBUG("connect() failed: " << strerror(errno)); - return nullptr; - } - return std::unique_ptr(new TCPStream(sd, &address)); -} -std::unique_ptr TCPConnector::connect(const char* server, int port, - int timeout) { - if (timeout == 0) return connect(server, port); - - struct sockaddr_in address; - - std::memset(&address, 0, sizeof(address)); - address.sin_family = AF_INET; - address.sin_port = htons(port); - if (ResolveHostName(server, &(address.sin_addr)) != 0) { - inet_pton(PF_INET, server, &(address.sin_addr)); + if (timeout == 0) { + int sd = socket(AF_INET, SOCK_STREAM, 0); + if (::connect(sd, (struct sockaddr*)&address, sizeof(address)) != 0) { + DEBUG("connect() failed: " << strerror(errno)); + return nullptr; + } + return std::unique_ptr(new TCPStream(sd, &address)); } long arg; @@ -117,5 +109,5 @@ std::unique_ptr TCPConnector::connect(const char* server, int port, // Create stream object if connected if (result == -1) return nullptr; - return std::unique_ptr(new TCPStream(sd, &address)); + return std::unique_ptr(new TCPStream(sd, &address)); } diff --git a/src/tcpsockets/TCPConnector.h b/src/tcpsockets/TCPConnector.h index 8c32b79754..ebac8590ec 100644 --- a/src/tcpsockets/TCPConnector.h +++ b/src/tcpsockets/TCPConnector.h @@ -26,13 +26,12 @@ #include -#include "TCPStream.h" +#include "NetworkStream.h" class TCPConnector { public: - static std::unique_ptr connect(const char* server, int port); - static std::unique_ptr connect(const char* server, int port, - int timeout); + static std::unique_ptr connect(const char* server, int port, + int timeout = 0); }; #endif