[wpinet] Add WebSocket::TrySendFrames() (#5607)

This takes advantage of the underlying byte-level TryWrite() functionality to minimize blocking behavior and enable higher layers to do things smartly when the network blocks.

Also:
- Fix handling of control packets in middle of fragmented
- Clean up debugging features
This commit is contained in:
Peter Johnson
2023-09-18 19:49:54 -07:00
committed by GitHub
parent c4643ba047
commit c395b29fb4
9 changed files with 1104 additions and 164 deletions

View File

@@ -5,6 +5,9 @@
#include "wpinet/WebSocket.h"
#include <random>
#include <span>
#include <string>
#include <string_view>
#include <fmt/format.h>
#include <wpi/Base64.h>
@@ -14,37 +17,70 @@
#include <wpi/raw_ostream.h>
#include <wpi/sha1.h>
#include "WebSocketDebug.h"
#include "WebSocketSerializer.h"
#include "wpinet/HttpParser.h"
#include "wpinet/raw_uv_ostream.h"
#include "wpinet/uv/Stream.h"
using namespace wpi;
namespace {
class WebSocketWriteReq : public uv::WriteReq {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
static std::string DebugBinary(std::span<const uint8_t> val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
std::string str;
wpi::raw_string_ostream stros{str};
for (auto ch : val) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
return str;
#else
return "";
#endif
}
static inline std::string_view DebugText(std::string_view val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
return val;
#else
return "";
#endif
}
#endif // WPINET_WEBSOCKET_VERBOSE_DEBUG
class WebSocket::WriteReq : public uv::WriteReq,
public detail::WebSocketWriteReqBase {
public:
explicit WebSocketWriteReq(
explicit WriteReq(
std::weak_ptr<WebSocket> ws,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback)
: m_callback{std::move(callback)} {
: m_ws{std::move(ws)}, m_callback{std::move(callback)} {
finish.connect([this](uv::Error err) {
for (auto&& buf : m_internalBufs) {
buf.Deallocate();
int result = Continue(GetStream(), shared_from_this());
WS_DEBUG("Continue() -> {}\n", result);
if (result <= 0) {
m_frames.ReleaseBufs();
auto ws = m_ws.lock();
if (ws) {
ws->m_writeInProgress = false;
}
m_callback(m_userBufs, err);
if (result == 0 && m_cont && ws) {
WS_DEBUG("Continuing with another write\n");
ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont);
}
}
m_callback(m_userBufs, err);
});
}
std::weak_ptr<WebSocket> m_ws;
std::function<void(std::span<uv::Buffer>, uv::Error)> m_callback;
SmallVector<uv::Buffer, 4> m_internalBufs;
SmallVector<uv::Buffer, 4> m_userBufs;
// for server
size_t m_internalBufPos = 0;
std::shared_ptr<WriteReq> m_cont;
};
} // namespace
static constexpr uint8_t kFlagMasking = 0x80;
static constexpr uint8_t kLenMask = 0x7f;
static constexpr size_t kWriteAllocSize = 4096;
class WebSocket::ClientHandshakeData {
public:
@@ -157,7 +193,7 @@ void WebSocket::StartClient(std::string_view uri, std::string_view host,
// Build client request
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
os << "GET " << uri << " HTTP/1.1\r\n";
os << "Host: " << host << "\r\n";
@@ -276,7 +312,7 @@ void WebSocket::StartServer(std::string_view key, std::string_view version,
// Build server response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
// Handle unsupported version
if (version != "13") {
@@ -324,13 +360,13 @@ void WebSocket::StartServer(std::string_view key, std::string_view version,
void WebSocket::SendClose(uint16_t code, std::string_view reason) {
SmallVector<uv::Buffer, 4> bufs;
if (code != 1005) {
raw_uv_ostream os{bufs, 4096};
raw_uv_ostream os{bufs, kWriteAllocSize};
const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
static_cast<uint8_t>(code & 0xff)};
os << std::span{codeMsb};
os << reason;
}
Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
SendControl(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
buf.Deallocate();
}
@@ -349,6 +385,17 @@ void WebSocket::Shutdown() {
m_stream.Shutdown([this] { m_stream.Close(); });
}
static inline void Unmask(std::span<uint8_t> data,
std::span<const uint8_t, 4> key) {
int n = 0;
for (uint8_t& ch : data) {
ch ^= key[n++];
if (n >= 4) {
n = 0;
}
}
}
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
// ignore incoming data if we're failed or closed
if (m_state == FAILED || m_state == CLOSED) {
@@ -449,32 +496,37 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
// limit maximum size
if ((m_payload.size() + m_frameSize) > m_maxMessageSize) {
bool control = (m_header[0] & kFlagControl) != 0;
if (((control ? m_controlPayload.size() : m_payload.size()) +
m_frameSize) > m_maxMessageSize) {
return Fail(1009, "message too large");
}
}
}
if (m_frameSize != UINT64_MAX) {
size_t need = m_frameStart + m_frameSize - m_payload.size();
bool control = (m_header[0] & kFlagControl) != 0;
size_t need;
if (control) {
need = m_frameSize - m_controlPayload.size();
} else {
need = m_frameStart + m_frameSize - m_payload.size();
}
size_t toCopy = (std::min)(need, data.size());
m_payload.append(data.data(), data.data() + toCopy);
if (control) {
m_controlPayload.append(data.data(), data.data() + toCopy);
} else {
m_payload.append(data.data(), data.data() + toCopy);
}
data.remove_prefix(toCopy);
need -= toCopy;
if (need == 0) {
// We have a complete frame
// If the message had masking, unmask it
if ((m_header[1] & kFlagMasking) != 0) {
uint8_t key[4] = {
m_header[m_headerSize - 4], m_header[m_headerSize - 3],
m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
int n = 0;
for (uint8_t& ch : std::span{m_payload}.subspan(m_frameStart)) {
ch ^= key[n++];
if (n >= 4) {
n = 0;
}
}
Unmask(control ? std::span{m_controlPayload}
: std::span{m_payload}.subspan(m_frameStart),
std::span<const uint8_t, 4>{&m_header[m_headerSize - 4], 4});
}
// Handle message
@@ -482,17 +534,23 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
uint8_t opcode = m_header[0] & kOpMask;
switch (opcode) {
case kOpCont:
WS_DEBUG("WS Fragment {} [{}]\n", m_payload.size(),
DebugBinary(m_payload));
switch (m_fragmentOpcode) {
case kOpText:
if (!m_combineFragments || fin) {
text(std::string_view{reinterpret_cast<char*>(
m_payload.data()),
m_payload.size()},
fin);
std::string_view content{
reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
WS_DEBUG("WS RecvText(Defrag) {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
break;
case kOpBinary:
if (!m_combineFragments || fin) {
WS_DEBUG("WS RecvBinary(Defrag) {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
break;
@@ -504,42 +562,38 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
m_fragmentOpcode = 0;
}
break;
case kOpText:
case kOpText: {
std::string_view content{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvText {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugText(content));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
fmt::print(
"WS RecvText({})\n",
std::string_view{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()});
#endif
text(std::string_view{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()},
fin);
WS_DEBUG("WS RecvText {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
if (!fin) {
WS_DEBUG("WS RecvText {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
}
case kOpBinary:
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvBinary {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugBinary(m_payload));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
SmallString<128> str;
raw_svector_ostream stros{str};
for (auto ch : m_payload) {
stros << fmt::format("{:02x},",
static_cast<unsigned int>(ch) & 0xff);
}
fmt::print("WS RecvBinary({})\n", str.str());
#endif
WS_DEBUG("WS RecvBinary {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
if (!fin) {
WS_DEBUG("WS RecvBinary {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
@@ -549,14 +603,15 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
if (!fin) {
code = 1002;
reason = "cannot fragment control frames";
} else if (m_payload.size() < 2) {
} else if (m_controlPayload.size() < 2) {
code = 1005;
} else {
code = (static_cast<uint16_t>(m_payload[0]) << 8) |
static_cast<uint16_t>(m_payload[1]);
reason = drop_front(
{reinterpret_cast<char*>(m_payload.data()), m_payload.size()},
2);
code = (static_cast<uint16_t>(m_controlPayload[0]) << 8) |
static_cast<uint16_t>(m_controlPayload[1]);
reason =
drop_front({reinterpret_cast<char*>(m_controlPayload.data()),
m_controlPayload.size()},
2);
}
// Echo the close if we didn't previously send it
if (m_state != CLOSING) {
@@ -577,8 +632,8 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
if (m_state == OPEN) {
SmallVector<uv::Buffer, 4> bufs;
{
raw_uv_ostream os{bufs, 4096};
os << m_payload;
raw_uv_ostream os{bufs, kWriteAllocSize};
os << m_controlPayload;
}
SendPong(bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
@@ -586,13 +641,17 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
});
}
ping(m_payload);
WS_DEBUG("WS RecvPing() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
ping(m_controlPayload);
break;
case kOpPong:
if (!fin) {
return Fail(1002, "cannot fragment control frames");
}
pong(m_payload);
WS_DEBUG("WS RecvPong() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
pong(m_controlPayload);
break;
default:
return Fail(1002, "invalid message opcode");
@@ -602,7 +661,11 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
m_header.clear();
m_headerSize = 0;
if (!m_combineFragments || fin) {
m_payload.clear();
if (control) {
m_controlPayload.clear();
} else {
m_payload.clear();
}
}
m_frameStart = m_payload.size();
m_frameSize = UINT64_MAX;
@@ -611,127 +674,132 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
}
}
static void WriteFrame(WebSocketWriteReq& req,
SmallVectorImpl<uv::Buffer>& bufs, bool server,
uint8_t opcode, std::span<const uv::Buffer> data) {
static void VerboseDebug(const WebSocket::Frame& frame) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
if ((opcode & 0x7f) == 0x01) {
if ((frame.opcode & 0x7f) == 0x01) {
SmallString<128> str;
for (auto&& d : data) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
for (auto&& d : frame.data) {
str.append(std::string_view(d.base, d.len));
}
#endif
fmt::print("WS SendText({})\n", str.str());
} else if ((opcode & 0x7f) == 0x02) {
} else if ((frame.opcode & 0x7f) == 0x02) {
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : data) {
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
fmt::print("WS SendBinary({})\n", str.str());
}
#endif
uint8_t header[10];
uint8_t* pHeader = header;
// opcode (includes FIN bit)
*pHeader++ = opcode;
// payload length
uint64_t size = 0;
for (auto&& buf : data) {
size += buf.len;
}
if (size < 126) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | size;
} else if (size <= 0xffff) {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 126;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
fmt::print("WS SendBinary({})\n", str.str());
} else {
*pHeader++ = (server ? 0x00 : kFlagMasking) | 127;
*pHeader++ = (size >> 56) & 0xff;
*pHeader++ = (size >> 48) & 0xff;
*pHeader++ = (size >> 40) & 0xff;
*pHeader++ = (size >> 32) & 0xff;
*pHeader++ = (size >> 24) & 0xff;
*pHeader++ = (size >> 16) & 0xff;
*pHeader++ = (size >> 8) & 0xff;
*pHeader++ = size & 0xff;
}
size_t headerSize = pHeader - header;
// clients need to mask the input data
if (!server) {
SmallVector<uv::Buffer, 4> internalBufs;
raw_uv_ostream os{internalBufs, 4096};
os << std::span<const uint8_t>{header, headerSize};
// generate masking key
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
uint8_t key[4];
for (uint8_t& v : key) {
v = dist(gen);
}
os << std::span<const uint8_t>{key, 4};
// copy and mask data
int n = 0;
for (auto&& buf : data) {
for (auto&& ch : buf.data()) {
os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
if (n >= 4) {
n = 0;
}
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
bufs.append(internalBufs.begin(), internalBufs.end());
req.m_internalBufs.append(internalBufs.begin(), internalBufs.end());
// don't send the user bufs as we copied their data
} else {
// manage m_internalBufs to efficiently store header
if (req.m_internalBufs.empty() ||
(req.m_internalBufPos + headerSize) > 4096) {
req.m_internalBufs.emplace_back(uv::Buffer::Allocate(4096));
req.m_internalBufPos = 0;
}
char* internalBuf =
req.m_internalBufs.back().data().data() + req.m_internalBufPos;
std::memcpy(internalBuf, header, headerSize);
bufs.emplace_back(internalBuf, headerSize);
req.m_internalBufPos += headerSize;
// servers can just send the buffers directly without masking
bufs.append(data.begin(), data.end());
#endif
fmt::print("WS SendOp({}, {})\n", frame.opcode, str.str());
}
req.m_userBufs.append(data.begin(), data.end());
#endif
}
void WebSocket::SendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
WS_DEBUG("SendFrames({})\n", frames.size());
if (m_state != OPEN) {
int err;
if (m_state == CONNECTING) {
err = UV_EAGAIN;
} else {
err = UV_ESHUTDOWN;
}
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
bufs.append(frame.data.begin(), frame.data.end());
}
callback(bufs, uv::Error{err});
SendError(frames, callback);
return;
}
auto req = std::make_shared<WebSocketWriteReq>(std::move(callback));
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
for (auto&& frame : frames) {
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
}
req->m_continueBufPos = req->m_frames.m_bufs.size();
if (m_writeInProgress) {
if (auto curReq = m_writeReq.lock()) {
// if write currently in progress, process as a continuation of that
m_writeReq = req;
curReq->m_cont = std::move(req);
return;
}
}
WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size());
m_stream.Write(req->m_frames.m_bufs, req);
}
std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError(frames, callback);
return {};
}
// If something else is still in flight, don't send anything
if (m_writeInProgress) {
return frames;
}
return detail::TrySendFrames(
m_server, m_stream, frames,
[this](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(cb));
m_writeInProgress = true;
m_writeReq = req;
return req;
},
std::move(callback));
}
void WebSocket::SendControl(
uint8_t opcode, std::span<const uv::Buffer> data,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
Frame frame{opcode, data};
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError({{frame}}, callback);
return;
}
// Control messages always send immediately with their own request, whether
// or not something else is in flight. The protocol allows control messages
// interspersed with fragmented frames, and we otherwise make sure that any
// pending Write() is already at a frame boundary.
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
req->m_continueBufPos = req->m_frames.m_bufs.size();
m_stream.Write(req->m_frames.m_bufs, req);
}
void WebSocket::SendError(
std::span<const Frame> frames,
const std::function<void(std::span<uv::Buffer>, uv::Error)>& callback) {
int err;
if (m_state == WebSocket::CONNECTING) {
err = UV_EAGAIN;
} else {
err = UV_ESHUTDOWN;
}
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
WriteFrame(*req, bufs, m_server, frame.opcode, frame.data);
bufs.append(frame.data.begin(), frame.data.end());
}
m_stream.Write(bufs, req);
callback(bufs, uv::Error{err});
}