Implement client/server handshaking.

This commit is contained in:
Peter Johnson
2015-07-31 20:32:52 -07:00
parent 98ad6d1b43
commit ead125555c
6 changed files with 251 additions and 136 deletions

View File

@@ -15,13 +15,20 @@
using namespace nt;
#define DEBUG(str) puts(str)
inline void DEBUG(const char* str, ...) {
va_list args;
va_start(args, str);
vfprintf(stderr, str, args);
fputc('\n', stderr);
va_end(args);
}
ATOMIC_STATIC_INIT(Dispatcher)
Dispatcher::Dispatcher()
: m_server(false),
m_do_flush(false),
m_reconnect_proto_rev(0x0300),
m_do_reconnect(false) {
m_active = false;
m_update_rate = 100;
@@ -86,8 +93,6 @@ void Dispatcher::Stop() {
// join threads
if (m_dispatch_thread.joinable()) m_dispatch_thread.join();
if (m_clientserver_thread.joinable()) m_clientserver_thread.join();
Storage::GetInstance().ClearOutgoing();
}
void Dispatcher::SetUpdateRate(double interval) {
@@ -199,6 +204,7 @@ void Dispatcher::ServerThreadMain(const char* listen_address,
Storage& storage = Storage::GetInstance();
std::unique_ptr<NetworkConnection> conn_unique(new NetworkConnection(
std::move(stream),
std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3),
std::bind(&Storage::GetEntryType, &storage, _1),
std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3)));
auto conn = conn_unique.get();
@@ -211,16 +217,8 @@ void Dispatcher::ServerThreadMain(const char* listen_address,
}
void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
#if 0
unsigned int proto_rev = 0x0300;
while (m_active) {
// get identity
std::string self_id;
{
std::lock_guard<std::mutex> lock(m_user_mutex);
self_id = m_identity;
}
// sleep between retries
std::this_thread::sleep_for(std::chrono::milliseconds(500));
@@ -230,91 +228,106 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
if (!stream) continue; // keep retrying
DEBUG("client connected");
std::unique_ptr<NetworkConnection> conn(new NetworkConnection(
using namespace std::placeholders;
Storage& storage = Storage::GetInstance();
std::unique_ptr<NetworkConnection> conn_unique(new NetworkConnection(
std::move(stream),
[this](unsigned int id) { return GetEntryType(id); }));
std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3),
std::bind(&Storage::GetEntryType, &storage, _1),
std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3)));
auto conn = conn_unique.get();
{
std::lock_guard<std::mutex> lock(m_user_mutex);
m_connections.resize(0); // disconnect any current
m_connections.emplace_back(std::move(conn_unique));
}
conn->set_proto_rev(proto_rev);
conn->Start();
// send client hello
DEBUG("client: sending hello");
conn->outgoing().push(
NetworkConnection::Outgoing{Message::ClientHello(self_id)});
// wait for response
auto msg = conn->incoming().pop();
if (!msg) {
// disconnected, retry
DEBUG("client: server disconnected before first response");
proto_rev = 0x0300;
continue;
}
if (msg->Is(Message::kProtoUnsup)) {
// reconnect with lower protocol (if possible)
if (proto_rev <= 0x0200) {
// no more options, abort (but keep trying to connect)
proto_rev = 0x0300;
continue;
}
proto_rev = 0x0200;
continue;
}
if (proto_rev >= 0x0300) {
// should be server hello; if not, disconnect, but keep trying to connect
// TODO: do something with initial connection flag
if (!msg->Is(Message::kServerHello)) continue;
conn->set_remote_id(msg->str());
// get the next message (blocks)
msg = conn->incoming().pop();
}
// receive initial assignments
std::vector<std::shared_ptr<Message>> incoming;
for (;;) {
if (!msg) {
// disconnected, retry
DEBUG("client: server disconnected during initial entries");
proto_rev = 0x0300;
continue;
}
if (msg->Is(Message::kServerHelloDone)) break;
if (!msg->Is(Message::kEntryAssign)) {
// unexpected message
DEBUG("client: received message other than entry assignment during initial handshake");
proto_rev = 0x0300;
continue;
}
incoming.push_back(msg);
// get the next message (blocks)
msg = conn->incoming().pop();
}
// generate outgoing assignments
NetworkConnection::Outgoing outgoing;
if (proto_rev >= 0x0300)
outgoing.push_back(Message::ClientHelloDone());
if (!outgoing.empty())
conn->outgoing().push(std::move(outgoing));
// add to connections list (the dispatcher thread will handle from here)
AddConnection(std::move(conn));
// block until told to reconnect
std::unique_lock<std::mutex> lock(m_reconnect_mutex);
m_reconnect_cv.wait(lock, [&] { return m_do_reconnect; });
proto_rev = m_reconnect_proto_rev;
m_do_reconnect = false;
lock.unlock();
}
#endif
}
bool Dispatcher::ClientHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
// get identity
std::string self_id;
{
std::lock_guard<std::mutex> lock(m_user_mutex);
self_id = m_identity;
}
// send client hello
DEBUG("client: sending hello");
send_msgs(Message::ClientHello(self_id));
// wait for response
auto msg = get_msg();
if (!msg) {
// disconnected, retry
DEBUG("client: server disconnected before first response");
return false;
}
if (msg->Is(Message::kProtoUnsup)) {
if (msg->id() == 0x0200) ClientReconnect(0x0200);
return false;
}
bool new_server = true;
if (conn.proto_rev() >= 0x0300) {
// should be server hello; if not, disconnect.
if (!msg->Is(Message::kServerHello)) return false;
conn.set_remote_id(msg->str());
if ((msg->flags() & 1) != 0) new_server = false;
// get the next message
msg = get_msg();
}
// receive initial assignments
std::vector<std::shared_ptr<Message>> incoming;
for (;;) {
if (!msg) {
// disconnected, retry
DEBUG("client: server disconnected during initial entries");
return false;
}
if (msg->Is(Message::kServerHelloDone)) break;
if (!msg->Is(Message::kEntryAssign)) {
// unexpected message
DEBUG("client: received message (%d) other than entry assignment during initial handshake", msg->type());
return false;
}
incoming.emplace_back(std::move(msg));
// get the next message
msg = get_msg();
}
// generate outgoing assignments
NetworkConnection::Outgoing outgoing;
Storage::GetInstance().ApplyInitialAssignments(incoming, new_server,
conn.proto_rev(), &outgoing);
if (conn.proto_rev() >= 0x0300)
outgoing.emplace_back(Message::ClientHelloDone());
if (!outgoing.empty()) send_msgs(outgoing);
return true;
}
bool Dispatcher::ServerHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg) {
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
// Wait for the client to send us a hello.
auto msg = get_msg();
if (!msg) {
@@ -330,48 +343,34 @@ bool Dispatcher::ServerHandshake(
unsigned int proto_rev = msg->id();
if (proto_rev > 0x0300) {
DEBUG("server: client requested proto > 0x0300");
conn.outgoing().push(NetworkConnection::Outgoing{Message::ProtoUnsup()});
send_msgs(Message::ProtoUnsup());
return false;
}
if (proto_rev >= 0x0300) conn.set_remote_id(msg->str());
// Set the proto version to the client requested version.
// Set the proto version to the client requested version
conn.set_proto_rev(proto_rev);
#if 0
// We need to copy the ID map. This is inefficient, but is necessary
// because we need to get a "snapshot" of the current server state. The
// dispatch thread will create outgoing assignments as necessary as the idmap
// changes, but we don't want duplicate assignments or (worse) missing
// assignments by iterating one entry at a time.
IdMap id_map;
{
std::lock_guard<std::mutex> lock(m_idmap_mutex);
id_map = m_idmap;
conn.set_state(NetworkConnection::kHandshake);
}
#endif
// send initial set of assignments
// Send initial set of assignments
NetworkConnection::Outgoing outgoing;
// Server hello. TODO: initial connection flag
// Start with server hello. TODO: initial connection flag
if (proto_rev >= 0x0300) {
std::lock_guard<std::mutex> lock(m_user_mutex);
outgoing.push_back(Message::ServerHello(0u, m_identity));
outgoing.emplace_back(Message::ServerHello(0u, m_identity));
}
#if 0
Storage& storage = Storage::GetInstance();
{
// take storage mutex as we must have a snapshot of the current values.
std::lock_guard<std::mutex> lock(storage.mutex());
std::lock_guard<std::mutex> lock(m_idmap_mutex);
outgoing.push_back(Message::EntryAssign(
}
#endif
outgoing.push_back(Message::ServerHelloDone());
conn.outgoing().push(std::move(outgoing));
#if 0
// Get snapshot of initial assignments
Storage::GetInstance().GetInitialAssignments(&outgoing);
// Finish with server hello done
outgoing.emplace_back(Message::ServerHelloDone());
// Batch transmit
DEBUG("server: sending initial assignments");
send_msgs(outgoing);
// In proto rev 3.0 and later, the handshake concludes with a client hello
// done message, so we can batch the assigns before marking the connection
// active. In pre-3.0, we need to just immediately mark it active and hand
@@ -379,32 +378,35 @@ bool Dispatcher::ServerHandshake(
if (proto_rev >= 0x0300) {
// receive client initial assignments
std::vector<std::shared_ptr<Message>> incoming;
msg = get_msg();
for (;;) {
if (!msg) {
// disconnected, retry
DEBUG("disconnected waiting for initial entries");
DEBUG("server: disconnected waiting for initial entries");
return false;
}
if (msg->Is(Message::kClientHelloDone)) break;
if (!msg->Is(Message::kEntryAssign)) {
// unexpected message
DEBUG("received message other than entry assignment during initial handshake");
DEBUG("server: received message (%d) other than entry assignment during initial handshake", msg->type());
return false;
}
incoming.push_back(msg);
// get the next message (blocks)
msg = get_msg();
}
Storage& storage = Storage::GetInstance();
for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn, proto_rev);
}
#endif
conn.set_state(NetworkConnection::kActive);
return true;
}
void Dispatcher::ClientReconnect() {
void Dispatcher::ClientReconnect(unsigned int proto_rev) {
if (m_server) return;
{
std::lock_guard<std::mutex> lock(m_reconnect_mutex);
m_reconnect_proto_rev = proto_rev;
m_do_reconnect = true;
}
m_reconnect_cv.notify_one();

View File

@@ -54,10 +54,16 @@ class Dispatcher {
void ServerThreadMain(const char* listen_address, unsigned int port);
void ClientThreadMain(const char* server_name, unsigned int port);
bool ServerHandshake(NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg);
bool ClientHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs);
bool ServerHandshake(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs);
void ClientReconnect();
void ClientReconnect(unsigned int proto_rev = 0x0300);
void QueueOutgoing(std::shared_ptr<Message> msg, NetworkConnection* only,
NetworkConnection* except);
@@ -93,6 +99,7 @@ class Dispatcher {
// Condition variable for client reconnect
std::mutex m_reconnect_mutex;
std::condition_variable m_reconnect_cv;
unsigned int m_reconnect_proto_rev;
bool m_do_reconnect;
ATOMIC_STATIC_DECL(Dispatcher)

View File

@@ -14,10 +14,20 @@
using namespace nt;
inline void DEBUG(const char* str, ...) {
va_list args;
va_start(args, str);
vfprintf(stderr, str, args);
fputc('\n', stderr);
va_end(args);
}
NetworkConnection::NetworkConnection(std::unique_ptr<TCPStream> stream,
HandshakeFunc handshake,
Message::GetEntryTypeFunc get_entry_type,
ProcessIncomingFunc process_incoming)
: m_stream(std::move(stream)),
m_handshake(handshake),
m_get_entry_type(get_entry_type),
m_process_incoming(process_incoming) {
m_active = false;
@@ -66,6 +76,24 @@ void NetworkConnection::ReadThreadMain() {
raw_socket_istream is(*m_stream);
WireDecoder decoder(is, m_proto_rev);
m_state = static_cast<int>(kHandshake);
if (!m_handshake(*this,
[&] {
decoder.set_proto_rev(m_proto_rev);
auto msg = Message::Read(decoder, m_get_entry_type);
if (!msg)
DEBUG("error reading in handshake: %s", decoder.error());
return msg;
},
[&](llvm::ArrayRef<std::shared_ptr<Message>> msgs) {
m_outgoing.emplace(msgs);
})) {
m_state = static_cast<int>(kDead);
m_active = false;
return;
}
m_state = static_cast<int>(kActive);
while (m_active) {
if (!m_stream)
break;
@@ -88,15 +116,19 @@ void NetworkConnection::WriteThreadMain() {
while (m_active) {
auto msgs = m_outgoing.pop();
DEBUG("write thread woke up");
if (msgs.empty()) break;
encoder.set_proto_rev(m_proto_rev);
encoder.Reset();
DEBUG("sending %d messages", msgs.size());
for (auto& msg : msgs) {
if (msg) msg->Write(encoder);
}
TCPStream::Error err;
if (!m_stream) break;
if (encoder.size() == 0) continue;
if (m_stream->send(encoder.data(), encoder.size(), &err) == 0) break;
DEBUG("sent %d bytes", encoder.size());
}
m_state = static_cast<int>(kDead);
m_active = false;

View File

@@ -23,6 +23,11 @@ class NetworkConnection {
public:
enum State { kCreated, kInit, kHandshake, kActive, kDead };
typedef std::function<bool(
NetworkConnection& conn,
std::function<std::shared_ptr<Message>()> get_msg,
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs)>
HandshakeFunc;
typedef std::function<void(std::shared_ptr<Message> msg,
NetworkConnection* conn, unsigned int proto_rev)>
ProcessIncomingFunc;
@@ -30,6 +35,7 @@ class NetworkConnection {
typedef ConcurrentQueue<Outgoing> OutgoingQueue;
NetworkConnection(std::unique_ptr<TCPStream> stream,
HandshakeFunc handshake,
Message::GetEntryTypeFunc get_entry_type,
ProcessIncomingFunc process_incoming);
~NetworkConnection();
@@ -59,6 +65,7 @@ class NetworkConnection {
std::unique_ptr<TCPStream> m_stream;
OutgoingQueue m_outgoing;
HandshakeFunc m_handshake;
Message::GetEntryTypeFunc m_get_entry_type;
ProcessIncomingFunc m_process_incoming;
std::thread m_read_thread;

View File

@@ -261,21 +261,86 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
}
}
void Storage::SendAssignments(
std::function<void(std::shared_ptr<Message>)> send_msg, bool reset_ids) {
std::vector<std::shared_ptr<Message>> msgs;
{
std::lock_guard<std::mutex> lock(m_mutex);
for (auto& i : m_entries) {
auto entry = i.getValue();
msgs.emplace_back(Message::EntryAssign(i.getKey(), entry->id,
entry->seq_num.value(),
entry->value, entry->flags));
if (!m_server && reset_ids) entry->id = 0xffff;
}
if (!m_server && reset_ids) m_idmap.resize(0);
void Storage::GetInitialAssignments(
std::vector<std::shared_ptr<Message>>* msgs) {
std::lock_guard<std::mutex> lock(m_mutex);
for (auto& i : m_entries) {
auto entry = i.getValue();
msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id,
entry->seq_num.value(),
entry->value, entry->flags));
}
for (auto& msg : msgs) send_msg(std::move(msg));
}
void Storage::ApplyInitialAssignments(
llvm::ArrayRef<std::shared_ptr<Message>> msgs, bool new_server,
unsigned int proto_rev, std::vector<std::shared_ptr<Message>>* out_msgs) {
std::unique_lock<std::mutex> lock(m_mutex);
if (m_server) return; // should not do this on server
std::vector<std::shared_ptr<Message>> update_msgs;
// clear existing id's
for (auto& i : m_entries) i.getValue()->id = 0xffff;
// clear existing idmap
m_idmap.resize(0);
// apply assignments
for (auto& msg : msgs) {
if (!msg->Is(Message::kEntryAssign)) {
DEBUG("client: received non-entry assignment request?");
continue;
}
unsigned int id = msg->id();
if (id == 0xffff) {
DEBUG("client: received entry assignment request?");
continue;
}
SequenceNumber seq_num(msg->seq_num_uid());
StringRef name = msg->str();
auto& entry = m_entries[name];
if (!entry) {
// doesn't currently exist
entry = std::make_shared<Entry>(name);
entry->value = msg->value();
entry->flags = msg->flags();
entry->seq_num = seq_num;
} else {
// if reconnect and sequence number not higher than local, then we
// don't update the local value and instead send it back to the server
// as an update message
if (!new_server && seq_num <= entry->seq_num) {
update_msgs.emplace_back(Message::EntryUpdate(
entry->id, entry->seq_num.value(), entry->value));
} else {
entry->value = msg->value();
entry->seq_num = seq_num;
// don't update flags from a <3.0 remote (not part of message)
if (proto_rev >= 0x0300) entry->flags = msg->flags();
}
}
// set id and save to idmap
entry->id = id;
if (id >= m_idmap.size()) m_idmap.resize(id+1);
m_idmap[id] = entry;
}
// generate assign messages for unassigned local entries
for (auto& i : m_entries) {
auto entry = i.getValue();
if (entry->id != 0xffff) continue;
out_msgs->emplace_back(Message::EntryAssign(entry->name, entry->id,
entry->seq_num.value(),
entry->value, entry->flags));
}
auto queue_outgoing = m_queue_outgoing;
lock.unlock();
for (auto& msg : update_msgs) queue_outgoing(msg, nullptr, nullptr);
}
std::shared_ptr<Value> Storage::GetEntryValue(StringRef name) const {

View File

@@ -46,8 +46,10 @@ class Storage {
void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn,
unsigned int proto_rev);
void SendAssignments(std::function<void(std::shared_ptr<Message>)> send_msg,
bool reset_ids);
void GetInitialAssignments(std::vector<std::shared_ptr<Message>>* msgs);
void ApplyInitialAssignments(llvm::ArrayRef<std::shared_ptr<Message>> msgs,
bool new_server, unsigned int proto_rev,
std::vector<std::shared_ptr<Message>>* out_msgs);
std::mutex& mutex() { return m_mutex; }