Refactor TCP-specific pieces of Dispatcher.

This commit is contained in:
Peter Johnson
2015-08-02 00:33:41 -07:00
parent 0dcaf56ed1
commit 84ff80710c
4 changed files with 80 additions and 65 deletions

View File

@@ -18,7 +18,17 @@ using namespace nt;
ATOMIC_STATIC_INIT(Dispatcher)
Dispatcher::Dispatcher(Storage& storage)
void Dispatcher::StartServer(const char* listen_address, unsigned int port) {
DispatcherBase::StartServer(std::unique_ptr<NetworkAcceptor>(
new TCPAcceptor(static_cast<int>(port), listen_address)));
}
void Dispatcher::StartClient(const char* server_name, unsigned int port) {
DispatcherBase::StartClient(std::bind(&TCPConnector::connect, server_name,
static_cast<int>(port), 1));
}
DispatcherBase::DispatcherBase(Storage& storage)
: m_storage(storage),
m_server(false),
m_do_flush(false),
@@ -28,28 +38,29 @@ Dispatcher::Dispatcher(Storage& storage)
m_update_rate = 100;
}
Dispatcher::~Dispatcher() {
DispatcherBase::~DispatcherBase() {
Stop();
}
void Dispatcher::StartServer(const char* listen_address, unsigned int port) {
void DispatcherBase::StartServer(std::unique_ptr<NetworkAcceptor> acceptor) {
{
std::lock_guard<std::mutex> lock(m_user_mutex);
if (m_active) return;
m_active = true;
}
m_server = true;
m_server_acceptor = std::move(acceptor);
using namespace std::placeholders;
m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
m_server);
m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
m_clientserver_thread =
std::thread(&Dispatcher::ServerThreadMain, this, listen_address, port);
m_clientserver_thread = std::thread(&Dispatcher::ServerThreadMain, this);
}
void Dispatcher::StartClient(const char* server_name, unsigned int port) {
void DispatcherBase::StartClient(
std::function<std::unique_ptr<NetworkStream>()> connect) {
{
std::lock_guard<std::mutex> lock(m_user_mutex);
if (m_active) return;
@@ -63,10 +74,10 @@ void Dispatcher::StartClient(const char* server_name, unsigned int port) {
m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
m_clientserver_thread =
std::thread(&Dispatcher::ClientThreadMain, this, server_name, port);
std::thread(&Dispatcher::ClientThreadMain, this, connect);
}
void Dispatcher::Stop() {
void DispatcherBase::Stop() {
m_active = false;
// wake up dispatch thread with a flush
@@ -92,19 +103,19 @@ void Dispatcher::Stop() {
conns.resize(0);
}
void Dispatcher::SetUpdateRate(double interval) {
void DispatcherBase::SetUpdateRate(double interval) {
// don't allow update rates faster than 100 ms
if (interval < 0.1)
interval = 0.1;
m_update_rate = static_cast<unsigned int>(interval * 1000);
}
void Dispatcher::SetIdentity(llvm::StringRef name) {
void DispatcherBase::SetIdentity(llvm::StringRef name) {
std::lock_guard<std::mutex> lock(m_user_mutex);
m_identity = name;
}
void Dispatcher::Flush() {
void DispatcherBase::Flush() {
auto now = std::chrono::steady_clock::now();
{
std::lock_guard<std::mutex> lock(m_flush_mutex);
@@ -117,7 +128,7 @@ void Dispatcher::Flush() {
m_flush_cv.notify_one();
}
void Dispatcher::DispatchThreadMain() {
void DispatcherBase::DispatchThreadMain() {
// local copy of active m_connections
struct ConnectionRef {
NetworkConnection* net;
@@ -177,9 +188,9 @@ void Dispatcher::DispatchThreadMain() {
}
}
void Dispatcher::QueueOutgoing(std::shared_ptr<Message> msg,
NetworkConnection* only,
NetworkConnection* except) {
void DispatcherBase::QueueOutgoing(std::shared_ptr<Message> msg,
NetworkConnection* only,
NetworkConnection* except) {
std::lock_guard<std::mutex> user_lock(m_user_mutex);
for (auto& conn : m_connections) {
if (conn.net.get() == except) continue;
@@ -191,10 +202,7 @@ void Dispatcher::QueueOutgoing(std::shared_ptr<Message> msg,
}
}
void Dispatcher::ServerThreadMain(const char* listen_address,
unsigned int port) {
m_server_acceptor.reset(
new TCPAcceptor(static_cast<int>(port), listen_address));
void DispatcherBase::ServerThreadMain() {
if (m_server_acceptor->start() != 0) {
m_active = false;
return;
@@ -224,14 +232,15 @@ void Dispatcher::ServerThreadMain(const char* listen_address,
}
}
void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
void DispatcherBase::ClientThreadMain(
std::function<std::unique_ptr<NetworkStream>()> connect) {
while (m_active) {
// sleep between retries
std::this_thread::sleep_for(std::chrono::milliseconds(500));
// try to connect (with timeout)
DEBUG("client trying to connect");
auto stream = TCPConnector::connect(server_name, static_cast<int>(port), 1);
auto stream = connect();
if (!stream) continue; // keep retrying
DEBUG("client connected");
@@ -254,7 +263,7 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
}
}
bool Dispatcher::ClientHandshake(
bool DispatcherBase::ClientHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
@@ -326,7 +335,7 @@ bool Dispatcher::ClientHandshake(
return true;
}
bool Dispatcher::ServerHandshake(
bool DispatcherBase::ServerHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
@@ -405,7 +414,7 @@ bool Dispatcher::ServerHandshake(
return true;
}
void Dispatcher::ClientReconnect(unsigned int proto_rev) {
void DispatcherBase::ClientReconnect(unsigned int proto_rev) {
if (m_server) return;
{
std::lock_guard<std::mutex> lock(m_user_mutex);

View File

@@ -23,21 +23,18 @@
#include "NetworkConnection.h"
#include "Storage.h"
class TCPAcceptor;
class NetworkAcceptor;
class NetworkStream;
namespace nt {
class Dispatcher {
class DispatcherBase {
friend class DispatcherTest;
public:
static Dispatcher& GetInstance() {
ATOMIC_STATIC(Dispatcher, instance);
return instance;
}
~Dispatcher();
virtual ~DispatcherBase();
void StartServer(const char* listen_address, unsigned int port);
void StartClient(const char* server_name, unsigned int port);
void StartServer(std::unique_ptr<NetworkAcceptor> acceptor);
void StartClient(std::function<std::unique_ptr<NetworkStream>()> connect);
void Stop();
void SetUpdateRate(double interval);
void SetIdentity(llvm::StringRef name);
@@ -45,16 +42,17 @@ class Dispatcher {
bool active() const { return m_active; }
Dispatcher(const Dispatcher&) = delete;
Dispatcher& operator=(const Dispatcher&) = delete;
DispatcherBase(const DispatcherBase&) = delete;
DispatcherBase& operator=(const DispatcherBase&) = delete;
protected:
DispatcherBase(Storage& storage);
private:
Dispatcher() : Dispatcher(Storage::GetInstance()) {}
Dispatcher(Storage& storage);
void DispatchThreadMain();
void ServerThreadMain(const char* listen_address, unsigned int port);
void ClientThreadMain(const char* server_name, unsigned int port);
void ServerThreadMain();
void ClientThreadMain(
std::function<std::unique_ptr<NetworkStream>()> connect);
bool ClientHandshake(
NetworkConnection& conn,
@@ -76,7 +74,7 @@ class Dispatcher {
std::thread m_clientserver_thread;
std::thread m_notifier_thread;
std::unique_ptr<TCPAcceptor> m_server_acceptor;
std::unique_ptr<NetworkAcceptor> m_server_acceptor;
// Mutex for user-accessible items
std::mutex m_user_mutex;
@@ -103,10 +101,27 @@ class Dispatcher {
std::condition_variable m_reconnect_cv;
unsigned int m_reconnect_proto_rev;
bool m_do_reconnect;
};
class Dispatcher : public DispatcherBase {
friend class DispatcherTest;
public:
static Dispatcher& GetInstance() {
ATOMIC_STATIC(Dispatcher, instance);
return instance;
}
void StartServer(const char* listen_address, unsigned int port);
void StartClient(const char* server_name, unsigned int port);
private:
Dispatcher() : Dispatcher(Storage::GetInstance()) {}
Dispatcher(Storage& storage) : DispatcherBase(storage) {}
ATOMIC_STATIC_DECL(Dispatcher)
};
} // namespace nt
#endif // NT_DISPATCHER_H_

View File

@@ -31,6 +31,8 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include "TCPStream.h"
#include "Log.h"
static int ResolveHostName(const char* hostname, struct in_addr* addr) {
@@ -45,7 +47,8 @@ static int ResolveHostName(const char* hostname, struct in_addr* addr) {
return result;
}
std::unique_ptr<TCPStream> TCPConnector::connect(const char* server, int port) {
std::unique_ptr<NetworkStream> TCPConnector::connect(const char* server,
int port, int timeout) {
struct sockaddr_in address;
std::memset(&address, 0, sizeof(address));
@@ -54,25 +57,14 @@ std::unique_ptr<TCPStream> TCPConnector::connect(const char* server, int port) {
if (ResolveHostName(server, &(address.sin_addr)) != 0) {
inet_pton(PF_INET, server, &(address.sin_addr));
}
int sd = socket(AF_INET, SOCK_STREAM, 0);
if (::connect(sd, (struct sockaddr*)&address, sizeof(address)) != 0) {
DEBUG("connect() failed: " << strerror(errno));
return nullptr;
}
return std::unique_ptr<TCPStream>(new TCPStream(sd, &address));
}
std::unique_ptr<TCPStream> TCPConnector::connect(const char* server, int port,
int timeout) {
if (timeout == 0) return connect(server, port);
struct sockaddr_in address;
std::memset(&address, 0, sizeof(address));
address.sin_family = AF_INET;
address.sin_port = htons(port);
if (ResolveHostName(server, &(address.sin_addr)) != 0) {
inet_pton(PF_INET, server, &(address.sin_addr));
if (timeout == 0) {
int sd = socket(AF_INET, SOCK_STREAM, 0);
if (::connect(sd, (struct sockaddr*)&address, sizeof(address)) != 0) {
DEBUG("connect() failed: " << strerror(errno));
return nullptr;
}
return std::unique_ptr<NetworkStream>(new TCPStream(sd, &address));
}
long arg;
@@ -117,5 +109,5 @@ std::unique_ptr<TCPStream> TCPConnector::connect(const char* server, int port,
// Create stream object if connected
if (result == -1) return nullptr;
return std::unique_ptr<TCPStream>(new TCPStream(sd, &address));
return std::unique_ptr<NetworkStream>(new TCPStream(sd, &address));
}

View File

@@ -26,13 +26,12 @@
#include <memory>
#include "TCPStream.h"
#include "NetworkStream.h"
class TCPConnector {
public:
static std::unique_ptr<TCPStream> connect(const char* server, int port);
static std::unique_ptr<TCPStream> connect(const char* server, int port,
int timeout);
static std::unique_ptr<NetworkStream> connect(const char* server, int port,
int timeout = 0);
};
#endif