From 13cbf4e2884b061cfc30297093efe20d28b6624f Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Wed, 15 Jul 2015 21:20:18 -0700 Subject: [PATCH] Message: use shared_ptr. NetworkConnection: Own the input and output queues. --- src/Message.cpp | 178 ++++++++++++++++++++------------------ src/Message.h | 69 +++++++++------ src/NetworkConnection.cpp | 11 +-- src/NetworkConnection.h | 16 ++-- 4 files changed, 151 insertions(+), 123 deletions(-) diff --git a/src/Message.cpp b/src/Message.cpp index 8ea63e600b..ca0a452003 100644 --- a/src/Message.cpp +++ b/src/Message.cpp @@ -14,210 +14,216 @@ static constexpr unsigned long kClearAllMagic = 0xD06CB27Aul; using namespace ntimpl; -bool Message::Read(WireDecoder& decoder, GetEntryTypeFunc get_entry_type, - Message* msg) { +std::shared_ptr Message::Read(WireDecoder& decoder, + GetEntryTypeFunc get_entry_type) { unsigned int msg_type; - if (!decoder.Read8(&msg_type)) return false; - *msg = Message(static_cast(msg_type)); + if (!decoder.Read8(&msg_type)) return nullptr; + auto msg = + std::make_shared(static_cast(msg_type), private_init()); switch (msg_type) { case kKeepAlive: break; case kClientHello: { unsigned int proto_rev; - if (!decoder.Read16(&proto_rev)) return false; + if (!decoder.Read16(&proto_rev)) return nullptr; msg->m_id = proto_rev; // This intentionally uses the provided proto_rev instead of // decoder.proto_rev(). if (proto_rev >= 0x0300u) { - if (!decoder.ReadString(&msg->m_str)) return false; + if (!decoder.ReadString(&msg->m_str)) return nullptr; } break; } case kProtoUnsup: { - if (!decoder.Read16(&msg->m_id)) return false; // proto rev + if (!decoder.Read16(&msg->m_id)) return nullptr; // proto rev break; } case kServerHelloDone: if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received SERVER_HELLO_DONE in protocol < 3.0"); - return false; + return nullptr; } break; case kClientHelloDone: if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received CLIENT_HELLO_DONE in protocol < 3.0"); - return false; + return nullptr; } break; case kEntryAssign: { - if (!decoder.ReadString(&msg->m_str)) return false; + if (!decoder.ReadString(&msg->m_str)) return nullptr; NT_Type type; - if (!decoder.ReadType(&type)) return false; // name - if (!decoder.Read16(&msg->m_id)) return false; // id - if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // seq num + if (!decoder.ReadType(&type)) return nullptr; // name + if (!decoder.Read16(&msg->m_id)) return nullptr; // id + if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num if (decoder.proto_rev() >= 0x0300u) { - if (!decoder.Read8(&msg->m_flags)) return false; // flags + if (!decoder.Read8(&msg->m_flags)) return nullptr; // flags } msg->m_value = std::make_shared(); - if (!decoder.ReadValue(type, &(*msg->m_value))) return false; + if (!decoder.ReadValue(type, &(*msg->m_value))) return nullptr; break; } case kEntryUpdate: { - if (!decoder.Read16(&msg->m_id)) return false; // id - if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // seq num + if (!decoder.Read16(&msg->m_id)) return nullptr; // id + if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num NT_Type type; if (decoder.proto_rev() >= 0x0300u) { unsigned int itype; - if (!decoder.Read8(&itype)) return false; + if (!decoder.Read8(&itype)) return nullptr; type = static_cast(itype); } else type = get_entry_type(msg->m_id); msg->m_value = std::make_shared(); - if (!decoder.ReadValue(type, &(*msg->m_value))) return false; + if (!decoder.ReadValue(type, &(*msg->m_value))) return nullptr; break; } case kFlagsUpdate: { if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received FLAGS_UPDATE in protocol < 3.0"); - return false; + return nullptr; } - if (!decoder.Read16(&msg->m_id)) return false; - if (!decoder.Read8(&msg->m_flags)) return false; + if (!decoder.Read16(&msg->m_id)) return nullptr; + if (!decoder.Read8(&msg->m_flags)) return nullptr; break; } case kEntryDelete: { if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received ENTRY_DELETE in protocol < 3.0"); - return false; + return nullptr; } - if (!decoder.Read16(&msg->m_id)) return false; + if (!decoder.Read16(&msg->m_id)) return nullptr; break; } case kClearEntries: { if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received CLEAR_ENTRIES in protocol < 3.0"); - return false; + return nullptr; } unsigned long magic; - if (!decoder.Read32(&magic)) return false; + if (!decoder.Read32(&magic)) return nullptr; if (magic != kClearAllMagic) { decoder.set_error( "received incorrect CLEAR_ENTRIES magic value, ignoring"); - return true; + return nullptr; } break; } case kExecuteRpc: { if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received EXECUTE_RPC in protocol < 3.0"); - return false; + return nullptr; } - if (!decoder.Read16(&msg->m_id)) return false; - if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // uid + if (!decoder.Read16(&msg->m_id)) return nullptr; + if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid unsigned long size; - if (!decoder.ReadUleb128(&size)) return false; + if (!decoder.ReadUleb128(&size)) return nullptr; const char* params; - if (!decoder.Read(¶ms, size)) return false; + if (!decoder.Read(¶ms, size)) return nullptr; msg->m_str = llvm::StringRef(params, size); break; } case kRpcResponse: { if (decoder.proto_rev() < 0x0300u) { decoder.set_error("received RPC_RESPONSE in protocol < 3.0"); - return false; + return nullptr; } - if (!decoder.Read16(&msg->m_id)) return false; - if (!decoder.Read16(&msg->m_seq_num_uid)) return false; // uid + if (!decoder.Read16(&msg->m_id)) return nullptr; + if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid unsigned long size; - if (!decoder.ReadUleb128(&size)) return false; + if (!decoder.ReadUleb128(&size)) return nullptr; const char* results; - if (!decoder.Read(&results, size)) return false; + if (!decoder.Read(&results, size)) return nullptr; msg->m_str = llvm::StringRef(results, size); break; } default: decoder.set_error("unrecognized message type"); - return false; + return nullptr; } - return true; -} - -Message Message::ClientHello(llvm::StringRef self_id) { - Message msg(kClientHello); - msg.m_str = self_id; return msg; } -Message Message::ServerHello(unsigned int flags, llvm::StringRef self_id) { - Message msg(kServerHello); - msg.m_str = self_id; - msg.m_flags = flags; +std::shared_ptr Message::ClientHello(llvm::StringRef self_id) { + auto msg = std::make_shared(kClientHello, private_init()); + msg->m_str = self_id; return msg; } -Message Message::EntryAssign(llvm::StringRef name, unsigned int id, - unsigned int seq_num, std::shared_ptr value, - unsigned int flags) { - Message msg(kEntryAssign); - msg.m_str = name; - msg.m_value = value; - msg.m_id = id; - msg.m_flags = flags; - msg.m_seq_num_uid = seq_num; +std::shared_ptr Message::ServerHello(unsigned int flags, + llvm::StringRef self_id) { + auto msg = std::make_shared(kServerHello, private_init()); + msg->m_str = self_id; + msg->m_flags = flags; return msg; } -Message Message::EntryUpdate(unsigned int id, unsigned int seq_num, - std::shared_ptr value) { - Message msg(kEntryUpdate); - msg.m_value = value; - msg.m_id = id; - msg.m_seq_num_uid = seq_num; +std::shared_ptr Message::EntryAssign(llvm::StringRef name, + unsigned int id, + unsigned int seq_num, + std::shared_ptr value, + unsigned int flags) { + auto msg = std::make_shared(kEntryAssign, private_init()); + msg->m_str = name; + msg->m_value = value; + msg->m_id = id; + msg->m_flags = flags; + msg->m_seq_num_uid = seq_num; return msg; } -Message Message::FlagsUpdate(unsigned int id, unsigned int flags) { - Message msg(kFlagsUpdate); - msg.m_id = id; - msg.m_flags = flags; +std::shared_ptr Message::EntryUpdate(unsigned int id, + unsigned int seq_num, + std::shared_ptr value) { + auto msg = std::make_shared(kEntryUpdate, private_init()); + msg->m_value = value; + msg->m_id = id; + msg->m_seq_num_uid = seq_num; return msg; } -Message Message::EntryDelete(unsigned int id) { - Message msg(kEntryDelete); - msg.m_id = id; +std::shared_ptr Message::FlagsUpdate(unsigned int id, + unsigned int flags) { + auto msg = std::make_shared(kFlagsUpdate, private_init()); + msg->m_id = id; + msg->m_flags = flags; return msg; } -Message Message::ExecuteRpc(unsigned int id, unsigned int uid, - llvm::ArrayRef params) { +std::shared_ptr Message::EntryDelete(unsigned int id) { + auto msg = std::make_shared(kEntryDelete, private_init()); + msg->m_id = id; + return msg; +} + +std::shared_ptr Message::ExecuteRpc(unsigned int id, unsigned int uid, + llvm::ArrayRef params) { WireEncoder enc(0x0300); for (auto& param : params) enc.WriteValue(param); return ExecuteRpc(id, uid, enc.ToStringRef()); } -Message Message::ExecuteRpc(unsigned int id, unsigned int uid, - llvm::StringRef params) { - Message msg(kExecuteRpc); - msg.m_str = params; - msg.m_id = id; - msg.m_seq_num_uid = uid; +std::shared_ptr Message::ExecuteRpc(unsigned int id, unsigned int uid, + llvm::StringRef params) { + auto msg = std::make_shared(kExecuteRpc, private_init()); + msg->m_str = params; + msg->m_id = id; + msg->m_seq_num_uid = uid; return msg; } -Message Message::RpcResponse(unsigned int id, unsigned int uid, - llvm::ArrayRef results) { +std::shared_ptr Message::RpcResponse( + unsigned int id, unsigned int uid, llvm::ArrayRef results) { WireEncoder enc(0x0300); for (auto& result : results) enc.WriteValue(result); return RpcResponse(id, uid, enc.ToStringRef()); } -Message Message::RpcResponse(unsigned int id, unsigned int uid, - llvm::StringRef results) { - Message msg(kRpcResponse); - msg.m_str = results; - msg.m_id = id; - msg.m_seq_num_uid = uid; +std::shared_ptr Message::RpcResponse(unsigned int id, unsigned int uid, + llvm::StringRef results) { + auto msg = std::make_shared(kRpcResponse, private_init()); + msg->m_str = results; + msg->m_id = id; + msg->m_seq_num_uid = uid; return msg; } diff --git a/src/Message.h b/src/Message.h index f1e05f62ff..0bc0c4cb2b 100644 --- a/src/Message.h +++ b/src/Message.h @@ -18,6 +18,8 @@ class WireDecoder; class WireEncoder; class Message { + struct private_init {}; + public: enum MsgType { kUnknown = -1, @@ -38,40 +40,57 @@ class Message { typedef NT_Type (*GetEntryTypeFunc)(unsigned int id); Message() : m_type(kUnknown), m_id(0), m_flags(0), m_seq_num_uid(0) {} + Message(MsgType type, const private_init&) + : m_type(type), m_id(0), m_flags(0), m_seq_num_uid(0) {} MsgType type() const { return m_type; } bool Is(MsgType type) const { return type == m_type; } // Read and write from wire representation void Write(WireEncoder& encoder) const; - static bool Read(WireDecoder& decoder, GetEntryTypeFunc get_entry_type, - Message* msg); + static std::shared_ptr Read(WireDecoder& decoder, + GetEntryTypeFunc get_entry_type); // Create messages without data - static Message KeepAlive() { return Message(kKeepAlive); } - static Message ProtoUnsup() { return Message(kProtoUnsup); } - static Message ServerHelloDone() { return Message(kServerHelloDone); } - static Message ClientHelloDone() { return Message(kClientHelloDone); } - static Message ClearEntries() { return Message(kClearEntries); } + static std::shared_ptr KeepAlive() { + return std::make_shared(kKeepAlive, private_init()); + } + static std::shared_ptr ProtoUnsup() { + return std::make_shared(kProtoUnsup, private_init()); + } + static std::shared_ptr ServerHelloDone() { + return std::make_shared(kServerHelloDone, private_init()); + } + static std::shared_ptr ClientHelloDone() { + return std::make_shared(kClientHelloDone, private_init()); + } + static std::shared_ptr ClearEntries() { + return std::make_shared(kClearEntries, private_init()); + } // Create messages with data - static Message ClientHello(llvm::StringRef self_id); - static Message ServerHello(unsigned int flags, llvm::StringRef self_id); - static Message EntryAssign(llvm::StringRef name, unsigned int id, - unsigned int seq_num, std::shared_ptr value, - unsigned int flags); - static Message EntryUpdate(unsigned int id, unsigned int seq_num, - std::shared_ptr value); - static Message FlagsUpdate(unsigned int id, unsigned int flags); - static Message EntryDelete(unsigned int id); - static Message ExecuteRpc(unsigned int id, unsigned int uid, - llvm::ArrayRef params); - static Message ExecuteRpc(unsigned int id, unsigned int uid, - llvm::StringRef params); - static Message RpcResponse(unsigned int id, unsigned int uid, - llvm::ArrayRef results); - static Message RpcResponse(unsigned int id, unsigned int uid, - llvm::StringRef results); + static std::shared_ptr ClientHello(llvm::StringRef self_id); + static std::shared_ptr ServerHello(unsigned int flags, + llvm::StringRef self_id); + static std::shared_ptr EntryAssign(llvm::StringRef name, + unsigned int id, + unsigned int seq_num, + std::shared_ptr value, + unsigned int flags); + static std::shared_ptr EntryUpdate(unsigned int id, + unsigned int seq_num, + std::shared_ptr value); + static std::shared_ptr FlagsUpdate(unsigned int id, + unsigned int flags); + static std::shared_ptr EntryDelete(unsigned int id); + static std::shared_ptr ExecuteRpc(unsigned int id, unsigned int uid, + llvm::ArrayRef params); + static std::shared_ptr ExecuteRpc(unsigned int id, unsigned int uid, + llvm::StringRef params); + static std::shared_ptr RpcResponse(unsigned int id, unsigned int uid, + llvm::ArrayRef results); + static std::shared_ptr RpcResponse(unsigned int id, unsigned int uid, + llvm::StringRef results); Message(const Message&) = delete; Message& operator=(const Message&) = delete; @@ -79,8 +98,6 @@ class Message { Message& operator=(Message&&) = default; private: - Message(MsgType type) : m_type(type), m_id(0), m_flags(0), m_seq_num_uid(0) {} - MsgType m_type; // Message data. Use varies by message type. diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp index 2fe3d9a49b..c985373484 100644 --- a/src/NetworkConnection.cpp +++ b/src/NetworkConnection.cpp @@ -15,11 +15,8 @@ using namespace ntimpl; NetworkConnection::NetworkConnection(std::unique_ptr stream, - BatchQueue& outgoing, Queue& incoming, Message::GetEntryTypeFunc get_entry_type) : m_stream(std::move(stream)), - m_outgoing(outgoing), - m_incoming(incoming), m_get_entry_type(get_entry_type), m_active(false), m_proto_rev(0x0300) {} @@ -49,8 +46,12 @@ void NetworkConnection::ReadThreadMain() { break; decoder.set_proto_rev(m_proto_rev); decoder.Reset(); - auto msg = std::make_shared(); - if (!Message::Read(decoder, m_get_entry_type, &(*msg))) break; + auto msg = Message::Read(decoder, m_get_entry_type); + if (!msg) { + // terminate connection on bad message + m_stream->close(); + break; + } m_incoming.push(msg); } m_active = false; diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h index 79aa57c6c6..abd4a7ebd2 100644 --- a/src/NetworkConnection.h +++ b/src/NetworkConnection.h @@ -21,11 +21,13 @@ namespace ntimpl { class NetworkConnection { public: - typedef ConcurrentQueue> Queue; - typedef ConcurrentQueue>> BatchQueue; + typedef std::shared_ptr Incoming; + typedef ConcurrentQueue IncomingQueue; + typedef std::vector> Outgoing; + typedef ConcurrentQueue OutgoingQueue; - NetworkConnection(std::unique_ptr stream, BatchQueue& outgoing, - Queue& incoming, Message::GetEntryTypeFunc get_entry_type); + NetworkConnection(std::unique_ptr stream, + Message::GetEntryTypeFunc get_entry_type); ~NetworkConnection(); void Start(); @@ -33,6 +35,8 @@ class NetworkConnection { bool active() const { return m_active; } TCPStream& stream() { return *m_stream; } + OutgoingQueue& outgoing() { return m_outgoing; } + IncomingQueue& incoming() { return m_incoming; } void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; } NetworkConnection(const NetworkConnection&) = delete; @@ -45,8 +49,8 @@ class NetworkConnection { void WriteThreadMain(); std::unique_ptr m_stream; - BatchQueue& m_outgoing; - Queue& m_incoming; + OutgoingQueue m_outgoing; + IncomingQueue m_incoming; Message::GetEntryTypeFunc m_get_entry_type; std::thread m_read_thread; std::thread m_write_thread;