[ntcore] NetworkOutgoingQueue: Move function defs inside class

This commit is contained in:
Peter Johnson
2024-10-15 16:48:10 -07:00
parent 876be30724
commit a0f38f83f9

View File

@@ -69,7 +69,43 @@ class NetworkOutgoingQueue {
m_queues.emplace_back(100); // default queue is 100 ms period
}
void SetPeriod(int id, uint32_t periodMs);
void SetPeriod(int id, uint32_t periodMs) {
// it's quite common to set a lot of things in a row with the same period
unsigned int queueIndex;
if (m_lastSetPeriod == periodMs) {
queueIndex = m_lastSetPeriodQueueIndex;
} else {
// find and possibly create queue for this period
auto it =
std::find_if(m_queues.begin(), m_queues.end(),
[&](const auto& q) { return q.periodMs == periodMs; });
if (it == m_queues.end()) {
queueIndex = m_queues.size();
m_queues.emplace_back(periodMs);
} else {
queueIndex = it - m_queues.begin();
}
m_lastSetPeriodQueueIndex = queueIndex;
m_lastSetPeriod = periodMs;
}
// map the handle to the queue
auto [infoIt, created] = m_idMap.try_emplace(id);
if (!created && infoIt->getSecond().queueIndex != queueIndex) {
// need to move any items from old queue to new queue
auto& oldMsgs = m_queues[infoIt->getSecond().queueIndex].msgs;
auto it =
std::stable_partition(oldMsgs.begin(), oldMsgs.end(),
[&](const auto& e) { return e.id != id; });
auto& newMsgs = m_queues[queueIndex].msgs;
for (auto i = it, end = oldMsgs.end(); i != end; ++i) {
newMsgs.emplace_back(std::move(*i));
}
oldMsgs.erase(it, oldMsgs.end());
}
infoIt->getSecond().queueIndex = queueIndex;
}
void EraseId(int id) { m_idMap.erase(id); }
@@ -79,9 +115,146 @@ class NetworkOutgoingQueue {
m_totalSize += sizeof(Message);
}
void SendValue(int id, const Value& value, ValueSendMode mode);
void SendValue(int id, const Value& value, ValueSendMode mode) {
if (m_local) {
mode = ValueSendMode::kImm; // always send local immediately
}
// backpressure by stopping sending all if the buffer is too full
if (mode == ValueSendMode::kAll && m_totalSize >= kOutgoingLimit) {
mode = ValueSendMode::kNormal;
}
switch (mode) {
case ValueSendMode::kDisabled: // do nothing
break;
case ValueSendMode::kImm: // send immediately
m_wire.SendBinary([&](auto& os) { EncodeValue(os, id, value); });
break;
case ValueSendMode::kAll: { // append to outgoing
auto& info = m_idMap[id];
auto& queue = m_queues[info.queueIndex];
info.valuePos = queue.msgs.size();
queue.Append(id, ValueMsg{id, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
case ValueSendMode::kNormal: {
// replace, or append if not present
auto& info = m_idMap[id];
auto& queue = m_queues[info.queueIndex];
if (info.valuePos != -1 &&
static_cast<unsigned int>(info.valuePos) < queue.msgs.size()) {
auto& elem = queue.msgs[info.valuePos];
if (auto m = std::get_if<ValueMsg>(&elem.msg.contents)) {
// double-check handle, and only replace if timestamp newer
if (elem.id == id &&
(m->value.time() == 0 || value.time() >= m->value.time())) {
int delta = value.size() - m->value.size();
m->value = value;
m_totalSize += delta;
return;
}
}
}
info.valuePos = queue.msgs.size();
queue.Append(id, ValueMsg{id, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
}
}
void SendOutgoing(uint64_t curTimeMs, bool flush);
void SendOutgoing(uint64_t curTimeMs, bool flush) {
if (m_totalSize == 0) {
return; // nothing to do
}
// rate limit frequency of transmissions
if (curTimeMs < (m_lastSendMs + kMinPeriodMs)) {
return;
}
if (!m_wire.Ready()) {
return; // don't bother, still sending the last batch
}
// what queues are ready to send?
wpi::SmallVector<unsigned int, 16> queues;
for (unsigned int i = 0; i < m_queues.size(); ++i) {
if (!m_queues[i].msgs.empty() &&
(flush || curTimeMs >= m_queues[i].nextSendMs)) {
queues.emplace_back(i);
}
}
if (queues.empty()) {
return; // nothing needs to be sent yet
}
// Sort transmission order by what queue has been waiting the longest time.
// XXX: byte-weighted fair queueing might be better, but is much more
// complex to implement.
std::sort(queues.begin(), queues.end(), [&](const auto& a, const auto& b) {
return m_queues[a].nextSendMs < m_queues[b].nextSendMs;
});
for (unsigned int queueIndex : queues) {
auto& queue = m_queues[queueIndex];
auto& msgs = queue.msgs;
auto it = msgs.begin();
auto end = msgs.end();
int unsent = 0;
for (; it != end && unsent == 0; ++it) {
if (auto m = std::get_if<ValueMsg>(&it->msg.contents)) {
unsent = m_wire.WriteBinary(
[&](auto& os) { EncodeValue(os, it->id, m->value); });
} else {
unsent = m_wire.WriteText([&](auto& os) {
if (!WireEncodeText(os, it->msg)) {
os << "{}";
}
});
}
}
if (unsent < 0) {
return; // error
}
if (unsent == 0) {
// finish writing any partial buffers
unsent = m_wire.Flush();
if (unsent < 0) {
return; // error
}
}
int delta = it - msgs.begin() - unsent;
for (auto&& msg : std::span{msgs}.subspan(0, delta)) {
if (auto m = std::get_if<ValueMsg>(&msg.msg.contents)) {
m_totalSize -= sizeof(Message) + m->value.size();
} else {
m_totalSize -= sizeof(Message);
}
}
msgs.erase(msgs.begin(), it - unsent);
for (auto&& kv : m_idMap) {
auto& info = kv.getSecond();
if (info.queueIndex == queueIndex) {
if (info.valuePos < delta) {
info.valuePos = -1;
} else {
info.valuePos -= delta;
}
}
}
// try to stay on periodic timing, unless it's falling behind current time
if (unsent == 0) {
queue.nextSendMs += queue.periodMs;
if (queue.nextSendMs < curTimeMs) {
queue.nextSendMs = curTimeMs + queue.periodMs;
}
}
}
m_lastSendMs = curTimeMs;
}
void SetTimeOffset(int64_t offsetUs) { m_timeOffsetUs = offsetUs; }
int64_t GetTimeOffset() const { return m_timeOffsetUs; }
@@ -92,7 +265,19 @@ class NetworkOutgoingQueue {
private:
using ValueMsg = typename MessageType::ValueMsg;
void EncodeValue(wpi::raw_ostream& os, int id, const Value& value);
void EncodeValue(wpi::raw_ostream& os, int id, const Value& value) {
int64_t time = value.time();
if constexpr (std::same_as<ValueMsg, ClientValueMsg>) {
if (time != 0) {
time += m_timeOffsetUs;
// make sure resultant time isn't exactly 0
if (time == 0) {
time = 1;
}
}
}
WireEncodeBinary(os, id, time, value);
}
struct Message {
Message() = default;
@@ -132,204 +317,4 @@ class NetworkOutgoingQueue {
static constexpr size_t kOutgoingLimit = 1024 * 1024;
};
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SetPeriod(int id, uint32_t periodMs) {
// it's quite common to set a lot of things in a row with the same period
unsigned int queueIndex;
if (m_lastSetPeriod == periodMs) {
queueIndex = m_lastSetPeriodQueueIndex;
} else {
// find and possibly create queue for this period
auto it =
std::find_if(m_queues.begin(), m_queues.end(),
[&](const auto& q) { return q.periodMs == periodMs; });
if (it == m_queues.end()) {
queueIndex = m_queues.size();
m_queues.emplace_back(periodMs);
} else {
queueIndex = it - m_queues.begin();
}
m_lastSetPeriodQueueIndex = queueIndex;
m_lastSetPeriod = periodMs;
}
// map the handle to the queue
auto [infoIt, created] = m_idMap.try_emplace(id);
if (!created && infoIt->getSecond().queueIndex != queueIndex) {
// need to move any items from old queue to new queue
auto& oldMsgs = m_queues[infoIt->getSecond().queueIndex].msgs;
auto it = std::stable_partition(oldMsgs.begin(), oldMsgs.end(),
[&](const auto& e) { return e.id != id; });
auto& newMsgs = m_queues[queueIndex].msgs;
for (auto i = it, end = oldMsgs.end(); i != end; ++i) {
newMsgs.emplace_back(std::move(*i));
}
oldMsgs.erase(it, oldMsgs.end());
}
infoIt->getSecond().queueIndex = queueIndex;
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SendValue(int id, const Value& value,
ValueSendMode mode) {
if (m_local) {
mode = ValueSendMode::kImm; // always send local immediately
}
// backpressure by stopping sending all if the buffer is too full
if (mode == ValueSendMode::kAll && m_totalSize >= kOutgoingLimit) {
mode = ValueSendMode::kNormal;
}
switch (mode) {
case ValueSendMode::kDisabled: // do nothing
break;
case ValueSendMode::kImm: // send immediately
m_wire.SendBinary([&](auto& os) { EncodeValue(os, id, value); });
break;
case ValueSendMode::kAll: { // append to outgoing
auto& info = m_idMap[id];
auto& queue = m_queues[info.queueIndex];
info.valuePos = queue.msgs.size();
queue.Append(id, ValueMsg{id, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
case ValueSendMode::kNormal: {
// replace, or append if not present
auto& info = m_idMap[id];
auto& queue = m_queues[info.queueIndex];
if (info.valuePos != -1 &&
static_cast<unsigned int>(info.valuePos) < queue.msgs.size()) {
auto& elem = queue.msgs[info.valuePos];
if (auto m = std::get_if<ValueMsg>(&elem.msg.contents)) {
// double-check handle, and only replace if timestamp newer
if (elem.id == id &&
(m->value.time() == 0 || value.time() >= m->value.time())) {
int delta = value.size() - m->value.size();
m->value = value;
m_totalSize += delta;
return;
}
}
}
info.valuePos = queue.msgs.size();
queue.Append(id, ValueMsg{id, value});
m_totalSize += sizeof(Message) + value.size();
break;
}
}
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::SendOutgoing(uint64_t curTimeMs,
bool flush) {
if (m_totalSize == 0) {
return; // nothing to do
}
// rate limit frequency of transmissions
if (curTimeMs < (m_lastSendMs + kMinPeriodMs)) {
return;
}
if (!m_wire.Ready()) {
return; // don't bother, still sending the last batch
}
// what queues are ready to send?
wpi::SmallVector<unsigned int, 16> queues;
for (unsigned int i = 0; i < m_queues.size(); ++i) {
if (!m_queues[i].msgs.empty() &&
(flush || curTimeMs >= m_queues[i].nextSendMs)) {
queues.emplace_back(i);
}
}
if (queues.empty()) {
return; // nothing needs to be sent yet
}
// Sort transmission order by what queue has been waiting the longest time.
// XXX: byte-weighted fair queueing might be better, but is much more complex
// to implement.
std::sort(queues.begin(), queues.end(), [&](const auto& a, const auto& b) {
return m_queues[a].nextSendMs < m_queues[b].nextSendMs;
});
for (unsigned int queueIndex : queues) {
auto& queue = m_queues[queueIndex];
auto& msgs = queue.msgs;
auto it = msgs.begin();
auto end = msgs.end();
int unsent = 0;
for (; it != end && unsent == 0; ++it) {
if (auto m = std::get_if<ValueMsg>(&it->msg.contents)) {
unsent = m_wire.WriteBinary(
[&](auto& os) { EncodeValue(os, it->id, m->value); });
} else {
unsent = m_wire.WriteText([&](auto& os) {
if (!WireEncodeText(os, it->msg)) {
os << "{}";
}
});
}
}
if (unsent < 0) {
return; // error
}
if (unsent == 0) {
// finish writing any partial buffers
unsent = m_wire.Flush();
if (unsent < 0) {
return; // error
}
}
int delta = it - msgs.begin() - unsent;
for (auto&& msg : std::span{msgs}.subspan(0, delta)) {
if (auto m = std::get_if<ValueMsg>(&msg.msg.contents)) {
m_totalSize -= sizeof(Message) + m->value.size();
} else {
m_totalSize -= sizeof(Message);
}
}
msgs.erase(msgs.begin(), it - unsent);
for (auto&& kv : m_idMap) {
auto& info = kv.getSecond();
if (info.queueIndex == queueIndex) {
if (info.valuePos < delta) {
info.valuePos = -1;
} else {
info.valuePos -= delta;
}
}
}
// try to stay on periodic timing, unless it's falling behind current time
if (unsent == 0) {
queue.nextSendMs += queue.periodMs;
if (queue.nextSendMs < curTimeMs) {
queue.nextSendMs = curTimeMs + queue.periodMs;
}
}
}
m_lastSendMs = curTimeMs;
}
template <NetworkMessage MessageType>
void NetworkOutgoingQueue<MessageType>::EncodeValue(wpi::raw_ostream& os,
int id,
const Value& value) {
int64_t time = value.time();
if constexpr (std::same_as<ValueMsg, ClientValueMsg>) {
if (time != 0) {
time += m_timeOffsetUs;
// make sure resultant time isn't exactly 0
if (time == 0) {
time = 1;
}
}
}
WireEncodeBinary(os, id, time, value);
}
} // namespace nt::net