[ntcore] Networking improvements (#5659)

- Utilize TrySend to properly backpressure network traffic
- Split updates into reasonable sized frames using WS fragmentation
- Use WS pings for network aliveness (requires 4.1 protocol revision)
- Measure RTT only at start of connection, rather than periodically
  (this avoids them being affected by other network traffic)
- Refactor network queue
- Refactor network ping, ping from server as well
- Improve meta topic performance
- Implement unified approach for network value updates (currently client and server use very different approaches) that respects requested subscriber update frequency

This adds a new protocol version (4.1) due to WS bugs in prior versions.
This commit is contained in:
Peter Johnson
2023-10-04 22:02:42 -07:00
committed by GitHub
parent 1d19e09ca9
commit 8b7c6852cf
21 changed files with 1369 additions and 950 deletions

View File

@@ -125,8 +125,8 @@ void NetworkClientBase::DoDisconnect(std::string_view reason) {
if (m_readLocalTimer) {
m_readLocalTimer->Stop();
}
if (m_sendValuesTimer) {
m_sendValuesTimer->Stop();
if (m_sendOutgoingTimer) {
m_sendOutgoingTimer->Stop();
}
m_localStorage.ClearNetwork();
m_localQueue.ClearQueue();
@@ -150,9 +150,9 @@ NetworkClient3::NetworkClient3(int inst, std::string_view id,
loop, kReconnectRate, m_logger,
[this](uv::Tcp& tcp) { TcpConnected(tcp); });
m_sendValuesTimer = uv::Timer::Create(loop);
if (m_sendValuesTimer) {
m_sendValuesTimer->timeout.connect([this] {
m_sendOutgoingTimer = uv::Timer::Create(loop);
if (m_sendOutgoingTimer) {
m_sendOutgoingTimer->timeout.connect([this] {
if (m_clientImpl) {
HandleLocal();
m_clientImpl->SendPeriodic(m_loop.Now().count(), false);
@@ -206,9 +206,9 @@ void NetworkClient3::TcpConnected(uv::Tcp& tcp) {
auto clientImpl = std::make_shared<net3::ClientImpl3>(
m_loop.Now().count(), m_inst, *wire, m_logger, [this](uint32_t repeatMs) {
DEBUG4("Setting periodic timer to {}", repeatMs);
if (m_sendValuesTimer) {
m_sendValuesTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
if (m_sendOutgoingTimer) {
m_sendOutgoingTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
}
});
clientImpl->Start(
@@ -302,18 +302,18 @@ NetworkClient::NetworkClient(
m_readLocalTimer->timeout.connect([this] {
if (m_clientImpl) {
HandleLocal();
m_clientImpl->SendControl(m_loop.Now().count());
m_clientImpl->SendOutgoing(m_loop.Now().count(), false);
}
});
m_readLocalTimer->Start(uv::Timer::Time{100}, uv::Timer::Time{100});
}
m_sendValuesTimer = uv::Timer::Create(loop);
if (m_sendValuesTimer) {
m_sendValuesTimer->timeout.connect([this] {
m_sendOutgoingTimer = uv::Timer::Create(loop);
if (m_sendOutgoingTimer) {
m_sendOutgoingTimer->timeout.connect([this] {
if (m_clientImpl) {
HandleLocal();
m_clientImpl->SendValues(m_loop.Now().count(), false);
m_clientImpl->SendOutgoing(m_loop.Now().count(), false);
}
});
}
@@ -324,7 +324,7 @@ NetworkClient::NetworkClient(
m_flush->wakeup.connect([this] {
if (m_clientImpl) {
HandleLocal();
m_clientImpl->SendValues(m_loop.Now().count(), true);
m_clientImpl->SendOutgoing(m_loop.Now().count(), true);
}
});
}
@@ -369,37 +369,41 @@ void NetworkClient::TcpConnected(uv::Tcp& tcp) {
wpi::SmallString<128> idBuf;
auto ws = wpi::WebSocket::CreateClient(
tcp, fmt::format("/nt/{}", wpi::EscapeURI(m_id, idBuf)), "",
{{"networktables.first.wpi.edu"}}, options);
{"v4.1.networktables.first.wpi.edu", "networktables.first.wpi.edu"},
options);
ws->SetMaxMessageSize(kMaxMessageSize);
ws->open.connect([this, &tcp, ws = ws.get()](std::string_view) {
ws->open.connect([this, &tcp, ws = ws.get()](std::string_view protocol) {
if (m_connList.IsConnected()) {
ws->Terminate(1006, "no longer needed");
return;
}
WsConnected(*ws, tcp);
WsConnected(*ws, tcp, protocol);
});
}
void NetworkClient::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp) {
void NetworkClient::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp,
std::string_view protocol) {
if (m_parallelConnect) {
m_parallelConnect->Succeeded(tcp);
}
ConnectionInfo connInfo;
uv::AddrToName(tcp.GetPeer(), &connInfo.remote_ip, &connInfo.remote_port);
connInfo.protocol_version = 0x0400;
connInfo.protocol_version =
protocol == "v4.1.networktables.first.wpi.edu" ? 0x0401 : 0x0400;
INFO("CONNECTED NT4 to {} port {}", connInfo.remote_ip, connInfo.remote_port);
m_connHandle = m_connList.AddConnection(connInfo);
m_wire = std::make_shared<net::WebSocketConnection>(ws);
m_wire =
std::make_shared<net::WebSocketConnection>(ws, connInfo.protocol_version);
m_clientImpl = std::make_unique<net::ClientImpl>(
m_loop.Now().count(), m_inst, *m_wire, m_logger, m_timeSyncUpdated,
[this](uint32_t repeatMs) {
DEBUG4("Setting periodic timer to {}", repeatMs);
if (m_sendValuesTimer) {
m_sendValuesTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
if (m_sendOutgoingTimer) {
m_sendOutgoingTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
}
});
m_clientImpl->SetLocal(&m_localStorage);

View File

@@ -76,7 +76,7 @@ class NetworkClientBase : public INetworkClient {
// used only from loop
std::shared_ptr<wpi::ParallelTcpConnector> m_parallelConnect;
std::shared_ptr<wpi::uv::Timer> m_readLocalTimer;
std::shared_ptr<wpi::uv::Timer> m_sendValuesTimer;
std::shared_ptr<wpi::uv::Timer> m_sendOutgoingTimer;
std::shared_ptr<wpi::uv::Async<>> m_flushLocal;
std::shared_ptr<wpi::uv::Async<>> m_flush;
@@ -138,7 +138,8 @@ class NetworkClient final : public NetworkClientBase {
private:
void HandleLocal();
void TcpConnected(wpi::uv::Tcp& tcp) final;
void WsConnected(wpi::WebSocket& ws, wpi::uv::Tcp& tcp);
void WsConnected(wpi::WebSocket& ws, wpi::uv::Tcp& tcp,
std::string_view protocol);
void ForceDisconnect(std::string_view reason) override;
void DoDisconnect(std::string_view reason) override;

View File

@@ -50,8 +50,8 @@ class NetworkServer::ServerConnection {
int GetClientId() const { return m_clientId; }
protected:
void SetupPeriodicTimer();
void UpdatePeriodicTimer(uint32_t repeatMs);
void SetupOutgoingTimer();
void UpdateOutgoingTimer(uint32_t repeatMs);
void ConnectionClosed();
NetworkServer& m_server;
@@ -61,7 +61,7 @@ class NetworkServer::ServerConnection {
int m_clientId;
private:
std::shared_ptr<uv::Timer> m_sendValuesTimer;
std::shared_ptr<uv::Timer> m_outgoingTimer;
};
class NetworkServer::ServerConnection3 : public ServerConnection {
@@ -82,7 +82,9 @@ class NetworkServer::ServerConnection4 final
std::string_view addr, unsigned int port,
wpi::Logger& logger)
: ServerConnection{server, addr, port, logger},
HttpWebSocketServerConnection(stream, {"networktables.first.wpi.edu"}) {
HttpWebSocketServerConnection(stream,
{"v4.1.networktables.first.wpi.edu",
"networktables.first.wpi.edu"}) {
m_info.protocol_version = 0x0400;
}
@@ -93,30 +95,32 @@ class NetworkServer::ServerConnection4 final
std::shared_ptr<net::WebSocketConnection> m_wire;
};
void NetworkServer::ServerConnection::SetupPeriodicTimer() {
m_sendValuesTimer = uv::Timer::Create(m_server.m_loop);
m_sendValuesTimer->timeout.connect([this] {
void NetworkServer::ServerConnection::SetupOutgoingTimer() {
m_outgoingTimer = uv::Timer::Create(m_server.m_loop);
m_outgoingTimer->timeout.connect([this] {
m_server.HandleLocal();
m_server.m_serverImpl.SendValues(m_clientId, m_server.m_loop.Now().count());
m_server.m_serverImpl.SendOutgoing(m_clientId,
m_server.m_loop.Now().count());
});
}
void NetworkServer::ServerConnection::UpdatePeriodicTimer(uint32_t repeatMs) {
void NetworkServer::ServerConnection::UpdateOutgoingTimer(uint32_t repeatMs) {
DEBUG4("Setting periodic timer to {}", repeatMs);
if (repeatMs == UINT32_MAX) {
m_sendValuesTimer->Stop();
m_outgoingTimer->Stop();
} else {
m_sendValuesTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
m_outgoingTimer->Start(uv::Timer::Time{repeatMs},
uv::Timer::Time{repeatMs});
}
}
void NetworkServer::ServerConnection::ConnectionClosed() {
// don't call back into m_server if it's being destroyed
if (!m_sendValuesTimer->IsLoopClosing()) {
if (!m_outgoingTimer->IsLoopClosing()) {
m_server.m_serverImpl.RemoveClient(m_clientId);
m_server.RemoveConnection(this);
}
m_sendValuesTimer->Close();
m_outgoingTimer->Close();
}
NetworkServer::ServerConnection3::ServerConnection3(
@@ -136,7 +140,7 @@ NetworkServer::ServerConnection3::ServerConnection3(
m_server.AddConnection(this, m_info);
INFO("CONNECTED NT3 client '{}' (from {})", name, m_connInfo);
},
[this](uint32_t repeatMs) { UpdatePeriodicTimer(repeatMs); });
[this](uint32_t repeatMs) { UpdateOutgoingTimer(repeatMs); });
stream->error.connect([this](uv::Error err) {
if (!m_wire->GetDisconnectReason().empty()) {
@@ -163,7 +167,7 @@ NetworkServer::ServerConnection3::ServerConnection3(
});
stream->StartRead();
SetupPeriodicTimer();
SetupOutgoingTimer();
}
void NetworkServer::ServerConnection4::ProcessRequest() {
@@ -228,13 +232,17 @@ void NetworkServer::ServerConnection4::ProcessWsUpgrade() {
m_websocket->SetMaxMessageSize(kMaxMessageSize);
m_websocket->open.connect([this, name = std::string{name}](std::string_view) {
m_wire = std::make_shared<net::WebSocketConnection>(*m_websocket);
m_websocket->open.connect([this, name = std::string{name}](
std::string_view protocol) {
m_info.protocol_version =
protocol == "v4.1.networktables.first.wpi.edu" ? 0x0401 : 0x0400;
m_wire = std::make_shared<net::WebSocketConnection>(
*m_websocket, m_info.protocol_version);
// TODO: set local flag appropriately
std::string dedupName;
std::tie(dedupName, m_clientId) = m_server.m_serverImpl.AddClient(
name, m_connInfo, false, *m_wire,
[this](uint32_t repeatMs) { UpdatePeriodicTimer(repeatMs); });
[this](uint32_t repeatMs) { UpdateOutgoingTimer(repeatMs); });
INFO("CONNECTED NT4 client '{}' (from {})", dedupName, m_connInfo);
m_info.remote_id = dedupName;
m_server.AddConnection(this, m_info);
@@ -251,7 +259,7 @@ void NetworkServer::ServerConnection4::ProcessWsUpgrade() {
m_server.m_serverImpl.ProcessIncomingBinary(m_clientId, data);
});
SetupPeriodicTimer();
SetupOutgoingTimer();
});
}
@@ -372,7 +380,7 @@ void NetworkServer::Init() {
if (m_readLocalTimer) {
m_readLocalTimer->timeout.connect([this] {
HandleLocal();
m_serverImpl.SendControl(m_loop.Now().count());
m_serverImpl.SendAllOutgoing(m_loop.Now().count(), false);
});
m_readLocalTimer->Start(uv::Timer::Time{100}, uv::Timer::Time{100});
}
@@ -398,9 +406,7 @@ void NetworkServer::Init() {
if (m_flush) {
m_flush->wakeup.connect([this] {
HandleLocal();
for (auto&& conn : m_connections) {
m_serverImpl.SendValues(conn.conn->GetClientId(), m_loop.Now().count());
}
m_serverImpl.SendAllOutgoing(m_loop.Now().count(), true);
});
}
m_flushAtomic = m_flush.get();

View File

@@ -20,17 +20,12 @@
#include "NetworkInterface.h"
#include "WireConnection.h"
#include "WireEncoder.h"
#include "net/NetworkOutgoingQueue.h"
#include "networktables/NetworkTableValue.h"
using namespace nt;
using namespace nt::net;
static constexpr uint32_t kMinPeriodMs = 5;
// maximum amount of time the wire can be not ready to send another
// transmission before we close the connection
static constexpr uint32_t kWireMaxNotReadyUs = 1000000;
ClientImpl::ClientImpl(
uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger,
std::function<void(int64_t serverTimeOffset, int64_t rtt2, bool valid)>
@@ -41,13 +36,16 @@ ClientImpl::ClientImpl(
m_logger{logger},
m_timeSyncUpdated{std::move(timeSyncUpdated)},
m_setPeriodic{std::move(setPeriodic)},
m_nextPingTimeMs{curTimeMs + kPingIntervalMs} {
m_ping{wire},
m_nextPingTimeMs{curTimeMs + (wire.GetVersion() >= 0x0401
? NetworkPing::kPingIntervalMs
: kRttIntervalMs)},
m_outgoing{wire, false} {
// immediately send RTT ping
auto out = m_wire.SendBinary();
auto now = wpi::Now();
DEBUG4("Sending initial RTT ping {}", now);
WireEncodeBinary(out.Add(), -1, 0, Value::MakeInteger(now));
m_wire.Flush();
m_wire.SendBinary(
[&](auto& os) { WireEncodeBinary(os, -1, 0, Value::MakeInteger(now)); });
m_setPeriodic(m_periodMs);
}
@@ -62,14 +60,15 @@ void ClientImpl::ProcessIncomingBinary(uint64_t curTimeMs,
int64_t id;
Value value;
std::string error;
if (!WireDecodeBinary(&data, &id, &value, &error, -m_serverTimeOffsetUs)) {
if (!WireDecodeBinary(&data, &id, &value, &error,
-m_outgoing.GetTimeOffset())) {
ERR("binary decode error: {}", error);
break; // FIXME
}
DEBUG4("BinaryMessage({})", id);
// handle RTT ping response
if (id == -1) {
// handle RTT ping response (only use first one)
if (!m_haveTimeOffset && id == -1) {
if (!value.IsInteger()) {
WARN("RTT ping response with non-integer type {}",
static_cast<int>(value.type()));
@@ -77,15 +76,18 @@ void ClientImpl::ProcessIncomingBinary(uint64_t curTimeMs,
}
DEBUG4("RTT ping response time {} value {}", value.time(),
value.GetInteger());
m_pongTimeMs = curTimeMs;
if (m_wire.GetVersion() < 0x0401) {
m_pongTimeMs = curTimeMs;
}
int64_t now = wpi::Now();
int64_t rtt2 = (now - value.GetInteger()) / 2;
if (rtt2 < m_rtt2Us) {
m_rtt2Us = rtt2;
m_serverTimeOffsetUs = value.server_time() + rtt2 - now;
DEBUG3("Time offset: {}", m_serverTimeOffsetUs);
int64_t serverTimeOffsetUs = value.server_time() + rtt2 - now;
DEBUG3("Time offset: {}", serverTimeOffsetUs);
m_outgoing.SetTimeOffset(serverTimeOffsetUs);
m_haveTimeOffset = true;
m_timeSyncUpdated(m_serverTimeOffsetUs, m_rtt2Us, true);
m_timeSyncUpdated(serverTimeOffsetUs, m_rtt2Us, true);
}
continue;
}
@@ -110,152 +112,65 @@ void ClientImpl::HandleLocal(std::vector<ClientMessage>&& msgs) {
// common case is value
if (auto msg = std::get_if<ClientValueMsg>(&elem.contents)) {
SetValue(msg->pubHandle, msg->value);
// setvalue puts on individual publish outgoing queues
} else if (auto msg = std::get_if<PublishMsg>(&elem.contents)) {
Publish(msg->pubHandle, msg->topicHandle, msg->name, msg->typeStr,
msg->properties, msg->options);
m_outgoing.emplace_back(std::move(elem));
m_outgoing.SendMessage(msg->pubHandle, std::move(elem));
} else if (auto msg = std::get_if<UnpublishMsg>(&elem.contents)) {
if (Unpublish(msg->pubHandle, msg->topicHandle)) {
m_outgoing.emplace_back(std::move(elem));
m_outgoing.SendMessage(msg->pubHandle, std::move(elem));
}
} else {
m_outgoing.emplace_back(std::move(elem));
m_outgoing.SendMessage(0, std::move(elem));
}
}
}
bool ClientImpl::DoSendControl(uint64_t curTimeMs) {
DEBUG4("SendControl({})", curTimeMs);
void ClientImpl::SendOutgoing(uint64_t curTimeMs, bool flush) {
DEBUG4("SendOutgoing({}, {})", curTimeMs, flush);
// rate limit sends
if (curTimeMs < (m_lastSendMs + kMinPeriodMs)) {
return false;
}
// start a timestamp RTT ping if it's time to do one
if (curTimeMs >= m_nextPingTimeMs) {
// if we didn't receive a response to our last ping, disconnect
if (m_nextPingTimeMs != 0 && m_pongTimeMs == 0) {
m_wire.Disconnect("timed out");
return false;
if (m_wire.GetVersion() >= 0x0401) {
// Use WS pings
if (!m_ping.Send(curTimeMs)) {
return;
}
if (!CheckNetworkReady(curTimeMs)) {
return false;
}
auto now = wpi::Now();
DEBUG4("Sending RTT ping {}", now);
WireEncodeBinary(m_wire.SendBinary().Add(), -1, 0, Value::MakeInteger(now));
// drift isn't critical here, so just go from current time
m_nextPingTimeMs = curTimeMs + kPingIntervalMs;
m_pongTimeMs = 0;
}
if (!m_outgoing.empty()) {
if (!CheckNetworkReady(curTimeMs)) {
return false;
}
auto writer = m_wire.SendText();
for (auto&& msg : m_outgoing) {
auto& stream = writer.Add();
if (!WireEncodeText(stream, msg)) {
// shouldn't happen, but just in case...
stream << "{}";
} else {
// Use RTT pings; it's unsafe to use WS pings due to bugs in WS message
// fragmentation in earlier NT4 implementations
if (curTimeMs >= m_nextPingTimeMs) {
// if we didn't receive a response to our last ping, disconnect
if (m_nextPingTimeMs != 0 && m_pongTimeMs == 0) {
m_wire.Disconnect("connection timed out");
return;
}
auto now = wpi::Now();
DEBUG4("Sending RTT ping {}", now);
m_wire.SendBinary([&](auto& os) {
WireEncodeBinary(os, -1, 0, Value::MakeInteger(now));
});
// drift isn't critical here, so just go from current time
m_nextPingTimeMs = curTimeMs + kRttIntervalMs;
m_pongTimeMs = 0;
}
m_outgoing.resize(0);
}
m_lastSendMs = curTimeMs;
return true;
}
void ClientImpl::DoSendValues(uint64_t curTimeMs, bool flush) {
DEBUG4("SendValues({})", curTimeMs);
// can't send value updates until we have a RTT
// wait until we have a RTT measurement before sending messages
if (!m_haveTimeOffset) {
return;
}
// ensure all control messages are sent ahead of value updates
if (!DoSendControl(curTimeMs)) {
return;
}
// send any pending updates due to be sent
bool checkedNetwork = false;
auto writer = m_wire.SendBinary();
for (auto&& pub : m_publishers) {
if (pub && !pub->outValues.empty() &&
(flush || curTimeMs >= pub->nextSendMs)) {
for (auto&& val : pub->outValues) {
if (!checkedNetwork) {
if (!CheckNetworkReady(curTimeMs)) {
return;
}
checkedNetwork = true;
}
DEBUG4("Sending {} value time={} server_time={} st_off={}", pub->handle,
val.time(), val.server_time(), m_serverTimeOffsetUs);
int64_t time = val.time();
if (time != 0) {
time += m_serverTimeOffsetUs;
// make sure resultant time isn't exactly 0
if (time == 0) {
time = 1;
}
}
WireEncodeBinary(writer.Add(), Handle{pub->handle}.GetIndex(), time,
val);
}
pub->outValues.resize(0);
pub->nextSendMs = curTimeMs + pub->periodMs;
}
}
m_outgoing.SendOutgoing(curTimeMs, flush);
}
void ClientImpl::SendInitialValues() {
DEBUG4("SendInitialValues()");
// ensure all control messages are sent ahead of value updates
if (!DoSendControl(0)) {
return;
void ClientImpl::UpdatePeriodic() {
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
// only send time=0 values (as we don't have a RTT yet)
auto writer = m_wire.SendBinary();
for (auto&& pub : m_publishers) {
if (pub && !pub->outValues.empty()) {
bool sent = false;
for (auto&& val : pub->outValues) {
if (val.server_time() == 0) {
DEBUG4("Sending {} value time={} server_time={}", pub->handle,
val.time(), val.server_time());
WireEncodeBinary(writer.Add(), Handle{pub->handle}.GetIndex(), 0,
val);
sent = true;
}
}
if (sent) {
std::erase_if(pub->outValues,
[](const auto& v) { return v.server_time() == 0; });
}
}
if (m_periodMs > kMaxPeriodMs) {
m_periodMs = kMaxPeriodMs;
}
}
bool ClientImpl::CheckNetworkReady(uint64_t curTimeMs) {
if (!m_wire.Ready()) {
uint64_t lastFlushTime = m_wire.GetLastFlushTime();
uint64_t now = wpi::Now();
if (lastFlushTime != 0 && now > (lastFlushTime + kWireMaxNotReadyUs)) {
m_wire.Disconnect("transmit stalled");
}
return false;
}
return true;
m_setPeriodic(m_periodMs);
}
void ClientImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle,
@@ -276,13 +191,11 @@ void ClientImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle,
if (publisher->periodMs < kMinPeriodMs) {
publisher->periodMs = kMinPeriodMs;
}
m_outgoing.SetPeriod(pubHandle, publisher->periodMs);
// update period
m_periodMs = std::gcd(m_periodMs, publisher->periodMs);
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
m_setPeriodic(m_periodMs);
m_periodMs = UpdatePeriodCalc(m_periodMs, publisher->periodMs);
UpdatePeriodic();
}
bool ClientImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) {
@@ -291,53 +204,34 @@ bool ClientImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) {
return false;
}
bool doSend = true;
if (m_publishers[index]) {
// Look through outgoing queue to see if the publish hasn't been sent yet;
// if it hasn't, delete it and don't send the server a message.
// The outgoing queue doesn't contain values; those are deleted with the
// publisher object.
auto it = std::find_if(
m_outgoing.begin(), m_outgoing.end(), [&](const auto& elem) {
if (auto msg = std::get_if<PublishMsg>(&elem.contents)) {
return msg->pubHandle == pubHandle;
}
return false;
});
if (it != m_outgoing.end()) {
m_outgoing.erase(it);
doSend = false;
}
}
m_publishers[index].reset();
// loop over all publishers to update period
m_periodMs = kPingIntervalMs + 10;
m_periodMs = kMaxPeriodMs;
for (auto&& pub : m_publishers) {
if (pub) {
m_periodMs = std::gcd(m_periodMs, pub->periodMs);
}
}
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
m_setPeriodic(m_periodMs);
UpdatePeriodic();
// remove from outgoing handle map
m_outgoing.EraseHandle(pubHandle);
return doSend;
}
void ClientImpl::SetValue(NT_Publisher pubHandle, const Value& value) {
DEBUG4("SetValue({}, time={}, server_time={}, st_off={})", pubHandle,
value.time(), value.server_time(), m_serverTimeOffsetUs);
DEBUG4("SetValue({}, time={}, server_time={})", pubHandle, value.time(),
value.server_time());
unsigned int index = Handle{pubHandle}.GetIndex();
if (index >= m_publishers.size() || !m_publishers[index]) {
return;
}
auto& publisher = *m_publishers[index];
if (publisher.outValues.empty() || publisher.options.sendAll) {
publisher.outValues.emplace_back(value);
} else {
publisher.outValues.back() = value;
}
m_outgoing.SendValue(
pubHandle, value,
publisher.options.sendAll ? ValueSendMode::kAll : ValueSendMode::kNormal);
}
void ClientImpl::ServerAnnounce(std::string_view name, int64_t id,
@@ -375,17 +269,4 @@ void ClientImpl::ProcessIncomingText(std::string_view data) {
WireDecodeText(data, *this, m_logger);
}
void ClientImpl::SendControl(uint64_t curTimeMs) {
DoSendControl(curTimeMs);
m_wire.Flush();
}
void ClientImpl::SendValues(uint64_t curTimeMs, bool flush) {
DoSendValues(curTimeMs, flush);
m_wire.Flush();
}
void ClientImpl::SendInitial() {
SendInitialValues();
m_wire.Flush();
}
void ClientImpl::SendInitial() {}

View File

@@ -16,6 +16,8 @@
#include <wpi/DenseMap.h>
#include "NetworkInterface.h"
#include "NetworkOutgoingQueue.h"
#include "NetworkPing.h"
#include "PubSubOptions.h"
#include "WireConnection.h"
#include "WireDecoder.h"
@@ -46,8 +48,7 @@ class ClientImpl final : private ServerMessageHandler {
void ProcessIncomingBinary(uint64_t curTimeMs, std::span<const uint8_t> data);
void HandleLocal(std::vector<ClientMessage>&& msgs);
void SendControl(uint64_t curTimeMs);
void SendValues(uint64_t curTimeMs, bool flush);
void SendOutgoing(uint64_t curTimeMs, bool flush);
void SetLocal(LocalInterface* local) { m_local = local; }
void SendInitial();
@@ -59,14 +60,9 @@ class ClientImpl final : private ServerMessageHandler {
// in options as double, but copy here as integer; rounded to the nearest
// 10 ms
uint32_t periodMs;
uint64_t nextSendMs{0};
std::vector<Value> outValues; // outgoing values
};
bool DoSendControl(uint64_t curTimeMs);
void DoSendValues(uint64_t curTimeMs, bool flush);
void SendInitialValues();
bool CheckNetworkReady(uint64_t curTimeMs);
void UpdatePeriodic();
// ServerMessageHandler interface
void ServerAnnounce(std::string_view name, int64_t id,
@@ -96,20 +92,23 @@ class ClientImpl final : private ServerMessageHandler {
// indexed by server-provided topic id
wpi::DenseMap<int64_t, NT_Topic> m_topicMap;
// ping
NetworkPing m_ping;
// timestamp handling
static constexpr uint32_t kPingIntervalMs = 3000;
static constexpr uint32_t kRttIntervalMs = 3000;
uint64_t m_nextPingTimeMs{0};
uint64_t m_pongTimeMs{0};
uint32_t m_rtt2Us{UINT32_MAX};
bool m_haveTimeOffset{false};
int64_t m_serverTimeOffsetUs{0};
// periodic sweep handling
uint32_t m_periodMs{kPingIntervalMs + 10};
uint64_t m_lastSendMs{0};
static constexpr uint32_t kMinPeriodMs = 5;
static constexpr uint32_t kMaxPeriodMs = NetworkPing::kPingIntervalMs;
uint32_t m_periodMs{kMaxPeriodMs};
// outgoing queue
std::vector<ClientMessage> m_outgoing;
NetworkOutgoingQueue<ClientMessage> m_outgoing;
};
} // namespace nt::net

View File

@@ -70,6 +70,7 @@ struct ClientMessage {
using Contents =
std::variant<std::monostate, PublishMsg, UnpublishMsg, SetPropertiesMsg,
SubscribeMsg, UnsubscribeMsg, ClientValueMsg>;
using ValueMsg = ClientValueMsg;
Contents contents;
};
@@ -103,6 +104,7 @@ struct ServerValueMsg {
struct ServerMessage {
using Contents = std::variant<std::monostate, AnnounceMsg, UnannounceMsg,
PropertiesUpdateMsg, ServerValueMsg>;
using ValueMsg = ServerValueMsg;
Contents contents;
};

View File

@@ -0,0 +1,335 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <stdint.h>
#include <algorithm>
#include <concepts>
#include <numeric>
#include <optional>
#include <span>
#include <utility>
#include <vector>
#include <wpi/DenseMap.h>
#include "Handle.h"
#include "Message.h"
#include "WireConnection.h"
#include "WireEncoder.h"
#include "networktables/NetworkTableValue.h"
#include "ntcore_c.h"
namespace nt::net {
static constexpr uint32_t kMinPeriodMs = 5;
inline uint32_t UpdatePeriodCalc(uint32_t period, uint32_t aPeriod) {
uint32_t newPeriod;
if (period == UINT32_MAX) {
newPeriod = aPeriod;
} else {
newPeriod = std::gcd(period, aPeriod);
}
if (newPeriod < kMinPeriodMs) {
return kMinPeriodMs;
}
return newPeriod;
}
template <typename T, typename F>
uint32_t CalculatePeriod(const T& container, F&& getPeriod) {
uint32_t period = UINT32_MAX;
for (auto&& item : container) {
if (period == UINT32_MAX) {
period = getPeriod(item);
} else {
period = std::gcd(period, getPeriod(item));
}
}
if (period < kMinPeriodMs) {
return kMinPeriodMs;
}
return period;
}
template <typename MessageType>
concept NetworkMessage =
std::same_as<typename MessageType::ValueMsg, ServerValueMsg> ||
std::same_as<typename MessageType::ValueMsg, ClientValueMsg>;
enum class ValueSendMode { kDisabled = 0, kAll, kNormal, kImm };
template <NetworkMessage MessageType>
class NetworkOutgoingQueue {
public:
NetworkOutgoingQueue(WireConnection& wire, bool local)
: m_wire{wire}, m_local{local} {
m_queues.emplace_back(100); // default queue is 100 ms period
}
void SetPeriod(NT_Handle handle, uint32_t periodMs);
void EraseHandle(NT_Handle handle) { m_handleMap.erase(handle); }
template <typename T>
void SendMessage(NT_Handle handle, T&& msg) {
m_queues[m_handleMap[handle].queueIndex].Append(handle,
std::forward<T>(msg));
m_totalSize += sizeof(Message);
}
void SendValue(NT_Handle handle, const Value& value, ValueSendMode mode);
void SendOutgoing(uint64_t curTimeMs, bool flush);
void SetTimeOffset(int64_t offsetUs) { m_timeOffsetUs = offsetUs; }
int64_t GetTimeOffset() const { return m_timeOffsetUs; }
public:
WireConnection& m_wire;
private:
using ValueMsg = typename MessageType::ValueMsg;
void EncodeValue(wpi::raw_ostream& os, NT_Handle handle, const Value& value);
struct Message {
Message() = default;
template <typename T>
Message(T&& msg, NT_Handle handle)
: msg{std::forward<T>(msg)}, handle{handle} {}
MessageType msg;
NT_Handle handle;
};
struct Queue {
explicit Queue(uint32_t periodMs) : periodMs{periodMs} {}
template <typename T>
void Append(NT_Handle handle, T&& msg) {
msgs.emplace_back(std::forward<T>(msg), handle);
}
std::vector<Message> msgs;
uint64_t nextSendMs = 0;
uint32_t periodMs;
};
std::vector<Queue> m_queues;
struct HandleInfo {
unsigned int queueIndex = 0;
int valuePos = -1; // -1 if not in queue
};
wpi::DenseMap<NT_Handle, HandleInfo> m_handleMap;
size_t m_totalSize{0};
uint64_t m_lastSendMs{0};
int64_t m_timeOffsetUs{0};
unsigned int m_lastSetPeriodQueueIndex = 0;
unsigned int m_lastSetPeriod = 100;
bool m_local;
// maximum total size of outgoing queues in bytes (approximate)
static constexpr size_t kOutgoingLimit = 1024 * 1024;
};
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SetPeriod(NT_Handle handle,
uint32_t periodMs) {
// it's quite common to set a lot of things in a row with the same period
unsigned int queueIndex;
if (m_lastSetPeriod == periodMs) {
queueIndex = m_lastSetPeriodQueueIndex;
} else {
// find and possibly create queue for this period
auto it =
std::find_if(m_queues.begin(), m_queues.end(),
[&](const auto& q) { return q.periodMs == periodMs; });
if (it == m_queues.end()) {
queueIndex = m_queues.size();
m_queues.emplace_back(periodMs);
} else {
queueIndex = it - m_queues.begin();
}
m_lastSetPeriodQueueIndex = queueIndex;
m_lastSetPeriod = periodMs;
}
// map the handle to the queue
auto [infoIt, created] = m_handleMap.try_emplace(handle);
if (!created && infoIt->getSecond().queueIndex != queueIndex) {
// need to move any items from old queue to new queue
auto& oldMsgs = m_queues[infoIt->getSecond().queueIndex].msgs;
auto it = std::remove_if(oldMsgs.begin(), oldMsgs.end(),
[&](const auto& e) { return e.handle == handle; });
auto& newMsgs = m_queues[queueIndex].msgs;
for (auto i = it, end = oldMsgs.end(); i != end; ++i) {
newMsgs.emplace_back(std::move(*i));
}
oldMsgs.erase(it, oldMsgs.end());
}
infoIt->getSecond().queueIndex = queueIndex;
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SendValue(NT_Handle handle,
const Value& value,
ValueSendMode mode) {
if (m_local) {
mode = ValueSendMode::kImm; // always send local immediately
}
// backpressure by stopping sending all if the buffer is too full
if (mode == ValueSendMode::kAll && m_totalSize >= kOutgoingLimit) {
mode = ValueSendMode::kNormal;
}
switch (mode) {
case ValueSendMode::kDisabled: // do nothing
break;
case ValueSendMode::kImm: // send immediately
m_wire.SendBinary([&](auto& os) { EncodeValue(os, handle, value); });
break;
case ValueSendMode::kAll: { // append to outgoing
auto& info = m_handleMap[handle];
auto& queue = m_queues[info.queueIndex];
info.valuePos = queue.msgs.size();
queue.Append(handle, ValueMsg{handle, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
case ValueSendMode::kNormal: {
// replace, or append if not present
auto& info = m_handleMap[handle];
auto& queue = m_queues[info.queueIndex];
if (info.valuePos != -1 &&
static_cast<unsigned int>(info.valuePos) < queue.msgs.size()) {
auto& elem = queue.msgs[info.valuePos];
if (auto m = std::get_if<ValueMsg>(&elem.msg.contents)) {
// double-check handle, and only replace if timestamp newer
if (elem.handle == handle &&
(m->value.time() == 0 || value.time() >= m->value.time())) {
int delta = value.size() - m->value.size();
m->value = value;
m_totalSize += delta;
return;
}
}
}
info.valuePos = queue.msgs.size();
queue.Append(handle, ValueMsg{handle, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
}
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SendOutgoing(uint64_t curTimeMs,
bool flush) {
if (m_totalSize == 0) {
return; // nothing to do
}
// rate limit frequency of transmissions
if (curTimeMs < (m_lastSendMs + kMinPeriodMs)) {
return;
}
if (!m_wire.Ready()) {
return; // don't bother, still sending the last batch
}
// what queues are ready to send?
wpi::SmallVector<unsigned int, 16> queues;
for (unsigned int i = 0; i < m_queues.size(); ++i) {
if (!m_queues[i].msgs.empty() &&
(flush || curTimeMs >= m_queues[i].nextSendMs)) {
queues.emplace_back(i);
}
}
if (queues.empty()) {
return; // nothing needs to be sent yet
}
// Sort transmission order by what queue has been waiting the longest time.
// XXX: byte-weighted fair queueing might be better, but is much more complex
// to implement.
std::sort(queues.begin(), queues.end(), [&](const auto& a, const auto& b) {
return m_queues[a].nextSendMs < m_queues[b].nextSendMs;
});
for (unsigned int queueIndex : queues) {
auto& queue = m_queues[queueIndex];
auto& msgs = queue.msgs;
auto it = msgs.begin();
auto end = msgs.end();
int unsent = 0;
for (; it != end && unsent == 0; ++it) {
if (auto m = std::get_if<ValueMsg>(&it->msg.contents)) {
unsent = m_wire.WriteBinary(
[&](auto& os) { EncodeValue(os, it->handle, m->value); });
} else {
unsent = m_wire.WriteText([&](auto& os) {
if (!WireEncodeText(os, it->msg)) {
os << "{}";
}
});
}
}
if (unsent == 0) {
// finish writing any partial buffers
unsent = m_wire.Flush();
}
int delta = it - msgs.begin() - unsent;
for (auto&& msg : std::span{msgs}.subspan(0, delta)) {
if (auto m = std::get_if<ValueMsg>(&msg.msg.contents)) {
m_totalSize -= sizeof(Message) + m->value.size();
} else {
m_totalSize -= sizeof(Message);
}
}
msgs.erase(msgs.begin(), it - unsent);
for (auto&& kv : m_handleMap) {
auto& info = kv.getSecond();
if (info.queueIndex == queueIndex) {
if (info.valuePos < delta) {
info.valuePos = -1;
} else {
info.valuePos -= delta;
}
}
}
// try to stay on periodic timing, unless it's falling behind current time
if (unsent == 0) {
queue.nextSendMs += queue.periodMs;
if (queue.nextSendMs < curTimeMs) {
queue.nextSendMs = curTimeMs + queue.periodMs;
}
}
}
m_lastSendMs = curTimeMs;
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::EncodeValue(wpi::raw_ostream& os,
NT_Handle handle,
const Value& value) {
int64_t time = value.time();
if constexpr (std::same_as<ValueMsg, ClientValueMsg>) {
if (time != 0) {
time += m_timeOffsetUs;
// make sure resultant time isn't exactly 0
if (time == 0) {
time = 1;
}
}
}
WireEncodeBinary(os, Handle{handle}.GetIndex(), time, value);
}
} // namespace nt::net

View File

@@ -0,0 +1,30 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include "NetworkPing.h"
#include "WireConnection.h"
using namespace nt::net;
bool NetworkPing::Send(uint64_t curTimeMs) {
if (curTimeMs < m_nextPingTimeMs) {
return true;
}
// if we didn't receive a timely response to our last ping, disconnect
uint64_t lastPing = m_wire.GetLastPingResponse();
// DEBUG4("WS ping: lastPing={} curTime={} pongTimeMs={}\n", lastPing,
// curTimeMs, m_pongTimeMs);
if (lastPing == 0) {
lastPing = m_pongTimeMs;
}
if (m_pongTimeMs != 0 && curTimeMs > (lastPing + kPingTimeoutMs)) {
m_wire.Disconnect("connection timed out");
return false;
}
m_wire.SendPing(curTimeMs);
m_nextPingTimeMs = curTimeMs + kPingIntervalMs;
m_pongTimeMs = curTimeMs;
return true;
}

View File

@@ -0,0 +1,28 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <stdint.h>
namespace nt::net {
class WireConnection;
class NetworkPing {
public:
static constexpr uint32_t kPingIntervalMs = 200;
static constexpr uint32_t kPingTimeoutMs = 1000;
explicit NetworkPing(WireConnection& wire) : m_wire{wire} {}
bool Send(uint64_t curTimeMs);
private:
WireConnection& m_wire;
uint64_t m_nextPingTimeMs{0};
uint64_t m_pongTimeMs{0};
};
} // namespace nt::net

View File

@@ -25,6 +25,7 @@
#include "Log.h"
#include "NetworkInterface.h"
#include "Types_internal.h"
#include "net/WireEncoder.h"
#include "net3/WireConnection3.h"
#include "net3/WireEncoder3.h"
#include "networktables/NetworkTableValue.h"
@@ -80,6 +81,76 @@ static void WriteOptions(mpack_writer_t& w, const PubSubOptionsImpl& options) {
mpack_finish_map(&w);
}
void ServerImpl::PublisherData::UpdateMeta() {
{
Writer w;
mpack_start_map(&w, 2);
mpack_write_str(&w, "uid");
mpack_write_int(&w, pubuid);
mpack_write_str(&w, "topic");
mpack_write_str(&w, topic->name);
mpack_finish_map(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
metaClient = std::move(w.bytes);
}
}
{
Writer w;
mpack_start_map(&w, 2);
mpack_write_str(&w, "client");
if (client) {
mpack_write_str(&w, client->GetName());
} else {
mpack_write_str(&w, "");
}
mpack_write_str(&w, "pubuid");
mpack_write_int(&w, pubuid);
mpack_finish_map(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
metaTopic = std::move(w.bytes);
}
}
}
void ServerImpl::SubscriberData::UpdateMeta() {
{
Writer w;
mpack_start_map(&w, 3);
mpack_write_str(&w, "uid");
mpack_write_int(&w, subuid);
mpack_write_str(&w, "topics");
mpack_start_array(&w, topicNames.size());
for (auto&& name : topicNames) {
mpack_write_str(&w, name);
}
mpack_finish_array(&w);
mpack_write_str(&w, "options");
WriteOptions(w, options);
mpack_finish_map(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
metaClient = std::move(w.bytes);
}
}
{
Writer w;
mpack_start_map(&w, 3);
mpack_write_str(&w, "client");
if (client) {
mpack_write_str(&w, client->GetName());
} else {
mpack_write_str(&w, "");
}
mpack_write_str(&w, "subuid");
mpack_write_int(&w, subuid);
mpack_write_str(&w, "options");
WriteOptions(w, options);
mpack_finish_map(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
metaTopic = std::move(w.bytes);
}
}
}
void ServerImpl::ClientData::UpdateMetaClientPub() {
if (!m_metaPub) {
return;
@@ -87,12 +158,9 @@ void ServerImpl::ClientData::UpdateMetaClientPub() {
Writer w;
mpack_start_array(&w, m_publishers.size());
for (auto&& pub : m_publishers) {
mpack_start_map(&w, 2);
mpack_write_str(&w, "uid");
mpack_write_int(&w, pub.first);
mpack_write_str(&w, "topic");
mpack_write_str(&w, pub.second->topic->name);
mpack_finish_map(&w);
mpack_write_object_bytes(
&w, reinterpret_cast<const char*>(pub.second->metaClient.data()),
pub.second->metaClient.size());
}
mpack_finish_array(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
@@ -107,18 +175,9 @@ void ServerImpl::ClientData::UpdateMetaClientSub() {
Writer w;
mpack_start_array(&w, m_subscribers.size());
for (auto&& sub : m_subscribers) {
mpack_start_map(&w, 3);
mpack_write_str(&w, "uid");
mpack_write_int(&w, sub.first);
mpack_write_str(&w, "topics");
mpack_start_array(&w, sub.second->topicNames.size());
for (auto&& name : sub.second->topicNames) {
mpack_write_str(&w, name);
}
mpack_finish_array(&w);
mpack_write_str(&w, "options");
WriteOptions(w, sub.second->options);
mpack_finish_map(&w);
mpack_write_object_bytes(
&w, reinterpret_cast<const char*>(sub.second->metaClient.data()),
sub.second->metaClient.size());
}
mpack_finish_array(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
@@ -154,11 +213,10 @@ void ServerImpl::ClientData4Base::ClientPublish(int64_t pubuid,
}
// add publisher to topic
topic->publishers.Add(publisherIt->getSecond().get());
topic->AddPublisher(this, publisherIt->getSecond().get());
// update meta data
m_server.UpdateMetaTopicPub(topic);
UpdateMetaClientPub();
// respond with announce with pubuid to client
DEBUG4("client {}: announce {} pubuid {}", m_id, topic->name, pubuid);
@@ -175,14 +233,13 @@ void ServerImpl::ClientData4Base::ClientUnpublish(int64_t pubuid) {
auto topic = publisher->topic;
// remove publisher from topic
topic->publishers.Remove(publisher);
topic->RemovePublisher(this, publisher);
// remove publisher from client
m_publishers.erase(publisherIt);
// update meta data
m_server.UpdateMetaTopicPub(topic);
UpdateMetaClientPub();
// delete topic if no longer published
if (!topic->IsPublished()) {
@@ -234,14 +291,7 @@ void ServerImpl::ClientData4Base::ClientSubscribe(
// update periodic sender (if not local)
if (!m_local) {
if (m_periodMs == UINT32_MAX) {
m_periodMs = sub->periodMs;
} else {
m_periodMs = std::gcd(m_periodMs, sub->periodMs);
}
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
m_periodMs = UpdatePeriodCalc(m_periodMs, sub->periodMs);
m_setPeriodic(m_periodMs);
}
@@ -252,30 +302,28 @@ void ServerImpl::ClientData4Base::ClientSubscribe(
std::vector<TopicData*> dataToSend;
dataToSend.reserve(m_server.m_topics.size());
for (auto&& topic : m_server.m_topics) {
bool removed = false;
if (replace) {
removed = topic->subscribers.Remove(sub.get());
}
auto tcdIt = topic->clients.find(this);
bool removed = tcdIt != topic->clients.end() && replace &&
tcdIt->second.subscribers.erase(sub.get());
// is client already subscribed?
bool wasSubscribed = false;
bool wasSubscribedValue = false;
for (auto subscriber : topic->subscribers) {
if (subscriber->client == this) {
wasSubscribed = true;
if (!subscriber->options.topicsOnly) {
wasSubscribedValue = true;
}
}
}
bool wasSubscribed =
tcdIt != topic->clients.end() && !tcdIt->second.subscribers.empty();
bool wasSubscribedValue =
wasSubscribed ? tcdIt->second.sendMode != ValueSendMode::kDisabled
: false;
bool added = false;
if (sub->Matches(topic->name, topic->special)) {
topic->subscribers.Add(sub.get());
if (tcdIt == topic->clients.end()) {
tcdIt = topic->clients.try_emplace(this).first;
}
tcdIt->second.AddSubscriber(sub.get());
added = true;
}
if (added ^ removed) {
UpdatePeriod(tcdIt->second, topic.get());
m_server.UpdateMetaTopicSub(topic.get());
}
@@ -294,13 +342,8 @@ void ServerImpl::ClientData4Base::ClientSubscribe(
for (auto topic : dataToSend) {
DEBUG4("send last value for {} to client {}", topic->name, m_id);
SendValue(topic, topic->lastValue, kSendAll);
SendValue(topic, topic->lastValue, ValueSendMode::kAll);
}
// update meta data
UpdateMetaClientSub();
Flush();
}
void ServerImpl::ClientData4Base::ClientUnsubscribe(int64_t subuid) {
@@ -313,28 +356,24 @@ void ServerImpl::ClientData4Base::ClientUnsubscribe(int64_t subuid) {
// remove from topics
for (auto&& topic : m_server.m_topics) {
if (topic->subscribers.Remove(sub)) {
m_server.UpdateMetaTopicSub(topic.get());
auto tcdIt = topic->clients.find(this);
if (tcdIt != topic->clients.end()) {
if (tcdIt->second.subscribers.erase(sub)) {
UpdatePeriod(tcdIt->second, topic.get());
m_server.UpdateMetaTopicSub(topic.get());
}
}
}
// delete it from client (future value sets will be ignored)
m_subscribers.erase(subIt);
UpdateMetaClientSub();
// loop over all publishers to update period
m_periodMs = UINT32_MAX;
for (auto&& sub : m_subscribers) {
if (m_periodMs == UINT32_MAX) {
m_periodMs = sub.getSecond()->periodMs;
} else {
m_periodMs = std::gcd(m_periodMs, sub.getSecond()->periodMs);
}
// loop over all subscribers to update period
if (!m_local) {
m_periodMs = CalculatePeriod(
m_subscribers, [](auto& x) { return x.getSecond()->periodMs; });
m_setPeriodic(m_periodMs);
}
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
m_setPeriodic(m_periodMs);
}
void ServerImpl::ClientData4Base::ClientSetValue(int64_t pubuid,
@@ -350,7 +389,8 @@ void ServerImpl::ClientData4Base::ClientSetValue(int64_t pubuid,
}
void ServerImpl::ClientDataLocal::SendValue(TopicData* topic,
const Value& value, SendMode mode) {
const Value& value,
ValueSendMode mode) {
if (m_server.m_local) {
m_server.m_local->NetworkSetValue(topic->localHandle, value);
}
@@ -395,27 +435,45 @@ void ServerImpl::ClientDataLocal::SendPropertiesUpdate(TopicData* topic,
void ServerImpl::ClientDataLocal::HandleLocal(
std::span<const ClientMessage> msgs) {
DEBUG4("HandleLocal()");
if (msgs.empty()) {
return;
}
// just map as a normal client into client=0 calls
bool updatepub = false;
bool updatesub = false;
for (const auto& elem : msgs) { // NOLINT
// common case is value, so check that first
if (auto msg = std::get_if<ClientValueMsg>(&elem.contents)) {
ClientSetValue(msg->pubHandle, msg->value);
} else if (auto msg = std::get_if<PublishMsg>(&elem.contents)) {
ClientPublish(msg->pubHandle, msg->name, msg->typeStr, msg->properties);
updatepub = true;
} else if (auto msg = std::get_if<UnpublishMsg>(&elem.contents)) {
ClientUnpublish(msg->pubHandle);
updatepub = true;
} else if (auto msg = std::get_if<SetPropertiesMsg>(&elem.contents)) {
ClientSetProperties(msg->name, msg->update);
} else if (auto msg = std::get_if<SubscribeMsg>(&elem.contents)) {
ClientSubscribe(msg->subHandle, msg->topicNames, msg->options);
updatesub = true;
} else if (auto msg = std::get_if<UnsubscribeMsg>(&elem.contents)) {
ClientUnsubscribe(msg->subHandle);
updatesub = true;
}
}
if (updatepub) {
UpdateMetaClientPub();
}
if (updatesub) {
UpdateMetaClientSub();
}
}
void ServerImpl::ClientData4::ProcessIncomingText(std::string_view data) {
WireDecodeText(data, *this, m_logger);
if (WireDecodeText(data, *this, m_logger)) {
UpdateMetaClientPub();
UpdateMetaClientSub();
}
}
void ServerImpl::ClientData4::ProcessIncomingBinary(
@@ -438,11 +496,8 @@ void ServerImpl::ClientData4::ProcessIncomingBinary(
if (pubuid == -1) {
auto now = wpi::Now();
DEBUG4("RTT ping from {}, responding with time={}", m_id, now);
{
auto out = m_wire.SendBinary();
WireEncodeBinary(out.Add(), -1, now, value);
}
m_wire.Flush();
m_wire.SendBinary(
[&](auto& os) { WireEncodeBinary(os, -1, now, value); });
continue;
}
@@ -452,40 +507,8 @@ void ServerImpl::ClientData4::ProcessIncomingBinary(
}
void ServerImpl::ClientData4::SendValue(TopicData* topic, const Value& value,
SendMode mode) {
if (m_local) {
mode = ClientData::kSendImmNoFlush; // always send local immediately
}
switch (mode) {
case ClientData::kSendDisabled: // do nothing
break;
case ClientData::kSendImmNoFlush: // send immediately
WriteBinary(topic->id, value.time(), value);
if (m_local) {
Flush();
}
break;
case ClientData::kSendAll: // append to outgoing
m_outgoingValueMap[topic->id] = m_outgoing.size();
m_outgoing.emplace_back(ServerMessage{ServerValueMsg{topic->id, value}});
break;
case ClientData::kSendNormal: {
// replace, or append if not present
auto [it, added] =
m_outgoingValueMap.try_emplace(topic->id, m_outgoing.size());
if (!added && it->second < m_outgoing.size()) {
if (auto m =
std::get_if<ServerValueMsg>(&m_outgoing[it->second].contents)) {
if (m->topic == topic->id) { // should always be true
m->value = value;
break;
}
}
}
m_outgoing.emplace_back(ServerMessage{ServerValueMsg{topic->id, value}});
break;
}
}
ValueSendMode mode) {
m_outgoing.SendValue(topic->GetIdHandle(), value, mode);
}
void ServerImpl::ClientData4::SendAnnounce(TopicData* topic,
@@ -497,14 +520,18 @@ void ServerImpl::ClientData4::SendAnnounce(TopicData* topic,
sent = true;
if (m_local) {
WireEncodeAnnounce(SendText().Add(), topic->name, topic->id, topic->typeStr,
topic->properties, pubuid);
Flush();
} else {
m_outgoing.emplace_back(ServerMessage{AnnounceMsg{
topic->name, topic->id, topic->typeStr, pubuid, topic->properties}});
m_server.m_controlReady = true;
int unsent = m_wire.WriteText([&](auto& os) {
WireEncodeAnnounce(os, topic->name, topic->id, topic->typeStr,
topic->properties, pubuid);
});
if (unsent == 0 && m_wire.Flush() == 0) {
return;
}
}
m_outgoing.SendMessage(topic->GetIdHandle(),
AnnounceMsg{topic->name, topic->id, topic->typeStr,
pubuid, topic->properties});
m_server.m_controlReady = true;
}
void ServerImpl::ClientData4::SendUnannounce(TopicData* topic) {
@@ -515,13 +542,16 @@ void ServerImpl::ClientData4::SendUnannounce(TopicData* topic) {
sent = false;
if (m_local) {
WireEncodeUnannounce(SendText().Add(), topic->name, topic->id);
Flush();
} else {
m_outgoing.emplace_back(
ServerMessage{UnannounceMsg{topic->name, topic->id}});
m_server.m_controlReady = true;
int unsent = m_wire.WriteText(
[&](auto& os) { WireEncodeUnannounce(os, topic->name, topic->id); });
if (unsent == 0 && m_wire.Flush() == 0) {
return;
}
}
m_outgoing.SendMessage(topic->GetIdHandle(),
UnannounceMsg{topic->name, topic->id});
m_outgoing.EraseHandle(topic->GetIdHandle());
m_server.m_controlReady = true;
}
void ServerImpl::ClientData4::SendPropertiesUpdate(TopicData* topic,
@@ -532,50 +562,33 @@ void ServerImpl::ClientData4::SendPropertiesUpdate(TopicData* topic,
}
if (m_local) {
WireEncodePropertiesUpdate(SendText().Add(), topic->name, update, ack);
Flush();
} else {
m_outgoing.emplace_back(
ServerMessage{PropertiesUpdateMsg{topic->name, update, ack}});
m_server.m_controlReady = true;
}
}
void ServerImpl::ClientData4::SendOutgoing(uint64_t curTimeMs) {
if (m_outgoing.empty()) {
return; // nothing to do
}
// rate limit frequency of transmissions
if (curTimeMs < (m_lastSendMs + kMinPeriodMs)) {
return;
}
if (!m_wire.Ready()) {
uint64_t lastFlushTime = m_wire.GetLastFlushTime();
uint64_t now = wpi::Now();
if (lastFlushTime != 0 && now > (lastFlushTime + kWireMaxNotReadyUs)) {
m_wire.Disconnect("transmit stalled");
}
return;
}
for (auto&& msg : m_outgoing) {
if (auto m = std::get_if<ServerValueMsg>(&msg.contents)) {
WriteBinary(m->topic, m->value.time(), m->value);
} else {
WireEncodeText(SendText().Add(), msg);
int unsent = m_wire.WriteText([&](auto& os) {
WireEncodePropertiesUpdate(os, topic->name, update, ack);
});
if (unsent == 0 && m_wire.Flush() == 0) {
return;
}
}
m_outgoing.resize(0);
m_outgoingValueMap.clear();
m_lastSendMs = curTimeMs;
m_outgoing.SendMessage(topic->GetIdHandle(),
PropertiesUpdateMsg{topic->name, update, ack});
m_server.m_controlReady = true;
}
void ServerImpl::ClientData4::Flush() {
m_outText.reset();
m_outBinary.reset();
m_wire.Flush();
void ServerImpl::ClientData4::SendOutgoing(uint64_t curTimeMs, bool flush) {
if (m_wire.GetVersion() >= 0x0401) {
if (!m_ping.Send(curTimeMs)) {
return;
}
}
m_outgoing.SendOutgoing(curTimeMs, flush);
}
void ServerImpl::ClientData4::UpdatePeriod(TopicData::TopicClientData& tcd,
TopicData* topic) {
uint32_t period =
CalculatePeriod(tcd.subscribers, [](auto& x) { return x->periodMs; });
DEBUG4("updating {} period to {} ms", topic->name, period);
m_outgoing.SetPeriod(topic->GetIdHandle(), period);
}
bool ServerImpl::ClientData3::TopicData3::UpdateFlags(TopicData* topic) {
@@ -593,21 +606,21 @@ void ServerImpl::ClientData3::ProcessIncomingBinary(
}
void ServerImpl::ClientData3::SendValue(TopicData* topic, const Value& value,
SendMode mode) {
ValueSendMode mode) {
if (m_state != kStateRunning) {
if (mode == kSendImmNoFlush) {
mode = kSendAll;
if (mode == ValueSendMode::kImm) {
mode = ValueSendMode::kAll;
}
} else if (m_local) {
mode = ClientData::kSendImmNoFlush; // always send local immediately
mode = ValueSendMode::kImm; // always send local immediately
}
TopicData3* topic3 = GetTopic3(topic);
bool added = false;
switch (mode) {
case ClientData::kSendDisabled: // do nothing
case ValueSendMode::kDisabled: // do nothing
break;
case ClientData::kSendImmNoFlush: // send immediately and flush
case ValueSendMode::kImm: // send immediately
++topic3->seqNum;
if (topic3->sentAssign) {
net3::WireEncodeEntryUpdate(m_wire.Send().stream(), topic->id,
@@ -622,7 +635,7 @@ void ServerImpl::ClientData3::SendValue(TopicData* topic, const Value& value,
Flush();
}
break;
case ClientData::kSendNormal: {
case ValueSendMode::kNormal: {
// replace, or append if not present
wpi::DenseMap<NT_Topic, size_t>::iterator it;
std::tie(it, added) =
@@ -639,7 +652,7 @@ void ServerImpl::ClientData3::SendValue(TopicData* topic, const Value& value,
}
}
// fallthrough
case ClientData::kSendAll: // append to outgoing
case ValueSendMode::kAll: // append to outgoing
if (!added) {
m_outgoingValueMap[topic->id] = m_outgoing.size();
}
@@ -666,7 +679,7 @@ void ServerImpl::ClientData3::SendAnnounce(TopicData* topic,
// subscribe to all non-special topics
if (!topic->special) {
topic->subscribers.Add(m_subscribers[0].get());
topic->clients[this].AddSubscriber(m_subscribers[0].get());
m_server.UpdateMetaTopicSub(topic);
}
@@ -720,7 +733,7 @@ void ServerImpl::ClientData3::SendPropertiesUpdate(TopicData* topic,
}
}
void ServerImpl::ClientData3::SendOutgoing(uint64_t curTimeMs) {
void ServerImpl::ClientData3::SendOutgoing(uint64_t curTimeMs, bool flush) {
if (m_outgoing.empty() || m_state != kStateRunning) {
return; // nothing to do
}
@@ -743,6 +756,7 @@ void ServerImpl::ClientData3::SendOutgoing(uint64_t curTimeMs) {
for (auto&& msg : m_outgoing) {
net3::WireEncode(out.stream(), msg);
}
m_wire.Flush();
m_outgoing.resize(0);
m_outgoingValueMap.clear();
m_lastSendMs = curTimeMs;
@@ -790,7 +804,7 @@ void ServerImpl::ClientData3::ClearEntries() {
auto publisherIt = m_publishers.find(topic3it.second.pubuid);
if (publisherIt != m_publishers.end()) {
// remove publisher from topic
topic->publishers.Remove(publisherIt->second.get());
topic->RemovePublisher(this, publisherIt->second.get());
// remove publisher from client
m_publishers.erase(publisherIt);
@@ -841,10 +855,7 @@ void ServerImpl::ClientData3::ClientHello(std::string_view self_id,
options.prefixMatch = true;
sub = std::make_unique<SubscriberData>(
this, std::span<const std::string>{{prefix}}, 0, options);
m_periodMs = std::gcd(m_periodMs, sub->periodMs);
if (m_periodMs < kMinPeriodMs) {
m_periodMs = kMinPeriodMs;
}
m_periodMs = UpdatePeriodCalc(m_periodMs, sub->periodMs);
m_setPeriodic(m_periodMs);
{
@@ -855,7 +866,7 @@ void ServerImpl::ClientData3::ClientHello(std::string_view self_id,
topic->lastValue) {
DEBUG4("client {}: initial announce of '{}' (id {})", m_id, topic->name,
topic->id);
topic->subscribers.Add(sub.get());
topic->clients[this].AddSubscriber(sub.get());
m_server.UpdateMetaTopicSub(topic.get());
TopicData3* topic3 = GetTopic3(topic.get());
@@ -922,7 +933,7 @@ void ServerImpl::ClientData3::EntryAssign(std::string_view name,
}
// add publisher to topic
topic->publishers.Add(publisherIt->getSecond().get());
topic->AddPublisher(this, publisherIt->getSecond().get());
// update meta data
m_server.UpdateMetaTopicPub(topic);
@@ -972,7 +983,7 @@ void ServerImpl::ClientData3::EntryUpdate(unsigned int id, unsigned int seq_num,
std::make_unique<PublisherData>(this, topic, topic3->pubuid));
if (isNew) {
// add publisher to topic
topic->publishers.Add(publisherIt->getSecond().get());
topic->AddPublisher(this, publisherIt->getSecond().get());
// update meta data
m_server.UpdateMetaTopicPub(topic);
@@ -1037,7 +1048,7 @@ void ServerImpl::ClientData3::EntryDelete(unsigned int id) {
auto publisherIt = m_publishers.find(topic3it->second.pubuid);
if (publisherIt != m_publishers.end()) {
// remove publisher from topic
topic->publishers.Remove(publisherIt->second.get());
topic->RemovePublisher(this, publisherIt->second.get());
// remove publisher from client
m_publishers.erase(publisherIt);
@@ -1159,8 +1170,6 @@ std::pair<std::string, int> ServerImpl::AddClient(
clientData->UpdateMetaClientPub();
clientData->UpdateMetaClientSub();
wire.Flush();
DEBUG3("AddClient('{}', '{}') -> {}", name, connInfo, index);
return {std::move(dedupName), index};
}
@@ -1197,17 +1206,14 @@ void ServerImpl::RemoveClient(int clientId) {
// remove all publishers and subscribers for this client
wpi::SmallVector<TopicData*, 16> toDelete;
for (auto&& topic : m_topics) {
auto pubRemove =
std::remove_if(topic->publishers.begin(), topic->publishers.end(),
[&](auto&& e) { return e->client == client.get(); });
bool pubChanged = pubRemove != topic->publishers.end();
topic->publishers.erase(pubRemove, topic->publishers.end());
auto subRemove =
std::remove_if(topic->subscribers.begin(), topic->subscribers.end(),
[&](auto&& e) { return e->client == client.get(); });
bool subChanged = subRemove != topic->subscribers.end();
topic->subscribers.erase(subRemove, topic->subscribers.end());
bool pubChanged = false;
bool subChanged = false;
auto tcdIt = topic->clients.find(client.get());
if (tcdIt != topic->clients.end()) {
pubChanged = !tcdIt->second.publishers.empty();
subChanged = !tcdIt->second.subscribers.empty();
topic->clients.erase(tcdIt);
}
if (!topic->IsPublished()) {
toDelete.push_back(topic.get());
@@ -1641,15 +1647,17 @@ ServerImpl::TopicData* ServerImpl::CreateTopic(ClientData* client,
wpi::SmallVector<SubscriberData*, 16> subscribersBuf;
auto subscribers =
aClient->GetSubscribers(name, topic->special, subscribersBuf);
for (auto subscriber : subscribers) {
topic->subscribers.Add(subscriber);
}
// don't announce to this client if no subscribers
if (subscribers.empty()) {
continue;
}
auto& tcd = topic->clients[aClient.get()];
for (auto subscriber : subscribers) {
tcd.AddSubscriber(subscriber);
}
if (aClient.get() == client) {
continue; // don't announce to requesting client again
}
@@ -1688,17 +1696,9 @@ void ServerImpl::DeleteTopic(TopicData* topic) {
}
// unannounce to all subscribers
wpi::SmallVector<bool, 16> clients;
clients.resize(m_clients.size());
for (auto&& sub : topic->subscribers) {
clients[sub->client->GetId()] = true;
}
for (size_t i = 0, iend = clients.size(); i < iend; ++i) {
if (!clients[i]) {
continue;
}
if (auto aClient = m_clients[i].get()) {
aClient->SendUnannounce(topic);
for (auto&& tcd : topic->clients) {
if (!tcd.second.subscribers.empty()) {
tcd.first->SendUnannounce(topic);
}
}
@@ -1755,32 +1755,9 @@ void ServerImpl::SetValue(ClientData* client, TopicData* topic,
}
}
// propagate to subscribers; as each client may have multiple subscribers,
// but we only want to send the value once, first map to clients and then
// take action based on union of subscriptions
// indexed by clientId
wpi::SmallVector<ClientData::SendMode, 16> toSend;
toSend.resize(m_clients.size());
for (auto&& subscriber : topic->subscribers) {
int id = subscriber->client->GetId();
if (subscriber->options.topicsOnly) {
continue;
} else if (subscriber->options.sendAll) {
toSend[id] = ClientData::kSendAll;
} else if (toSend[id] != ClientData::kSendAll) {
toSend[id] = ClientData::kSendNormal;
}
}
for (size_t i = 0, iend = toSend.size(); i < iend; ++i) {
auto aClient = m_clients[i].get();
if (!aClient || client == aClient) {
continue; // don't echo back
}
if (toSend[i] != ClientData::kSendDisabled) {
aClient->SendValue(topic, value, toSend[i]);
for (auto&& tcd : topic->clients) {
if (tcd.second.sendMode != ValueSendMode::kDisabled) {
tcd.first->SendValue(topic, value, tcd.second.sendMode);
}
}
}
@@ -1811,18 +1788,17 @@ void ServerImpl::UpdateMetaTopicPub(TopicData* topic) {
return;
}
Writer w;
mpack_start_array(&w, topic->publishers.size());
for (auto&& pub : topic->publishers) {
mpack_start_map(&w, 2);
mpack_write_str(&w, "client");
if (pub->client) {
mpack_write_str(&w, pub->client->GetName());
} else {
mpack_write_str(&w, "");
uint32_t count = 0;
for (auto&& tcd : topic->clients) {
count += tcd.second.publishers.size();
}
mpack_start_array(&w, count);
for (auto&& tcd : topic->clients) {
for (auto&& pub : tcd.second.publishers) {
mpack_write_object_bytes(
&w, reinterpret_cast<const char*>(pub->metaTopic.data()),
pub->metaTopic.size());
}
mpack_write_str(&w, "pubuid");
mpack_write_int(&w, pub->pubuid);
mpack_finish_map(&w);
}
mpack_finish_array(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
@@ -1835,20 +1811,17 @@ void ServerImpl::UpdateMetaTopicSub(TopicData* topic) {
return;
}
Writer w;
mpack_start_array(&w, topic->subscribers.size());
for (auto&& sub : topic->subscribers) {
mpack_start_map(&w, 3);
mpack_write_str(&w, "client");
if (sub->client) {
mpack_write_str(&w, sub->client->GetName());
} else {
mpack_write_str(&w, "");
uint32_t count = 0;
for (auto&& tcd : topic->clients) {
count += tcd.second.subscribers.size();
}
mpack_start_array(&w, count);
for (auto&& tcd : topic->clients) {
for (auto&& sub : tcd.second.subscribers) {
mpack_write_object_bytes(
&w, reinterpret_cast<const char*>(sub->metaTopic.data()),
sub->metaTopic.size());
}
mpack_write_str(&w, "subuid");
mpack_write_int(&w, sub->subuid);
mpack_write_str(&w, "options");
WriteOptions(w, sub->options);
mpack_finish_map(&w);
}
mpack_finish_array(&w);
if (mpack_writer_destroy(&w) == mpack_ok) {
@@ -1863,41 +1836,23 @@ void ServerImpl::PropertiesChanged(ClientData* client, TopicData* topic,
DeleteTopic(topic);
} else {
// send updated announcement to all subscribers
wpi::SmallVector<bool, 16> clients;
clients.resize(m_clients.size());
for (auto&& sub : topic->subscribers) {
clients[sub->client->GetId()] = true;
}
for (size_t i = 0, iend = clients.size(); i < iend; ++i) {
if (!clients[i]) {
continue;
}
if (auto aClient = m_clients[i].get()) {
aClient->SendPropertiesUpdate(topic, update, aClient == client);
}
for (auto&& tcd : topic->clients) {
tcd.first->SendPropertiesUpdate(topic, update, tcd.first == client);
}
}
}
void ServerImpl::SendControl(uint64_t curTimeMs) {
if (!m_controlReady) {
return;
}
m_controlReady = false;
void ServerImpl::SendAllOutgoing(uint64_t curTimeMs, bool flush) {
for (auto&& client : m_clients) {
if (client) {
// to ensure ordering, just send everything
client->SendOutgoing(curTimeMs);
client->Flush();
client->SendOutgoing(curTimeMs, flush);
}
}
}
void ServerImpl::SendValues(int clientId, uint64_t curTimeMs) {
void ServerImpl::SendOutgoing(int clientId, uint64_t curTimeMs) {
if (auto client = m_clients[clientId].get()) {
client->SendOutgoing(curTimeMs);
client->Flush();
client->SendOutgoing(curTimeMs, false);
}
}

View File

@@ -16,12 +16,17 @@
#include <vector>
#include <wpi/DenseMap.h>
#include <wpi/SmallPtrSet.h>
#include <wpi/StringMap.h>
#include <wpi/UidVector.h>
#include <wpi/json.h>
#include "Handle.h"
#include "Log.h"
#include "Message.h"
#include "NetworkInterface.h"
#include "NetworkOutgoingQueue.h"
#include "NetworkPing.h"
#include "PubSubOptions.h"
#include "VectorSet.h"
#include "WireConnection.h"
@@ -57,8 +62,8 @@ class ServerImpl final {
explicit ServerImpl(wpi::Logger& logger);
void SendControl(uint64_t curTimeMs);
void SendValues(int clientId, uint64_t curTimeMs);
void SendAllOutgoing(uint64_t curTimeMs, bool flush);
void SendOutgoing(int clientId, uint64_t curTimeMs);
void HandleLocal(std::span<const ClientMessage> msgs);
void SetLocal(LocalInterface* local);
@@ -88,9 +93,75 @@ class ServerImpl final {
private:
static constexpr uint32_t kMinPeriodMs = 5;
class ClientData;
struct PublisherData;
struct SubscriberData;
struct TopicData;
struct TopicData {
TopicData(std::string_view name, std::string_view typeStr)
: name{name}, typeStr{typeStr} {}
TopicData(std::string_view name, std::string_view typeStr,
wpi::json properties)
: name{name}, typeStr{typeStr}, properties(std::move(properties)) {
RefreshProperties();
}
bool IsPublished() const {
return persistent || retained || publisherCount != 0;
}
// returns true if properties changed
bool SetProperties(const wpi::json& update);
void RefreshProperties();
bool SetFlags(unsigned int flags_);
NT_Handle GetIdHandle() const { return Handle(0, id, Handle::kTopic); }
std::string name;
unsigned int id;
Value lastValue;
ClientData* lastValueClient = nullptr;
std::string typeStr;
wpi::json properties = wpi::json::object();
unsigned int publisherCount{0};
bool persistent{false};
bool retained{false};
bool special{false};
NT_Topic localHandle{0};
void AddPublisher(ClientData* client, PublisherData* pub) {
if (clients[client].publishers.insert(pub).second) {
++publisherCount;
}
}
void RemovePublisher(ClientData* client, PublisherData* pub) {
if (clients[client].publishers.erase(pub)) {
--publisherCount;
}
}
struct TopicClientData {
wpi::SmallPtrSet<PublisherData*, 2> publishers;
wpi::SmallPtrSet<SubscriberData*, 2> subscribers;
ValueSendMode sendMode = ValueSendMode::kDisabled;
bool AddSubscriber(SubscriberData* sub) {
bool added = subscribers.insert(sub).second;
if (!sub->options.topicsOnly && sendMode == ValueSendMode::kDisabled) {
sendMode = ValueSendMode::kNormal;
} else if (sub->options.sendAll) {
sendMode = ValueSendMode::kAll;
}
return added;
}
};
wpi::SmallDenseMap<ClientData*, TopicClientData, 4> clients;
// meta topics
TopicData* metaPub = nullptr;
TopicData* metaSub = nullptr;
};
class ClientData {
public:
@@ -109,16 +180,14 @@ class ServerImpl final {
virtual void ProcessIncomingText(std::string_view data) = 0;
virtual void ProcessIncomingBinary(std::span<const uint8_t> data) = 0;
enum SendMode { kSendDisabled = 0, kSendAll, kSendNormal, kSendImmNoFlush };
virtual void SendValue(TopicData* topic, const Value& value,
SendMode mode) = 0;
ValueSendMode mode) = 0;
virtual void SendAnnounce(TopicData* topic,
std::optional<int64_t> pubuid) = 0;
virtual void SendUnannounce(TopicData* topic) = 0;
virtual void SendPropertiesUpdate(TopicData* topic, const wpi::json& update,
bool ack) = 0;
virtual void SendOutgoing(uint64_t curTimeMs) = 0;
virtual void SendOutgoing(uint64_t curTimeMs, bool flush) = 0;
virtual void Flush() = 0;
void UpdateMetaClientPub();
@@ -132,13 +201,14 @@ class ServerImpl final {
int GetId() const { return m_id; }
protected:
virtual void UpdatePeriodic(TopicData* topic) {}
std::string m_name;
std::string m_connInfo;
bool m_local; // local to machine
ServerImpl::SetPeriodicFunc m_setPeriodic;
// TODO: make this per-topic?
uint32_t m_periodMs{UINT32_MAX};
uint64_t m_lastSendMs{0};
ServerImpl& m_server;
int m_id;
@@ -175,6 +245,9 @@ class ServerImpl final {
void ClientSetValue(int64_t pubuid, const Value& value);
virtual void UpdatePeriod(TopicData::TopicClientData& tcd,
TopicData* topic) {}
wpi::DenseMap<TopicData*, bool> m_announceSent;
};
@@ -186,12 +259,13 @@ class ServerImpl final {
void ProcessIncomingText(std::string_view data) final {}
void ProcessIncomingBinary(std::span<const uint8_t> data) final {}
void SendValue(TopicData* topic, const Value& value, SendMode mode) final;
void SendValue(TopicData* topic, const Value& value,
ValueSendMode mode) final;
void SendAnnounce(TopicData* topic, std::optional<int64_t> pubuid) final;
void SendUnannounce(TopicData* topic) final;
void SendPropertiesUpdate(TopicData* topic, const wpi::json& update,
bool ack) final;
void SendOutgoing(uint64_t curTimeMs) final {}
void SendOutgoing(uint64_t curTimeMs, bool flush) final {}
void Flush() final {}
void HandleLocal(std::span<const ClientMessage> msgs);
@@ -204,50 +278,31 @@ class ServerImpl final {
ServerImpl& server, int id, wpi::Logger& logger)
: ClientData4Base{name, connInfo, local, setPeriodic,
server, id, logger},
m_wire{wire} {}
m_wire{wire},
m_ping{wire},
m_outgoing{wire, local} {}
void ProcessIncomingText(std::string_view data) final;
void ProcessIncomingBinary(std::span<const uint8_t> data) final;
void SendValue(TopicData* topic, const Value& value, SendMode mode) final;
void SendValue(TopicData* topic, const Value& value,
ValueSendMode mode) final;
void SendAnnounce(TopicData* topic, std::optional<int64_t> pubuid) final;
void SendUnannounce(TopicData* topic) final;
void SendPropertiesUpdate(TopicData* topic, const wpi::json& update,
bool ack) final;
void SendOutgoing(uint64_t curTimeMs) final;
void SendOutgoing(uint64_t curTimeMs, bool flush) final;
void Flush() final;
void Flush() final {}
void UpdatePeriod(TopicData::TopicClientData& tcd, TopicData* topic) final;
public:
WireConnection& m_wire;
private:
std::vector<ServerMessage> m_outgoing;
wpi::DenseMap<NT_Topic, size_t> m_outgoingValueMap;
bool WriteBinary(int64_t id, int64_t time, const Value& value) {
return WireEncodeBinary(SendBinary().Add(), id, time, value);
}
TextWriter& SendText() {
m_outBinary.reset(); // ensure proper interleaving of text and binary
if (!m_outText) {
m_outText = m_wire.SendText();
}
return *m_outText;
}
BinaryWriter& SendBinary() {
m_outText.reset(); // ensure proper interleaving of text and binary
if (!m_outBinary) {
m_outBinary = m_wire.SendBinary();
}
return *m_outBinary;
}
// valid when we are actively writing to this client
std::optional<TextWriter> m_outText;
std::optional<BinaryWriter> m_outBinary;
NetworkPing m_ping;
NetworkOutgoingQueue<ServerMessage> m_outgoing;
};
class ClientData3 final : public ClientData, private net3::MessageHandler3 {
@@ -265,12 +320,13 @@ class ServerImpl final {
void ProcessIncomingText(std::string_view data) final {}
void ProcessIncomingBinary(std::span<const uint8_t> data) final;
void SendValue(TopicData* topic, const Value& value, SendMode mode) final;
void SendValue(TopicData* topic, const Value& value,
ValueSendMode mode) final;
void SendAnnounce(TopicData* topic, std::optional<int64_t> pubuid) final;
void SendUnannounce(TopicData* topic) final;
void SendPropertiesUpdate(TopicData* topic, const wpi::json& update,
bool ack) final;
void SendOutgoing(uint64_t curTimeMs) final;
void SendOutgoing(uint64_t curTimeMs, bool flush) final;
void Flush() final { m_wire.Flush(); }
@@ -305,6 +361,7 @@ class ServerImpl final {
std::vector<net3::Message3> m_outgoing;
wpi::DenseMap<NT_Topic, size_t> m_outgoingValueMap;
int64_t m_nextPubUid{1};
uint64_t m_lastSendMs{0};
struct TopicData3 {
explicit TopicData3(TopicData* topic) { UpdateFlags(topic); }
@@ -323,50 +380,19 @@ class ServerImpl final {
}
};
struct TopicData {
TopicData(std::string_view name, std::string_view typeStr)
: name{name}, typeStr{typeStr} {}
TopicData(std::string_view name, std::string_view typeStr,
wpi::json properties)
: name{name}, typeStr{typeStr}, properties(std::move(properties)) {
RefreshProperties();
}
bool IsPublished() const {
return persistent || retained || !publishers.empty();
}
// returns true if properties changed
bool SetProperties(const wpi::json& update);
void RefreshProperties();
bool SetFlags(unsigned int flags_);
std::string name;
unsigned int id;
Value lastValue;
ClientData* lastValueClient = nullptr;
std::string typeStr;
wpi::json properties = wpi::json::object();
bool persistent{false};
bool retained{false};
bool special{false};
NT_Topic localHandle{0};
VectorSet<PublisherData*> publishers;
VectorSet<SubscriberData*> subscribers;
// meta topics
TopicData* metaPub = nullptr;
TopicData* metaSub = nullptr;
};
struct PublisherData {
PublisherData(ClientData* client, TopicData* topic, int64_t pubuid)
: client{client}, topic{topic}, pubuid{pubuid} {}
: client{client}, topic{topic}, pubuid{pubuid} {
UpdateMeta();
}
void UpdateMeta();
ClientData* client;
TopicData* topic;
int64_t pubuid;
std::vector<uint8_t> metaClient;
std::vector<uint8_t> metaTopic;
};
struct SubscriberData {
@@ -377,6 +403,7 @@ class ServerImpl final {
subuid{subuid},
options{options},
periodMs(std::lround(options.periodicMs / 10.0) * 10) {
UpdateMeta();
if (periodMs < kMinPeriodMs) {
periodMs = kMinPeriodMs;
}
@@ -386,6 +413,7 @@ class ServerImpl final {
const PubSubOptionsImpl& options_) {
topicNames = {topicNames_.begin(), topicNames_.end()};
options = options_;
UpdateMeta();
periodMs = std::lround(options_.periodicMs / 10.0) * 10;
if (periodMs < kMinPeriodMs) {
periodMs = kMinPeriodMs;
@@ -394,10 +422,15 @@ class ServerImpl final {
bool Matches(std::string_view name, bool special);
void UpdateMeta();
ClientData* client;
std::vector<std::string> topicNames;
int64_t subuid;
PubSubOptionsImpl options;
std::vector<uint8_t> metaClient;
std::vector<uint8_t> metaTopic;
wpi::DenseMap<TopicData*, bool> topics;
// in options as double, but copy here as integer; rounded to the nearest
// 10 ms
uint32_t periodMs;

View File

@@ -6,134 +6,243 @@
#include <span>
#include <wpi/Endian.h>
#include <wpi/SpanExtras.h>
#include <wpi/raw_ostream.h>
#include <wpi/timestamp.h>
#include <wpinet/WebSocket.h>
#include <wpinet/raw_uv_ostream.h>
using namespace nt;
using namespace nt::net;
static constexpr size_t kAllocSize = 4096;
static constexpr size_t kTextFrameRolloverSize = 4096;
static constexpr size_t kBinaryFrameRolloverSize = 8192;
static constexpr size_t kMaxPoolSize = 16;
// MTU - assume Ethernet, IPv6, TCP; does not include WS frame header (max 10)
static constexpr size_t kMTU = 1500 - 40 - 20;
static constexpr size_t kAllocSize = kMTU - 10;
// leave enough room for a "typical" message size so we don't create lots of
// fragmented frames
static constexpr size_t kNewFrameThresholdBytes = kAllocSize - 50;
static constexpr size_t kFlushThresholdFrames = 32;
static constexpr size_t kFlushThresholdBytes = 16384;
static constexpr size_t kMaxPoolSize = 32;
WebSocketConnection::WebSocketConnection(wpi::WebSocket& ws)
: m_ws{ws},
m_text_os{m_text_buffers, [this] { return AllocBuf(); }},
m_binary_os{m_binary_buffers, [this] { return AllocBuf(); }} {}
class WebSocketConnection::Stream final : public wpi::raw_ostream {
public:
explicit Stream(WebSocketConnection& conn) : m_conn{conn} {
auto& buf = conn.m_bufs.back();
SetBuffer(buf.base + buf.len, kAllocSize - buf.len);
}
~Stream() final {
m_disableAlloc = true;
flush();
}
private:
size_t preferred_buffer_size() const final { return 0; }
void write_impl(const char* data, size_t len) final;
uint64_t current_pos() const final { return m_conn.m_framePos; }
WebSocketConnection& m_conn;
bool m_disableAlloc = false;
};
void WebSocketConnection::Stream::write_impl(const char* data, size_t len) {
if (len > kAllocSize) {
// only called by raw_ostream::write() when the buffer is empty and a large
// thing is being written; called with a length that's a multiple of the
// alloc size
assert((len % kAllocSize) == 0);
assert(m_conn.m_bufs.back().len == 0);
while (len > 0) {
auto& buf = m_conn.m_bufs.back();
std::memcpy(buf.base, data, kAllocSize);
buf.len = kAllocSize;
m_conn.m_framePos += kAllocSize;
m_conn.m_written += kAllocSize;
data += kAllocSize;
len -= kAllocSize;
// fragment the current frame and start a new one
m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin;
m_conn.StartFrame(wpi::WebSocket::Frame::kFragment);
}
SetBuffer(m_conn.m_bufs.back().base, kAllocSize);
[[unlikely]] return;
}
auto& buf = m_conn.m_bufs.back();
buf.len += len;
m_conn.m_framePos += len;
m_conn.m_written += len;
if (!m_disableAlloc && buf.len >= kAllocSize) {
// fragment the current frame and start a new one
[[unlikely]] m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin;
m_conn.StartFrame(wpi::WebSocket::Frame::kFragment);
SetBuffer(m_conn.m_bufs.back().base, kAllocSize);
}
}
WebSocketConnection::WebSocketConnection(wpi::WebSocket& ws,
unsigned int version)
: m_ws{ws}, m_version{version} {
m_ws.pong.connect([this](auto data) {
if (data.size() != 8) {
return;
}
m_lastPingResponse =
wpi::support::endian::read64<wpi::support::native>(data.data());
});
}
WebSocketConnection::~WebSocketConnection() {
for (auto&& buf : m_bufs) {
buf.Deallocate();
}
for (auto&& buf : m_buf_pool) {
buf.Deallocate();
}
for (auto&& buf : m_text_buffers) {
buf.Deallocate();
}
for (auto&& buf : m_binary_buffers) {
buf.Deallocate();
}
}
void WebSocketConnection::Flush() {
FinishSendText();
FinishSendBinary();
if (m_frames.empty()) {
return;
}
// convert internal frames into WS frames
m_ws_frames.clear();
m_ws_frames.reserve(m_frames.size());
for (auto&& frame : m_frames) {
m_ws_frames.emplace_back(frame.opcode,
std::span{frame.bufs->begin() + frame.start,
frame.bufs->begin() + frame.end});
}
++m_sendsActive;
m_ws.SendFrames(m_ws_frames, [selfweak = weak_from_this()](auto bufs, auto) {
void WebSocketConnection::SendPing(uint64_t time) {
auto buf = AllocBuf();
buf.len = 8;
wpi::support::endian::write64<wpi::support::native>(buf.base, time);
m_ws.SendPing({buf}, [selfweak = weak_from_this()](auto bufs, auto err) {
if (auto self = selfweak.lock()) {
size_t numToPool =
(std::min)(bufs.size(), kMaxPoolSize - self->m_buf_pool.size());
self->m_buf_pool.insert(self->m_buf_pool.end(), bufs.begin(),
bufs.begin() + numToPool);
for (auto&& buf : bufs.subspan(numToPool)) {
buf.Deallocate();
}
if (self->m_sendsActive > 0) {
--self->m_sendsActive;
}
self->m_err = err;
self->ReleaseBufs(bufs);
} else {
for (auto&& buf : bufs) {
buf.Deallocate();
}
}
});
m_frames.clear();
m_text_buffers.clear();
m_binary_buffers.clear();
m_text_pos = 0;
m_binary_pos = 0;
}
void WebSocketConnection::StartFrame(uint8_t opcode) {
m_frames.emplace_back(opcode, m_bufs.size(), m_bufs.size() + 1);
m_bufs.emplace_back(AllocBuf());
m_bufs.back().len = 0;
}
void WebSocketConnection::FinishText() {
assert(!m_bufs.empty());
auto& buf = m_bufs.back();
assert(buf.len < kAllocSize + 1); // safe because we alloc one more byte
buf.base[buf.len++] = ']';
}
int WebSocketConnection::Write(
State kind, wpi::function_ref<void(wpi::raw_ostream& os)> writer) {
bool first = false;
if (m_state != kind ||
(m_state == kind && m_framePos >= kNewFrameThresholdBytes)) {
// start a new frame
if (m_state == kText) {
FinishText();
}
m_state = kind;
if (!m_frames.empty()) {
m_frames.back().opcode |= wpi::WebSocket::kFlagFin;
}
StartFrame(m_state == kText ? wpi::WebSocket::Frame::kText
: wpi::WebSocket::Frame::kBinary);
m_framePos = 0;
first = true;
}
{
Stream os{*this};
if (kind == kText) {
os << (first ? '[' : ',');
}
writer(os);
}
++m_frames.back().count;
if (m_frames.size() > kFlushThresholdFrames ||
m_written >= kFlushThresholdBytes) {
return Flush();
}
return 0;
}
int WebSocketConnection::Flush() {
m_lastFlushTime = wpi::Now();
if (m_state == kEmpty) {
return 0;
}
if (m_state == kText) {
FinishText();
}
m_state = kEmpty;
m_written = 0;
if (m_frames.empty()) {
return 0;
}
m_frames.back().opcode |= wpi::WebSocket::kFlagFin;
// convert internal frames into WS frames
m_ws_frames.clear();
m_ws_frames.reserve(m_frames.size());
for (auto&& frame : m_frames) {
m_ws_frames.emplace_back(
frame.opcode, std::span{&m_bufs[frame.start], &m_bufs[frame.end]});
}
auto unsentFrames = m_ws.TrySendFrames(
m_ws_frames, [selfweak = weak_from_this()](auto bufs, auto err) {
if (auto self = selfweak.lock()) {
self->m_err = err;
self->ReleaseBufs(bufs);
} else {
for (auto&& buf : bufs) {
buf.Deallocate();
}
}
});
m_ws_frames.clear();
if (m_err) {
m_frames.clear();
m_bufs.clear();
return m_err.code();
}
int count = 0;
for (auto&& frame :
wpi::take_back(std::span{m_frames}, unsentFrames.size())) {
count += frame.count;
}
m_frames.clear();
m_bufs.clear();
return count;
}
void WebSocketConnection::Send(
uint8_t opcode, wpi::function_ref<void(wpi::raw_ostream& os)> writer) {
wpi::SmallVector<wpi::uv::Buffer, 4> bufs;
wpi::raw_uv_ostream os{bufs, [this] { return AllocBuf(); }};
if (opcode == wpi::WebSocket::Frame::kText) {
os << '[';
}
writer(os);
if (opcode == wpi::WebSocket::Frame::kText) {
os << ']';
}
wpi::WebSocket::Frame frame{opcode, os.bufs()};
m_ws.SendFrames({{frame}}, [selfweak = weak_from_this()](auto bufs, auto) {
if (auto self = selfweak.lock()) {
self->ReleaseBufs(bufs);
} else {
for (auto&& buf : bufs) {
buf.Deallocate();
}
}
});
}
void WebSocketConnection::Disconnect(std::string_view reason) {
m_reason = reason;
m_ws.Close(1005, reason);
}
void WebSocketConnection::StartSendText() {
// limit amount per single frame
size_t total = 0;
for (size_t i = m_text_pos; i < m_text_buffers.size(); ++i) {
total += m_text_buffers[i].len;
}
if (total >= kTextFrameRolloverSize) {
FinishSendText();
}
if (m_in_text) {
m_text_os << ',';
} else {
m_text_os << '[';
m_in_text = true;
}
}
void WebSocketConnection::FinishSendText() {
if (m_in_text) {
m_text_os << ']';
m_in_text = false;
}
if (m_text_pos >= m_text_buffers.size()) {
return;
}
m_frames.emplace_back(wpi::WebSocket::Frame::kText, &m_text_buffers,
m_text_pos, m_text_buffers.size());
m_text_pos = m_text_buffers.size();
m_text_os.reset();
}
void WebSocketConnection::StartSendBinary() {
// limit amount per single frame
size_t total = 0;
for (size_t i = m_binary_pos; i < m_binary_buffers.size(); ++i) {
total += m_binary_buffers[i].len;
}
if (total >= kBinaryFrameRolloverSize) {
FinishSendBinary();
}
}
void WebSocketConnection::FinishSendBinary() {
if (m_binary_pos >= m_binary_buffers.size()) {
return;
}
m_frames.emplace_back(wpi::WebSocket::Frame::kBinary, &m_binary_buffers,
m_binary_pos, m_binary_buffers.size());
m_binary_pos = m_binary_buffers.size();
m_binary_os.reset();
m_ws.Fail(1005, reason);
}
wpi::uv::Buffer WebSocketConnection::AllocBuf() {
@@ -142,5 +251,13 @@ wpi::uv::Buffer WebSocketConnection::AllocBuf() {
m_buf_pool.pop_back();
return buf;
}
return wpi::uv::Buffer::Allocate(kAllocSize);
return wpi::uv::Buffer::Allocate(kAllocSize + 1); // leave space for ']'
}
void WebSocketConnection::ReleaseBufs(std::span<wpi::uv::Buffer> bufs) {
size_t numToPool = (std::min)(bufs.size(), kMaxPoolSize - m_buf_pool.size());
m_buf_pool.insert(m_buf_pool.end(), bufs.begin(), bufs.begin() + numToPool);
for (auto&& buf : bufs.subspan(numToPool)) {
buf.Deallocate();
}
}

View File

@@ -9,9 +9,8 @@
#include <string_view>
#include <vector>
#include <wpi/SmallVector.h>
#include <wpi/function_ref.h>
#include <wpinet/WebSocket.h>
#include <wpinet/raw_uv_ostream.h>
#include <wpinet/uv/Buffer.h>
#include "WireConnection.h"
@@ -22,56 +21,77 @@ class WebSocketConnection final
: public WireConnection,
public std::enable_shared_from_this<WebSocketConnection> {
public:
explicit WebSocketConnection(wpi::WebSocket& ws);
WebSocketConnection(wpi::WebSocket& ws, unsigned int version);
~WebSocketConnection() override;
WebSocketConnection(const WebSocketConnection&) = delete;
WebSocketConnection& operator=(const WebSocketConnection&) = delete;
bool Ready() const final { return m_sendsActive == 0; }
unsigned int GetVersion() const final { return m_version; }
TextWriter SendText() final { return {m_text_os, *this}; }
BinaryWriter SendBinary() final { return {m_binary_os, *this}; }
void SendPing(uint64_t time) final;
void Flush() final;
bool Ready() const final { return !m_ws.IsWriteInProgress(); }
int WriteText(wpi::function_ref<void(wpi::raw_ostream& os)> writer) final {
return Write(kText, writer);
}
int WriteBinary(wpi::function_ref<void(wpi::raw_ostream& os)> writer) final {
return Write(kBinary, writer);
}
int Flush() final;
void SendText(wpi::function_ref<void(wpi::raw_ostream& os)> writer) final {
Send(wpi::WebSocket::Frame::kText, writer);
}
void SendBinary(wpi::function_ref<void(wpi::raw_ostream& os)> writer) final {
Send(wpi::WebSocket::Frame::kBinary, writer);
}
uint64_t GetLastFlushTime() const final { return m_lastFlushTime; }
uint64_t GetLastPingResponse() const final { return m_lastPingResponse; }
void Disconnect(std::string_view reason) final;
std::string_view GetDisconnectReason() const { return m_reason; }
private:
void StartSendText() final;
void FinishSendText() final;
void StartSendBinary() final;
void FinishSendBinary() final;
enum State { kEmpty, kText, kBinary };
int Write(State kind, wpi::function_ref<void(wpi::raw_ostream& os)> writer);
void Send(uint8_t opcode,
wpi::function_ref<void(wpi::raw_ostream& os)> writer);
void StartFrame(uint8_t opcode);
void FinishText();
wpi::uv::Buffer AllocBuf();
void ReleaseBufs(std::span<wpi::uv::Buffer> bufs);
wpi::WebSocket& m_ws;
class Stream;
// Can't use WS frames directly as span could have dangling pointers
struct Frame {
Frame(uint8_t opcode, wpi::SmallVectorImpl<wpi::uv::Buffer>* bufs,
size_t start, size_t end)
: opcode{opcode}, bufs{bufs}, start{start}, end{end} {}
uint8_t opcode;
wpi::SmallVectorImpl<wpi::uv::Buffer>* bufs;
Frame(uint8_t opcode, size_t start, size_t end)
: start{start}, end{end}, opcode{opcode} {}
size_t start;
size_t end;
unsigned int count = 0;
uint8_t opcode;
};
std::vector<Frame> m_frames;
std::vector<wpi::WebSocket::Frame> m_ws_frames; // to reduce allocs
wpi::SmallVector<wpi::uv::Buffer, 4> m_text_buffers;
wpi::SmallVector<wpi::uv::Buffer, 4> m_binary_buffers;
std::vector<Frame> m_frames;
std::vector<wpi::uv::Buffer> m_bufs;
std::vector<wpi::uv::Buffer> m_buf_pool;
wpi::raw_uv_ostream m_text_os;
wpi::raw_uv_ostream m_binary_os;
size_t m_text_pos = 0;
size_t m_binary_pos = 0;
bool m_in_text = false;
int m_sendsActive = 0;
size_t m_framePos = 0;
size_t m_written = 0;
wpi::uv::Error m_err;
State m_state = kEmpty;
std::string m_reason;
uint64_t m_lastFlushTime = 0;
uint64_t m_lastPingResponse = 0;
unsigned int m_version;
};
} // namespace nt::net

View File

@@ -8,105 +8,53 @@
#include <string_view>
#include <wpi/raw_ostream.h>
#include <wpi/function_ref.h>
namespace wpi {
class raw_ostream;
} // namespace wpi
namespace nt::net {
class BinaryWriter;
class TextWriter;
class WireConnection {
friend class TextWriter;
friend class BinaryWriter;
public:
virtual ~WireConnection() = default;
virtual unsigned int GetVersion() const = 0;
virtual void SendPing(uint64_t time) = 0;
virtual bool Ready() const = 0;
virtual TextWriter SendText() = 0;
// These return <0 on error, 0 on success. On buffer full, a positive number
// is is returned indicating the number of previous messages (including this
// call) that were NOT sent, e.g. 1 if just this call to WriteText or
// WriteBinary was not sent, 2 if the this call and the *previous* call were
// not sent.
[[nodiscard]]
virtual int WriteText(
wpi::function_ref<void(wpi::raw_ostream& os)> writer) = 0;
[[nodiscard]]
virtual int WriteBinary(
wpi::function_ref<void(wpi::raw_ostream& os)> writer) = 0;
virtual BinaryWriter SendBinary() = 0;
// Flushes any pending buffers. Return value equivalent to
// WriteText/WriteBinary (e.g. 1 means the last WriteX call was not sent).
[[nodiscard]]
virtual int Flush() = 0;
virtual void Flush() = 0;
// These immediately send the data even if the buffer is full.
virtual void SendText(
wpi::function_ref<void(wpi::raw_ostream& os)> writer) = 0;
virtual void SendBinary(
wpi::function_ref<void(wpi::raw_ostream& os)> writer) = 0;
virtual uint64_t GetLastFlushTime() const = 0; // in microseconds
// Gets the timestamp of the last ping we got a reply to
virtual uint64_t GetLastPingResponse() const = 0; // in microseconds
virtual void Disconnect(std::string_view reason) = 0;
protected:
virtual void StartSendText() = 0;
virtual void FinishSendText() = 0;
virtual void StartSendBinary() = 0;
virtual void FinishSendBinary() = 0;
};
class TextWriter {
public:
TextWriter(wpi::raw_ostream& os, WireConnection& wire)
: m_os{&os}, m_wire{&wire} {}
TextWriter(const TextWriter&) = delete;
TextWriter(TextWriter&& rhs) : m_os{rhs.m_os}, m_wire{rhs.m_wire} {
rhs.m_os = nullptr;
rhs.m_wire = nullptr;
}
TextWriter& operator=(const TextWriter&) = delete;
TextWriter& operator=(TextWriter&& rhs) {
m_os = rhs.m_os;
m_wire = rhs.m_wire;
rhs.m_os = nullptr;
rhs.m_wire = nullptr;
return *this;
}
~TextWriter() {
if (m_os) {
m_wire->FinishSendText();
}
}
wpi::raw_ostream& Add() {
m_wire->StartSendText();
return *m_os;
}
WireConnection& wire() { return *m_wire; }
private:
wpi::raw_ostream* m_os;
WireConnection* m_wire;
};
class BinaryWriter {
public:
BinaryWriter(wpi::raw_ostream& os, WireConnection& wire)
: m_os{&os}, m_wire{&wire} {}
BinaryWriter(const BinaryWriter&) = delete;
BinaryWriter(BinaryWriter&& rhs) : m_os{rhs.m_os}, m_wire{rhs.m_wire} {
rhs.m_os = nullptr;
rhs.m_wire = nullptr;
}
BinaryWriter& operator=(const BinaryWriter&) = delete;
BinaryWriter& operator=(BinaryWriter&& rhs) {
m_os = rhs.m_os;
m_wire = rhs.m_wire;
rhs.m_os = nullptr;
rhs.m_wire = nullptr;
return *this;
}
~BinaryWriter() {
if (m_wire) {
m_wire->FinishSendBinary();
}
}
wpi::raw_ostream& Add() {
m_wire->StartSendBinary();
return *m_os;
}
WireConnection& wire() { return *m_wire; }
private:
wpi::raw_ostream* m_os;
WireConnection* m_wire;
};
} // namespace nt::net

View File

@@ -107,21 +107,22 @@ static bool ObjGetStringArray(wpi::json::object_t& obj, std::string_view key,
template <typename T>
requires(std::same_as<T, ClientMessageHandler> ||
std::same_as<T, ServerMessageHandler>)
static void WireDecodeTextImpl(std::string_view in, T& out,
static bool WireDecodeTextImpl(std::string_view in, T& out,
wpi::Logger& logger) {
wpi::json j;
try {
j = wpi::json::parse(in);
} catch (wpi::json::parse_error& err) {
WPI_WARNING(logger, "could not decode JSON message: {}", err.what());
return;
return false;
}
if (!j.is_array()) {
WPI_WARNING(logger, "expected JSON array at top level");
return;
return false;
}
bool rv = false;
int i = -1;
for (auto&& jmsg : j) {
++i;
@@ -187,6 +188,7 @@ static void WireDecodeTextImpl(std::string_view in, T& out,
// complete
out.ClientPublish(pubuid, *name, *typeStr, *properties);
rv = true;
} else if (*method == UnpublishMsg::kMethodStr) {
// pubuid
int64_t pubuid;
@@ -196,6 +198,7 @@ static void WireDecodeTextImpl(std::string_view in, T& out,
// complete
out.ClientUnpublish(pubuid);
rv = true;
} else if (*method == SetPropertiesMsg::kMethodStr) {
// name
auto name = ObjGetString(*params, "name", &error);
@@ -288,6 +291,7 @@ static void WireDecodeTextImpl(std::string_view in, T& out,
// complete
out.ClientSubscribe(subuid, topicNames, options);
rv = true;
} else if (*method == UnsubscribeMsg::kMethodStr) {
// subuid
int64_t subuid;
@@ -297,6 +301,7 @@ static void WireDecodeTextImpl(std::string_view in, T& out,
// complete
out.ClientUnsubscribe(subuid);
rv = true;
} else {
error = fmt::format("unrecognized method '{}'", *method);
goto err;
@@ -404,15 +409,17 @@ static void WireDecodeTextImpl(std::string_view in, T& out,
err:
WPI_WARNING(logger, "{}: {}", i, error);
}
return rv;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
void nt::net::WireDecodeText(std::string_view in, ClientMessageHandler& out,
bool nt::net::WireDecodeText(std::string_view in, ClientMessageHandler& out,
wpi::Logger& logger) {
::WireDecodeTextImpl(in, out, logger);
return ::WireDecodeTextImpl(in, out, logger);
}
void nt::net::WireDecodeText(std::string_view in, ServerMessageHandler& out,

View File

@@ -52,7 +52,8 @@ class ServerMessageHandler {
const wpi::json& update, bool ack) = 0;
};
void WireDecodeText(std::string_view in, ClientMessageHandler& out,
// return true if client pub/sub metadata needs updating
bool WireDecodeText(std::string_view in, ClientMessageHandler& out,
wpi::Logger& logger);
void WireDecodeText(std::string_view in, ServerMessageHandler& out,
wpi::Logger& logger);