[ntcore] Allow duplicate client IDs on server (#4676)

Currently, the server rejects duplicate client IDs. As we want to make
the client implementation as simple as possible, instead deduplicate the
name on the server side by appending "@" and a count.

NT4 spec has been updated for this change.
This commit is contained in:
Peter Johnson
2022-11-22 10:27:49 -08:00
committed by GitHub
parent 52d2c53888
commit b44034dadc
4 changed files with 69 additions and 45 deletions

View File

@@ -245,17 +245,12 @@ void ServerConnection4::ProcessWsUpgrade() {
m_websocket->open.connect([this, name = std::string{name}](std::string_view) {
m_wire = std::make_shared<net::WebSocketConnection>(*m_websocket);
// TODO: set local flag appropriately
m_clientId = m_server.m_serverImpl.AddClient(
std::string dedupName;
std::tie(dedupName, m_clientId) = m_server.m_serverImpl.AddClient(
name, m_connInfo, false, *m_wire,
[this](uint32_t repeatMs) { UpdatePeriodicTimer(repeatMs); });
if (m_clientId < 0) {
INFO("duplicate connection name '{}' (from {}), closing", name,
m_connInfo);
m_websocket->Fail(409, fmt::format("duplicate name '{}'", name));
return;
}
INFO("CONNECTED NT4 client '{}' (from {})", name, m_connInfo);
m_info.remote_id = name;
INFO("CONNECTED NT4 client '{}' (from {})", dedupName, m_connInfo);
m_info.remote_id = dedupName;
m_server.AddConnection(this, m_info);
m_websocket->closed.connect([this](uint16_t, std::string_view reason) {
INFO("DISCONNECTED NT4 client '{}' (from {}): {}", m_info.remote_id,

View File

@@ -78,10 +78,12 @@ class SImpl;
class ClientData {
public:
ClientData(std::string_view name, std::string_view connInfo, bool local,
ClientData(std::string_view originalName, std::string_view name,
std::string_view connInfo, bool local,
ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, int id,
wpi::Logger& logger)
: m_name{name},
: m_originalName{originalName},
m_name{name},
m_connInfo{connInfo},
m_local{local},
m_setPeriodic{std::move(setPeriodic)},
@@ -111,10 +113,12 @@ class ClientData {
// returns nullptr if there is no subscriber for that topic name
SubscriberData* GetSubscriber(std::string_view name, bool special);
std::string_view GetOriginalName() const { return m_originalName; }
std::string_view GetName() const { return m_name; }
int GetId() const { return m_id; }
protected:
std::string m_originalName;
std::string m_name;
std::string m_connInfo;
bool m_local; // local to machine
@@ -138,10 +142,12 @@ class ClientData {
class ClientData4Base : public ClientData, protected ClientMessageHandler {
public:
ClientData4Base(std::string_view name, std::string_view connInfo, bool local,
ClientData4Base(std::string_view originalName, std::string_view name,
std::string_view connInfo, bool local,
ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server,
int id, wpi::Logger& logger)
: ClientData{name, connInfo, local, setPeriodic, server, id, logger} {}
: ClientData{originalName, name, connInfo, local,
setPeriodic, server, id, logger} {}
protected:
// ClientMessageHandler interface
@@ -165,7 +171,8 @@ class ClientDataLocal final : public ClientData4Base {
public:
ClientDataLocal(SImpl& server, int id, wpi::Logger& logger)
: ClientData4Base{"", "", true, [](uint32_t) {}, server, id, logger} {}
: ClientData4Base{"", "", "", true, [](uint32_t) {}, server, id, logger} {
}
void ProcessIncomingText(std::string_view data) final {}
void ProcessIncomingBinary(std::span<const uint8_t> data) final {}
@@ -183,10 +190,12 @@ class ClientDataLocal final : public ClientData4Base {
class ClientData4 final : public ClientData4Base {
public:
ClientData4(std::string_view name, std::string_view connInfo, bool local,
WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic,
SImpl& server, int id, wpi::Logger& logger)
: ClientData4Base{name, connInfo, local, setPeriodic, server, id, logger},
ClientData4(std::string_view originalName, std::string_view name,
std::string_view connInfo, bool local, WireConnection& wire,
ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, int id,
wpi::Logger& logger)
: ClientData4Base{originalName, name, connInfo, local,
setPeriodic, server, id, logger},
m_wire{wire} {}
void ProcessIncomingText(std::string_view data) final;
@@ -239,7 +248,7 @@ class ClientData3 final : public ClientData, private net3::MessageHandler3 {
net3::WireConnection3& wire, ServerImpl::Connected3Func connected,
ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, int id,
wpi::Logger& logger)
: ClientData{"", connInfo, local, setPeriodic, server, id, logger},
: ClientData{"", "", connInfo, local, setPeriodic, server, id, logger},
m_connected{std::move(connected)},
m_wire{wire},
m_decoder{*this} {}
@@ -403,8 +412,9 @@ class SImpl {
TopicData* m_metaClients;
// ServerImpl interface
int AddClient(std::string_view name, std::string_view connInfo, bool local,
WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic);
std::pair<std::string, int> AddClient(
std::string_view name, std::string_view connInfo, bool local,
WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic);
int AddClient3(std::string_view connInfo, bool local,
net3::WireConnection3& wire,
ServerImpl::Connected3Func connected,
@@ -1189,11 +1199,8 @@ void ClientData3::ClientHello(std::string_view self_id,
fmt::format("unsupported protocol version {:04x}", proto_rev));
return;
}
m_name = self_id;
// create a unique name if none provided
if (m_name.empty()) {
m_name = fmt::format("NT3@{}", m_connInfo);
}
// create a unique name (just ignore provided client id)
m_name = fmt::format("NT3@{}", m_connInfo);
m_connected(m_name, 0x0300);
m_connected = nullptr; // no longer required
@@ -1487,16 +1494,22 @@ SImpl::SImpl(wpi::Logger& logger) : m_logger{logger} {
m_localClient = static_cast<ClientDataLocal*>(m_clients.back().get());
}
int SImpl::AddClient(std::string_view name, std::string_view connInfo,
bool local, WireConnection& wire,
ServerImpl::SetPeriodicFunc setPeriodic) {
std::pair<std::string, int> SImpl::AddClient(
std::string_view name, std::string_view connInfo, bool local,
WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic) {
// strip anything after @ in the name
name = wpi::split(name, '@').first;
if (name.empty()) {
name = "NT4";
}
size_t index = m_clients.size();
// find an empty slot and ensure there's no duplicates
// find an empty slot and check for duplicates
// just do a linear search as number of clients is typically small (<10)
int duplicateName = 0;
for (size_t i = 0, end = index; i < end; ++i) {
auto& clientData = m_clients[i];
if (clientData && clientData->GetName() == name) {
return -1; // don't allow duplicate client names
if (clientData && clientData->GetOriginalName() == name) {
++duplicateName;
} else if (!clientData && index == end) {
index = i;
}
@@ -1505,14 +1518,24 @@ int SImpl::AddClient(std::string_view name, std::string_view connInfo,
m_clients.emplace_back();
}
// if duplicate name, de-duplicate
std::string dedupName;
if (duplicateName > 0) {
dedupName = fmt::format("{}@{}", name, duplicateName);
} else {
dedupName = name;
}
auto& clientData = m_clients[index];
clientData = std::make_unique<ClientData4>(name, connInfo, local, wire,
std::move(setPeriodic), *this,
index, m_logger);
clientData = std::make_unique<ClientData4>(name, dedupName, connInfo, local,
wire, std::move(setPeriodic),
*this, index, m_logger);
// create client meta topics
clientData->m_metaPub = CreateMetaTopic(fmt::format("$clientpub${}", name));
clientData->m_metaSub = CreateMetaTopic(fmt::format("$clientsub${}", name));
clientData->m_metaPub =
CreateMetaTopic(fmt::format("$clientpub${}", dedupName));
clientData->m_metaSub =
CreateMetaTopic(fmt::format("$clientsub${}", dedupName));
// update meta topics
clientData->UpdateMetaClientPub();
@@ -1521,7 +1544,7 @@ int SImpl::AddClient(std::string_view name, std::string_view connInfo,
wire.Flush();
DEBUG3("AddClient('{}', '{}') -> {}", name, connInfo, index);
return index;
return {std::move(dedupName), index};
}
int SImpl::AddClient3(std::string_view connInfo, bool local,
@@ -1532,8 +1555,9 @@ int SImpl::AddClient3(std::string_view connInfo, bool local,
// find an empty slot; we can't check for duplicates until we get a hello.
// just do a linear search as number of clients is typically small (<10)
for (size_t i = 0, end = index; i < end; ++i) {
if (!m_clients[i] && index == end) {
if (!m_clients[i]) {
index = i;
break;
}
}
if (index == m_clients.size()) {
@@ -2292,9 +2316,11 @@ void ServerImpl::ProcessIncomingBinary(int clientId,
m_impl->m_clients[clientId]->ProcessIncomingBinary(data);
}
int ServerImpl::AddClient(std::string_view name, std::string_view connInfo,
bool local, WireConnection& wire,
SetPeriodicFunc setPeriodic) {
std::pair<std::string, int> ServerImpl::AddClient(std::string_view name,
std::string_view connInfo,
bool local,
WireConnection& wire,
SetPeriodicFunc setPeriodic) {
return m_impl->AddClient(name, connInfo, local, wire, std::move(setPeriodic));
}

View File

@@ -11,6 +11,7 @@
#include <span>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "NetworkInterface.h"
@@ -53,8 +54,10 @@ class ServerImpl final {
// Returns -1 if cannot add client (e.g. due to duplicate name).
// Caller must ensure WireConnection lifetime lasts until RemoveClient() call.
int AddClient(std::string_view name, std::string_view connInfo, bool local,
WireConnection& wire, SetPeriodicFunc setPeriodic);
std::pair<std::string, int> AddClient(std::string_view name,
std::string_view connInfo, bool local,
WireConnection& wire,
SetPeriodicFunc setPeriodic);
int AddClient3(std::string_view connInfo, bool local,
net3::WireConnection3& wire, Connected3Func connected,
SetPeriodicFunc setPeriodic);