mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-07-04 03:11:43 +00:00
Ensure initial synchronization is atomic.
This commit is contained in:
@@ -180,8 +180,10 @@ void Dispatcher::QueueOutgoing(std::shared_ptr<Message> msg,
|
||||
for (auto& conn : m_connections) {
|
||||
if (conn.net.get() == except) continue;
|
||||
if (only && conn.net.get() != only) continue;
|
||||
if (conn.net->state() != NetworkConnection::kDead)
|
||||
conn.outgoing.push_back(msg);
|
||||
auto state = conn.net->state();
|
||||
if (state != NetworkConnection::kSynchronized &&
|
||||
state != NetworkConnection::kActive) continue;
|
||||
conn.outgoing.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +211,7 @@ void Dispatcher::ServerThreadMain(const char* listen_address,
|
||||
std::move(stream),
|
||||
std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3),
|
||||
std::bind(&Storage::GetEntryType, &storage, _1),
|
||||
std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3)));
|
||||
std::bind(&Storage::ProcessIncoming, &storage, _1, _2)));
|
||||
auto conn = conn_unique.get();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_user_mutex);
|
||||
@@ -237,7 +239,7 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
|
||||
std::move(stream),
|
||||
std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3),
|
||||
std::bind(&Storage::GetEntryType, &storage, _1),
|
||||
std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3)));
|
||||
std::bind(&Storage::ProcessIncoming, &storage, _1, _2)));
|
||||
auto conn = conn_unique.get();
|
||||
m_connections.resize(0); // disconnect any current
|
||||
m_connections.emplace_back(std::move(conn_unique));
|
||||
@@ -310,8 +312,8 @@ bool Dispatcher::ClientHandshake(
|
||||
// generate outgoing assignments
|
||||
NetworkConnection::Outgoing outgoing;
|
||||
|
||||
Storage::GetInstance().ApplyInitialAssignments(incoming, new_server,
|
||||
conn.proto_rev(), &outgoing);
|
||||
Storage::GetInstance().ApplyInitialAssignments(conn, incoming, new_server,
|
||||
&outgoing);
|
||||
|
||||
if (conn.proto_rev() >= 0x0300)
|
||||
outgoing.emplace_back(Message::ClientHelloDone());
|
||||
@@ -361,7 +363,7 @@ bool Dispatcher::ServerHandshake(
|
||||
}
|
||||
|
||||
// Get snapshot of initial assignments
|
||||
Storage::GetInstance().GetInitialAssignments(&outgoing);
|
||||
Storage::GetInstance().GetInitialAssignments(conn, &outgoing);
|
||||
|
||||
// Finish with server hello done
|
||||
outgoing.emplace_back(Message::ServerHelloDone());
|
||||
@@ -395,7 +397,7 @@ bool Dispatcher::ServerHandshake(
|
||||
msg = get_msg();
|
||||
}
|
||||
Storage& storage = Storage::GetInstance();
|
||||
for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn, proto_rev);
|
||||
for (auto& msg : incoming) storage.ProcessIncoming(msg, &conn);
|
||||
}
|
||||
|
||||
INFO("server: client CONNECTED: " << conn.stream().getPeerIP() << " port "
|
||||
|
||||
@@ -98,7 +98,7 @@ void NetworkConnection::ReadThreadMain() {
|
||||
if (m_stream) m_stream->close();
|
||||
break;
|
||||
}
|
||||
m_process_incoming(std::move(msg), this, m_proto_rev);
|
||||
m_process_incoming(std::move(msg), this);
|
||||
}
|
||||
DEBUG3("read thread died");
|
||||
m_state = static_cast<int>(kDead);
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace nt {
|
||||
|
||||
class NetworkConnection {
|
||||
public:
|
||||
enum State { kCreated, kInit, kHandshake, kActive, kDead };
|
||||
enum State { kCreated, kInit, kHandshake, kSynchronized, kActive, kDead };
|
||||
|
||||
typedef std::function<bool(
|
||||
NetworkConnection& conn,
|
||||
@@ -29,8 +29,7 @@ class NetworkConnection {
|
||||
std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs)>
|
||||
HandshakeFunc;
|
||||
typedef std::function<void(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* conn, unsigned int proto_rev)>
|
||||
ProcessIncomingFunc;
|
||||
NetworkConnection* conn)> ProcessIncomingFunc;
|
||||
typedef std::vector<std::shared_ptr<Message>> Outgoing;
|
||||
typedef ConcurrentQueue<Outgoing> OutgoingQueue;
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "llvm/StringExtras.h"
|
||||
#include "Base64.h"
|
||||
#include "Log.h"
|
||||
#include "NetworkConnection.h"
|
||||
|
||||
using namespace nt;
|
||||
|
||||
@@ -41,7 +42,7 @@ NT_Type Storage::GetEntryType(unsigned int id) const {
|
||||
}
|
||||
|
||||
void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* conn, unsigned int proto_rev) {
|
||||
NetworkConnection* conn) {
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
switch (msg->type()) {
|
||||
case Message::kKeepAlive:
|
||||
@@ -148,7 +149,7 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
|
||||
entry->seq_num = seq_num;
|
||||
|
||||
// don't update flags from a <3.0 remote (not part of message)
|
||||
if (proto_rev >= 0x0300) entry->flags = msg->flags();
|
||||
if (conn->proto_rev() >= 0x0300) entry->flags = msg->flags();
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
@@ -261,8 +262,9 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
|
||||
}
|
||||
|
||||
void Storage::GetInitialAssignments(
|
||||
std::vector<std::shared_ptr<Message>>* msgs) {
|
||||
NetworkConnection& conn, std::vector<std::shared_ptr<Message>>* msgs) {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
conn.set_state(NetworkConnection::kSynchronized);
|
||||
for (auto& i : m_entries) {
|
||||
auto entry = i.getValue();
|
||||
msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id,
|
||||
@@ -272,11 +274,13 @@ void Storage::GetInitialAssignments(
|
||||
}
|
||||
|
||||
void Storage::ApplyInitialAssignments(
|
||||
llvm::ArrayRef<std::shared_ptr<Message>> msgs, bool new_server,
|
||||
unsigned int proto_rev, std::vector<std::shared_ptr<Message>>* out_msgs) {
|
||||
NetworkConnection& conn, llvm::ArrayRef<std::shared_ptr<Message>> msgs,
|
||||
bool new_server, std::vector<std::shared_ptr<Message>>* out_msgs) {
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
if (m_server) return; // should not do this on server
|
||||
|
||||
conn.set_state(NetworkConnection::kSynchronized);
|
||||
|
||||
std::vector<std::shared_ptr<Message>> update_msgs;
|
||||
|
||||
// clear existing id's
|
||||
@@ -319,7 +323,7 @@ void Storage::ApplyInitialAssignments(
|
||||
entry->value = msg->value();
|
||||
entry->seq_num = seq_num;
|
||||
// don't update flags from a <3.0 remote (not part of message)
|
||||
if (proto_rev >= 0x0300) entry->flags = msg->flags();
|
||||
if (conn.proto_rev() >= 0x0300) entry->flags = msg->flags();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,11 +44,12 @@ class Storage {
|
||||
|
||||
NT_Type GetEntryType(unsigned int id) const;
|
||||
|
||||
void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn,
|
||||
unsigned int proto_rev);
|
||||
void GetInitialAssignments(std::vector<std::shared_ptr<Message>>* msgs);
|
||||
void ApplyInitialAssignments(llvm::ArrayRef<std::shared_ptr<Message>> msgs,
|
||||
bool new_server, unsigned int proto_rev,
|
||||
void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn);
|
||||
void GetInitialAssignments(NetworkConnection& conn,
|
||||
std::vector<std::shared_ptr<Message>>* msgs);
|
||||
void ApplyInitialAssignments(NetworkConnection& conn,
|
||||
llvm::ArrayRef<std::shared_ptr<Message>> msgs,
|
||||
bool new_server,
|
||||
std::vector<std::shared_ptr<Message>>* out_msgs);
|
||||
|
||||
std::mutex& mutex() { return m_mutex; }
|
||||
|
||||
Reference in New Issue
Block a user