Files
allwpilib/wpinet/src/main/native/cpp/WebSocket.cpp
Peter Johnson 3e7ba2cc6f [wpinet] WebSocket: Fix write behavior (#5841)
On Windows, TryWrite will always return 0 if there is a Write in progress. The previous behavior for SendFrames and SendControl just used a normal Write, which caused issues with code that combined these with TrySendFrames. Instead, have SendFrames and SendControl also use TryWrite under the hood if possible, and create write requests if not. The implementation preserves the priority of SendControl against an existing write request with multiple frames.
2023-10-29 16:48:25 -07:00

852 lines
27 KiB
C++

// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include "wpinet/WebSocket.h"
#include <random>
#include <span>
#include <string>
#include <string_view>
#include <fmt/format.h>
#include <wpi/Base64.h>
#include <wpi/SmallString.h>
#include <wpi/SmallVector.h>
#include <wpi/StringExtras.h>
#include <wpi/raw_ostream.h>
#include <wpi/sha1.h>
#include "WebSocketDebug.h"
#include "WebSocketSerializer.h"
#include "wpinet/HttpParser.h"
#include "wpinet/raw_uv_ostream.h"
#include "wpinet/uv/Stream.h"
using namespace wpi;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
static std::string DebugBinary(std::span<const uint8_t> val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
std::string str;
wpi::raw_string_ostream stros{str};
for (auto ch : val) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
return str;
#else
return "";
#endif
}
static inline std::string_view DebugText(std::string_view val) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
return val;
#else
return "";
#endif
}
#endif // WPINET_WEBSOCKET_VERBOSE_DEBUG
class WebSocket::WriteReq : public uv::WriteReq,
public detail::WebSocketWriteReqBase {
public:
explicit WriteReq(
std::weak_ptr<WebSocket> ws,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback)
: m_ws{std::move(ws)}, m_callback{std::move(callback)} {
finish.connect([this](uv::Error err) { Send(err); });
}
void Send(uv::Error err) {
auto ws = m_ws.lock();
if (!ws || err) {
WS_DEBUG("no WS or error, calling callback\n");
m_frames.ReleaseBufs();
m_callback(m_userBufs, err);
return;
}
// Continue() is designed so this is *only* called on frame boundaries
if (m_controlCont) {
// We have a control frame; switch to it. We will come back here via
// the control frame's m_cont when it's done.
WS_DEBUG("Continuing with a control write\n");
auto controlCont = std::move(m_controlCont);
m_controlCont.reset();
return controlCont->Send({});
}
int result = Continue(ws->m_stream, shared_from_this());
WS_DEBUG("Continue() -> {}\n", result);
if (result <= 0) {
m_frames.ReleaseBufs();
m_callback(m_userBufs, uv::Error{result});
if (result == 0 && m_cont) {
WS_DEBUG("Continuing with another write\n");
ws->m_curWriteReq = m_cont;
return m_cont->Send({});
} else {
ws->m_writeInProgress = false;
ws->m_curWriteReq.reset();
ws->m_lastWriteReq.reset();
}
}
}
std::weak_ptr<WebSocket> m_ws;
std::function<void(std::span<uv::Buffer>, uv::Error)> m_callback;
std::shared_ptr<WriteReq> m_cont;
std::shared_ptr<WriteReq> m_controlCont;
};
static constexpr uint8_t kFlagMasking = 0x80;
static constexpr uint8_t kLenMask = 0x7f;
static constexpr size_t kWriteAllocSize = 4096;
class WebSocket::ClientHandshakeData {
public:
ClientHandshakeData() {
// key is a random nonce
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
char nonce[16]; // the nonce sent to the server
for (char& v : nonce) {
v = static_cast<char>(dist(gen));
}
raw_svector_ostream os(key);
Base64Encode(os, {nonce, 16});
}
~ClientHandshakeData() {
if (auto t = timer.lock()) {
t->Stop();
t->Close();
}
}
SmallString<64> key; // the key sent to the server
SmallVector<std::string, 2> protocols; // valid protocols
HttpParser parser{HttpParser::kResponse}; // server response parser
bool hasUpgrade = false;
bool hasConnection = false;
bool hasAccept = false;
bool hasProtocol = false;
std::weak_ptr<uv::Timer> timer;
};
static std::string_view AcceptHash(std::string_view key,
SmallVectorImpl<char>& buf) {
SHA1 hash;
hash.Update(key);
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
SmallString<64> hashBuf;
return Base64Encode(hash.RawFinal(hashBuf), buf);
}
WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
: m_stream{stream}, m_server{server} {
// Connect closed and error signals to ourselves
m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
m_stream.error.connect([this](uv::Error err) {
Terminate(1006, fmt::format("stream error: {}", err.name()));
});
// Start reading
m_stream.StopRead(); // we may have been reading
m_stream.StartRead();
m_stream.data.connect(
[this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
m_stream.end.connect(
[this]() { Terminate(1006, "remote end closed connection"); });
}
WebSocket::~WebSocket() = default;
std::shared_ptr<WebSocket> WebSocket::CreateClient(
uv::Stream& stream, std::string_view uri, std::string_view host,
std::span<const std::string_view> protocols, const ClientOptions& options) {
auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
stream.SetData(ws);
ws->StartClient(uri, host, protocols, options);
return ws;
}
std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
std::string_view key,
std::string_view version,
std::string_view protocol) {
auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
stream.SetData(ws);
ws->StartServer(key, version, protocol);
return ws;
}
void WebSocket::Close(uint16_t code, std::string_view reason) {
SendClose(code, reason);
if (m_state != FAILED && m_state != CLOSED) {
m_state = CLOSING;
}
}
void WebSocket::Fail(uint16_t code, std::string_view reason) {
if (m_state == FAILED || m_state == CLOSED) {
return;
}
SendClose(code, reason);
SetClosed(code, reason, true);
Shutdown();
}
void WebSocket::Terminate(uint16_t code, std::string_view reason) {
if (m_state == FAILED || m_state == CLOSED) {
return;
}
SetClosed(code, reason);
Shutdown();
}
void WebSocket::StartClient(std::string_view uri, std::string_view host,
std::span<const std::string_view> protocols,
const ClientOptions& options) {
// Create client handshake data
m_clientHandshake = std::make_unique<ClientHandshakeData>();
// Build client request
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, kWriteAllocSize};
os << "GET " << uri << " HTTP/1.1\r\n";
os << "Host: " << host << "\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
os << "Sec-WebSocket-Version: 13\r\n";
// protocols (if provided)
if (!protocols.empty()) {
os << "Sec-WebSocket-Protocol: ";
bool first = true;
for (auto protocol : protocols) {
if (!first) {
os << ", ";
} else {
first = false;
}
os << protocol;
// also save for later checking against server response
m_clientHandshake->protocols.emplace_back(protocol);
}
os << "\r\n";
}
// other headers
for (auto&& header : options.extraHeaders) {
os << header.first << ": " << header.second << "\r\n";
}
// finish headers
os << "\r\n";
// Send client request
m_stream.Write(bufs, [](auto bufs, uv::Error) {
for (auto& buf : bufs) {
buf.Deallocate();
}
});
// Set up client response handling
m_clientHandshake->parser.status.connect([this](std::string_view status) {
unsigned int code = m_clientHandshake->parser.GetStatusCode();
if (code != 101) {
Terminate(code, status);
}
});
m_clientHandshake->parser.header.connect(
[this](std::string_view name, std::string_view value) {
value = trim(value);
if (equals_lower(name, "upgrade")) {
if (!equals_lower(value, "websocket")) {
return Terminate(1002, "invalid upgrade response value");
}
m_clientHandshake->hasUpgrade = true;
} else if (equals_lower(name, "connection")) {
if (!equals_lower(value, "upgrade")) {
return Terminate(1002, "invalid connection response value");
}
m_clientHandshake->hasConnection = true;
} else if (equals_lower(name, "sec-websocket-accept")) {
// Check against expected response
SmallString<64> acceptBuf;
if (!equals(value, AcceptHash(m_clientHandshake->key, acceptBuf))) {
return Terminate(1002, "invalid accept key");
}
m_clientHandshake->hasAccept = true;
} else if (equals_lower(name, "sec-websocket-extensions")) {
// No extensions are supported
if (!value.empty()) {
return Terminate(1010, "unsupported extension");
}
} else if (equals_lower(name, "sec-websocket-protocol")) {
// Make sure it was one of the provided protocols
bool match = false;
for (auto&& protocol : m_clientHandshake->protocols) {
if (equals_lower(value, protocol)) {
match = true;
break;
}
}
if (!match) {
return Terminate(1003, "unsupported protocol");
}
m_clientHandshake->hasProtocol = true;
m_protocol = value;
}
});
m_clientHandshake->parser.headersComplete.connect([this](bool) {
if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
!m_clientHandshake->hasAccept ||
(!m_clientHandshake->hasProtocol &&
!m_clientHandshake->protocols.empty())) {
return Terminate(1002, "invalid response");
}
if (m_state == CONNECTING) {
m_state = OPEN;
open(m_protocol);
}
});
// Start handshake timer if a timeout was specified
if (options.handshakeTimeout != (uv::Timer::Time::max)()) {
if (auto timer = uv::Timer::Create(m_stream.GetLoopRef())) {
timer->timeout.connect(
[this]() { Terminate(1006, "connection timed out"); });
timer->Start(options.handshakeTimeout);
m_clientHandshake->timer = timer;
}
}
}
void WebSocket::StartServer(std::string_view key, std::string_view version,
std::string_view protocol) {
m_protocol = protocol;
// Build server response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, kWriteAllocSize};
// Handle unsupported version
if (version != "13") {
os << "HTTP/1.1 426 Upgrade Required\r\n";
os << "Upgrade: WebSocket\r\n";
os << "Sec-WebSocket-Version: 13\r\n\r\n";
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
for (auto& buf : bufs) {
buf.Deallocate();
}
// XXX: Should we support sending a new handshake on the same connection?
// XXX: "this->" is required by GCC 5.5 (bug)
this->Terminate(1003, "unsupported protocol version");
});
return;
}
os << "HTTP/1.1 101 Switching Protocols\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
// accept hash
SmallString<64> acceptBuf;
os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
if (!protocol.empty()) {
os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
}
// end headers
os << "\r\n";
// Send server response
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
for (auto& buf : bufs) {
buf.Deallocate();
}
if (m_state == CONNECTING) {
m_state = OPEN;
open(m_protocol);
}
});
}
void WebSocket::SendClose(uint16_t code, std::string_view reason) {
SmallVector<uv::Buffer, 4> bufs;
if (code != 1005) {
raw_uv_ostream os{bufs, kWriteAllocSize};
const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
static_cast<uint8_t>(code & 0xff)};
os << std::span{codeMsb};
os << reason;
}
SendControl(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
buf.Deallocate();
}
});
}
void WebSocket::SetClosed(uint16_t code, std::string_view reason, bool failed) {
if (m_state == FAILED || m_state == CLOSED) {
return;
}
m_state = failed ? FAILED : CLOSED;
closed(code, reason);
}
void WebSocket::Shutdown() {
m_stream.Shutdown([this] { m_stream.Close(); });
}
static inline void Unmask(std::span<uint8_t> data,
std::span<const uint8_t, 4> key) {
int n = 0;
for (uint8_t& ch : data) {
ch ^= key[n++];
if (n >= 4) {
n = 0;
}
}
}
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
// ignore incoming data if we're failed or closed
if (m_state == FAILED || m_state == CLOSED) {
return;
}
std::string_view data{buf.base, size};
// Handle connecting state (mainly on client)
if (m_state == CONNECTING) {
if (m_clientHandshake) {
data = m_clientHandshake->parser.Execute(data);
// check for parser failure
if (m_clientHandshake->parser.HasError()) {
return Terminate(1003, "invalid response");
}
if (m_state != OPEN) {
return; // not done with handshake yet
}
// we're done with the handshake, so release its memory
m_clientHandshake.reset();
// fall through to process additional data after handshake
} else {
return Terminate(1003, "got data on server before response");
}
}
// Message processing
while (!data.empty()) {
if (m_frameSize == UINT64_MAX) {
// Need at least two bytes to determine header length
if (m_header.size() < 2u) {
size_t toCopy = (std::min)(2u - m_header.size(), data.size());
m_header.append(data.data(), data.data() + toCopy);
data.remove_prefix(toCopy);
if (m_header.size() < 2u) {
return; // need more data
}
// Validate RSV bits are zero
if ((m_header[0] & 0x70) != 0) {
return Fail(1002, "nonzero RSV");
}
}
// Once we have first two bytes, we can calculate the header size
if (m_headerSize == 0) {
m_headerSize = 2;
uint8_t len = m_header[1] & kLenMask;
if (len == 126) {
m_headerSize += 2;
} else if (len == 127) {
m_headerSize += 8;
}
bool masking = (m_header[1] & kFlagMasking) != 0;
if (masking) {
m_headerSize += 4; // masking key
}
// On server side, incoming messages MUST be masked
// On client side, incoming messages MUST NOT be masked
if (m_server && !masking) {
return Fail(1002, "client data not masked");
}
if (!m_server && masking) {
return Fail(1002, "server data masked");
}
}
// Need to complete header to calculate message size
if (m_header.size() < m_headerSize) {
size_t toCopy = (std::min)(m_headerSize - m_header.size(), data.size());
m_header.append(data.data(), data.data() + toCopy);
data.remove_prefix(toCopy);
if (m_header.size() < m_headerSize) {
return; // need more data
}
}
if (m_header.size() >= m_headerSize) {
// get payload length
uint8_t len = m_header[1] & kLenMask;
if (len == 126) {
m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
static_cast<uint16_t>(m_header[3]);
} else if (len == 127) {
m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
(static_cast<uint64_t>(m_header[3]) << 48) |
(static_cast<uint64_t>(m_header[4]) << 40) |
(static_cast<uint64_t>(m_header[5]) << 32) |
(static_cast<uint64_t>(m_header[6]) << 24) |
(static_cast<uint64_t>(m_header[7]) << 16) |
(static_cast<uint64_t>(m_header[8]) << 8) |
static_cast<uint64_t>(m_header[9]);
} else {
m_frameSize = len;
}
// limit maximum size
bool control = (m_header[0] & kFlagControl) != 0;
if (((control ? m_controlPayload.size() : m_payload.size()) +
m_frameSize) > m_maxMessageSize) {
return Fail(1009, "message too large");
}
}
}
if (m_frameSize != UINT64_MAX) {
bool control = (m_header[0] & kFlagControl) != 0;
size_t need;
if (control) {
need = m_frameSize - m_controlPayload.size();
} else {
need = m_frameStart + m_frameSize - m_payload.size();
}
size_t toCopy = (std::min)(need, data.size());
if (control) {
m_controlPayload.append(data.data(), data.data() + toCopy);
} else {
m_payload.append(data.data(), data.data() + toCopy);
}
data.remove_prefix(toCopy);
need -= toCopy;
if (need == 0) {
// We have a complete frame
// If the message had masking, unmask it
if ((m_header[1] & kFlagMasking) != 0) {
Unmask(control ? std::span{m_controlPayload}
: std::span{m_payload}.subspan(m_frameStart),
std::span<const uint8_t, 4>{&m_header[m_headerSize - 4], 4});
}
// Handle message
bool fin = (m_header[0] & kFlagFin) != 0;
uint8_t opcode = m_header[0] & kOpMask;
switch (opcode) {
case kOpCont:
WS_DEBUG("WS Fragment {} [{}]\n", m_payload.size(),
DebugBinary(m_payload));
switch (m_fragmentOpcode) {
case kOpText:
if (!m_combineFragments || fin) {
std::string_view content{
reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
WS_DEBUG("WS RecvText(Defrag) {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
break;
case kOpBinary:
if (!m_combineFragments || fin) {
WS_DEBUG("WS RecvBinary(Defrag) {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
break;
default:
// no preceding message?
return Fail(1002, "invalid continuation message");
}
if (fin) {
m_fragmentOpcode = 0;
}
break;
case kOpText: {
std::string_view content{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()};
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvText {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugText(content));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
WS_DEBUG("WS RecvText {} ({})\n", m_payload.size(),
DebugText(content));
text(content, fin);
}
if (!fin) {
WS_DEBUG("WS RecvText {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
}
case kOpBinary:
if (m_fragmentOpcode != 0) {
WS_DEBUG("WS RecvBinary {} ({}) -> INCOMPLETE FRAGMENT\n",
m_payload.size(), DebugBinary(m_payload));
return Fail(1002, "incomplete fragment");
}
if (!m_combineFragments || fin) {
WS_DEBUG("WS RecvBinary {} ({})\n", m_payload.size(),
DebugBinary(m_payload));
binary(m_payload, fin);
}
if (!fin) {
WS_DEBUG("WS RecvBinary {} StartFrag\n", m_payload.size());
m_fragmentOpcode = opcode;
}
break;
case kOpClose: {
uint16_t code;
std::string_view reason;
if (!fin) {
code = 1002;
reason = "cannot fragment control frames";
} else if (m_controlPayload.size() < 2) {
code = 1005;
} else {
code = (static_cast<uint16_t>(m_controlPayload[0]) << 8) |
static_cast<uint16_t>(m_controlPayload[1]);
reason =
drop_front({reinterpret_cast<char*>(m_controlPayload.data()),
m_controlPayload.size()},
2);
}
// Echo the close if we didn't previously send it
if (m_state != CLOSING) {
SendClose(code, reason);
}
SetClosed(code, reason);
// If we're the server, shutdown the connection.
if (m_server) {
Shutdown();
}
break;
}
case kOpPing:
if (!fin) {
return Fail(1002, "cannot fragment control frames");
}
// If the connection is open, send a Pong in response
if (m_state == OPEN) {
SmallVector<uv::Buffer, 4> bufs;
{
raw_uv_ostream os{bufs, kWriteAllocSize};
os << m_controlPayload;
}
SendPong(bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) {
buf.Deallocate();
}
});
}
WS_DEBUG("WS RecvPing() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
ping(m_controlPayload);
break;
case kOpPong:
if (!fin) {
return Fail(1002, "cannot fragment control frames");
}
WS_DEBUG("WS RecvPong() {} ({})\n", m_controlPayload.size(),
DebugBinary(m_controlPayload));
pong(m_controlPayload);
break;
default:
return Fail(1002, "invalid message opcode");
}
// Prepare for next message
m_header.clear();
m_headerSize = 0;
if (!m_combineFragments || fin) {
if (control) {
m_controlPayload.clear();
} else {
m_payload.clear();
}
}
m_frameStart = m_payload.size();
m_frameSize = UINT64_MAX;
}
}
}
}
static void VerboseDebug(const WebSocket::Frame& frame) {
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
if ((frame.opcode & 0x7f) == 0x01) {
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
for (auto&& d : frame.data) {
str.append(std::string_view(d.base, d.len));
}
#endif
fmt::print("WS SendText({})\n", str.str());
} else if ((frame.opcode & 0x7f) == 0x02) {
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
#endif
fmt::print("WS SendBinary({})\n", str.str());
} else {
SmallString<128> str;
#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
raw_svector_ostream stros{str};
for (auto&& d : frame.data) {
for (auto ch : d.data()) {
stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
}
}
#endif
fmt::print("WS SendOp({}, {})\n", frame.opcode, str.str());
}
#endif
}
void WebSocket::SendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
WS_DEBUG("SendFrames({})\n", frames.size());
if (m_state != OPEN) {
SendError(frames, callback);
return;
}
// Build request
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
int numBytes = 0;
for (auto&& frame : frames) {
VerboseDebug(frame);
numBytes += req->m_frames.AddFrame(frame, m_server);
req->m_continueFrameOffs.emplace_back(numBytes);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
}
if (m_writeInProgress) {
if (auto lastReq = m_lastWriteReq.lock()) {
// if write currently in progress, process as a continuation of that
m_lastWriteReq = req;
// make sure we're really at the end
while (lastReq->m_cont) {
lastReq = lastReq->m_cont;
}
lastReq->m_cont = std::move(req);
return;
}
}
m_writeInProgress = true;
m_curWriteReq = req;
m_lastWriteReq = req;
req->Send({});
}
std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
std::span<const Frame> frames,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError(frames, callback);
return {};
}
// If something else is still in flight, don't send anything
if (m_writeInProgress) {
return frames;
}
return detail::TrySendFrames(
m_server, m_stream, frames,
[this](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(cb));
m_writeInProgress = true;
m_curWriteReq = req;
m_lastWriteReq = req;
return req;
},
std::move(callback));
}
void WebSocket::SendControl(
uint8_t opcode, std::span<const uv::Buffer> data,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
Frame frame{opcode, data};
// If we're not open, emit an error and don't send the data
if (m_state != WebSocket::OPEN) {
SendError({{frame}}, callback);
return;
}
// If nothing else is in flight, just use SendFrames()
std::shared_ptr<WriteReq> curReq = m_curWriteReq.lock();
if (!m_writeInProgress || !curReq) {
return SendFrames({{frame}}, std::move(callback));
}
// There's a write request in flight, but since this is a control frame, we
// want to send it as soon as we can, without waiting for all frames in that
// request (or any continuations) to be sent.
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
VerboseDebug(frame);
size_t numBytes = req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
req->m_continueFrameOffs.emplace_back(numBytes);
req->m_cont = curReq;
// There may be multiple control packets in flight; maintain in-order
// transmission. Linear search here is O(n^2), but should be pretty rare.
if (!curReq->m_controlCont) {
curReq->m_controlCont = std::move(req);
} else {
curReq = curReq->m_controlCont;
while (curReq->m_cont != req->m_cont) {
curReq = curReq->m_cont;
}
curReq->m_cont = std::move(req);
}
}
void WebSocket::SendError(
std::span<const Frame> frames,
const std::function<void(std::span<uv::Buffer>, uv::Error)>& callback) {
int err;
if (m_state == WebSocket::CONNECTING) {
err = UV_EAGAIN;
} else {
err = UV_ESHUTDOWN;
}
SmallVector<uv::Buffer, 4> bufs;
for (auto&& frame : frames) {
bufs.append(frame.data.begin(), frame.data.end());
}
callback(bufs, uv::Error{err});
}