diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp index 8fcfd4772e..4bd9d0877e 100644 --- a/src/Dispatcher.cpp +++ b/src/Dispatcher.cpp @@ -15,13 +15,20 @@ using namespace nt; -#define DEBUG(str) puts(str) +inline void DEBUG(const char* str, ...) { + va_list args; + va_start(args, str); + vfprintf(stderr, str, args); + fputc('\n', stderr); + va_end(args); +} ATOMIC_STATIC_INIT(Dispatcher) Dispatcher::Dispatcher() : m_server(false), m_do_flush(false), + m_reconnect_proto_rev(0x0300), m_do_reconnect(false) { m_active = false; m_update_rate = 100; @@ -86,8 +93,6 @@ void Dispatcher::Stop() { // join threads if (m_dispatch_thread.joinable()) m_dispatch_thread.join(); if (m_clientserver_thread.joinable()) m_clientserver_thread.join(); - - Storage::GetInstance().ClearOutgoing(); } void Dispatcher::SetUpdateRate(double interval) { @@ -199,6 +204,7 @@ void Dispatcher::ServerThreadMain(const char* listen_address, Storage& storage = Storage::GetInstance(); std::unique_ptr conn_unique(new NetworkConnection( std::move(stream), + std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3), std::bind(&Storage::GetEntryType, &storage, _1), std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3))); auto conn = conn_unique.get(); @@ -211,16 +217,8 @@ void Dispatcher::ServerThreadMain(const char* listen_address, } void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) { -#if 0 unsigned int proto_rev = 0x0300; while (m_active) { - // get identity - std::string self_id; - { - std::lock_guard lock(m_user_mutex); - self_id = m_identity; - } - // sleep between retries std::this_thread::sleep_for(std::chrono::milliseconds(500)); @@ -230,91 +228,106 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) { if (!stream) continue; // keep retrying DEBUG("client connected"); - std::unique_ptr conn(new NetworkConnection( + using namespace std::placeholders; + Storage& storage = Storage::GetInstance(); + std::unique_ptr conn_unique(new NetworkConnection( std::move(stream), - [this](unsigned int id) { return GetEntryType(id); })); + std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3), + std::bind(&Storage::GetEntryType, &storage, _1), + std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3))); + auto conn = conn_unique.get(); + { + std::lock_guard lock(m_user_mutex); + m_connections.resize(0); // disconnect any current + m_connections.emplace_back(std::move(conn_unique)); + } conn->set_proto_rev(proto_rev); conn->Start(); - // send client hello - DEBUG("client: sending hello"); - conn->outgoing().push( - NetworkConnection::Outgoing{Message::ClientHello(self_id)}); - - // wait for response - auto msg = conn->incoming().pop(); - if (!msg) { - // disconnected, retry - DEBUG("client: server disconnected before first response"); - proto_rev = 0x0300; - continue; - } - - if (msg->Is(Message::kProtoUnsup)) { - // reconnect with lower protocol (if possible) - if (proto_rev <= 0x0200) { - // no more options, abort (but keep trying to connect) - proto_rev = 0x0300; - continue; - } - proto_rev = 0x0200; - continue; - } - - if (proto_rev >= 0x0300) { - // should be server hello; if not, disconnect, but keep trying to connect - // TODO: do something with initial connection flag - if (!msg->Is(Message::kServerHello)) continue; - conn->set_remote_id(msg->str()); - // get the next message (blocks) - msg = conn->incoming().pop(); - } - - // receive initial assignments - std::vector> incoming; - for (;;) { - if (!msg) { - // disconnected, retry - DEBUG("client: server disconnected during initial entries"); - proto_rev = 0x0300; - continue; - } - if (msg->Is(Message::kServerHelloDone)) break; - if (!msg->Is(Message::kEntryAssign)) { - // unexpected message - DEBUG("client: received message other than entry assignment during initial handshake"); - proto_rev = 0x0300; - continue; - } - incoming.push_back(msg); - // get the next message (blocks) - msg = conn->incoming().pop(); - } - - // generate outgoing assignments - NetworkConnection::Outgoing outgoing; - - if (proto_rev >= 0x0300) - outgoing.push_back(Message::ClientHelloDone()); - - if (!outgoing.empty()) - conn->outgoing().push(std::move(outgoing)); - - // add to connections list (the dispatcher thread will handle from here) - AddConnection(std::move(conn)); - // block until told to reconnect std::unique_lock lock(m_reconnect_mutex); m_reconnect_cv.wait(lock, [&] { return m_do_reconnect; }); + proto_rev = m_reconnect_proto_rev; m_do_reconnect = false; lock.unlock(); } -#endif +} + +bool Dispatcher::ClientHandshake( + NetworkConnection& conn, + std::function()> get_msg, + std::function>)> send_msgs) { + // get identity + std::string self_id; + { + std::lock_guard lock(m_user_mutex); + self_id = m_identity; + } + + // send client hello + DEBUG("client: sending hello"); + send_msgs(Message::ClientHello(self_id)); + + // wait for response + auto msg = get_msg(); + if (!msg) { + // disconnected, retry + DEBUG("client: server disconnected before first response"); + return false; + } + + if (msg->Is(Message::kProtoUnsup)) { + if (msg->id() == 0x0200) ClientReconnect(0x0200); + return false; + } + + bool new_server = true; + if (conn.proto_rev() >= 0x0300) { + // should be server hello; if not, disconnect. + if (!msg->Is(Message::kServerHello)) return false; + conn.set_remote_id(msg->str()); + if ((msg->flags() & 1) != 0) new_server = false; + // get the next message + msg = get_msg(); + } + + // receive initial assignments + std::vector> incoming; + for (;;) { + if (!msg) { + // disconnected, retry + DEBUG("client: server disconnected during initial entries"); + return false; + } + if (msg->Is(Message::kServerHelloDone)) break; + if (!msg->Is(Message::kEntryAssign)) { + // unexpected message + DEBUG("client: received message (%d) other than entry assignment during initial handshake", msg->type()); + return false; + } + incoming.emplace_back(std::move(msg)); + // get the next message + msg = get_msg(); + } + + // generate outgoing assignments + NetworkConnection::Outgoing outgoing; + + Storage::GetInstance().ApplyInitialAssignments(incoming, new_server, + conn.proto_rev(), &outgoing); + + if (conn.proto_rev() >= 0x0300) + outgoing.emplace_back(Message::ClientHelloDone()); + + if (!outgoing.empty()) send_msgs(outgoing); + + return true; } bool Dispatcher::ServerHandshake( NetworkConnection& conn, - std::function()> get_msg) { + std::function()> get_msg, + std::function>)> send_msgs) { // Wait for the client to send us a hello. auto msg = get_msg(); if (!msg) { @@ -330,48 +343,34 @@ bool Dispatcher::ServerHandshake( unsigned int proto_rev = msg->id(); if (proto_rev > 0x0300) { DEBUG("server: client requested proto > 0x0300"); - conn.outgoing().push(NetworkConnection::Outgoing{Message::ProtoUnsup()}); + send_msgs(Message::ProtoUnsup()); return false; } if (proto_rev >= 0x0300) conn.set_remote_id(msg->str()); - // Set the proto version to the client requested version. + // Set the proto version to the client requested version conn.set_proto_rev(proto_rev); -#if 0 - // We need to copy the ID map. This is inefficient, but is necessary - // because we need to get a "snapshot" of the current server state. The - // dispatch thread will create outgoing assignments as necessary as the idmap - // changes, but we don't want duplicate assignments or (worse) missing - // assignments by iterating one entry at a time. - IdMap id_map; - { - std::lock_guard lock(m_idmap_mutex); - id_map = m_idmap; - conn.set_state(NetworkConnection::kHandshake); - } -#endif - // send initial set of assignments + + // Send initial set of assignments NetworkConnection::Outgoing outgoing; - // Server hello. TODO: initial connection flag + // Start with server hello. TODO: initial connection flag if (proto_rev >= 0x0300) { std::lock_guard lock(m_user_mutex); - outgoing.push_back(Message::ServerHello(0u, m_identity)); + outgoing.emplace_back(Message::ServerHello(0u, m_identity)); } -#if 0 - Storage& storage = Storage::GetInstance(); - { - // take storage mutex as we must have a snapshot of the current values. - std::lock_guard lock(storage.mutex()); - std::lock_guard lock(m_idmap_mutex); - outgoing.push_back(Message::EntryAssign( - } -#endif - outgoing.push_back(Message::ServerHelloDone()); - conn.outgoing().push(std::move(outgoing)); -#if 0 + // Get snapshot of initial assignments + Storage::GetInstance().GetInitialAssignments(&outgoing); + + // Finish with server hello done + outgoing.emplace_back(Message::ServerHelloDone()); + + // Batch transmit + DEBUG("server: sending initial assignments"); + send_msgs(outgoing); + // In proto rev 3.0 and later, the handshake concludes with a client hello // done message, so we can batch the assigns before marking the connection // active. In pre-3.0, we need to just immediately mark it active and hand @@ -379,32 +378,35 @@ bool Dispatcher::ServerHandshake( if (proto_rev >= 0x0300) { // receive client initial assignments std::vector> incoming; + msg = get_msg(); for (;;) { if (!msg) { // disconnected, retry - DEBUG("disconnected waiting for initial entries"); + DEBUG("server: disconnected waiting for initial entries"); return false; } if (msg->Is(Message::kClientHelloDone)) break; if (!msg->Is(Message::kEntryAssign)) { // unexpected message - DEBUG("received message other than entry assignment during initial handshake"); + DEBUG("server: received message (%d) other than entry assignment during initial handshake", msg->type()); return false; } incoming.push_back(msg); // get the next message (blocks) msg = get_msg(); } + Storage& storage = Storage::GetInstance(); + for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn, proto_rev); } -#endif - conn.set_state(NetworkConnection::kActive); + return true; } -void Dispatcher::ClientReconnect() { +void Dispatcher::ClientReconnect(unsigned int proto_rev) { if (m_server) return; { std::lock_guard lock(m_reconnect_mutex); + m_reconnect_proto_rev = proto_rev; m_do_reconnect = true; } m_reconnect_cv.notify_one(); diff --git a/src/Dispatcher.h b/src/Dispatcher.h index 7461dd0029..c9b027ee1d 100644 --- a/src/Dispatcher.h +++ b/src/Dispatcher.h @@ -54,10 +54,16 @@ class Dispatcher { void ServerThreadMain(const char* listen_address, unsigned int port); void ClientThreadMain(const char* server_name, unsigned int port); - bool ServerHandshake(NetworkConnection& conn, - std::function()> get_msg); + bool ClientHandshake( + NetworkConnection& conn, + std::function()> get_msg, + std::function>)> send_msgs); + bool ServerHandshake( + NetworkConnection& conn, + std::function()> get_msg, + std::function>)> send_msgs); - void ClientReconnect(); + void ClientReconnect(unsigned int proto_rev = 0x0300); void QueueOutgoing(std::shared_ptr msg, NetworkConnection* only, NetworkConnection* except); @@ -93,6 +99,7 @@ class Dispatcher { // Condition variable for client reconnect std::mutex m_reconnect_mutex; std::condition_variable m_reconnect_cv; + unsigned int m_reconnect_proto_rev; bool m_do_reconnect; ATOMIC_STATIC_DECL(Dispatcher) diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp index 3f4345c96d..f7885e2ab6 100644 --- a/src/NetworkConnection.cpp +++ b/src/NetworkConnection.cpp @@ -14,10 +14,20 @@ using namespace nt; +inline void DEBUG(const char* str, ...) { + va_list args; + va_start(args, str); + vfprintf(stderr, str, args); + fputc('\n', stderr); + va_end(args); +} + NetworkConnection::NetworkConnection(std::unique_ptr stream, + HandshakeFunc handshake, Message::GetEntryTypeFunc get_entry_type, ProcessIncomingFunc process_incoming) : m_stream(std::move(stream)), + m_handshake(handshake), m_get_entry_type(get_entry_type), m_process_incoming(process_incoming) { m_active = false; @@ -66,6 +76,24 @@ void NetworkConnection::ReadThreadMain() { raw_socket_istream is(*m_stream); WireDecoder decoder(is, m_proto_rev); + m_state = static_cast(kHandshake); + if (!m_handshake(*this, + [&] { + decoder.set_proto_rev(m_proto_rev); + auto msg = Message::Read(decoder, m_get_entry_type); + if (!msg) + DEBUG("error reading in handshake: %s", decoder.error()); + return msg; + }, + [&](llvm::ArrayRef> msgs) { + m_outgoing.emplace(msgs); + })) { + m_state = static_cast(kDead); + m_active = false; + return; + } + + m_state = static_cast(kActive); while (m_active) { if (!m_stream) break; @@ -88,15 +116,19 @@ void NetworkConnection::WriteThreadMain() { while (m_active) { auto msgs = m_outgoing.pop(); + DEBUG("write thread woke up"); if (msgs.empty()) break; encoder.set_proto_rev(m_proto_rev); encoder.Reset(); + DEBUG("sending %d messages", msgs.size()); for (auto& msg : msgs) { if (msg) msg->Write(encoder); } TCPStream::Error err; if (!m_stream) break; + if (encoder.size() == 0) continue; if (m_stream->send(encoder.data(), encoder.size(), &err) == 0) break; + DEBUG("sent %d bytes", encoder.size()); } m_state = static_cast(kDead); m_active = false; diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h index 3e653aeaa8..9e59e67261 100644 --- a/src/NetworkConnection.h +++ b/src/NetworkConnection.h @@ -23,6 +23,11 @@ class NetworkConnection { public: enum State { kCreated, kInit, kHandshake, kActive, kDead }; + typedef std::function()> get_msg, + std::function>)> send_msgs)> + HandshakeFunc; typedef std::function msg, NetworkConnection* conn, unsigned int proto_rev)> ProcessIncomingFunc; @@ -30,6 +35,7 @@ class NetworkConnection { typedef ConcurrentQueue OutgoingQueue; NetworkConnection(std::unique_ptr stream, + HandshakeFunc handshake, Message::GetEntryTypeFunc get_entry_type, ProcessIncomingFunc process_incoming); ~NetworkConnection(); @@ -59,6 +65,7 @@ class NetworkConnection { std::unique_ptr m_stream; OutgoingQueue m_outgoing; + HandshakeFunc m_handshake; Message::GetEntryTypeFunc m_get_entry_type; ProcessIncomingFunc m_process_incoming; std::thread m_read_thread; diff --git a/src/Storage.cpp b/src/Storage.cpp index 419d444480..3cebf5a749 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -261,21 +261,86 @@ void Storage::ProcessIncoming(std::shared_ptr msg, } } -void Storage::SendAssignments( - std::function)> send_msg, bool reset_ids) { - std::vector> msgs; - { - std::lock_guard lock(m_mutex); - for (auto& i : m_entries) { - auto entry = i.getValue(); - msgs.emplace_back(Message::EntryAssign(i.getKey(), entry->id, - entry->seq_num.value(), - entry->value, entry->flags)); - if (!m_server && reset_ids) entry->id = 0xffff; - } - if (!m_server && reset_ids) m_idmap.resize(0); +void Storage::GetInitialAssignments( + std::vector>* msgs) { + std::lock_guard lock(m_mutex); + for (auto& i : m_entries) { + auto entry = i.getValue(); + msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id, + entry->seq_num.value(), + entry->value, entry->flags)); } - for (auto& msg : msgs) send_msg(std::move(msg)); +} + +void Storage::ApplyInitialAssignments( + llvm::ArrayRef> msgs, bool new_server, + unsigned int proto_rev, std::vector>* out_msgs) { + std::unique_lock lock(m_mutex); + if (m_server) return; // should not do this on server + + std::vector> update_msgs; + + // clear existing id's + for (auto& i : m_entries) i.getValue()->id = 0xffff; + + // clear existing idmap + m_idmap.resize(0); + + // apply assignments + for (auto& msg : msgs) { + if (!msg->Is(Message::kEntryAssign)) { + DEBUG("client: received non-entry assignment request?"); + continue; + } + + unsigned int id = msg->id(); + if (id == 0xffff) { + DEBUG("client: received entry assignment request?"); + continue; + } + + SequenceNumber seq_num(msg->seq_num_uid()); + StringRef name = msg->str(); + + auto& entry = m_entries[name]; + if (!entry) { + // doesn't currently exist + entry = std::make_shared(name); + entry->value = msg->value(); + entry->flags = msg->flags(); + entry->seq_num = seq_num; + } else { + // if reconnect and sequence number not higher than local, then we + // don't update the local value and instead send it back to the server + // as an update message + if (!new_server && seq_num <= entry->seq_num) { + update_msgs.emplace_back(Message::EntryUpdate( + entry->id, entry->seq_num.value(), entry->value)); + } else { + entry->value = msg->value(); + entry->seq_num = seq_num; + // don't update flags from a <3.0 remote (not part of message) + if (proto_rev >= 0x0300) entry->flags = msg->flags(); + } + } + + // set id and save to idmap + entry->id = id; + if (id >= m_idmap.size()) m_idmap.resize(id+1); + m_idmap[id] = entry; + } + + // generate assign messages for unassigned local entries + for (auto& i : m_entries) { + auto entry = i.getValue(); + if (entry->id != 0xffff) continue; + out_msgs->emplace_back(Message::EntryAssign(entry->name, entry->id, + entry->seq_num.value(), + entry->value, entry->flags)); + } + auto queue_outgoing = m_queue_outgoing; + lock.unlock(); + for (auto& msg : update_msgs) queue_outgoing(msg, nullptr, nullptr); } std::shared_ptr Storage::GetEntryValue(StringRef name) const { diff --git a/src/Storage.h b/src/Storage.h index 434adbbe75..ab3c4f9328 100644 --- a/src/Storage.h +++ b/src/Storage.h @@ -46,8 +46,10 @@ class Storage { void ProcessIncoming(std::shared_ptr msg, NetworkConnection* conn, unsigned int proto_rev); - void SendAssignments(std::function)> send_msg, - bool reset_ids); + void GetInitialAssignments(std::vector>* msgs); + void ApplyInitialAssignments(llvm::ArrayRef> msgs, + bool new_server, unsigned int proto_rev, + std::vector>* out_msgs); std::mutex& mutex() { return m_mutex; }