[wpinet] WebSocket: Fix write behavior (#5841)

On Windows, TryWrite will always return 0 if there is a Write in progress. The previous behavior for SendFrames and SendControl just used a normal Write, which caused issues with code that combined these with TrySendFrames. Instead, have SendFrames and SendControl also use TryWrite under the hood if possible, and create write requests if not. The implementation preserves the priority of SendControl against an existing write request with multiple frames.
This commit is contained in:
Peter Johnson
2023-10-29 16:48:25 -07:00
committed by GitHub
parent 80c47da237
commit 3e7ba2cc6f
5 changed files with 102 additions and 40 deletions

View File

@@ -55,27 +55,48 @@ class WebSocket::WriteReq : public uv::WriteReq,
std::weak_ptr<WebSocket> ws,
std::function<void(std::span<uv::Buffer>, uv::Error)> callback)
: m_ws{std::move(ws)}, m_callback{std::move(callback)} {
finish.connect([this](uv::Error err) {
int result = Continue(GetStream(), shared_from_this());
WS_DEBUG("Continue() -> {}\n", result);
if (result <= 0) {
m_frames.ReleaseBufs();
auto ws = m_ws.lock();
if (ws) {
ws->m_writeInProgress = false;
}
m_callback(m_userBufs, err);
if (result == 0 && m_cont && ws) {
WS_DEBUG("Continuing with another write\n");
ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont);
}
finish.connect([this](uv::Error err) { Send(err); });
}
void Send(uv::Error err) {
auto ws = m_ws.lock();
if (!ws || err) {
WS_DEBUG("no WS or error, calling callback\n");
m_frames.ReleaseBufs();
m_callback(m_userBufs, err);
return;
}
// Continue() is designed so this is *only* called on frame boundaries
if (m_controlCont) {
// We have a control frame; switch to it. We will come back here via
// the control frame's m_cont when it's done.
WS_DEBUG("Continuing with a control write\n");
auto controlCont = std::move(m_controlCont);
m_controlCont.reset();
return controlCont->Send({});
}
int result = Continue(ws->m_stream, shared_from_this());
WS_DEBUG("Continue() -> {}\n", result);
if (result <= 0) {
m_frames.ReleaseBufs();
m_callback(m_userBufs, uv::Error{result});
if (result == 0 && m_cont) {
WS_DEBUG("Continuing with another write\n");
ws->m_curWriteReq = m_cont;
return m_cont->Send({});
} else {
ws->m_writeInProgress = false;
ws->m_curWriteReq.reset();
ws->m_lastWriteReq.reset();
}
});
}
}
std::weak_ptr<WebSocket> m_ws;
std::function<void(std::span<uv::Buffer>, uv::Error)> m_callback;
std::shared_ptr<WriteReq> m_cont;
std::shared_ptr<WriteReq> m_controlCont;
};
static constexpr uint8_t kFlagMasking = 0x80;
@@ -720,24 +741,33 @@ void WebSocket::SendFrames(
return;
}
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
// Build request
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
int numBytes = 0;
for (auto&& frame : frames) {
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
numBytes += req->m_frames.AddFrame(frame, m_server);
req->m_continueFrameOffs.emplace_back(numBytes);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
}
req->m_continueBufPos = req->m_frames.m_bufs.size();
if (m_writeInProgress) {
if (auto curReq = m_writeReq.lock()) {
if (auto lastReq = m_lastWriteReq.lock()) {
// if write currently in progress, process as a continuation of that
m_writeReq = req;
curReq->m_cont = std::move(req);
m_lastWriteReq = req;
// make sure we're really at the end
while (lastReq->m_cont) {
lastReq = lastReq->m_cont;
}
lastReq->m_cont = std::move(req);
return;
}
}
WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size());
m_stream.Write(req->m_frames.m_bufs, req);
m_writeInProgress = true;
m_curWriteReq = req;
m_lastWriteReq = req;
req->Send({});
}
std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
@@ -759,7 +789,8 @@ std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
[this](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(cb));
m_writeInProgress = true;
m_writeReq = req;
m_curWriteReq = req;
m_lastWriteReq = req;
return req;
},
std::move(callback));
@@ -775,17 +806,32 @@ void WebSocket::SendControl(
return;
}
// Control messages always send immediately with their own request, whether
// or not something else is in flight. The protocol allows control messages
// interspersed with fragmented frames, and we otherwise make sure that any
// pending Write() is already at a frame boundary.
auto req = std::make_shared<WriteReq>(std::weak_ptr<WebSocket>{},
std::move(callback));
// If nothing else is in flight, just use SendFrames()
std::shared_ptr<WriteReq> curReq = m_curWriteReq.lock();
if (!m_writeInProgress || !curReq) {
return SendFrames({{frame}}, std::move(callback));
}
// There's a write request in flight, but since this is a control frame, we
// want to send it as soon as we can, without waiting for all frames in that
// request (or any continuations) to be sent.
auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
VerboseDebug(frame);
req->m_frames.AddFrame(frame, m_server);
size_t numBytes = req->m_frames.AddFrame(frame, m_server);
req->m_userBufs.append(frame.data.begin(), frame.data.end());
req->m_continueBufPos = req->m_frames.m_bufs.size();
m_stream.Write(req->m_frames.m_bufs, req);
req->m_continueFrameOffs.emplace_back(numBytes);
req->m_cont = curReq;
// There may be multiple control packets in flight; maintain in-order
// transmission. Linear search here is O(n^2), but should be pretty rare.
if (!curReq->m_controlCont) {
curReq->m_controlCont = std::move(req);
} else {
curReq = curReq->m_controlCont;
while (curReq->m_cont != req->m_cont) {
curReq = curReq->m_cont;
}
curReq->m_cont = std::move(req);
}
}
void WebSocket::SendError(

View File

@@ -130,7 +130,7 @@ int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr<Req> req) {
}
m_continueFramePos = offIt - m_continueFrameOffs.begin();
m_continueBufPos = &(*bufIt) - &(*m_frames.m_bufs.begin());
m_continueBufPos += bufIt - bufs.begin();
if (writeBufs.empty()) {
WS_DEBUG("Write Done\n");
@@ -256,7 +256,6 @@ std::span<const WebSocket::Frame> TrySendFrames(
// continuation (so the caller isn't responsible for doing this)
size_t continuePos = 0;
while (frameStart != frameEnd && !isFin) {
req->m_continueFrameOffs.emplace_back(continuePos);
if (offIt != offEnd) {
// we already generated the wire buffers for this frame, use them
while (pos < *offIt && bufIt != bufEnd) {
@@ -271,6 +270,7 @@ std::span<const WebSocket::Frame> TrySendFrames(
// need to generate and add this frame
continuePos += req->m_frames.AddFrame(*frameStart, server);
}
req->m_continueFrameOffs.emplace_back(continuePos);
isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0;
++frameStart;
}