mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-26 01:51:41 +00:00
Message: use shared_ptr.
NetworkConnection: Own the input and output queues.
This commit is contained in:
committed by
Peter Johnson
parent
beb92e6cbf
commit
13cbf4e288
178
src/Message.cpp
178
src/Message.cpp
@@ -14,210 +14,216 @@ static constexpr unsigned long kClearAllMagic = 0xD06CB27Aul;
|
||||
|
||||
using namespace ntimpl;
|
||||
|
||||
bool Message::Read(WireDecoder& decoder, GetEntryTypeFunc get_entry_type,
|
||||
Message* msg) {
|
||||
std::shared_ptr<Message> Message::Read(WireDecoder& decoder,
|
||||
GetEntryTypeFunc get_entry_type) {
|
||||
unsigned int msg_type;
|
||||
if (!decoder.Read8(&msg_type)) return false;
|
||||
*msg = Message(static_cast<MsgType>(msg_type));
|
||||
if (!decoder.Read8(&msg_type)) return nullptr;
|
||||
auto msg =
|
||||
std::make_shared<Message>(static_cast<MsgType>(msg_type), private_init());
|
||||
switch (msg_type) {
|
||||
case kKeepAlive:
|
||||
break;
|
||||
case kClientHello: {
|
||||
unsigned int proto_rev;
|
||||
if (!decoder.Read16(&proto_rev)) return false;
|
||||
if (!decoder.Read16(&proto_rev)) return nullptr;
|
||||
msg->m_id = proto_rev;
|
||||
// This intentionally uses the provided proto_rev instead of
|
||||
// decoder.proto_rev().
|
||||
if (proto_rev >= 0x0300u) {
|
||||
if (!decoder.ReadString(&msg->m_str)) return false;
|
||||
if (!decoder.ReadString(&msg->m_str)) return nullptr;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kProtoUnsup: {
|
||||
if (!decoder.Read16(&msg->m_id)) return false; // proto rev
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr; // proto rev
|
||||
break;
|
||||
}
|
||||
case kServerHelloDone:
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received SERVER_HELLO_DONE in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
break;
|
||||
case kClientHelloDone:
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received CLIENT_HELLO_DONE in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
break;
|
||||
case kEntryAssign: {
|
||||
if (!decoder.ReadString(&msg->m_str)) return false;
|
||||
if (!decoder.ReadString(&msg->m_str)) return nullptr;
|
||||
NT_Type type;
|
||||
if (!decoder.ReadType(&type)) return false; // name
|
||||
if (!decoder.Read16(&msg->m_id)) return false; // id
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // seq num
|
||||
if (!decoder.ReadType(&type)) return nullptr; // name
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr; // id
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num
|
||||
if (decoder.proto_rev() >= 0x0300u) {
|
||||
if (!decoder.Read8(&msg->m_flags)) return false; // flags
|
||||
if (!decoder.Read8(&msg->m_flags)) return nullptr; // flags
|
||||
}
|
||||
msg->m_value = std::make_shared<Value>();
|
||||
if (!decoder.ReadValue(type, &(*msg->m_value))) return false;
|
||||
if (!decoder.ReadValue(type, &(*msg->m_value))) return nullptr;
|
||||
break;
|
||||
}
|
||||
case kEntryUpdate: {
|
||||
if (!decoder.Read16(&msg->m_id)) return false; // id
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // seq num
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr; // id
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num
|
||||
NT_Type type;
|
||||
if (decoder.proto_rev() >= 0x0300u) {
|
||||
unsigned int itype;
|
||||
if (!decoder.Read8(&itype)) return false;
|
||||
if (!decoder.Read8(&itype)) return nullptr;
|
||||
type = static_cast<NT_Type>(itype);
|
||||
} else
|
||||
type = get_entry_type(msg->m_id);
|
||||
msg->m_value = std::make_shared<Value>();
|
||||
if (!decoder.ReadValue(type, &(*msg->m_value))) return false;
|
||||
if (!decoder.ReadValue(type, &(*msg->m_value))) return nullptr;
|
||||
break;
|
||||
}
|
||||
case kFlagsUpdate: {
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received FLAGS_UPDATE in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
if (!decoder.Read16(&msg->m_id)) return false;
|
||||
if (!decoder.Read8(&msg->m_flags)) return false;
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr;
|
||||
if (!decoder.Read8(&msg->m_flags)) return nullptr;
|
||||
break;
|
||||
}
|
||||
case kEntryDelete: {
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received ENTRY_DELETE in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
if (!decoder.Read16(&msg->m_id)) return false;
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr;
|
||||
break;
|
||||
}
|
||||
case kClearEntries: {
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received CLEAR_ENTRIES in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
unsigned long magic;
|
||||
if (!decoder.Read32(&magic)) return false;
|
||||
if (!decoder.Read32(&magic)) return nullptr;
|
||||
if (magic != kClearAllMagic) {
|
||||
decoder.set_error(
|
||||
"received incorrect CLEAR_ENTRIES magic value, ignoring");
|
||||
return true;
|
||||
return nullptr;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kExecuteRpc: {
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received EXECUTE_RPC in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
if (!decoder.Read16(&msg->m_id)) return false;
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // uid
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr;
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid
|
||||
unsigned long size;
|
||||
if (!decoder.ReadUleb128(&size)) return false;
|
||||
if (!decoder.ReadUleb128(&size)) return nullptr;
|
||||
const char* params;
|
||||
if (!decoder.Read(¶ms, size)) return false;
|
||||
if (!decoder.Read(¶ms, size)) return nullptr;
|
||||
msg->m_str = llvm::StringRef(params, size);
|
||||
break;
|
||||
}
|
||||
case kRpcResponse: {
|
||||
if (decoder.proto_rev() < 0x0300u) {
|
||||
decoder.set_error("received RPC_RESPONSE in protocol < 3.0");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
if (!decoder.Read16(&msg->m_id)) return false;
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // uid
|
||||
if (!decoder.Read16(&msg->m_id)) return nullptr;
|
||||
if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid
|
||||
unsigned long size;
|
||||
if (!decoder.ReadUleb128(&size)) return false;
|
||||
if (!decoder.ReadUleb128(&size)) return nullptr;
|
||||
const char* results;
|
||||
if (!decoder.Read(&results, size)) return false;
|
||||
if (!decoder.Read(&results, size)) return nullptr;
|
||||
msg->m_str = llvm::StringRef(results, size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
decoder.set_error("unrecognized message type");
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Message Message::ClientHello(llvm::StringRef self_id) {
|
||||
Message msg(kClientHello);
|
||||
msg.m_str = self_id;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::ServerHello(unsigned int flags, llvm::StringRef self_id) {
|
||||
Message msg(kServerHello);
|
||||
msg.m_str = self_id;
|
||||
msg.m_flags = flags;
|
||||
std::shared_ptr<Message> Message::ClientHello(llvm::StringRef self_id) {
|
||||
auto msg = std::make_shared<Message>(kClientHello, private_init());
|
||||
msg->m_str = self_id;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::EntryAssign(llvm::StringRef name, unsigned int id,
|
||||
unsigned int seq_num, std::shared_ptr<Value> value,
|
||||
unsigned int flags) {
|
||||
Message msg(kEntryAssign);
|
||||
msg.m_str = name;
|
||||
msg.m_value = value;
|
||||
msg.m_id = id;
|
||||
msg.m_flags = flags;
|
||||
msg.m_seq_num_uid = seq_num;
|
||||
std::shared_ptr<Message> Message::ServerHello(unsigned int flags,
|
||||
llvm::StringRef self_id) {
|
||||
auto msg = std::make_shared<Message>(kServerHello, private_init());
|
||||
msg->m_str = self_id;
|
||||
msg->m_flags = flags;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::EntryUpdate(unsigned int id, unsigned int seq_num,
|
||||
std::shared_ptr<Value> value) {
|
||||
Message msg(kEntryUpdate);
|
||||
msg.m_value = value;
|
||||
msg.m_id = id;
|
||||
msg.m_seq_num_uid = seq_num;
|
||||
std::shared_ptr<Message> Message::EntryAssign(llvm::StringRef name,
|
||||
unsigned int id,
|
||||
unsigned int seq_num,
|
||||
std::shared_ptr<Value> value,
|
||||
unsigned int flags) {
|
||||
auto msg = std::make_shared<Message>(kEntryAssign, private_init());
|
||||
msg->m_str = name;
|
||||
msg->m_value = value;
|
||||
msg->m_id = id;
|
||||
msg->m_flags = flags;
|
||||
msg->m_seq_num_uid = seq_num;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::FlagsUpdate(unsigned int id, unsigned int flags) {
|
||||
Message msg(kFlagsUpdate);
|
||||
msg.m_id = id;
|
||||
msg.m_flags = flags;
|
||||
std::shared_ptr<Message> Message::EntryUpdate(unsigned int id,
|
||||
unsigned int seq_num,
|
||||
std::shared_ptr<Value> value) {
|
||||
auto msg = std::make_shared<Message>(kEntryUpdate, private_init());
|
||||
msg->m_value = value;
|
||||
msg->m_id = id;
|
||||
msg->m_seq_num_uid = seq_num;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::EntryDelete(unsigned int id) {
|
||||
Message msg(kEntryDelete);
|
||||
msg.m_id = id;
|
||||
std::shared_ptr<Message> Message::FlagsUpdate(unsigned int id,
|
||||
unsigned int flags) {
|
||||
auto msg = std::make_shared<Message>(kFlagsUpdate, private_init());
|
||||
msg->m_id = id;
|
||||
msg->m_flags = flags;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> params) {
|
||||
std::shared_ptr<Message> Message::EntryDelete(unsigned int id) {
|
||||
auto msg = std::make_shared<Message>(kEntryDelete, private_init());
|
||||
msg->m_id = id;
|
||||
return msg;
|
||||
}
|
||||
|
||||
std::shared_ptr<Message> Message::ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> params) {
|
||||
WireEncoder enc(0x0300);
|
||||
for (auto& param : params) enc.WriteValue(param);
|
||||
return ExecuteRpc(id, uid, enc.ToStringRef());
|
||||
}
|
||||
|
||||
Message Message::ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef params) {
|
||||
Message msg(kExecuteRpc);
|
||||
msg.m_str = params;
|
||||
msg.m_id = id;
|
||||
msg.m_seq_num_uid = uid;
|
||||
std::shared_ptr<Message> Message::ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef params) {
|
||||
auto msg = std::make_shared<Message>(kExecuteRpc, private_init());
|
||||
msg->m_str = params;
|
||||
msg->m_id = id;
|
||||
msg->m_seq_num_uid = uid;
|
||||
return msg;
|
||||
}
|
||||
|
||||
Message Message::RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> results) {
|
||||
std::shared_ptr<Message> Message::RpcResponse(
|
||||
unsigned int id, unsigned int uid, llvm::ArrayRef<NT_Value> results) {
|
||||
WireEncoder enc(0x0300);
|
||||
for (auto& result : results) enc.WriteValue(result);
|
||||
return RpcResponse(id, uid, enc.ToStringRef());
|
||||
}
|
||||
|
||||
Message Message::RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef results) {
|
||||
Message msg(kRpcResponse);
|
||||
msg.m_str = results;
|
||||
msg.m_id = id;
|
||||
msg.m_seq_num_uid = uid;
|
||||
std::shared_ptr<Message> Message::RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef results) {
|
||||
auto msg = std::make_shared<Message>(kRpcResponse, private_init());
|
||||
msg->m_str = results;
|
||||
msg->m_id = id;
|
||||
msg->m_seq_num_uid = uid;
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ class WireDecoder;
|
||||
class WireEncoder;
|
||||
|
||||
class Message {
|
||||
struct private_init {};
|
||||
|
||||
public:
|
||||
enum MsgType {
|
||||
kUnknown = -1,
|
||||
@@ -38,40 +40,57 @@ class Message {
|
||||
typedef NT_Type (*GetEntryTypeFunc)(unsigned int id);
|
||||
|
||||
Message() : m_type(kUnknown), m_id(0), m_flags(0), m_seq_num_uid(0) {}
|
||||
Message(MsgType type, const private_init&)
|
||||
: m_type(type), m_id(0), m_flags(0), m_seq_num_uid(0) {}
|
||||
|
||||
MsgType type() const { return m_type; }
|
||||
bool Is(MsgType type) const { return type == m_type; }
|
||||
|
||||
// Read and write from wire representation
|
||||
void Write(WireEncoder& encoder) const;
|
||||
static bool Read(WireDecoder& decoder, GetEntryTypeFunc get_entry_type,
|
||||
Message* msg);
|
||||
static std::shared_ptr<Message> Read(WireDecoder& decoder,
|
||||
GetEntryTypeFunc get_entry_type);
|
||||
|
||||
// Create messages without data
|
||||
static Message KeepAlive() { return Message(kKeepAlive); }
|
||||
static Message ProtoUnsup() { return Message(kProtoUnsup); }
|
||||
static Message ServerHelloDone() { return Message(kServerHelloDone); }
|
||||
static Message ClientHelloDone() { return Message(kClientHelloDone); }
|
||||
static Message ClearEntries() { return Message(kClearEntries); }
|
||||
static std::shared_ptr<Message> KeepAlive() {
|
||||
return std::make_shared<Message>(kKeepAlive, private_init());
|
||||
}
|
||||
static std::shared_ptr<Message> ProtoUnsup() {
|
||||
return std::make_shared<Message>(kProtoUnsup, private_init());
|
||||
}
|
||||
static std::shared_ptr<Message> ServerHelloDone() {
|
||||
return std::make_shared<Message>(kServerHelloDone, private_init());
|
||||
}
|
||||
static std::shared_ptr<Message> ClientHelloDone() {
|
||||
return std::make_shared<Message>(kClientHelloDone, private_init());
|
||||
}
|
||||
static std::shared_ptr<Message> ClearEntries() {
|
||||
return std::make_shared<Message>(kClearEntries, private_init());
|
||||
}
|
||||
|
||||
// Create messages with data
|
||||
static Message ClientHello(llvm::StringRef self_id);
|
||||
static Message ServerHello(unsigned int flags, llvm::StringRef self_id);
|
||||
static Message EntryAssign(llvm::StringRef name, unsigned int id,
|
||||
unsigned int seq_num, std::shared_ptr<Value> value,
|
||||
unsigned int flags);
|
||||
static Message EntryUpdate(unsigned int id, unsigned int seq_num,
|
||||
std::shared_ptr<Value> value);
|
||||
static Message FlagsUpdate(unsigned int id, unsigned int flags);
|
||||
static Message EntryDelete(unsigned int id);
|
||||
static Message ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> params);
|
||||
static Message ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef params);
|
||||
static Message RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> results);
|
||||
static Message RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef results);
|
||||
static std::shared_ptr<Message> ClientHello(llvm::StringRef self_id);
|
||||
static std::shared_ptr<Message> ServerHello(unsigned int flags,
|
||||
llvm::StringRef self_id);
|
||||
static std::shared_ptr<Message> EntryAssign(llvm::StringRef name,
|
||||
unsigned int id,
|
||||
unsigned int seq_num,
|
||||
std::shared_ptr<Value> value,
|
||||
unsigned int flags);
|
||||
static std::shared_ptr<Message> EntryUpdate(unsigned int id,
|
||||
unsigned int seq_num,
|
||||
std::shared_ptr<Value> value);
|
||||
static std::shared_ptr<Message> FlagsUpdate(unsigned int id,
|
||||
unsigned int flags);
|
||||
static std::shared_ptr<Message> EntryDelete(unsigned int id);
|
||||
static std::shared_ptr<Message> ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> params);
|
||||
static std::shared_ptr<Message> ExecuteRpc(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef params);
|
||||
static std::shared_ptr<Message> RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::ArrayRef<NT_Value> results);
|
||||
static std::shared_ptr<Message> RpcResponse(unsigned int id, unsigned int uid,
|
||||
llvm::StringRef results);
|
||||
|
||||
Message(const Message&) = delete;
|
||||
Message& operator=(const Message&) = delete;
|
||||
@@ -79,8 +98,6 @@ class Message {
|
||||
Message& operator=(Message&&) = default;
|
||||
|
||||
private:
|
||||
Message(MsgType type) : m_type(type), m_id(0), m_flags(0), m_seq_num_uid(0) {}
|
||||
|
||||
MsgType m_type;
|
||||
|
||||
// Message data. Use varies by message type.
|
||||
|
||||
@@ -15,11 +15,8 @@
|
||||
using namespace ntimpl;
|
||||
|
||||
NetworkConnection::NetworkConnection(std::unique_ptr<TCPStream> stream,
|
||||
BatchQueue& outgoing, Queue& incoming,
|
||||
Message::GetEntryTypeFunc get_entry_type)
|
||||
: m_stream(std::move(stream)),
|
||||
m_outgoing(outgoing),
|
||||
m_incoming(incoming),
|
||||
m_get_entry_type(get_entry_type),
|
||||
m_active(false),
|
||||
m_proto_rev(0x0300) {}
|
||||
@@ -49,8 +46,12 @@ void NetworkConnection::ReadThreadMain() {
|
||||
break;
|
||||
decoder.set_proto_rev(m_proto_rev);
|
||||
decoder.Reset();
|
||||
auto msg = std::make_shared<Message>();
|
||||
if (!Message::Read(decoder, m_get_entry_type, &(*msg))) break;
|
||||
auto msg = Message::Read(decoder, m_get_entry_type);
|
||||
if (!msg) {
|
||||
// terminate connection on bad message
|
||||
m_stream->close();
|
||||
break;
|
||||
}
|
||||
m_incoming.push(msg);
|
||||
}
|
||||
m_active = false;
|
||||
|
||||
@@ -21,11 +21,13 @@ namespace ntimpl {
|
||||
|
||||
class NetworkConnection {
|
||||
public:
|
||||
typedef ConcurrentQueue<std::shared_ptr<Message>> Queue;
|
||||
typedef ConcurrentQueue<std::vector<std::shared_ptr<Message>>> BatchQueue;
|
||||
typedef std::shared_ptr<Message> Incoming;
|
||||
typedef ConcurrentQueue<Incoming> IncomingQueue;
|
||||
typedef std::vector<std::shared_ptr<Message>> Outgoing;
|
||||
typedef ConcurrentQueue<Outgoing> OutgoingQueue;
|
||||
|
||||
NetworkConnection(std::unique_ptr<TCPStream> stream, BatchQueue& outgoing,
|
||||
Queue& incoming, Message::GetEntryTypeFunc get_entry_type);
|
||||
NetworkConnection(std::unique_ptr<TCPStream> stream,
|
||||
Message::GetEntryTypeFunc get_entry_type);
|
||||
~NetworkConnection();
|
||||
|
||||
void Start();
|
||||
@@ -33,6 +35,8 @@ class NetworkConnection {
|
||||
|
||||
bool active() const { return m_active; }
|
||||
TCPStream& stream() { return *m_stream; }
|
||||
OutgoingQueue& outgoing() { return m_outgoing; }
|
||||
IncomingQueue& incoming() { return m_incoming; }
|
||||
void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; }
|
||||
|
||||
NetworkConnection(const NetworkConnection&) = delete;
|
||||
@@ -45,8 +49,8 @@ class NetworkConnection {
|
||||
void WriteThreadMain();
|
||||
|
||||
std::unique_ptr<TCPStream> m_stream;
|
||||
BatchQueue& m_outgoing;
|
||||
Queue& m_incoming;
|
||||
OutgoingQueue m_outgoing;
|
||||
IncomingQueue m_incoming;
|
||||
Message::GetEntryTypeFunc m_get_entry_type;
|
||||
std::thread m_read_thread;
|
||||
std::thread m_write_thread;
|
||||
|
||||
Reference in New Issue
Block a user