[wpinet] Add WebSocket::TrySendFrames() (#5607)

This takes advantage of the underlying byte-level TryWrite() functionality to minimize blocking behavior and enable higher layers to do things smartly when the network blocks.

Also:
- Fix handling of control packets in middle of fragmented
- Clean up debugging features
This commit is contained in:
Peter Johnson
2023-09-18 19:49:54 -07:00
committed by GitHub
parent c4643ba047
commit c395b29fb4
9 changed files with 1104 additions and 164 deletions

View File

@@ -5,6 +5,9 @@
#include "wpinet/WebSocket.h"
#include <random>
#include <span>
#include <string>
#include <string_view>
#include <fmt/format.h>
#include <wpi/Base64.h>
@@ -14,37 +17,70 @@
#include <wpi/raw_ostream.h>
#include <wpi/sha1.h>
#include "WebSocketDebug.h"
#include "WebSocketSerializer.h"
#include "wpinet/HttpParser.h"
#include "wpinet/raw_uv_ostream.h"
#include "wpinet/uv/Stream.h"
using namespace wpi;
namespace {
class WebSocketWriteReq : public uv::WriteReq {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
static std::string DebugBinary(std::span<const uint8_t> val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
std::string str;
wpi::raw_string_ostream stros{str};
for (auto ch : val) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
return str;
#else
return "";
#endif
}
static inline std::string_view DebugText(std::string_view val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
return val;
#else
return "";
#endif
}
#endif // WPINET_WEBSOCKET_VERBOSE_DEBUG
class WebSocket::WriteReq : public uv::WriteReq,
public detail::WebSocketWriteReqBase {
public:
explicit WebSocketWriteReq(
explicit WriteReq(
std::weak_ptr<WebSocket> ws,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback)
: m_callback{std::move(callback)} {
: m_ws{std::move(ws)}, m_callback{std::move(callback)} {
finish.connect([this](uv::Error err) {
for (auto&& buf : m_internalBufs) {
buf.Deallocate();
int result = Continue(GetStream(), shared_from_this());
WS_DEBUG("Continue() -> {}\n", result);
if (result <= 0) {
m_frames.ReleaseBufs();
auto ws = m_ws.lock();
if (ws) {
ws->m_writeInProgress = false;
}
m_callback(m_userBufs, err);
if (result == 0 && m_cont && ws) {
WS_DEBUG("Continuing with another write\n");
ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont);
}
}
m_callback(m_userBufs, err);
});
}
std::weak_ptr<WebSocket> m_ws;
std::function<void(std::span<uv::Buffer>, uv::Error)> m_callback;
SmallVector<uv::Buffer, 4> m_internalBufs;
SmallVector<uv::Buffer, 4> m_userBufs;
// for server
size_t m_internalBufPos = 0;
std::shared_ptr<WriteReq> m_cont;
};
} // namespace
static constexpr uint8_t kFlagMasking = 0x80;
static constexpr uint8_t kLenMask = 0x7f;
static constexpr size_t kWriteAllocSize = 4096;
class WebSocket::ClientHandshakeData {
public:
@@ -157,7 +193,7 @@ void WebSocket::StartClient(std::string_view uri, std::string_view host,
// Build client request
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
os << "GET " << uri << " HTTP/1.1\r\n";
os << "Host: " << host << "\r\n";
@@ -276,7 +312,7 @@ void WebSocket::StartServer(std::string_view key, std::string_view version,
// Build server response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
// Handle unsupported version
if (version != "13") {
@@ -324,13 +360,13 @@ void WebSocket::StartServer(std::string_view key, std::string_view version,
void WebSocket::SendClose(uint16_t code, std::string_view reason) {
SmallVector<uv::Buffer, 4> bufs;
if (code != 1005) {
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
static_cast<uint8_t>(code & 0xff)};
os << std::span{codeMsb};
os << reason;
}
Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
SendControl(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
buf.Deallocate();
}
@@ -349,6 +385,17 @@ void WebSocket::Shutdown() {
m_stream.Shutdown([this] { m_stream.Close(); });
}
static inline void Unmask(std::span<uint8_t> data,
std::span<const uint8_t, 4> key) {
int n = 0;
for (uint8_t& ch : data) {
ch ^= key[n++];
if (n >= 4) {
n = 0;
}
}
}
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
// ignore incoming data if we're failed or closed
if (m_state == FAILED || m_state == CLOSED) {
@@ -449,32 +496,37 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
// limit maximum size
if ((m_payload.size() + m_frameSize) > m_maxMessageSize) {
bool control = (m_header[0] & kFlagControl) != 0;
if (((control ? m_controlPayload.size() : m_payload.size()) +
m_frameSize) > m_maxMessageSize) {
return Fail(1009, "message too large");
}
}
}
if (m_frameSize != UINT64_MAX) {
size_t need = m_frameStart + m_frameSize - m_payload.size();
bool control = (m_header[0] & kFlagControl) != 0;
size_t need;
if (control) {
need = m_frameSize - m_controlPayload.size();
} else {
need = m_frameStart + m_frameSize - m_payload.size();
}
size_t toCopy = (std::min)(need, data.size());
m_payload.append(data.data(), data.data() + toCopy);
if (control) {
m_controlPayload.append(data.data(), data.data() + toCopy);
} else {
m_payload.append(data.data(), data.data() + toCopy);
}
data.remove_prefix(toCopy);
need -= toCopy;
if (need == 0) {
// We have a complete frame
// If the message had masking, unmask it
if ((m_header[1] & kFlagMasking) != 0) {
uint8_t key[4] = {
m_header[m_headerSize - 4], m_header[m_headerSize - 3],
m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
int n = 0;
for (uint8_t& ch : std::span{m_payload}.subspan(m_frameStart)) {
ch ^= key[n++];
if (n >= 4) {
n = 0;
}
}
Unmask(control ? std::span{m_controlPayload}
: std::span{m_payload}.subspan(m_frameStart),
std::span<const uint8_t, 4>{&m_header[m_headerSize - 4], 4});
}
// Handle message
@@ -482,17 +534,23 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
uint8_t opcode = m_header[0] & kOpMask;
switch (opcode) {
case kOpCont:
WS_DEBUG("WS Fragment {} [{}]\n", m_payload.size(),
DebugBinary(m_payload));
switch (m_fragmentOpcode) {
case kOpText:
if (!m_combineFragments || fin) {
text(std::string_view{reinterpret_cast<char*>(
m_payload.data()),
m_payload.size()},
fin);
std::string_view content{
reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
WS_DEBUG("WS RecvText(Defrag) {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
break;
case kOpBinary:
if (!m_combineFragments || fin) {
WS_DEBUG("WS RecvBinary(Defrag) {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
break;
@@ -504,42 +562,38 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
m_fragmentOpcode = 0;
}
break;
case kOpText:
case kOpText: {
std::string_view content{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvText {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugText(content));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
fmt::print(
"WS RecvText({})\n",
std::string_view{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()});
#endif
text(std::string_view{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()},
fin);
WS_DEBUG("WS RecvText {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
if (!fin) {
WS_DEBUG("WS RecvText {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
}
case kOpBinary:
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvBinary {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugBinary(m_payload));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
SmallString<128> str;
raw_svector_ostream stros{str};
for (auto ch : m_payload) {
stros << fmt::format("{:02x},",
static_cast<unsigned int>(ch) & 0xff);
}
fmt::print("WS RecvBinary({})\n", str.str());
#endif
WS_DEBUG("WS RecvBinary {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
if (!fin) {
WS_DEBUG("WS RecvBinary {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
@@ -549,14 +603,15 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
if (!fin) {
code = 1002;
reason = "cannot fragment control frames";
} else if (m_payload.size() < 2) {
} else if (m_controlPayload.size() < 2) {
code = 1005;
} else {
code = (static_cast<uint16_t>(m_payload[0]) << 8) |
static_cast<uint16_t>(m_payload[1]);
reason = drop_front(
{reinterpret_cast<char*>(m_payload.data()), m_payload.size()},
2);
code = (static_cast<uint16_t>(m_controlPayload[0]) << 8) |
static_cast<uint16_t>(m_controlPayload[1]);
reason =
drop_front({reinterpret_cast<char*>(m_controlPayload.data()),
m_controlPayload.size()},
2);
}
// Echo the close if we didn't previously send it
if (m_state != CLOSING) {
@@ -577,8 +632,8 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
if (m_state == OPEN) {
SmallVector<uv::Buffer, 4> bufs;
{
raw_uv_ostream os{bufs, 4096};
os << m_payload;
raw_uv_ostream os{bufs, kWriteAllocSize};
os << m_controlPayload;
}
SendPong(bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
@@ -586,13 +641,17 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
});
}
ping(m_payload);
WS_DEBUG("WS RecvPing() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
ping(m_controlPayload);
break;
case kOpPong:
if (!fin) {
return Fail(1002, "cannot fragment control frames");
}
pong(m_payload);
WS_DEBUG("WS RecvPong() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
pong(m_controlPayload);
break;
default:
return Fail(1002, "invalid message opcode");
@@ -602,7 +661,11 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
m_header.clear();
m_headerSize = 0;
if (!m_combineFragments || fin) {
m_payload.clear();
if (control) {
m_controlPayload.clear();
} else {
m_payload.clear();
}
}
m_frameStart = m_payload.size();
m_frameSize = UINT64_MAX;
@@ -611,127 +674,132 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
}
static void WriteFrame(WebSocketWriteReq& req,
SmallVectorImpl<uv::Buffer>& bufs, bool server,
uint8_t opcode, std::span<const uv::Buffer> data) {
static void VerboseDebug(const WebSocket::Frame& frame) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
if ((opcode & 0x7f) == 0x01) {
if ((frame.opcode & 0x7f) == 0x01) {
SmallString<128> str;
for (auto&& d : data) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
for (auto&& d : frame.data) {
str.append(std::string_view(d.base, d.len));
}
#endif
fmt::print("WS SendText({})\n", str.str());
} else if ((opcode & 0x7f) == 0x02) {
} else if ((frame.opcode & 0x7f) == 0x02) {
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : data) {
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
fmt::print("WS SendBinary({})\n", str.str());
}
#endif
uint8_t header[10];
uint8_t* pHeader = header;
// opcode (includes FIN bit)
*pHeader++ = opcode;
// payload length
uint64_t size = 0;
for (auto&& buf : data) {
size += buf.len;
}
if (size < 126) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | size;
} else if (size <= 0xffff) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 126;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
fmt::print("WS SendBinary({})\n", str.str());
} else {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 127;
*pHeader++ = (size >> 56) & 0xff;
*pHeader++ = (size >> 48) & 0xff;
*pHeader++ = (size >> 40) & 0xff;
*pHeader++ = (size >> 32) & 0xff;
*pHeader++ = (size >> 24) & 0xff;
*pHeader++ = (size >> 16) & 0xff;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
}
size_t headerSize = pHeader - header;
// clients need to mask the input data
if (!server) {
SmallVector<uv::Buffer, 4> internalBufs;
raw_uv_ostream os{internalBufs, 4096};
os << std::span<const uint8_t>{header, headerSize};
// generate masking key
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
uint8_t key[4];
for (uint8_t& v : key) {
v = dist(gen);
}
os << std::span<const uint8_t>{key, 4};
// copy and mask data
int n = 0;
for (auto&& buf : data) {
for (auto&& ch : buf.data()) {
os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
if (n >= 4) {
n = 0;
}
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
bufs.append(internalBufs.begin(), internalBufs.end());
req.m_internalBufs.append(internalBufs.begin(), internalBufs.end());
// don't send the user bufs as we copied their data
} else {
// manage m_internalBufs to efficiently store header
if (req.m_internalBufs.empty() ||
(req.m_internalBufPos + headerSize) > 4096) {
req.m_internalBufs.emplace_back(uv::Buffer::Allocate(4096));
req.m_internalBufPos = 0;
}
char* internalBuf =
req.m_internalBufs.back().data().data() + req.m_internalBufPos;
std::memcpy(internalBuf, header, headerSize);
bufs.emplace_back(internalBuf, headerSize);
req.m_internalBufPos += headerSize;
// servers can just send the buffers directly without masking
bufs.append(data.begin(), data.end());
#endif
fmt::print("WS SendOp({}, {})\n", frame.opcode, str.str());
}
req.m_userBufs.append(data.begin(), data.end());
#endif
}
void WebSocket::SendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
WS_DEBUG("SendFrames({})\n", frames.size());
if (m_state != OPEN) {
int err;
if (m_state == CONNECTING) {
err = UV_EAGAIN;
} else {
err = UV_ESHUTDOWN;
}
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
bufs.append(frame.data.begin(), frame.data.end());
}
callback(bufs, uv::Error{err});
SendError(frames, callback);
return;
}
auto req = std::make_shared<WebSocketWriteReq>(std::move(callback));
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
for (auto&& frame : frames) {
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
}
req->m_continueBufPos = req->m_frames.m_bufs.size();
if (m_writeInProgress) {
if (auto curReq = m_writeReq.lock()) {
// if write currently in progress, process as a continuation of that
m_writeReq = req;
curReq->m_cont = std::move(req);
return;
}
}
WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size());
m_stream.Write(req->m_frames.m_bufs, req);
}
std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError(frames, callback);
return {};
}
// If something else is still in flight, don't send anything
if (m_writeInProgress) {
return frames;
}
return detail::TrySendFrames(
m_server, m_stream, frames,
[this](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(cb));
m_writeInProgress = true;
m_writeReq = req;
return req;
},
std::move(callback));
}
void WebSocket::SendControl(
uint8_t opcode, std::span<const uv::Buffer> data,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
Frame frame{opcode, data};
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError({{frame}}, callback);
return;
}
// Control messages always send immediately with their own request, whether
// or not something else is in flight. The protocol allows control messages
// interspersed with fragmented frames, and we otherwise make sure that any
// pending Write() is already at a frame boundary.
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
req->m_continueBufPos = req->m_frames.m_bufs.size();
m_stream.Write(req->m_frames.m_bufs, req);
}
void WebSocket::SendError(
std::span<const Frame> frames,
const std::function<void(std::span<uv::Buffer>, uv::Error)>& callback) {
int err;
if (m_state == WebSocket::CONNECTING) {
err = UV_EAGAIN;
} else {
err = UV_ESHUTDOWN;
}
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
WriteFrame(*req, bufs, m_server, frame.opcode, frame.data);
bufs.append(frame.data.begin(), frame.data.end());
}
m_stream.Write(bufs, req);
callback(bufs, uv::Error{err});
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <fmt/format.h>
// #define WPINET_WEBSOCKET_VERBOSE_DEBUG
// #define WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
#ifdef __clang__
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#endif
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
#define WS_DEBUG(format, ...) \
::fmt::print(FMT_STRING(format) __VA_OPT__(, ) __VA_ARGS__)
#else
#define WS_DEBUG(fmt, ...)
#endif

View File

@@ -0,0 +1,108 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include "WebSocketSerializer.h"
#include <random>
using namespace wpi::detail;
static constexpr uint8_t kFlagMasking = 0x80;
static constexpr size_t kWriteAllocSize = 4096;
static std::span<uint8_t> BuildHeader(std::span<uint8_t, 10> header,
bool server,
const wpi::WebSocket::Frame& frame) {
uint8_t* pHeader = header.data();
// opcode (includes FIN bit)
*pHeader++ = frame.opcode;
// payload length
uint64_t size = 0;
for (auto&& buf : frame.data) {
size += buf.len;
}
if (size < 126) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | size;
} else if (size <= 0xffff) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 126;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
} else {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 127;
*pHeader++ = (size >> 56) & 0xff;
*pHeader++ = (size >> 48) & 0xff;
*pHeader++ = (size >> 40) & 0xff;
*pHeader++ = (size >> 32) & 0xff;
*pHeader++ = (size >> 24) & 0xff;
*pHeader++ = (size >> 16) & 0xff;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
}
return header.subspan(0, pHeader - header.data());
}
size_t SerializedFrames::AddClientFrame(const WebSocket::Frame& frame) {
uint8_t headerBuf[10];
auto header = BuildHeader(headerBuf, false, frame);
// allocate a buffer per frame
size_t size = header.size() + 4;
for (auto&& buf : frame.data) {
size += buf.len;
}
m_allocBufs.emplace_back(uv::Buffer::Allocate(size));
m_bufs.emplace_back(m_allocBufs.back());
char* internalBuf = m_allocBufs.back().data().data();
std::memcpy(internalBuf, header.data(), header.size());
internalBuf += header.size();
// generate masking key
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
uint8_t key[4];
for (uint8_t& v : key) {
v = dist(gen);
}
std::memcpy(internalBuf, key, 4);
internalBuf += 4;
// copy and mask data
int n = 0;
for (auto&& buf : frame.data) {
for (auto&& ch : buf.data()) {
*internalBuf++ = static_cast<uint8_t>(ch) ^ key[n++];
if (n >= 4) {
n = 0;
}
}
}
return size;
}
size_t SerializedFrames::AddServerFrame(const WebSocket::Frame& frame) {
uint8_t headerBuf[10];
auto header = BuildHeader(headerBuf, true, frame);
// manage allocBufs to efficiently store header
if (m_allocBufs.empty() ||
(m_allocBufPos + header.size()) > kWriteAllocSize) {
m_allocBufs.emplace_back(uv::Buffer::Allocate(kWriteAllocSize));
m_allocBufPos = 0;
}
char* internalBuf = m_allocBufs.back().data().data() + m_allocBufPos;
std::memcpy(internalBuf, header.data(), header.size());
m_bufs.emplace_back(internalBuf, header.size());
m_allocBufPos += header.size();
// servers can just send the buffers directly without masking
m_bufs.append(frame.data.begin(), frame.data.end());
size_t sent = header.size();
for (auto&& buf : frame.data) {
sent += buf.len;
}
return sent;
}

View File

@@ -0,0 +1,286 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <functional>
#include <memory>
#include <utility>
#include <wpi/SmallVector.h>
#include <wpi/SpanExtras.h>
#include "WebSocketDebug.h"
#include "wpinet/WebSocket.h"
#include "wpinet/uv/Buffer.h"
namespace wpi::detail {
class SerializedFrames {
public:
SerializedFrames() = default;
SerializedFrames(const SerializedFrames&) = delete;
SerializedFrames& operator=(const SerializedFrames&) = delete;
~SerializedFrames() { ReleaseBufs(); }
size_t AddFrame(const WebSocket::Frame& frame, bool server) {
if (server) {
return AddServerFrame(frame);
} else {
return AddClientFrame(frame);
}
}
size_t AddClientFrame(const WebSocket::Frame& frame);
size_t AddServerFrame(const WebSocket::Frame& frame);
void ReleaseBufs() {
for (auto&& buf : m_allocBufs) {
buf.Deallocate();
}
m_allocBufs.clear();
}
SmallVector<uv::Buffer, 4> m_allocBufs;
SmallVector<uv::Buffer, 4> m_bufs;
size_t m_allocBufPos = 0;
};
class WebSocketWriteReqBase {
public:
template <typename Stream, typename Req>
int Continue(Stream& stream, std::shared_ptr<Req> req);
SmallVector<uv::Buffer, 4> m_userBufs;
SerializedFrames m_frames;
SmallVector<int, 0> m_continueFrameOffs;
size_t m_continueBufPos = 0;
size_t m_continueFramePos = 0;
};
template <typename Stream, typename Req>
int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr<Req> req) {
if (m_continueBufPos >= m_frames.m_bufs.size()) {
return 0; // nothing more to send
}
// try writing everything remaining
std::span bufs = std::span{m_frames.m_bufs}.subspan(m_continueBufPos);
int numBytes = 0;
for (auto&& buf : bufs) {
numBytes += buf.len;
}
int sentBytes = stream.TryWrite(bufs);
WS_DEBUG("TryWrite({}) -> {} (expected {})\n", bufs.size(), sentBytes,
numBytes);
if (sentBytes < 0) {
return sentBytes; // error
}
if (sentBytes == numBytes) {
m_continueBufPos = m_frames.m_bufs.size();
return 0; // nothing more to send
}
// we didn't send everything; deal with the leftovers
// figure out what the last (partially) frame sent actually was
auto offIt = m_continueFrameOffs.begin() + m_continueFramePos;
auto offEnd = m_continueFrameOffs.end();
while (offIt != offEnd && *offIt < sentBytes) {
++offIt;
}
assert(offIt != offEnd);
// build a list of buffers to send as a normal write:
SmallVector<uv::Buffer, 4> writeBufs;
auto bufIt = bufs.begin();
auto bufEnd = bufs.end();
// start with the remaining portion of the last buffer actually sent
int pos = 0;
while (bufIt != bufEnd && pos < sentBytes) {
pos += (bufIt++)->len;
}
if (bufIt != bufs.begin() && pos != sentBytes) {
writeBufs.emplace_back(
wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes));
}
// continue through the last buffer of the last partial frame
while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
pos += bufIt->len;
writeBufs.emplace_back(*bufIt++);
}
if (offIt != offEnd) {
++offIt;
}
// if writeBufs is still empty, write all of the next frame
if (writeBufs.empty()) {
while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
pos += bufIt->len;
writeBufs.emplace_back(*bufIt++);
}
if (offIt != offEnd) {
++offIt;
}
}
m_continueFramePos = offIt - m_continueFrameOffs.begin();
m_continueBufPos = &(*bufIt) - &(*m_frames.m_bufs.begin());
if (writeBufs.empty()) {
WS_DEBUG("Write Done\n");
return 0;
}
WS_DEBUG("Write({})\n", writeBufs.size());
stream.Write(writeBufs, req);
return 1;
}
template <typename MakeReq, typename Stream>
std::span<const WebSocket::Frame> TrySendFrames(
bool server, Stream& stream, std::span<const WebSocket::Frame> frames,
MakeReq&& makeReq,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
WS_DEBUG("TrySendFrames({})\n", frames.size());
auto frameIt = frames.begin();
auto frameEnd = frames.end();
while (frameIt != frameEnd) {
auto frameStart = frameIt;
// build buffers to send
SerializedFrames sendFrames;
SmallVector<int, 32> frameOffs;
int numBytes = 0;
while (frameIt != frameEnd) {
frameOffs.emplace_back(numBytes);
numBytes += sendFrames.AddFrame(*frameIt++, server);
if ((server && (numBytes >= 65536 || frameOffs.size() > 32)) ||
(!server && numBytes >= 8192)) {
// don't waste too much memory or effort on header generation or masking
break;
}
}
// try to send
int sentBytes = stream.TryWrite(sendFrames.m_bufs);
WS_DEBUG("TryWrite({}) -> {} (expected {})\n", sendFrames.m_bufs.size(),
sentBytes, numBytes);
if (sentBytes == 0) {
// we haven't started a frame yet; clean up any bufs that have actually
// sent, and return unsent frames
SmallVector<uv::Buffer, 4> bufs;
for (auto it = frames.begin(); it != frameStart; ++it) {
bufs.append(it->data.begin(), it->data.end());
}
callback(bufs, {});
return {&*frameStart, &*frameEnd};
} else if (sentBytes < 0) {
// error
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
bufs.append(frame.data.begin(), frame.data.end());
}
callback(bufs, uv::Error{sentBytes});
return frames;
} else if (sentBytes != numBytes) {
// we didn't send everything; deal with the leftovers
// figure out what the last (partially) frame sent actually was
auto offIt = frameOffs.begin();
auto offEnd = frameOffs.end();
bool isFin = true;
while (offIt != offEnd && *offIt < sentBytes) {
++offIt;
isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0;
++frameStart;
}
if (offIt != offEnd && *offIt == sentBytes && isFin) {
// we finished at a normal FIN frame boundary; no need for a Write()
SmallVector<uv::Buffer, 4> bufs;
for (auto it = frames.begin(); it != frameStart; ++it) {
bufs.append(it->data.begin(), it->data.end());
}
callback(bufs, {});
return {&*frameStart, &*frameEnd};
}
// build a list of buffers to send as a normal write:
SmallVector<uv::Buffer, 4> writeBufs;
auto bufIt = sendFrames.m_bufs.begin();
auto bufEnd = sendFrames.m_bufs.end();
// start with the remaining portion of the last buffer actually sent
int pos = 0;
while (bufIt != bufEnd && pos < sentBytes) {
pos += (bufIt++)->len;
}
if (bufIt != sendFrames.m_bufs.begin() && pos != sentBytes) {
writeBufs.emplace_back(
wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes));
}
// continue through the last buffer of the last partial frame
while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
pos += bufIt->len;
writeBufs.emplace_back(*bufIt++);
}
if (offIt != offEnd) {
++offIt;
}
// move allocated buffers into request
auto req = makeReq(std::move(callback));
req->m_frames.m_allocBufs = std::move(sendFrames.m_allocBufs);
req->m_frames.m_allocBufPos = sendFrames.m_allocBufPos;
// if partial frame was non-FIN, put any additional non-FIN frames into
// continuation (so the caller isn't responsible for doing this)
size_t continuePos = 0;
while (frameStart != frameEnd && !isFin) {
req->m_continueFrameOffs.emplace_back(continuePos);
if (offIt != offEnd) {
// we already generated the wire buffers for this frame, use them
while (pos < *offIt && bufIt != bufEnd) {
pos += bufIt->len;
continuePos += bufIt->len;
req->m_frames.m_bufs.emplace_back(*bufIt++);
}
++offIt;
} else {
// WS_DEBUG("generating frame for continuation {} {}\n",
// frameStart->opcode, frameStart->data.size());
// need to generate and add this frame
continuePos += req->m_frames.AddFrame(*frameStart, server);
}
isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0;
++frameStart;
}
// only the non-returned user buffers are added to the request
for (auto it = frames.begin(); it != frameStart; ++it) {
req->m_userBufs.append(it->data.begin(), it->data.end());
}
WS_DEBUG("Write({})\n", writeBufs.size());
stream.Write(writeBufs, req);
return {&*frameStart, &*frameEnd};
}
}
// nothing left to send
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
bufs.append(frame.data.begin(), frame.data.end());
}
callback(bufs, {});
return {};
}
} // namespace wpi::detail