diff --git a/wpinet/src/main/native/cpp/WebSocket.cpp b/wpinet/src/main/native/cpp/WebSocket.cpp index 986d090b95..43b901ecb6 100644 --- a/wpinet/src/main/native/cpp/WebSocket.cpp +++ b/wpinet/src/main/native/cpp/WebSocket.cpp @@ -55,27 +55,48 @@ class WebSocket::WriteReq : public uv::WriteReq, std::weak_ptr ws, std::function, uv::Error)> callback) : m_ws{std::move(ws)}, m_callback{std::move(callback)} { - finish.connect([this](uv::Error err) { - int result = Continue(GetStream(), shared_from_this()); - WS_DEBUG("Continue() -> {}\n", result); - if (result <= 0) { - m_frames.ReleaseBufs(); - auto ws = m_ws.lock(); - if (ws) { - ws->m_writeInProgress = false; - } - m_callback(m_userBufs, err); - if (result == 0 && m_cont && ws) { - WS_DEBUG("Continuing with another write\n"); - ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont); - } + finish.connect([this](uv::Error err) { Send(err); }); + } + + void Send(uv::Error err) { + auto ws = m_ws.lock(); + if (!ws || err) { + WS_DEBUG("no WS or error, calling callback\n"); + m_frames.ReleaseBufs(); + m_callback(m_userBufs, err); + return; + } + + // Continue() is designed so this is *only* called on frame boundaries + if (m_controlCont) { + // We have a control frame; switch to it. We will come back here via + // the control frame's m_cont when it's done. + WS_DEBUG("Continuing with a control write\n"); + auto controlCont = std::move(m_controlCont); + m_controlCont.reset(); + return controlCont->Send({}); + } + int result = Continue(ws->m_stream, shared_from_this()); + WS_DEBUG("Continue() -> {}\n", result); + if (result <= 0) { + m_frames.ReleaseBufs(); + m_callback(m_userBufs, uv::Error{result}); + if (result == 0 && m_cont) { + WS_DEBUG("Continuing with another write\n"); + ws->m_curWriteReq = m_cont; + return m_cont->Send({}); + } else { + ws->m_writeInProgress = false; + ws->m_curWriteReq.reset(); + ws->m_lastWriteReq.reset(); } - }); + } } std::weak_ptr m_ws; std::function, uv::Error)> m_callback; std::shared_ptr m_cont; + std::shared_ptr m_controlCont; }; static constexpr uint8_t kFlagMasking = 0x80; @@ -720,24 +741,33 @@ void WebSocket::SendFrames( return; } - auto req = std::make_shared(std::weak_ptr{}, - std::move(callback)); + // Build request + auto req = std::make_shared(weak_from_this(), std::move(callback)); + int numBytes = 0; for (auto&& frame : frames) { VerboseDebug(frame); - req->m_frames.AddFrame(frame, m_server); + numBytes += req->m_frames.AddFrame(frame, m_server); + req->m_continueFrameOffs.emplace_back(numBytes); req->m_userBufs.append(frame.data.begin(), frame.data.end()); } - req->m_continueBufPos = req->m_frames.m_bufs.size(); + if (m_writeInProgress) { - if (auto curReq = m_writeReq.lock()) { + if (auto lastReq = m_lastWriteReq.lock()) { // if write currently in progress, process as a continuation of that - m_writeReq = req; - curReq->m_cont = std::move(req); + m_lastWriteReq = req; + // make sure we're really at the end + while (lastReq->m_cont) { + lastReq = lastReq->m_cont; + } + lastReq->m_cont = std::move(req); return; } } - WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size()); - m_stream.Write(req->m_frames.m_bufs, req); + + m_writeInProgress = true; + m_curWriteReq = req; + m_lastWriteReq = req; + req->Send({}); } std::span WebSocket::TrySendFrames( @@ -759,7 +789,8 @@ std::span WebSocket::TrySendFrames( [this](std::function, uv::Error)>&& cb) { auto req = std::make_shared(weak_from_this(), std::move(cb)); m_writeInProgress = true; - m_writeReq = req; + m_curWriteReq = req; + m_lastWriteReq = req; return req; }, std::move(callback)); @@ -775,17 +806,32 @@ void WebSocket::SendControl( return; } - // Control messages always send immediately with their own request, whether - // or not something else is in flight. The protocol allows control messages - // interspersed with fragmented frames, and we otherwise make sure that any - // pending Write() is already at a frame boundary. - auto req = std::make_shared(std::weak_ptr{}, - std::move(callback)); + // If nothing else is in flight, just use SendFrames() + std::shared_ptr curReq = m_curWriteReq.lock(); + if (!m_writeInProgress || !curReq) { + return SendFrames({{frame}}, std::move(callback)); + } + + // There's a write request in flight, but since this is a control frame, we + // want to send it as soon as we can, without waiting for all frames in that + // request (or any continuations) to be sent. + auto req = std::make_shared(weak_from_this(), std::move(callback)); VerboseDebug(frame); - req->m_frames.AddFrame(frame, m_server); + size_t numBytes = req->m_frames.AddFrame(frame, m_server); req->m_userBufs.append(frame.data.begin(), frame.data.end()); - req->m_continueBufPos = req->m_frames.m_bufs.size(); - m_stream.Write(req->m_frames.m_bufs, req); + req->m_continueFrameOffs.emplace_back(numBytes); + req->m_cont = curReq; + // There may be multiple control packets in flight; maintain in-order + // transmission. Linear search here is O(n^2), but should be pretty rare. + if (!curReq->m_controlCont) { + curReq->m_controlCont = std::move(req); + } else { + curReq = curReq->m_controlCont; + while (curReq->m_cont != req->m_cont) { + curReq = curReq->m_cont; + } + curReq->m_cont = std::move(req); + } } void WebSocket::SendError( diff --git a/wpinet/src/main/native/cpp/WebSocketSerializer.h b/wpinet/src/main/native/cpp/WebSocketSerializer.h index 50c1e0656b..264b8f594d 100644 --- a/wpinet/src/main/native/cpp/WebSocketSerializer.h +++ b/wpinet/src/main/native/cpp/WebSocketSerializer.h @@ -130,7 +130,7 @@ int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr req) { } m_continueFramePos = offIt - m_continueFrameOffs.begin(); - m_continueBufPos = &(*bufIt) - &(*m_frames.m_bufs.begin()); + m_continueBufPos += bufIt - bufs.begin(); if (writeBufs.empty()) { WS_DEBUG("Write Done\n"); @@ -256,7 +256,6 @@ std::span TrySendFrames( // continuation (so the caller isn't responsible for doing this) size_t continuePos = 0; while (frameStart != frameEnd && !isFin) { - req->m_continueFrameOffs.emplace_back(continuePos); if (offIt != offEnd) { // we already generated the wire buffers for this frame, use them while (pos < *offIt && bufIt != bufEnd) { @@ -271,6 +270,7 @@ std::span TrySendFrames( // need to generate and add this frame continuePos += req->m_frames.AddFrame(*frameStart, server); } + req->m_continueFrameOffs.emplace_back(continuePos); isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; ++frameStart; } diff --git a/wpinet/src/main/native/include/wpinet/WebSocket.h b/wpinet/src/main/native/include/wpinet/WebSocket.h index 98ce9a9693..297c1235fd 100644 --- a/wpinet/src/main/native/include/wpinet/WebSocket.h +++ b/wpinet/src/main/native/include/wpinet/WebSocket.h @@ -514,7 +514,8 @@ class WebSocket : public std::enable_shared_from_this { // outgoing write request bool m_writeInProgress = false; class WriteReq; - std::weak_ptr m_writeReq; + std::weak_ptr m_curWriteReq; + std::weak_ptr m_lastWriteReq; // operating state State m_state = CONNECTING; diff --git a/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp b/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp index cba37e3ed5..6b74d822d2 100644 --- a/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp +++ b/wpinet/src/test/native/cpp/WebSocketIntegrationTest.cpp @@ -150,18 +150,24 @@ TEST_F(WebSocketIntegrationTest, ClientSendText) { TEST_F(WebSocketIntegrationTest, ServerSendPing) { int gotPing = 0; int gotPong = 0; + int gotData = 0; serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); auto server = WebSocketServer::Create(*conn); server->connected.connect([&](std::string_view, WebSocket& ws) { + ws.SendText({{"hello"}}, [&](auto, uv::Error) {}); ws.SendPing({uv::Buffer{"\x03\x04", 2}}, [&](auto, uv::Error) {}); + ws.SendPing({uv::Buffer{"\x03\x04", 2}}, [&](auto, uv::Error) {}); + ws.SendText({{"hello"}}, [&](auto, uv::Error) {}); ws.pong.connect([&](auto data) { ++gotPong; std::vector recvData{data.begin(), data.end()}; std::vector expectData{0x03, 0x04}; ASSERT_EQ(recvData, expectData); - ws.Close(); + if (gotPong == 2) { + ws.Close(); + } }); }); }); @@ -180,12 +186,17 @@ TEST_F(WebSocketIntegrationTest, ServerSendPing) { std::vector expectData{0x03, 0x04}; ASSERT_EQ(recvData, expectData); }); + ws->text.connect([&](std::string_view data, bool) { + ++gotData; + ASSERT_EQ(data, "hello"); + }); }); loop->Run(); - ASSERT_EQ(gotPing, 1); - ASSERT_EQ(gotPong, 1); + ASSERT_EQ(gotPing, 2); + ASSERT_EQ(gotPong, 2); + ASSERT_EQ(gotData, 2); } } // namespace wpi diff --git a/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp index 49a23a127d..6767a23c68 100644 --- a/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp +++ b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp @@ -281,10 +281,14 @@ TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf0) { std::array remBufs{std::span{m_buf0data}.subspan(2), m_bufs[1]}; std::array contBufs{m_frameHeaders[1], m_bufs[2]}; + std::array contFrameOffs{static_cast(m_serialized[1].size())}; EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); ASSERT_EQ(makeReqCalled, 1); ASSERT_THAT(req->m_frames.m_bufs, SpanEq(contBufs)); + ASSERT_EQ(req->m_continueBufPos, 0u); + ASSERT_EQ(req->m_continueFramePos, 0u); + ASSERT_THAT(req->m_continueFrameOffs, SpanEq(contFrameOffs)); ASSERT_EQ(callbackCalled, 0); }