From 3e7ba2cc6fc56b720ac341718601b9734857885e Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Sun, 29 Oct 2023 16:48:25 -0700 Subject: [PATCH] [wpinet] WebSocket: Fix write behavior (#5841) On Windows, TryWrite will always return 0 if there is a Write in progress. The previous behavior for SendFrames and SendControl just used a normal Write, which caused issues with code that combined these with TrySendFrames. Instead, have SendFrames and SendControl also use TryWrite under the hood if possible, and create write requests if not. The implementation preserves the priority of SendControl against an existing write request with multiple frames. --- wpinet/src/main/native/cpp/WebSocket.cpp | 114 ++++++++++++------ .../src/main/native/cpp/WebSocketSerializer.h | 4 +- .../main/native/include/wpinet/WebSocket.h | 3 +- .../native/cpp/WebSocketIntegrationTest.cpp | 17 ++- .../native/cpp/WebSocketSerializerTest.cpp | 4 + 5 files changed, 102 insertions(+), 40 deletions(-) diff --git a/wpinet/src/main/native/cpp/WebSocket.cpp b/wpinet/src/main/native/cpp/WebSocket.cpp index 986d090b95..43b901ecb6 100644 --- a/wpinet/src/main/native/cpp/WebSocket.cpp +++ b/wpinet/src/main/native/cpp/WebSocket.cpp @@ -55,27 +55,48 @@ class WebSocket::WriteReq : public uv::WriteReq, std::weak_ptr ws, std::function, uv::Error)> callback) : m_ws{std::move(ws)}, m_callback{std::move(callback)} { - finish.connect([this](uv::Error err) { - int result = Continue(GetStream(), shared_from_this()); - WS_DEBUG("Continue() -> {}\n", result); - if (result <= 0) { - m_frames.ReleaseBufs(); - auto ws = m_ws.lock(); - if (ws) { - ws->m_writeInProgress = false; - } - m_callback(m_userBufs, err); - if (result == 0 && m_cont && ws) { - WS_DEBUG("Continuing with another write\n"); - ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont); - } + finish.connect([this](uv::Error err) { Send(err); }); + } + + void Send(uv::Error err) { + auto ws = m_ws.lock(); + if (!ws || err) { + WS_DEBUG("no WS or error, calling callback\n"); + m_frames.ReleaseBufs(); + m_callback(m_userBufs, err); + return; + } + + // Continue() is designed so this is *only* called on frame boundaries + if (m_controlCont) { + // We have a control frame; switch to it. We will come back here via + // the control frame's m_cont when it's done. + WS_DEBUG("Continuing with a control write\n"); + auto controlCont = std::move(m_controlCont); + m_controlCont.reset(); + return controlCont->Send({}); + } + int result = Continue(ws->m_stream, shared_from_this()); + WS_DEBUG("Continue() -> {}\n", result); + if (result <= 0) { + m_frames.ReleaseBufs(); + m_callback(m_userBufs, uv::Error{result}); + if (result == 0 && m_cont) { + WS_DEBUG("Continuing with another write\n"); + ws->m_curWriteReq = m_cont; + return m_cont->Send({}); + } else { + ws->m_writeInProgress = false; + ws->m_curWriteReq.reset(); + ws->m_lastWriteReq.reset(); } - }); + } } std::weak_ptr m_ws; std::function, uv::Error)> m_callback; std::shared_ptr m_cont; + std::shared_ptr m_controlCont; }; static constexpr uint8_t kFlagMasking = 0x80; @@ -720,24 +741,33 @@ void WebSocket::SendFrames( return; } - auto req = std::make_shared(std::weak_ptr{}, - std::move(callback)); + // Build request + auto req = std::make_shared(weak_from_this(), std::move(callback)); + int numBytes = 0; for (auto&& frame : frames) { VerboseDebug(frame); - req->m_frames.AddFrame(frame, m_server); + numBytes += req->m_frames.AddFrame(frame, m_server); + req->m_continueFrameOffs.emplace_back(numBytes); req->m_userBufs.append(frame.data.begin(), frame.data.end()); } - req->m_continueBufPos = req->m_frames.m_bufs.size(); + if (m_writeInProgress) { - if (auto curReq = m_writeReq.lock()) { + if (auto lastReq = m_lastWriteReq.lock()) { // if write currently in progress, process as a continuation of that - m_writeReq = req; - curReq->m_cont = std::move(req); + m_lastWriteReq = req; + // make sure we're really at the end + while (lastReq->m_cont) { + lastReq = lastReq->m_cont; + } + lastReq->m_cont = std::move(req); return; } } - WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size()); - m_stream.Write(req->m_frames.m_bufs, req); + + m_writeInProgress = true; + m_curWriteReq = req; + m_lastWriteReq = req; + req->Send({}); } std::span WebSocket::TrySendFrames( @@ -759,7 +789,8 @@ std::span WebSocket::TrySendFrames( [this](std::function, uv::Error)>&& cb) { auto req = std::make_shared(weak_from_this(), std::move(cb)); m_writeInProgress = true; - m_writeReq = req; + m_curWriteReq = req; + m_lastWriteReq = req; return req; }, std::move(callback)); @@ -775,17 +806,32 @@ void WebSocket::SendControl( return; } - // Control messages always send immediately with their own request, whether - // or not something else is in flight. The protocol allows control messages - // interspersed with fragmented frames, and we otherwise make sure that any - // pending Write() is already at a frame boundary. - auto req = std::make_shared(std::weak_ptr{}, - std::move(callback)); + // If nothing else is in flight, just use SendFrames() + std::shared_ptr curReq = m_curWriteReq.lock(); + if (!m_writeInProgress || !curReq) { + return SendFrames({{frame}}, std::move(callback)); + } + + // There's a write request in flight, but since this is a control frame, we + // want to send it as soon as we can, without waiting for all frames in that + // request (or any continuations) to be sent. + auto req = std::make_shared(weak_from_this(), std::move(callback)); VerboseDebug(frame); - req->m_frames.AddFrame(frame, m_server); + size_t numBytes = req->m_frames.AddFrame(frame, m_server); req->m_userBufs.append(frame.data.begin(), frame.data.end()); - req->m_continueBufPos = req->m_frames.m_bufs.size(); - m_stream.Write(req->m_frames.m_bufs, req); + req->m_continueFrameOffs.emplace_back(numBytes); + req->m_cont = curReq; + // There may be multiple control packets in flight; maintain in-order + // transmission. Linear search here is O(n^2), but should be pretty rare. + if (!curReq->m_controlCont) { + curReq->m_controlCont = std::move(req); + } else { + curReq = curReq->m_controlCont; + while (curReq->m_cont != req->m_cont) { + curReq = curReq->m_cont; + } + curReq->m_cont = std::move(req); + } } void WebSocket::SendError( diff --git a/wpinet/src/main/native/cpp/WebSocketSerializer.h b/wpinet/src/main/native/cpp/WebSocketSerializer.h index 50c1e0656b..264b8f594d 100644 --- a/wpinet/src/main/native/cpp/WebSocketSerializer.h +++ b/wpinet/src/main/native/cpp/WebSocketSerializer.h @@ -130,7 +130,7 @@ int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr req) { } m_continueFramePos = offIt - m_continueFrameOffs.begin(); - m_continueBufPos = &(*bufIt) - &(*m_frames.m_bufs.begin()); + m_continueBufPos += bufIt - bufs.begin(); if (writeBufs.empty()) { WS_DEBUG("Write Done\n"); @@ -256,7 +256,6 @@ std::span TrySendFrames( // continuation (so the caller isn't responsible for doing this) size_t continuePos = 0; while (frameStart != frameEnd && !isFin) { - req->m_continueFrameOffs.emplace_back(continuePos); if (offIt != offEnd) { // we already generated the wire buffers for this frame, use them while (pos < *offIt && bufIt != bufEnd) { @@ -271,6 +270,7 @@ std::span TrySendFrames( // need to generate and add this frame continuePos += req->m_frames.AddFrame(*frameStart, server); } + req->m_continueFrameOffs.emplace_back(continuePos); isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; ++frameStart; } diff --git a/wpinet/src/main/native/include/wpinet/WebSocket.h b/wpinet/src/main/native/include/wpinet/WebSocket.h index 98ce9a9693..297c1235fd 100644 --- a/wpinet/src/main/native/include/wpinet/WebSocket.h +++ b/wpinet/src/main/native/include/wpinet/WebSocket.h @@ -514,7 +514,8 @@ class WebSocket : public std::enable_shared_from_this { // outgoing write request bool m_writeInProgress = false; class WriteReq; - std::weak_ptr m_writeReq; + std::weak_ptr m_curWriteReq; + std::weak_ptr m_lastWriteReq; // operating state State m_state = CONNECTING; diff --git a/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp b/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp index cba37e3ed5..6b74d822d2 100644 --- a/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp +++ b/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp @@ -150,18 +150,24 @@ TEST_F(WebSocketIntegrationTest, ClientSendText) { TEST_F(WebSocketIntegrationTest, ServerSendPing) { int gotPing = 0; int gotPong = 0; + int gotData = 0; serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); auto server = WebSocketServer::Create(*conn); server->connected.connect([&](std::string_view, WebSocket& ws) { + ws.SendText({{"hello"}}, [&](auto, uv::Error) {}); ws.SendPing({uv::Buffer{"\x03\x04", 2}}, [&](auto, uv::Error) {}); + ws.SendPing({uv::Buffer{"\x03\x04", 2}}, [&](auto, uv::Error) {}); + ws.SendText({{"hello"}}, [&](auto, uv::Error) {}); ws.pong.connect([&](auto data) { ++gotPong; std::vector recvData{data.begin(), data.end()}; std::vector expectData{0x03, 0x04}; ASSERT_EQ(recvData, expectData); - ws.Close(); + if (gotPong == 2) { + ws.Close(); + } }); }); }); @@ -180,12 +186,17 @@ TEST_F(WebSocketIntegrationTest, ServerSendPing) { std::vector expectData{0x03, 0x04}; ASSERT_EQ(recvData, expectData); }); + ws->text.connect([&](std::string_view data, bool) { + ++gotData; + ASSERT_EQ(data, "hello"); + }); }); loop->Run(); - ASSERT_EQ(gotPing, 1); - ASSERT_EQ(gotPong, 1); + ASSERT_EQ(gotPing, 2); + ASSERT_EQ(gotPong, 2); + ASSERT_EQ(gotData, 2); } } // namespace wpi diff --git a/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp index 49a23a127d..6767a23c68 100644 --- a/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp +++ b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp @@ -281,10 +281,14 @@ TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf0) { std::array remBufs{std::span{m_buf0data}.subspan(2), m_bufs[1]}; std::array contBufs{m_frameHeaders[1], m_bufs[2]}; + std::array contFrameOffs{static_cast(m_serialized[1].size())}; EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); ASSERT_EQ(makeReqCalled, 1); ASSERT_THAT(req->m_frames.m_bufs, SpanEq(contBufs)); + ASSERT_EQ(req->m_continueBufPos, 0u); + ASSERT_EQ(req->m_continueFramePos, 0u); + ASSERT_THAT(req->m_continueFrameOffs, SpanEq(contFrameOffs)); ASSERT_EQ(callbackCalled, 0); }