diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp index ad2460a098..8a9091934f 100644 --- a/src/Dispatcher.cpp +++ b/src/Dispatcher.cpp @@ -180,8 +180,10 @@ void Dispatcher::QueueOutgoing(std::shared_ptr msg, for (auto& conn : m_connections) { if (conn.net.get() == except) continue; if (only && conn.net.get() != only) continue; - if (conn.net->state() != NetworkConnection::kDead) - conn.outgoing.push_back(msg); + auto state = conn.net->state(); + if (state != NetworkConnection::kSynchronized && + state != NetworkConnection::kActive) continue; + conn.outgoing.push_back(msg); } } @@ -209,7 +211,7 @@ void Dispatcher::ServerThreadMain(const char* listen_address, std::move(stream), std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3), std::bind(&Storage::GetEntryType, &storage, _1), - std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3))); + std::bind(&Storage::ProcessIncoming, &storage, _1, _2))); auto conn = conn_unique.get(); { std::lock_guard lock(m_user_mutex); @@ -237,7 +239,7 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) { std::move(stream), std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3), std::bind(&Storage::GetEntryType, &storage, _1), - std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3))); + std::bind(&Storage::ProcessIncoming, &storage, _1, _2))); auto conn = conn_unique.get(); m_connections.resize(0); // disconnect any current m_connections.emplace_back(std::move(conn_unique)); @@ -310,8 +312,8 @@ bool Dispatcher::ClientHandshake( // generate outgoing assignments NetworkConnection::Outgoing outgoing; - Storage::GetInstance().ApplyInitialAssignments(incoming, new_server, - conn.proto_rev(), &outgoing); + Storage::GetInstance().ApplyInitialAssignments(conn, incoming, new_server, + &outgoing); if (conn.proto_rev() >= 0x0300) outgoing.emplace_back(Message::ClientHelloDone()); @@ -361,7 +363,7 @@ bool Dispatcher::ServerHandshake( } // Get snapshot of initial assignments - Storage::GetInstance().GetInitialAssignments(&outgoing); + Storage::GetInstance().GetInitialAssignments(conn, &outgoing); // Finish with server hello done outgoing.emplace_back(Message::ServerHelloDone()); @@ -395,7 +397,7 @@ bool Dispatcher::ServerHandshake( msg = get_msg(); } Storage& storage = Storage::GetInstance(); - for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn, proto_rev); + for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn); } INFO("server: client CONNECTED: " << conn.stream().getPeerIP() << " port " diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp index d9526159a2..5ed20063af 100644 --- a/src/NetworkConnection.cpp +++ b/src/NetworkConnection.cpp @@ -98,7 +98,7 @@ void NetworkConnection::ReadThreadMain() { if (m_stream) m_stream->close(); break; } - m_process_incoming(std::move(msg), this, m_proto_rev); + m_process_incoming(std::move(msg), this); } DEBUG3("read thread died"); m_state = static_cast(kDead); diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h index 9e59e67261..28e062c0ad 100644 --- a/src/NetworkConnection.h +++ b/src/NetworkConnection.h @@ -21,7 +21,7 @@ namespace nt { class NetworkConnection { public: - enum State { kCreated, kInit, kHandshake, kActive, kDead }; + enum State { kCreated, kInit, kHandshake, kSynchronized, kActive, kDead }; typedef std::function>)> send_msgs)> HandshakeFunc; typedef std::function msg, - NetworkConnection* conn, unsigned int proto_rev)> - ProcessIncomingFunc; + NetworkConnection* conn)> ProcessIncomingFunc; typedef std::vector> Outgoing; typedef ConcurrentQueue OutgoingQueue; diff --git a/src/Storage.cpp b/src/Storage.cpp index cf7fc862aa..a8fac63d71 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -13,6 +13,7 @@ #include "llvm/StringExtras.h" #include "Base64.h" #include "Log.h" +#include "NetworkConnection.h" using namespace nt; @@ -41,7 +42,7 @@ NT_Type Storage::GetEntryType(unsigned int id) const { } void Storage::ProcessIncoming(std::shared_ptr msg, - NetworkConnection* conn, unsigned int proto_rev) { + NetworkConnection* conn) { std::unique_lock lock(m_mutex); switch (msg->type()) { case Message::kKeepAlive: @@ -148,7 +149,7 @@ void Storage::ProcessIncoming(std::shared_ptr msg, entry->seq_num = seq_num; // don't update flags from a <3.0 remote (not part of message) - if (proto_rev >= 0x0300) entry->flags = msg->flags(); + if (conn->proto_rev() >= 0x0300) entry->flags = msg->flags(); // broadcast to all other connections (note for client there won't // be any other connections, so don't bother) @@ -261,8 +262,9 @@ void Storage::ProcessIncoming(std::shared_ptr msg, } void Storage::GetInitialAssignments( - std::vector>* msgs) { + NetworkConnection& conn, std::vector>* msgs) { std::lock_guard lock(m_mutex); + conn.set_state(NetworkConnection::kSynchronized); for (auto& i : m_entries) { auto entry = i.getValue(); msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id, @@ -272,11 +274,13 @@ void Storage::GetInitialAssignments( } void Storage::ApplyInitialAssignments( - llvm::ArrayRef> msgs, bool new_server, - unsigned int proto_rev, std::vector>* out_msgs) { + NetworkConnection& conn, llvm::ArrayRef> msgs, + bool new_server, std::vector>* out_msgs) { std::unique_lock lock(m_mutex); if (m_server) return; // should not do this on server + conn.set_state(NetworkConnection::kSynchronized); + std::vector> update_msgs; // clear existing id's @@ -319,7 +323,7 @@ void Storage::ApplyInitialAssignments( entry->value = msg->value(); entry->seq_num = seq_num; // don't update flags from a <3.0 remote (not part of message) - if (proto_rev >= 0x0300) entry->flags = msg->flags(); + if (conn.proto_rev() >= 0x0300) entry->flags = msg->flags(); } } diff --git a/src/Storage.h b/src/Storage.h index ab3c4f9328..9ff92ad045 100644 --- a/src/Storage.h +++ b/src/Storage.h @@ -44,11 +44,12 @@ class Storage { NT_Type GetEntryType(unsigned int id) const; - void ProcessIncoming(std::shared_ptr msg, NetworkConnection* conn, - unsigned int proto_rev); - void GetInitialAssignments(std::vector>* msgs); - void ApplyInitialAssignments(llvm::ArrayRef> msgs, - bool new_server, unsigned int proto_rev, + void ProcessIncoming(std::shared_ptr msg, NetworkConnection* conn); + void GetInitialAssignments(NetworkConnection& conn, + std::vector>* msgs); + void ApplyInitialAssignments(NetworkConnection& conn, + llvm::ArrayRef> msgs, + bool new_server, std::vector>* out_msgs); std::mutex& mutex() { return m_mutex; }