mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-23 01:21:42 +00:00
Immediately process incoming messages.
This required moving message processing into the Storage class.
This commit is contained in:
@@ -38,6 +38,11 @@ void Dispatcher::StartServer(const char* listen_address, unsigned int port) {
|
||||
m_active = true;
|
||||
}
|
||||
m_server = true;
|
||||
|
||||
using namespace std::placeholders;
|
||||
Storage::GetInstance().SetOutgoing(
|
||||
std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3), m_server);
|
||||
|
||||
m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
|
||||
m_clientserver_thread =
|
||||
std::thread(&Dispatcher::ServerThreadMain, this, listen_address, port);
|
||||
@@ -50,6 +55,11 @@ void Dispatcher::StartClient(const char* server_name, unsigned int port) {
|
||||
m_active = true;
|
||||
}
|
||||
m_server = false;
|
||||
|
||||
using namespace std::placeholders;
|
||||
Storage::GetInstance().SetOutgoing(
|
||||
std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3), m_server);
|
||||
|
||||
m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
|
||||
m_clientserver_thread =
|
||||
std::thread(&Dispatcher::ClientThreadMain, this, server_name, port);
|
||||
@@ -61,7 +71,7 @@ void Dispatcher::Stop() {
|
||||
m_active = false;
|
||||
|
||||
// close all connections
|
||||
for (auto& conn : m_connections) conn->Stop();
|
||||
for (auto& conn : m_connections) conn.net->Stop();
|
||||
}
|
||||
|
||||
// wake up dispatch thread with a flush
|
||||
@@ -76,6 +86,8 @@ void Dispatcher::Stop() {
|
||||
// join threads
|
||||
if (m_dispatch_thread.joinable()) m_dispatch_thread.join();
|
||||
if (m_clientserver_thread.joinable()) m_clientserver_thread.join();
|
||||
|
||||
Storage::GetInstance().ClearOutgoing();
|
||||
}
|
||||
|
||||
void Dispatcher::SetUpdateRate(double interval) {
|
||||
@@ -107,10 +119,11 @@ void Dispatcher::DispatchThreadMain() {
|
||||
Storage& storage = Storage::GetInstance();
|
||||
|
||||
// local copy of active m_connections
|
||||
std::vector<NetworkConnection*> connections;
|
||||
|
||||
// Outgoing messages for each remote (indexed the same as connections).
|
||||
std::vector<NetworkConnection::Outgoing> outgoing;
|
||||
struct ConnectionRef {
|
||||
NetworkConnection* net;
|
||||
NetworkConnection::Outgoing outgoing;
|
||||
};
|
||||
std::vector<ConnectionRef> connections;
|
||||
|
||||
auto timeout_time = std::chrono::steady_clock::now();
|
||||
int count = 0;
|
||||
@@ -134,45 +147,24 @@ void Dispatcher::DispatchThreadMain() {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
// clear outgoing
|
||||
outgoing.resize(0);
|
||||
|
||||
// Everything from this point forward needs to be treated as an atomic
|
||||
// operation on idmap. The user code never needs access to this, so
|
||||
// this is really a dispatcher-internal lock that only affects network
|
||||
// side code.
|
||||
std::unique_lock<std::mutex> idmap_lock(m_idmap_mutex);
|
||||
|
||||
// make a local copy of the connections list (so we don't hold the lock)
|
||||
connections.resize(0);
|
||||
{
|
||||
std::lock_guard<std::mutex> user_lock(m_user_mutex);
|
||||
for (auto& conn : m_connections) {
|
||||
if (conn->state() == NetworkConnection::kActive)
|
||||
connections.push_back(conn.get());
|
||||
if (conn.net->state() == NetworkConnection::kActive) {
|
||||
connections.push_back(ConnectionRef());
|
||||
connections.back().net = conn.net.get();
|
||||
connections.back().outgoing.swap(conn.outgoing);
|
||||
}
|
||||
}
|
||||
}
|
||||
outgoing.resize(connections.size());
|
||||
|
||||
// grab local storage updates
|
||||
Storage::UpdateMap updates;
|
||||
bool delete_all;
|
||||
storage.GetUpdates(&updates, &delete_all);
|
||||
|
||||
// special handling of delete all operation: we ignore all in-flight
|
||||
// messages
|
||||
if (delete_all) {
|
||||
// send it to all remotes
|
||||
auto outmsg = Message::ClearEntries();
|
||||
for (auto& q : outgoing) q.push_back(outmsg);
|
||||
|
||||
// empty all incoming messages
|
||||
for (auto conn : connections) {
|
||||
auto& incoming = conn->incoming();
|
||||
while (!incoming.empty()) incoming.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// local entry updates
|
||||
for (auto& update_entry : updates) {
|
||||
auto update = update_entry.getValue();
|
||||
@@ -182,205 +174,23 @@ void Dispatcher::DispatchThreadMain() {
|
||||
}
|
||||
}
|
||||
|
||||
// read all incoming messages
|
||||
for (std::size_t i=0; i<connections.size(); ++i) {
|
||||
auto conn = connections[i];
|
||||
auto& incoming = conn->incoming();
|
||||
while (!incoming.empty()) {
|
||||
auto msg = incoming.pop();
|
||||
if (!msg) continue; // should never happen, but just in case...
|
||||
switch (msg->type()) {
|
||||
case Message::kKeepAlive:
|
||||
break; // ignore
|
||||
case Message::kClientHello:
|
||||
case Message::kProtoUnsup:
|
||||
case Message::kServerHelloDone:
|
||||
case Message::kServerHello:
|
||||
case Message::kClientHelloDone:
|
||||
// shouldn't get these, but ignore if we do
|
||||
break;
|
||||
case Message::kEntryAssign: {
|
||||
unsigned int id = msg->id();
|
||||
std::shared_ptr<StorageEntry> entry;
|
||||
if (m_server) {
|
||||
// if we're a server, id=0xffff requests are requests for an id
|
||||
// to be assigned, and we need to send the new assignment back to
|
||||
// the sender as well as all other connections.
|
||||
if (id == 0xffff) {
|
||||
// see if it was already assigned; ignore if so.
|
||||
if (!storage.FindEntry(msg->str())) continue;
|
||||
|
||||
// create it locally
|
||||
id = m_idmap.size();
|
||||
entry = storage.DispatchCreateEntry(msg->str(), msg->value(),
|
||||
msg->flags());
|
||||
m_idmap.push_back(entry);
|
||||
entry->set_id(id);
|
||||
|
||||
// send the assignment to everyone (including the originator)
|
||||
auto outmsg = Message::EntryAssign(msg->str(), id,
|
||||
entry->seq_num().value(),
|
||||
msg->value(), msg->flags());
|
||||
for (auto& q : outgoing) q.push_back(outmsg);
|
||||
continue;
|
||||
}
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry assignments
|
||||
// this can happen due to e.g. assignment to deleted entry
|
||||
DEBUG("server: received assignment to unknown entry");
|
||||
continue;
|
||||
}
|
||||
entry = m_idmap[id];
|
||||
} else {
|
||||
// clients simply accept new assignments
|
||||
if (id == 0xffff) {
|
||||
DEBUG("client: received entry assignment request?");
|
||||
continue;
|
||||
}
|
||||
if (id >= m_idmap.size()) m_idmap.resize(id+1);
|
||||
entry = m_idmap[id];
|
||||
if (!entry) {
|
||||
// create local
|
||||
entry = storage.DispatchCreateEntry(msg->str(), msg->value(),
|
||||
msg->flags());
|
||||
m_idmap[id] = entry;
|
||||
entry->set_id(id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// common client and server handling
|
||||
|
||||
// already exists; ignore if sequence number not higher than local
|
||||
SequenceNumber seq_num(msg->seq_num_uid());
|
||||
if (seq_num <= entry->seq_num()) continue;
|
||||
|
||||
// sanity check: name should match id
|
||||
if (msg->str() != entry->name()) {
|
||||
DEBUG("entry assignment for same id with different name?");
|
||||
continue;
|
||||
}
|
||||
|
||||
// update local
|
||||
entry->set_value(msg->value());
|
||||
entry->set_seq_num(seq_num);
|
||||
|
||||
// don't update flags from a <3.0 remote (not part of message)
|
||||
if (conn->proto_rev() >= 0x0300) entry->set_flags(msg->flags());
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
if (m_server) {
|
||||
auto outmsg =
|
||||
Message::EntryAssign(entry->name(), id, msg->seq_num_uid(),
|
||||
msg->value(), entry->flags());
|
||||
for (std::size_t j = 0; j < connections.size(); ++j) {
|
||||
if (j != i) outgoing[j].push_back(outmsg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kEntryUpdate: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
DEBUG("received update to unknown entry");
|
||||
continue;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// ignore if sequence number not higher than local
|
||||
SequenceNumber seq_num(msg->seq_num_uid());
|
||||
if (seq_num <= entry->seq_num()) continue;
|
||||
|
||||
// update local
|
||||
entry->set_value(msg->value());
|
||||
entry->set_seq_num(seq_num);
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
if (m_server) {
|
||||
for (std::size_t j = 0; j < connections.size(); ++j) {
|
||||
if (j != i) outgoing[j].push_back(msg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kFlagsUpdate: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
DEBUG("received flags update to unknown entry");
|
||||
continue;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// update local
|
||||
entry->set_flags(msg->flags());
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
if (m_server) {
|
||||
for (std::size_t j = 0; j < connections.size(); ++j) {
|
||||
if (j != i) outgoing[j].push_back(msg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kEntryDelete: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
DEBUG("received delete to unknown entry");
|
||||
continue;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// update local
|
||||
storage.DispatchDeleteEntry(entry->name());
|
||||
entry.reset(); // delete it from idmap too
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
if (m_server) {
|
||||
for (std::size_t j = 0; j < connections.size(); ++j) {
|
||||
if (j != i) outgoing[j].push_back(msg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kClearEntries: {
|
||||
// update local
|
||||
storage.DispatchDeleteAllEntries();
|
||||
m_idmap.resize(0);
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
if (m_server) {
|
||||
for (std::size_t j = 0; j < connections.size(); ++j) {
|
||||
if (j != i) outgoing[j].push_back(msg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kExecuteRpc:
|
||||
case Message::kRpcResponse:
|
||||
// TODO
|
||||
break;
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
idmap_lock.unlock();
|
||||
// scan outgoing messages to remove unnecessary updates
|
||||
|
||||
// send outgoing messages
|
||||
for (std::size_t i = 0; i < connections.size(); ++i)
|
||||
connections[i]->outgoing().emplace(std::move(outgoing[i]));
|
||||
for (auto& conn : connections)
|
||||
conn.net->outgoing().emplace(std::move(conn.outgoing));
|
||||
}
|
||||
}
|
||||
|
||||
void Dispatcher::QueueOutgoing(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* only,
|
||||
NetworkConnection* except) {
|
||||
std::lock_guard<std::mutex> user_lock(m_user_mutex);
|
||||
for (auto& conn : m_connections) {
|
||||
if (conn.net.get() == except) continue;
|
||||
if (only && conn.net.get() != only) continue;
|
||||
if (conn.net->state() != NetworkConnection::kDead)
|
||||
conn.outgoing.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,15 +211,23 @@ void Dispatcher::ServerThreadMain(const char* listen_address,
|
||||
DEBUG("server got a connection");
|
||||
|
||||
// add to connections list
|
||||
std::unique_ptr<NetworkConnection> conn(new NetworkConnection(
|
||||
using namespace std::placeholders;
|
||||
Storage& storage = Storage::GetInstance();
|
||||
std::unique_ptr<NetworkConnection> conn_unique(new NetworkConnection(
|
||||
std::move(stream),
|
||||
[this](unsigned int id) { return GetEntryType(id); }));
|
||||
std::bind(&Storage::GetEntryType, &storage, _1),
|
||||
std::bind(&Storage::ProcessIncoming, &storage, _1, _2, _3)));
|
||||
auto conn = conn_unique.get();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_user_mutex);
|
||||
m_connections.emplace_back(std::move(conn_unique));
|
||||
}
|
||||
conn->Start();
|
||||
AddConnection(std::move(conn));
|
||||
}
|
||||
}
|
||||
|
||||
void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
|
||||
#if 0
|
||||
unsigned int proto_rev = 0x0300;
|
||||
while (m_active) {
|
||||
// get identity
|
||||
@@ -507,6 +325,7 @@ void Dispatcher::ClientThreadMain(const char* server_name, unsigned int port) {
|
||||
m_do_reconnect = false;
|
||||
lock.unlock();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool Dispatcher::ServerHandshake(
|
||||
@@ -606,16 +425,3 @@ void Dispatcher::ClientReconnect() {
|
||||
}
|
||||
m_reconnect_cv.notify_one();
|
||||
}
|
||||
|
||||
void Dispatcher::AddConnection(std::unique_ptr<NetworkConnection> conn) {
|
||||
std::lock_guard<std::mutex> lock(m_user_mutex);
|
||||
m_connections.push_back(std::move(conn));
|
||||
}
|
||||
|
||||
NT_Type Dispatcher::GetEntryType(unsigned int id) const {
|
||||
std::lock_guard<std::mutex> lock(m_idmap_mutex);
|
||||
if (id >= m_idmap.size()) return NT_UNASSIGNED;
|
||||
auto value = m_idmap[id]->value();
|
||||
if (!value) return NT_UNASSIGNED;
|
||||
return value->type();
|
||||
}
|
||||
|
||||
@@ -59,9 +59,8 @@ class Dispatcher {
|
||||
|
||||
void ClientReconnect();
|
||||
|
||||
NT_Type GetEntryType(unsigned int id) const;
|
||||
|
||||
void AddConnection(std::unique_ptr<NetworkConnection> conn);
|
||||
void QueueOutgoing(std::shared_ptr<Message> msg, NetworkConnection* only,
|
||||
NetworkConnection* except);
|
||||
|
||||
bool m_server;
|
||||
std::thread m_dispatch_thread;
|
||||
@@ -72,7 +71,14 @@ class Dispatcher {
|
||||
|
||||
// Mutex for user-accessible items
|
||||
std::mutex m_user_mutex;
|
||||
std::vector<std::unique_ptr<NetworkConnection>> m_connections;
|
||||
struct Connection {
|
||||
Connection() = default;
|
||||
explicit Connection(std::unique_ptr<NetworkConnection> net_)
|
||||
: net(std::move(net_)) {}
|
||||
std::unique_ptr<NetworkConnection> net;
|
||||
NetworkConnection::Outgoing outgoing;
|
||||
};
|
||||
std::vector<Connection> m_connections;
|
||||
std::string m_identity;
|
||||
|
||||
std::atomic_bool m_active; // set to false to terminate threads
|
||||
@@ -89,11 +95,6 @@ class Dispatcher {
|
||||
std::condition_variable m_reconnect_cv;
|
||||
bool m_do_reconnect;
|
||||
|
||||
// Map from integer id to storage entry. Id is 16-bit, so just use a vector.
|
||||
typedef std::vector<std::shared_ptr<StorageEntry>> IdMap;
|
||||
mutable std::mutex m_idmap_mutex;
|
||||
IdMap m_idmap;
|
||||
|
||||
ATOMIC_STATIC_DECL(Dispatcher)
|
||||
};
|
||||
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
using namespace nt;
|
||||
|
||||
NetworkConnection::NetworkConnection(std::unique_ptr<TCPStream> stream,
|
||||
Message::GetEntryTypeFunc get_entry_type)
|
||||
Message::GetEntryTypeFunc get_entry_type,
|
||||
ProcessIncomingFunc process_incoming)
|
||||
: m_stream(std::move(stream)),
|
||||
m_get_entry_type(get_entry_type) {
|
||||
m_get_entry_type(get_entry_type),
|
||||
m_process_incoming(process_incoming) {
|
||||
m_active = false;
|
||||
m_proto_rev = 0x0300;
|
||||
m_state = static_cast<int>(kCreated);
|
||||
@@ -29,8 +31,7 @@ void NetworkConnection::Start() {
|
||||
if (m_active) return;
|
||||
m_active = true;
|
||||
m_state = static_cast<int>(kInit);
|
||||
// clear queues
|
||||
while (!m_incoming.empty()) m_incoming.pop();
|
||||
// clear queue
|
||||
while (!m_outgoing.empty()) m_outgoing.pop();
|
||||
// start threads
|
||||
m_write_thread = std::thread(&NetworkConnection::WriteThreadMain, this);
|
||||
@@ -47,8 +48,7 @@ void NetworkConnection::Stop() {
|
||||
// wait for threads to terminate
|
||||
if (m_write_thread.joinable()) m_write_thread.join();
|
||||
if (m_read_thread.joinable()) m_read_thread.join();
|
||||
// clear queues
|
||||
while (!m_incoming.empty()) m_incoming.pop();
|
||||
// clear queue
|
||||
while (!m_outgoing.empty()) m_outgoing.pop();
|
||||
}
|
||||
|
||||
@@ -77,9 +77,8 @@ void NetworkConnection::ReadThreadMain() {
|
||||
if (m_stream) m_stream->close();
|
||||
break;
|
||||
}
|
||||
m_incoming.emplace(std::move(msg));
|
||||
m_process_incoming(std::move(msg), this, m_proto_rev);
|
||||
}
|
||||
m_incoming.emplace(nullptr); // notify anyone waiting that we disconnected
|
||||
m_state = static_cast<int>(kDead);
|
||||
m_active = false;
|
||||
}
|
||||
@@ -92,7 +91,9 @@ void NetworkConnection::WriteThreadMain() {
|
||||
if (msgs.empty()) break;
|
||||
encoder.set_proto_rev(m_proto_rev);
|
||||
encoder.Reset();
|
||||
for (auto& msg : msgs) msg->Write(encoder);
|
||||
for (auto& msg : msgs) {
|
||||
if (msg) msg->Write(encoder);
|
||||
}
|
||||
TCPStream::Error err;
|
||||
if (!m_stream) break;
|
||||
if (m_stream->send(encoder.data(), encoder.size(), &err) == 0) break;
|
||||
|
||||
@@ -23,13 +23,15 @@ class NetworkConnection {
|
||||
public:
|
||||
enum State { kCreated, kInit, kHandshake, kActive, kDead };
|
||||
|
||||
typedef std::shared_ptr<Message> Incoming;
|
||||
typedef ConcurrentQueue<Incoming> IncomingQueue;
|
||||
typedef std::function<void(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* conn, unsigned int proto_rev)>
|
||||
ProcessIncomingFunc;
|
||||
typedef std::vector<std::shared_ptr<Message>> Outgoing;
|
||||
typedef ConcurrentQueue<Outgoing> OutgoingQueue;
|
||||
|
||||
NetworkConnection(std::unique_ptr<TCPStream> stream,
|
||||
Message::GetEntryTypeFunc get_entry_type);
|
||||
Message::GetEntryTypeFunc get_entry_type,
|
||||
ProcessIncomingFunc process_incoming);
|
||||
~NetworkConnection();
|
||||
|
||||
void Start();
|
||||
@@ -38,7 +40,6 @@ class NetworkConnection {
|
||||
bool active() const { return m_active; }
|
||||
TCPStream& stream() { return *m_stream; }
|
||||
OutgoingQueue& outgoing() { return m_outgoing; }
|
||||
IncomingQueue& incoming() { return m_incoming; }
|
||||
|
||||
unsigned int proto_rev() const { return m_proto_rev; }
|
||||
void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; }
|
||||
@@ -58,8 +59,8 @@ class NetworkConnection {
|
||||
|
||||
std::unique_ptr<TCPStream> m_stream;
|
||||
OutgoingQueue m_outgoing;
|
||||
IncomingQueue m_incoming;
|
||||
Message::GetEntryTypeFunc m_get_entry_type;
|
||||
ProcessIncomingFunc m_process_incoming;
|
||||
std::thread m_read_thread;
|
||||
std::thread m_write_thread;
|
||||
std::atomic_bool m_active;
|
||||
|
||||
219
src/Storage.cpp
219
src/Storage.cpp
@@ -15,6 +15,8 @@
|
||||
|
||||
using namespace nt;
|
||||
|
||||
#define DEBUG(str) puts(str)
|
||||
|
||||
ATOMIC_STATIC_INIT(Storage)
|
||||
|
||||
Storage::Storage() {
|
||||
@@ -23,28 +25,213 @@ Storage::Storage() {
|
||||
|
||||
Storage::~Storage() {}
|
||||
|
||||
std::shared_ptr<StorageEntry> Storage::DispatchCreateEntry(
|
||||
StringRef name, std::shared_ptr<Value> value, unsigned int flags) {
|
||||
void Storage::SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server) {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
auto& entry = m_entries[name];
|
||||
if (!entry) entry = std::make_shared<StorageEntry>(name);
|
||||
entry->set_value(value);
|
||||
entry->set_flags(flags);
|
||||
return entry;
|
||||
m_queue_outgoing = queue_outgoing;
|
||||
m_server = server;
|
||||
}
|
||||
|
||||
void Storage::DispatchDeleteEntry(StringRef name) {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
auto i = m_entries.find(name);
|
||||
if (i == m_entries.end()) return;
|
||||
auto entry = i->getValue();
|
||||
m_entries.erase(i); // erase from map
|
||||
void Storage::ClearOutgoing() {
|
||||
m_queue_outgoing = nullptr;
|
||||
}
|
||||
|
||||
void Storage::DispatchDeleteAllEntries() {
|
||||
NT_Type Storage::GetEntryType(unsigned int id) const {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
if (m_entries.empty()) return;
|
||||
m_entries.clear();
|
||||
if (id >= m_idmap.size()) return NT_UNASSIGNED;
|
||||
auto value = m_idmap[id]->value();
|
||||
if (!value) return NT_UNASSIGNED;
|
||||
return value->type();
|
||||
}
|
||||
|
||||
void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* conn, unsigned int proto_rev) {
|
||||
if (!m_queue_outgoing) return; // sanity check
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
switch (msg->type()) {
|
||||
case Message::kKeepAlive:
|
||||
break; // ignore
|
||||
case Message::kClientHello:
|
||||
case Message::kProtoUnsup:
|
||||
case Message::kServerHelloDone:
|
||||
case Message::kServerHello:
|
||||
case Message::kClientHelloDone:
|
||||
// shouldn't get these, but ignore if we do
|
||||
break;
|
||||
case Message::kEntryAssign: {
|
||||
unsigned int id = msg->id();
|
||||
StringRef name = msg->str();
|
||||
std::shared_ptr<StorageEntry> entry;
|
||||
if (m_server) {
|
||||
// if we're a server, id=0xffff requests are requests for an id
|
||||
// to be assigned, and we need to send the new assignment back to
|
||||
// the sender as well as all other connections.
|
||||
if (id == 0xffff) {
|
||||
// see if it was already assigned; ignore if so.
|
||||
if (m_entries.count(name) != 0) return;
|
||||
|
||||
// create it locally
|
||||
id = m_idmap.size();
|
||||
auto& new_entry = m_entries[name];
|
||||
if (!new_entry) new_entry = std::make_shared<StorageEntry>(name);
|
||||
entry = new_entry;
|
||||
entry->set_value(msg->value());
|
||||
entry->set_flags(msg->flags());
|
||||
entry->set_id(id);
|
||||
m_idmap.push_back(entry);
|
||||
|
||||
// send the assignment to everyone (including the originator)
|
||||
lock.unlock();
|
||||
m_queue_outgoing(
|
||||
Message::EntryAssign(name, id, entry->seq_num().value(),
|
||||
msg->value(), msg->flags()),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
}
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry assignments
|
||||
// this can happen due to e.g. assignment to deleted entry
|
||||
lock.unlock();
|
||||
DEBUG("server: received assignment to unknown entry");
|
||||
return;
|
||||
}
|
||||
entry = m_idmap[id];
|
||||
} else {
|
||||
// clients simply accept new assignments
|
||||
if (id == 0xffff) {
|
||||
lock.unlock();
|
||||
DEBUG("client: received entry assignment request?");
|
||||
return;
|
||||
}
|
||||
if (id >= m_idmap.size()) m_idmap.resize(id+1);
|
||||
entry = m_idmap[id];
|
||||
if (!entry) {
|
||||
// create local
|
||||
auto& new_entry = m_entries[name];
|
||||
if (!new_entry) new_entry = std::make_shared<StorageEntry>(name);
|
||||
entry = new_entry;
|
||||
entry->set_value(msg->value());
|
||||
entry->set_flags(msg->flags());
|
||||
entry->set_id(id);
|
||||
m_idmap[id] = entry;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// common client and server handling
|
||||
|
||||
// already exists; ignore if sequence number not higher than local
|
||||
SequenceNumber seq_num(msg->seq_num_uid());
|
||||
if (seq_num <= entry->seq_num()) return;
|
||||
|
||||
// sanity check: name should match id
|
||||
if (msg->str() != entry->name()) {
|
||||
lock.unlock();
|
||||
DEBUG("entry assignment for same id with different name?");
|
||||
return;
|
||||
}
|
||||
|
||||
// update local
|
||||
entry->set_value(msg->value());
|
||||
entry->set_seq_num(seq_num);
|
||||
|
||||
// don't update flags from a <3.0 remote (not part of message)
|
||||
if (proto_rev >= 0x0300) entry->set_flags(msg->flags());
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
lock.unlock();
|
||||
if (m_server) {
|
||||
m_queue_outgoing(
|
||||
Message::EntryAssign(entry->name(), id, msg->seq_num_uid(),
|
||||
msg->value(), entry->flags()),
|
||||
nullptr, conn);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Message::kEntryUpdate: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
lock.unlock();
|
||||
DEBUG("received update to unknown entry");
|
||||
return;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// ignore if sequence number not higher than local
|
||||
SequenceNumber seq_num(msg->seq_num_uid());
|
||||
if (seq_num <= entry->seq_num()) return;
|
||||
|
||||
// update local
|
||||
entry->set_value(msg->value());
|
||||
entry->set_seq_num(seq_num);
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
lock.unlock();
|
||||
if (m_server) m_queue_outgoing(msg, nullptr, conn);
|
||||
break;
|
||||
}
|
||||
case Message::kFlagsUpdate: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
lock.unlock();
|
||||
DEBUG("received flags update to unknown entry");
|
||||
return;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// update local
|
||||
entry->set_flags(msg->flags());
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
lock.unlock();
|
||||
if (m_server) m_queue_outgoing(msg, nullptr, conn);
|
||||
break;
|
||||
}
|
||||
case Message::kEntryDelete: {
|
||||
unsigned int id = msg->id();
|
||||
if (id >= m_idmap.size() || !m_idmap[id]) {
|
||||
// ignore arbitrary entry updates;
|
||||
// this can happen due to deleted entries
|
||||
lock.unlock();
|
||||
DEBUG("received delete to unknown entry");
|
||||
return;
|
||||
}
|
||||
auto& entry = m_idmap[id];
|
||||
|
||||
// update local
|
||||
m_entries.erase(entry->name()); // erase from map
|
||||
entry.reset(); // delete it from idmap too
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
lock.unlock();
|
||||
if (m_server) m_queue_outgoing(msg, nullptr, conn);
|
||||
break;
|
||||
}
|
||||
case Message::kClearEntries: {
|
||||
// update local
|
||||
m_entries.clear();
|
||||
m_idmap.resize(0);
|
||||
|
||||
// broadcast to all other connections (note for client there won't
|
||||
// be any other connections, so don't bother)
|
||||
lock.unlock();
|
||||
if (m_server) m_queue_outgoing(msg, nullptr, conn);
|
||||
break;
|
||||
}
|
||||
case Message::kExecuteRpc:
|
||||
case Message::kRpcResponse:
|
||||
// TODO
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Storage::GetUpdates(UpdateMap* updates, bool* delete_all) {
|
||||
|
||||
@@ -16,13 +16,14 @@
|
||||
#include <mutex>
|
||||
|
||||
#include "llvm/StringMap.h"
|
||||
#include "support/ConcurrentQueue.h"
|
||||
#include "atomic_static.h"
|
||||
#include "Message.h"
|
||||
#include "ntcore_cpp.h"
|
||||
#include "SequenceNumber.h"
|
||||
|
||||
namespace nt {
|
||||
|
||||
class NetworkConnection;
|
||||
class StorageTest;
|
||||
|
||||
class StorageEntry {
|
||||
@@ -106,14 +107,21 @@ class Storage {
|
||||
};
|
||||
typedef llvm::StringMap<Update> UpdateMap;
|
||||
|
||||
typedef std::function<void(std::shared_ptr<Message> msg,
|
||||
NetworkConnection* only,
|
||||
NetworkConnection* except)> QueueOutgoingFunc;
|
||||
void SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server);
|
||||
void ClearOutgoing();
|
||||
|
||||
NT_Type GetEntryType(unsigned int id) const;
|
||||
|
||||
void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn,
|
||||
unsigned int proto_rev);
|
||||
|
||||
// Finds, but does not create entry. Returns nullptr if not found.
|
||||
std::shared_ptr<StorageEntry> FindEntry(StringRef name) const;
|
||||
|
||||
// Accessors required by Dispatcher.
|
||||
std::shared_ptr<StorageEntry> DispatchCreateEntry(
|
||||
StringRef name, std::shared_ptr<Value> value, unsigned int flags);
|
||||
void DispatchDeleteEntry(StringRef name);
|
||||
void DispatchDeleteAllEntries();
|
||||
void GetUpdates(UpdateMap* updates, bool* delete_all);
|
||||
std::mutex& mutex() { return m_mutex; }
|
||||
|
||||
@@ -140,12 +148,17 @@ class Storage {
|
||||
void AddUpdate(std::shared_ptr<StorageEntry> entry, Update::Kind kind);
|
||||
|
||||
typedef llvm::StringMap<std::shared_ptr<StorageEntry>> EntriesMap;
|
||||
typedef std::vector<std::shared_ptr<StorageEntry>> IdMap;
|
||||
|
||||
mutable std::mutex m_mutex;
|
||||
EntriesMap m_entries;
|
||||
IdMap m_idmap;
|
||||
UpdateMap m_updates;
|
||||
bool m_updates_delete_all;
|
||||
|
||||
QueueOutgoingFunc m_queue_outgoing;
|
||||
bool m_server;
|
||||
|
||||
ATOMIC_STATIC_DECL(Storage)
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user