From c395b29fb4923865a5a83447df54fdc20853e328 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Mon, 18 Sep 2023 19:49:54 -0700 Subject: [PATCH] [wpinet] Add WebSocket::TrySendFrames() (#5607) This takes advantage of the underlying byte-level TryWrite() functionality to minimize blocking behavior and enable higher layers to do things smartly when the network blocks. Also: - Fix handling of control packets in middle of fragmented - Clean up debugging features --- wpinet/.styleguide | 1 + wpinet/CMakeLists.txt | 4 +- wpinet/src/main/native/cpp/WebSocket.cpp | 384 +++++++++++------- wpinet/src/main/native/cpp/WebSocketDebug.h | 21 + .../main/native/cpp/WebSocketSerializer.cpp | 108 +++++ .../src/main/native/cpp/WebSocketSerializer.h | 286 +++++++++++++ .../main/native/include/wpinet/WebSocket.h | 46 ++- .../native/cpp/WebSocketSerializerTest.cpp | 375 +++++++++++++++++ .../test/native/cpp/WebSocketServerTest.cpp | 43 ++ 9 files changed, 1104 insertions(+), 164 deletions(-) create mode 100644 wpinet/src/main/native/cpp/WebSocketDebug.h create mode 100644 wpinet/src/main/native/cpp/WebSocketSerializer.cpp create mode 100644 wpinet/src/main/native/cpp/WebSocketSerializer.h create mode 100644 wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp diff --git a/wpinet/.styleguide b/wpinet/.styleguide index bfd5e75084..04f54ebf61 100644 --- a/wpinet/.styleguide +++ b/wpinet/.styleguide @@ -32,6 +32,7 @@ includeGuardRoots { includeOtherLibs { ^fmt/ + ^gmock/ ^gtest/ ^wpi/ } diff --git a/wpinet/CMakeLists.txt b/wpinet/CMakeLists.txt index 580593b410..312d8a0fb0 100644 --- a/wpinet/CMakeLists.txt +++ b/wpinet/CMakeLists.txt @@ -201,6 +201,6 @@ set_property(TARGET netconsoleTee PROPERTY FOLDER "examples") if (WITH_TESTS) wpilib_add_test(wpinet src/test/native/cpp) - target_include_directories(wpinet_test PRIVATE src/test/native/include) - target_link_libraries(wpinet_test wpinet ${LIBUTIL} gmock_main) + target_include_directories(wpinet_test PRIVATE src/test/native/include src/main/native/cpp) + target_link_libraries(wpinet_test wpinet ${LIBUTIL} gmock_main wpiutil_testlib) endif() diff --git a/wpinet/src/main/native/cpp/WebSocket.cpp b/wpinet/src/main/native/cpp/WebSocket.cpp index 0d9b66981a..986d090b95 100644 --- a/wpinet/src/main/native/cpp/WebSocket.cpp +++ b/wpinet/src/main/native/cpp/WebSocket.cpp @@ -5,6 +5,9 @@ #include "wpinet/WebSocket.h" #include +#include +#include +#include #include #include @@ -14,37 +17,70 @@ #include #include +#include "WebSocketDebug.h" +#include "WebSocketSerializer.h" #include "wpinet/HttpParser.h" #include "wpinet/raw_uv_ostream.h" #include "wpinet/uv/Stream.h" using namespace wpi; -namespace { -class WebSocketWriteReq : public uv::WriteReq { +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG +static std::string DebugBinary(std::span val) { +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT + std::string str; + wpi::raw_string_ostream stros{str}; + for (auto ch : val) { + stros << fmt::format("{:02x},", static_cast(ch) & 0xff); + } + return str; +#else + return ""; +#endif +} + +static inline std::string_view DebugText(std::string_view val) { +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT + return val; +#else + return ""; +#endif +} +#endif // WPINET_WEBSOCKET_VERBOSE_DEBUG + +class WebSocket::WriteReq : public uv::WriteReq, + public detail::WebSocketWriteReqBase { public: - explicit WebSocketWriteReq( + explicit WriteReq( + std::weak_ptr ws, std::function, uv::Error)> callback) - : m_callback{std::move(callback)} { + : m_ws{std::move(ws)}, m_callback{std::move(callback)} { finish.connect([this](uv::Error err) { - for (auto&& buf : m_internalBufs) { - buf.Deallocate(); + int result = Continue(GetStream(), shared_from_this()); + WS_DEBUG("Continue() -> {}\n", result); + if (result <= 0) { + m_frames.ReleaseBufs(); + auto ws = m_ws.lock(); + if (ws) { + ws->m_writeInProgress = false; + } + m_callback(m_userBufs, err); + if (result == 0 && m_cont && ws) { + WS_DEBUG("Continuing with another write\n"); + ws->m_stream.Write(m_cont->m_frames.m_bufs, m_cont); + } } - m_callback(m_userBufs, err); }); } + std::weak_ptr m_ws; std::function, uv::Error)> m_callback; - SmallVector m_internalBufs; - SmallVector m_userBufs; - - // for server - size_t m_internalBufPos = 0; + std::shared_ptr m_cont; }; -} // namespace static constexpr uint8_t kFlagMasking = 0x80; static constexpr uint8_t kLenMask = 0x7f; +static constexpr size_t kWriteAllocSize = 4096; class WebSocket::ClientHandshakeData { public: @@ -157,7 +193,7 @@ void WebSocket::StartClient(std::string_view uri, std::string_view host, // Build client request SmallVector bufs; - raw_uv_ostream os{bufs, 4096}; + raw_uv_ostream os{bufs, kWriteAllocSize}; os << "GET " << uri << " HTTP/1.1\r\n"; os << "Host: " << host << "\r\n"; @@ -276,7 +312,7 @@ void WebSocket::StartServer(std::string_view key, std::string_view version, // Build server response SmallVector bufs; - raw_uv_ostream os{bufs, 4096}; + raw_uv_ostream os{bufs, kWriteAllocSize}; // Handle unsupported version if (version != "13") { @@ -324,13 +360,13 @@ void WebSocket::StartServer(std::string_view key, std::string_view version, void WebSocket::SendClose(uint16_t code, std::string_view reason) { SmallVector bufs; if (code != 1005) { - raw_uv_ostream os{bufs, 4096}; + raw_uv_ostream os{bufs, kWriteAllocSize}; const uint8_t codeMsb[] = {static_cast((code >> 8) & 0xff), static_cast(code & 0xff)}; os << std::span{codeMsb}; os << reason; } - Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) { + SendControl(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) { for (auto&& buf : bufs) { buf.Deallocate(); } @@ -349,6 +385,17 @@ void WebSocket::Shutdown() { m_stream.Shutdown([this] { m_stream.Close(); }); } +static inline void Unmask(std::span data, + std::span key) { + int n = 0; + for (uint8_t& ch : data) { + ch ^= key[n++]; + if (n >= 4) { + n = 0; + } + } +} + void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { // ignore incoming data if we're failed or closed if (m_state == FAILED || m_state == CLOSED) { @@ -449,32 +496,37 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { } // limit maximum size - if ((m_payload.size() + m_frameSize) > m_maxMessageSize) { + bool control = (m_header[0] & kFlagControl) != 0; + if (((control ? m_controlPayload.size() : m_payload.size()) + + m_frameSize) > m_maxMessageSize) { return Fail(1009, "message too large"); } } } if (m_frameSize != UINT64_MAX) { - size_t need = m_frameStart + m_frameSize - m_payload.size(); + bool control = (m_header[0] & kFlagControl) != 0; + size_t need; + if (control) { + need = m_frameSize - m_controlPayload.size(); + } else { + need = m_frameStart + m_frameSize - m_payload.size(); + } size_t toCopy = (std::min)(need, data.size()); - m_payload.append(data.data(), data.data() + toCopy); + if (control) { + m_controlPayload.append(data.data(), data.data() + toCopy); + } else { + m_payload.append(data.data(), data.data() + toCopy); + } data.remove_prefix(toCopy); need -= toCopy; if (need == 0) { // We have a complete frame // If the message had masking, unmask it if ((m_header[1] & kFlagMasking) != 0) { - uint8_t key[4] = { - m_header[m_headerSize - 4], m_header[m_headerSize - 3], - m_header[m_headerSize - 2], m_header[m_headerSize - 1]}; - int n = 0; - for (uint8_t& ch : std::span{m_payload}.subspan(m_frameStart)) { - ch ^= key[n++]; - if (n >= 4) { - n = 0; - } - } + Unmask(control ? std::span{m_controlPayload} + : std::span{m_payload}.subspan(m_frameStart), + std::span{&m_header[m_headerSize - 4], 4}); } // Handle message @@ -482,17 +534,23 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { uint8_t opcode = m_header[0] & kOpMask; switch (opcode) { case kOpCont: + WS_DEBUG("WS Fragment {} [{}]\n", m_payload.size(), + DebugBinary(m_payload)); switch (m_fragmentOpcode) { case kOpText: if (!m_combineFragments || fin) { - text(std::string_view{reinterpret_cast( - m_payload.data()), - m_payload.size()}, - fin); + std::string_view content{ + reinterpret_cast(m_payload.data()), + m_payload.size()}; + WS_DEBUG("WS RecvText(Defrag) {} ({})\n", m_payload.size(), + DebugText(content)); + text(content, fin); } break; case kOpBinary: if (!m_combineFragments || fin) { + WS_DEBUG("WS RecvBinary(Defrag) {} ({})\n", m_payload.size(), + DebugBinary(m_payload)); binary(m_payload, fin); } break; @@ -504,42 +562,38 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { m_fragmentOpcode = 0; } break; - case kOpText: + case kOpText: { + std::string_view content{reinterpret_cast(m_payload.data()), + m_payload.size()}; if (m_fragmentOpcode != 0) { + WS_DEBUG("WS RecvText {} ({}) -> INCOMPLETE FRAGMENT\n", + m_payload.size(), DebugText(content)); return Fail(1002, "incomplete fragment"); } if (!m_combineFragments || fin) { -#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG - fmt::print( - "WS RecvText({})\n", - std::string_view{reinterpret_cast(m_payload.data()), - m_payload.size()}); -#endif - text(std::string_view{reinterpret_cast(m_payload.data()), - m_payload.size()}, - fin); + WS_DEBUG("WS RecvText {} ({})\n", m_payload.size(), + DebugText(content)); + text(content, fin); } if (!fin) { + WS_DEBUG("WS RecvText {} StartFrag\n", m_payload.size()); m_fragmentOpcode = opcode; } break; + } case kOpBinary: if (m_fragmentOpcode != 0) { + WS_DEBUG("WS RecvBinary {} ({}) -> INCOMPLETE FRAGMENT\n", + m_payload.size(), DebugBinary(m_payload)); return Fail(1002, "incomplete fragment"); } if (!m_combineFragments || fin) { -#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG - SmallString<128> str; - raw_svector_ostream stros{str}; - for (auto ch : m_payload) { - stros << fmt::format("{:02x},", - static_cast(ch) & 0xff); - } - fmt::print("WS RecvBinary({})\n", str.str()); -#endif + WS_DEBUG("WS RecvBinary {} ({})\n", m_payload.size(), + DebugBinary(m_payload)); binary(m_payload, fin); } if (!fin) { + WS_DEBUG("WS RecvBinary {} StartFrag\n", m_payload.size()); m_fragmentOpcode = opcode; } break; @@ -549,14 +603,15 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { if (!fin) { code = 1002; reason = "cannot fragment control frames"; - } else if (m_payload.size() < 2) { + } else if (m_controlPayload.size() < 2) { code = 1005; } else { - code = (static_cast(m_payload[0]) << 8) | - static_cast(m_payload[1]); - reason = drop_front( - {reinterpret_cast(m_payload.data()), m_payload.size()}, - 2); + code = (static_cast(m_controlPayload[0]) << 8) | + static_cast(m_controlPayload[1]); + reason = + drop_front({reinterpret_cast(m_controlPayload.data()), + m_controlPayload.size()}, + 2); } // Echo the close if we didn't previously send it if (m_state != CLOSING) { @@ -577,8 +632,8 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { if (m_state == OPEN) { SmallVector bufs; { - raw_uv_ostream os{bufs, 4096}; - os << m_payload; + raw_uv_ostream os{bufs, kWriteAllocSize}; + os << m_controlPayload; } SendPong(bufs, [](auto bufs, uv::Error) { for (auto&& buf : bufs) { @@ -586,13 +641,17 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { } }); } - ping(m_payload); + WS_DEBUG("WS RecvPing() {} ({})\n", m_controlPayload.size(), + DebugBinary(m_controlPayload)); + ping(m_controlPayload); break; case kOpPong: if (!fin) { return Fail(1002, "cannot fragment control frames"); } - pong(m_payload); + WS_DEBUG("WS RecvPong() {} ({})\n", m_controlPayload.size(), + DebugBinary(m_controlPayload)); + pong(m_controlPayload); break; default: return Fail(1002, "invalid message opcode"); @@ -602,7 +661,11 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { m_header.clear(); m_headerSize = 0; if (!m_combineFragments || fin) { - m_payload.clear(); + if (control) { + m_controlPayload.clear(); + } else { + m_payload.clear(); + } } m_frameStart = m_payload.size(); m_frameSize = UINT64_MAX; @@ -611,127 +674,132 @@ void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { } } -static void WriteFrame(WebSocketWriteReq& req, - SmallVectorImpl& bufs, bool server, - uint8_t opcode, std::span data) { +static void VerboseDebug(const WebSocket::Frame& frame) { #ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG - if ((opcode & 0x7f) == 0x01) { + if ((frame.opcode & 0x7f) == 0x01) { SmallString<128> str; - for (auto&& d : data) { +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT + for (auto&& d : frame.data) { str.append(std::string_view(d.base, d.len)); } +#endif fmt::print("WS SendText({})\n", str.str()); - } else if ((opcode & 0x7f) == 0x02) { + } else if ((frame.opcode & 0x7f) == 0x02) { SmallString<128> str; +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT raw_svector_ostream stros{str}; - for (auto&& d : data) { + for (auto&& d : frame.data) { for (auto ch : d.data()) { stros << fmt::format("{:02x},", static_cast(ch) & 0xff); } } - fmt::print("WS SendBinary({})\n", str.str()); - } #endif - - uint8_t header[10]; - uint8_t* pHeader = header; - - // opcode (includes FIN bit) - *pHeader++ = opcode; - - // payload length - uint64_t size = 0; - for (auto&& buf : data) { - size += buf.len; - } - if (size < 126) { - *pHeader++ = (server ? 0x00 : kFlagMasking) | size; - } else if (size <= 0xffff) { - *pHeader++ = (server ? 0x00 : kFlagMasking) | 126; - *pHeader++ = (size >> 8) & 0xff; - *pHeader++ = size & 0xff; + fmt::print("WS SendBinary({})\n", str.str()); } else { - *pHeader++ = (server ? 0x00 : kFlagMasking) | 127; - *pHeader++ = (size >> 56) & 0xff; - *pHeader++ = (size >> 48) & 0xff; - *pHeader++ = (size >> 40) & 0xff; - *pHeader++ = (size >> 32) & 0xff; - *pHeader++ = (size >> 24) & 0xff; - *pHeader++ = (size >> 16) & 0xff; - *pHeader++ = (size >> 8) & 0xff; - *pHeader++ = size & 0xff; - } - size_t headerSize = pHeader - header; - - // clients need to mask the input data - if (!server) { - SmallVector internalBufs; - raw_uv_ostream os{internalBufs, 4096}; - - os << std::span{header, headerSize}; - // generate masking key - static std::random_device rd; - static std::default_random_engine gen{rd()}; - std::uniform_int_distribution dist(0, 255); - uint8_t key[4]; - for (uint8_t& v : key) { - v = dist(gen); - } - os << std::span{key, 4}; - // copy and mask data - int n = 0; - for (auto&& buf : data) { - for (auto&& ch : buf.data()) { - os << static_cast(static_cast(ch) ^ key[n++]); - if (n >= 4) { - n = 0; - } + SmallString<128> str; +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT + raw_svector_ostream stros{str}; + for (auto&& d : frame.data) { + for (auto ch : d.data()) { + stros << fmt::format("{:02x},", static_cast(ch) & 0xff); } } - bufs.append(internalBufs.begin(), internalBufs.end()); - req.m_internalBufs.append(internalBufs.begin(), internalBufs.end()); - // don't send the user bufs as we copied their data - } else { - // manage m_internalBufs to efficiently store header - if (req.m_internalBufs.empty() || - (req.m_internalBufPos + headerSize) > 4096) { - req.m_internalBufs.emplace_back(uv::Buffer::Allocate(4096)); - req.m_internalBufPos = 0; - } - char* internalBuf = - req.m_internalBufs.back().data().data() + req.m_internalBufPos; - std::memcpy(internalBuf, header, headerSize); - bufs.emplace_back(internalBuf, headerSize); - req.m_internalBufPos += headerSize; - // servers can just send the buffers directly without masking - bufs.append(data.begin(), data.end()); +#endif + fmt::print("WS SendOp({}, {})\n", frame.opcode, str.str()); } - req.m_userBufs.append(data.begin(), data.end()); +#endif } void WebSocket::SendFrames( std::span frames, std::function, uv::Error)> callback) { // If we're not open, emit an error and don't send the data + WS_DEBUG("SendFrames({})\n", frames.size()); if (m_state != OPEN) { - int err; - if (m_state == CONNECTING) { - err = UV_EAGAIN; - } else { - err = UV_ESHUTDOWN; - } - SmallVector bufs; - for (auto&& frame : frames) { - bufs.append(frame.data.begin(), frame.data.end()); - } - callback(bufs, uv::Error{err}); + SendError(frames, callback); return; } - auto req = std::make_shared(std::move(callback)); + auto req = std::make_shared(std::weak_ptr{}, + std::move(callback)); + for (auto&& frame : frames) { + VerboseDebug(frame); + req->m_frames.AddFrame(frame, m_server); + req->m_userBufs.append(frame.data.begin(), frame.data.end()); + } + req->m_continueBufPos = req->m_frames.m_bufs.size(); + if (m_writeInProgress) { + if (auto curReq = m_writeReq.lock()) { + // if write currently in progress, process as a continuation of that + m_writeReq = req; + curReq->m_cont = std::move(req); + return; + } + } + WS_DEBUG("Write({})\n", req->m_frames.m_bufs.size()); + m_stream.Write(req->m_frames.m_bufs, req); +} + +std::span WebSocket::TrySendFrames( + std::span frames, + std::function, uv::Error)> callback) { + // If we're not open, emit an error and don't send the data + if (m_state != WebSocket::OPEN) { + SendError(frames, callback); + return {}; + } + + // If something else is still in flight, don't send anything + if (m_writeInProgress) { + return frames; + } + + return detail::TrySendFrames( + m_server, m_stream, frames, + [this](std::function, uv::Error)>&& cb) { + auto req = std::make_shared(weak_from_this(), std::move(cb)); + m_writeInProgress = true; + m_writeReq = req; + return req; + }, + std::move(callback)); +} + +void WebSocket::SendControl( + uint8_t opcode, std::span data, + std::function, uv::Error)> callback) { + Frame frame{opcode, data}; + // If we're not open, emit an error and don't send the data + if (m_state != WebSocket::OPEN) { + SendError({{frame}}, callback); + return; + } + + // Control messages always send immediately with their own request, whether + // or not something else is in flight. The protocol allows control messages + // interspersed with fragmented frames, and we otherwise make sure that any + // pending Write() is already at a frame boundary. + auto req = std::make_shared(std::weak_ptr{}, + std::move(callback)); + VerboseDebug(frame); + req->m_frames.AddFrame(frame, m_server); + req->m_userBufs.append(frame.data.begin(), frame.data.end()); + req->m_continueBufPos = req->m_frames.m_bufs.size(); + m_stream.Write(req->m_frames.m_bufs, req); +} + +void WebSocket::SendError( + std::span frames, + const std::function, uv::Error)>& callback) { + int err; + if (m_state == WebSocket::CONNECTING) { + err = UV_EAGAIN; + } else { + err = UV_ESHUTDOWN; + } SmallVector bufs; for (auto&& frame : frames) { - WriteFrame(*req, bufs, m_server, frame.opcode, frame.data); + bufs.append(frame.data.begin(), frame.data.end()); } - m_stream.Write(bufs, req); + callback(bufs, uv::Error{err}); } diff --git a/wpinet/src/main/native/cpp/WebSocketDebug.h b/wpinet/src/main/native/cpp/WebSocketDebug.h new file mode 100644 index 0000000000..5653b5f783 --- /dev/null +++ b/wpinet/src/main/native/cpp/WebSocketDebug.h @@ -0,0 +1,21 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +// #define WPINET_WEBSOCKET_VERBOSE_DEBUG +// #define WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#endif + +#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG +#define WS_DEBUG(format, ...) \ + ::fmt::print(FMT_STRING(format) __VA_OPT__(, ) __VA_ARGS__) +#else +#define WS_DEBUG(fmt, ...) +#endif diff --git a/wpinet/src/main/native/cpp/WebSocketSerializer.cpp b/wpinet/src/main/native/cpp/WebSocketSerializer.cpp new file mode 100644 index 0000000000..c5d9548f39 --- /dev/null +++ b/wpinet/src/main/native/cpp/WebSocketSerializer.cpp @@ -0,0 +1,108 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include "WebSocketSerializer.h" + +#include + +using namespace wpi::detail; + +static constexpr uint8_t kFlagMasking = 0x80; +static constexpr size_t kWriteAllocSize = 4096; + +static std::span BuildHeader(std::span header, + bool server, + const wpi::WebSocket::Frame& frame) { + uint8_t* pHeader = header.data(); + + // opcode (includes FIN bit) + *pHeader++ = frame.opcode; + + // payload length + uint64_t size = 0; + for (auto&& buf : frame.data) { + size += buf.len; + } + if (size < 126) { + *pHeader++ = (server ? 0x00 : kFlagMasking) | size; + } else if (size <= 0xffff) { + *pHeader++ = (server ? 0x00 : kFlagMasking) | 126; + *pHeader++ = (size >> 8) & 0xff; + *pHeader++ = size & 0xff; + } else { + *pHeader++ = (server ? 0x00 : kFlagMasking) | 127; + *pHeader++ = (size >> 56) & 0xff; + *pHeader++ = (size >> 48) & 0xff; + *pHeader++ = (size >> 40) & 0xff; + *pHeader++ = (size >> 32) & 0xff; + *pHeader++ = (size >> 24) & 0xff; + *pHeader++ = (size >> 16) & 0xff; + *pHeader++ = (size >> 8) & 0xff; + *pHeader++ = size & 0xff; + } + return header.subspan(0, pHeader - header.data()); +} + +size_t SerializedFrames::AddClientFrame(const WebSocket::Frame& frame) { + uint8_t headerBuf[10]; + auto header = BuildHeader(headerBuf, false, frame); + + // allocate a buffer per frame + size_t size = header.size() + 4; + for (auto&& buf : frame.data) { + size += buf.len; + } + m_allocBufs.emplace_back(uv::Buffer::Allocate(size)); + m_bufs.emplace_back(m_allocBufs.back()); + + char* internalBuf = m_allocBufs.back().data().data(); + std::memcpy(internalBuf, header.data(), header.size()); + internalBuf += header.size(); + + // generate masking key + static std::random_device rd; + static std::default_random_engine gen{rd()}; + std::uniform_int_distribution dist(0, 255); + uint8_t key[4]; + for (uint8_t& v : key) { + v = dist(gen); + } + std::memcpy(internalBuf, key, 4); + internalBuf += 4; + + // copy and mask data + int n = 0; + for (auto&& buf : frame.data) { + for (auto&& ch : buf.data()) { + *internalBuf++ = static_cast(ch) ^ key[n++]; + if (n >= 4) { + n = 0; + } + } + } + return size; +} + +size_t SerializedFrames::AddServerFrame(const WebSocket::Frame& frame) { + uint8_t headerBuf[10]; + auto header = BuildHeader(headerBuf, true, frame); + + // manage allocBufs to efficiently store header + if (m_allocBufs.empty() || + (m_allocBufPos + header.size()) > kWriteAllocSize) { + m_allocBufs.emplace_back(uv::Buffer::Allocate(kWriteAllocSize)); + m_allocBufPos = 0; + } + char* internalBuf = m_allocBufs.back().data().data() + m_allocBufPos; + std::memcpy(internalBuf, header.data(), header.size()); + m_bufs.emplace_back(internalBuf, header.size()); + m_allocBufPos += header.size(); + // servers can just send the buffers directly without masking + m_bufs.append(frame.data.begin(), frame.data.end()); + size_t sent = header.size(); + for (auto&& buf : frame.data) { + sent += buf.len; + } + return sent; +} diff --git a/wpinet/src/main/native/cpp/WebSocketSerializer.h b/wpinet/src/main/native/cpp/WebSocketSerializer.h new file mode 100644 index 0000000000..825dfbae26 --- /dev/null +++ b/wpinet/src/main/native/cpp/WebSocketSerializer.h @@ -0,0 +1,286 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include +#include +#include + +#include +#include + +#include "WebSocketDebug.h" +#include "wpinet/WebSocket.h" +#include "wpinet/uv/Buffer.h" + +namespace wpi::detail { + +class SerializedFrames { + public: + SerializedFrames() = default; + SerializedFrames(const SerializedFrames&) = delete; + SerializedFrames& operator=(const SerializedFrames&) = delete; + ~SerializedFrames() { ReleaseBufs(); } + + size_t AddFrame(const WebSocket::Frame& frame, bool server) { + if (server) { + return AddServerFrame(frame); + } else { + return AddClientFrame(frame); + } + } + + size_t AddClientFrame(const WebSocket::Frame& frame); + size_t AddServerFrame(const WebSocket::Frame& frame); + + void ReleaseBufs() { + for (auto&& buf : m_allocBufs) { + buf.Deallocate(); + } + m_allocBufs.clear(); + } + + SmallVector m_allocBufs; + SmallVector m_bufs; + size_t m_allocBufPos = 0; +}; + +class WebSocketWriteReqBase { + public: + template + int Continue(Stream& stream, std::shared_ptr req); + + SmallVector m_userBufs; + SerializedFrames m_frames; + SmallVector m_continueFrameOffs; + size_t m_continueBufPos = 0; + size_t m_continueFramePos = 0; +}; + +template +int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr req) { + if (m_continueBufPos >= m_frames.m_bufs.size()) { + return 0; // nothing more to send + } + + // try writing everything remaining + std::span bufs = std::span{m_frames.m_bufs}.subspan(m_continueBufPos); + int numBytes = 0; + for (auto&& buf : bufs) { + numBytes += buf.len; + } + + int sentBytes = stream.TryWrite(bufs); + WS_DEBUG("TryWrite({}) -> {} (expected {})\n", bufs.size(), sentBytes, + numBytes); + if (sentBytes < 0) { + return sentBytes; // error + } + + if (sentBytes == numBytes) { + m_continueBufPos = m_frames.m_bufs.size(); + return 0; // nothing more to send + } + + // we didn't send everything; deal with the leftovers + + // figure out what the last (partially) frame sent actually was + auto offIt = m_continueFrameOffs.begin() + m_continueFramePos; + auto offEnd = m_continueFrameOffs.end(); + while (offIt != offEnd && *offIt < sentBytes) { + ++offIt; + } + assert(offIt != offEnd); + + // build a list of buffers to send as a normal write: + SmallVector writeBufs; + auto bufIt = bufs.begin(); + auto bufEnd = bufs.end(); + + // start with the remaining portion of the last buffer actually sent + int pos = 0; + while (bufIt != bufEnd && pos < sentBytes) { + pos += (bufIt++)->len; + } + if (bufIt != bufs.begin() && pos != sentBytes) { + writeBufs.emplace_back( + wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes)); + } + + // continue through the last buffer of the last partial frame + while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { + pos += bufIt->len; + writeBufs.emplace_back(*bufIt++); + } + if (offIt != offEnd) { + ++offIt; + } + + // if writeBufs is still empty, write all of the next frame + if (writeBufs.empty()) { + while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { + pos += bufIt->len; + writeBufs.emplace_back(*bufIt++); + } + if (offIt != offEnd) { + ++offIt; + } + } + + m_continueFramePos = offIt - m_continueFrameOffs.begin(); + m_continueBufPos = &(*bufIt) - &(*m_frames.m_bufs.begin()); + + if (writeBufs.empty()) { + WS_DEBUG("Write Done\n"); + return 0; + } + WS_DEBUG("Write({})\n", writeBufs.size()); + stream.Write(writeBufs, req); + return 1; +} + +template +std::span TrySendFrames( + bool server, Stream& stream, std::span frames, + MakeReq&& makeReq, + std::function, uv::Error)> callback) { + WS_DEBUG("TrySendFrames({})\n", frames.size()); + auto frameIt = frames.begin(); + auto frameEnd = frames.end(); + while (frameIt != frameEnd) { + auto frameStart = frameIt; + + // build buffers to send + SerializedFrames sendFrames; + SmallVector frameOffs; + int numBytes = 0; + while (frameIt != frameEnd) { + frameOffs.emplace_back(numBytes); + numBytes += sendFrames.AddFrame(*frameIt++, server); + if ((server && (numBytes >= 65536 || frameOffs.size() > 32)) || + (!server && numBytes >= 8192)) { + // don't waste too much memory or effort on header generation or masking + break; + } + } + + // try to send + int sentBytes = stream.TryWrite(sendFrames.m_bufs); + WS_DEBUG("TryWrite({}) -> {} (expected {})\n", sendFrames.m_bufs.size(), + sentBytes, numBytes); + + if (sentBytes == 0) { + // we haven't started a frame yet; clean up any bufs that have actually + // sent, and return unsent frames + SmallVector bufs; + for (auto it = frames.begin(); it != frameStart; ++it) { + bufs.append(it->data.begin(), it->data.end()); + } + callback(bufs, {}); + return {&*frameStart, &*frameEnd}; + } else if (sentBytes < 0) { + // error + SmallVector bufs; + for (auto&& frame : frames) { + bufs.append(frame.data.begin(), frame.data.end()); + } + callback(bufs, uv::Error{sentBytes}); + return frames; + } else if (sentBytes != numBytes) { + // we didn't send everything; deal with the leftovers + + // figure out what the last (partially) frame sent actually was + auto offIt = frameOffs.begin(); + auto offEnd = frameOffs.end(); + bool isFin = true; + while (offIt != offEnd && *offIt < sentBytes) { + ++offIt; + isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; + ++frameStart; + } + + if (offIt != offEnd && *offIt == sentBytes && isFin) { + // we finished at a normal FIN frame boundary; no need for a Write() + SmallVector bufs; + for (auto it = frames.begin(); it != frameStart; ++it) { + bufs.append(it->data.begin(), it->data.end()); + } + callback(bufs, {}); + return {&*frameStart, &*frameEnd}; + } + + // build a list of buffers to send as a normal write: + SmallVector writeBufs; + auto bufIt = sendFrames.m_bufs.begin(); + auto bufEnd = sendFrames.m_bufs.end(); + + // start with the remaining portion of the last buffer actually sent + int pos = 0; + while (bufIt != bufEnd && pos < sentBytes) { + pos += (bufIt++)->len; + } + if (bufIt != sendFrames.m_bufs.begin() && pos != sentBytes) { + writeBufs.emplace_back( + wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes)); + } + + // continue through the last buffer of the last partial frame + while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { + pos += bufIt->len; + writeBufs.emplace_back(*bufIt++); + } + if (offIt != offEnd) { + ++offIt; + } + + // move allocated buffers into request + auto req = makeReq(std::move(callback)); + req->m_frames.m_allocBufs = std::move(sendFrames.m_allocBufs); + req->m_frames.m_allocBufPos = sendFrames.m_allocBufPos; + + // if partial frame was non-FIN, put any additional non-FIN frames into + // continuation (so the caller isn't responsible for doing this) + size_t continuePos = 0; + while (frameStart != frameEnd && !isFin) { + req->m_continueFrameOffs.emplace_back(continuePos); + if (offIt != offEnd) { + // we already generated the wire buffers for this frame, use them + while (pos < *offIt && bufIt != bufEnd) { + pos += bufIt->len; + continuePos += bufIt->len; + req->m_frames.m_bufs.emplace_back(*bufIt++); + } + ++offIt; + } else { + // WS_DEBUG("generating frame for continuation {} {}\n", + // frameStart->opcode, frameStart->data.size()); + // need to generate and add this frame + continuePos += req->m_frames.AddFrame(*frameStart, server); + } + isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; + ++frameStart; + } + + // only the non-returned user buffers are added to the request + for (auto it = frames.begin(); it != frameStart; ++it) { + req->m_userBufs.append(it->data.begin(), it->data.end()); + } + + WS_DEBUG("Write({})\n", writeBufs.size()); + stream.Write(writeBufs, req); + return {&*frameStart, &*frameEnd}; + } + } + + // nothing left to send + SmallVector bufs; + for (auto&& frame : frames) { + bufs.append(frame.data.begin(), frame.data.end()); + } + callback(bufs, {}); + return {}; +} + +} // namespace wpi::detail diff --git a/wpinet/src/main/native/include/wpinet/WebSocket.h b/wpinet/src/main/native/include/wpinet/WebSocket.h index 6e886c3941..98ce9a9693 100644 --- a/wpinet/src/main/native/include/wpinet/WebSocket.h +++ b/wpinet/src/main/native/include/wpinet/WebSocket.h @@ -34,6 +34,7 @@ class Stream; class WebSocket : public std::enable_shared_from_this { struct private_init {}; + public: static constexpr uint8_t kOpCont = 0x00; static constexpr uint8_t kOpText = 0x01; static constexpr uint8_t kOpBinary = 0x02; @@ -42,8 +43,8 @@ class WebSocket : public std::enable_shared_from_this { static constexpr uint8_t kOpPong = 0x0A; static constexpr uint8_t kOpMask = 0x0F; static constexpr uint8_t kFlagFin = 0x80; + static constexpr uint8_t kFlagControl = 0x08; - public: WebSocket(uv::Stream& stream, bool server, const private_init&); WebSocket(const WebSocket&) = delete; WebSocket(WebSocket&&) = delete; @@ -93,7 +94,7 @@ class WebSocket : public std::enable_shared_from_this { static constexpr uint8_t kPing = kFlagFin | kOpPing; static constexpr uint8_t kPong = kFlagFin | kOpPong; - Frame(uint8_t opcode, std::span data) + constexpr Frame(uint8_t opcode, std::span data) : opcode{opcode}, data{data} {} uint8_t opcode; @@ -339,7 +340,7 @@ class WebSocket : public std::enable_shared_from_this { void SendPing( std::span data, std::function, uv::Error)> callback) { - Send(kFlagFin | kOpPing, data, std::move(callback)); + SendControl(kFlagFin | kOpPing, data, std::move(callback)); } /** @@ -376,7 +377,7 @@ class WebSocket : public std::enable_shared_from_this { void SendPong( std::span data, std::function, uv::Error)> callback) { - Send(kFlagFin | kOpPong, data, std::move(callback)); + SendControl(kFlagFin | kOpPong, data, std::move(callback)); } /** @@ -401,6 +402,31 @@ class WebSocket : public std::enable_shared_from_this { std::span frames, std::function, uv::Error)> callback); + /** + * Try to send multiple frames. Tries to send as many frames as possible + * immediately, and only queues the "last" frame it can (as the network queue + * will almost always fill partway through a frame). The frames following + * the last frame will NOT be queued for transmission; the caller is + * responsible for how to handle (e.g. re-send) those frames (e.g. when the + * callback is called). + * + * @param frames Frame type/data pairs + * @param callback Callback which is invoked when the write completes of the + * last frame that is not returned. + * @return Remaining frames that will not be sent + */ + std::span TrySendFrames( + std::span frames, + std::function, uv::Error)> callback); + + /** + * Returns whether or not a previous TrySendFrames is still in progress. + * Calling TrySendFrames if this returns true will return all frames. + * + * @return True if a TryWrite is in progress + */ + bool IsWriteInProgress() const { return m_writeInProgress; } + /** * Fail the connection. */ @@ -485,6 +511,11 @@ class WebSocket : public std::enable_shared_from_this { size_t m_maxMessageSize = 128 * 1024; bool m_combineFragments = true; + // outgoing write request + bool m_writeInProgress = false; + class WriteReq; + std::weak_ptr m_writeReq; + // operating state State m_state = CONNECTING; @@ -492,6 +523,7 @@ class WebSocket : public std::enable_shared_from_this { SmallVector m_header; size_t m_headerSize = 0; SmallVector m_payload; + SmallVector m_controlPayload; size_t m_frameStart = 0; uint64_t m_frameSize = UINT64_MAX; uint8_t m_fragmentOpcode = 0; @@ -508,10 +540,16 @@ class WebSocket : public std::enable_shared_from_this { void SendClose(uint16_t code, std::string_view reason); void SetClosed(uint16_t code, std::string_view reason, bool failed = false); void HandleIncoming(uv::Buffer& buf, size_t size); + void SendControl( + uint8_t opcode, std::span data, + std::function, uv::Error)> callback); void Send(uint8_t opcode, std::span data, std::function, uv::Error)> callback) { SendFrames({{Frame{opcode, data}}}, std::move(callback)); } + void SendError( + std::span frames, + const std::function, uv::Error)>& callback); }; } // namespace wpi diff --git a/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp new file mode 100644 index 0000000000..49a23a127d --- /dev/null +++ b/wpinet/src/test/native/cpp/WebSocketSerializerTest.cpp @@ -0,0 +1,375 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include "WebSocketSerializer.h" // NOLINT(build/include_order) + +#include +#include +#include +#include + +#include +#include + +#include "WebSocketTest.h" +#include "wpinet/uv/Buffer.h" + +using ::testing::_; +using ::testing::AnyOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::Pointee; +using ::testing::Return; + +namespace wpi::uv { +inline bool operator==(const Buffer& lhs, const Buffer& rhs) { + return lhs.len == rhs.len && + std::equal(lhs.base, lhs.base + lhs.len, rhs.base); +} +inline void PrintTo(const Buffer& buf, ::std::ostream* os) { + ::wpi::PrintTo(buf.bytes(), os); +} +} // namespace wpi::uv + +namespace wpi { +inline bool operator==(const WebSocket::Frame& lhs, + const WebSocket::Frame& rhs) { + return lhs.opcode == rhs.opcode && + std::equal(lhs.data.begin(), lhs.data.end(), rhs.data.begin()); +} +inline void PrintTo(const WebSocket::Frame& frame, ::std::ostream* os) { + *os << frame.opcode << ": "; + ::wpi::PrintTo(frame.data, os); +} +} // namespace wpi + +namespace wpi::detail { + +class MockWebSocketWriteReq + : public std::enable_shared_from_this, + public detail::WebSocketWriteReqBase { + public: + explicit MockWebSocketWriteReq( + std::function, uv::Error)> callback) {} +}; + +class MockStream { + public: + MOCK_METHOD(int, TryWrite, (std::span)); + void Write(std::span bufs, + const std::shared_ptr& req) { + // std::cout << "Write("; + // PrintTo(bufs, &std::cout); + // std::cout << ")\n"; + DoWrite(bufs, req); + } + MOCK_METHOD(void, DoWrite, + (std::span bufs, + const std::shared_ptr& req)); +}; + +class WebSocketWriteReqTest : public ::testing::Test { + public: + WebSocketWriteReqTest() { + req->m_frames.m_bufs.emplace_back(m_buf0); + req->m_frames.m_bufs.emplace_back(m_buf1); + req->m_frames.m_bufs.emplace_back(m_buf2); + req->m_continueFrameOffs.emplace_back(5); // frame 0: first 2 buffers + req->m_continueFrameOffs.emplace_back(9); // frame 1: last buffer + } + + std::shared_ptr req = + std::make_shared([](auto, auto) {}); + ::testing::StrictMock stream; + static const uint8_t m_buf0[3]; + static const uint8_t m_buf1[2]; + static const uint8_t m_buf2[4]; +}; + +const uint8_t WebSocketWriteReqTest::m_buf0[3] = {1, 2, 3}; +const uint8_t WebSocketWriteReqTest::m_buf1[2] = {4, 5}; +const uint8_t WebSocketWriteReqTest::m_buf2[4] = {6, 7, 8, 9}; + +TEST_F(WebSocketWriteReqTest, ContinueDone) { + req->m_continueBufPos = 3; + ASSERT_EQ(req->Continue(stream, req), 0); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWriteComplete) { + EXPECT_CALL(stream, TryWrite(wpi::SpanEq(req->m_frames.m_bufs))) + .WillOnce(Return(9)); + ASSERT_EQ(req->Continue(stream, req), 0); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWriteNoProgress) { + // if TryWrite returns 0 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(0)); + // Write should get called for all of next frame - make forward progress + uv::Buffer remBufs[2] = {uv::Buffer{m_buf0}, uv::Buffer{m_buf1}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWriteError) { + // if TryWrite returns -1, the error is passed along + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(-1)); + ASSERT_EQ(req->Continue(stream, req), -1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf1) { + // stop partway through buf 0 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(2)); + // Write should get called for remainder of buf 0 and all of buf 1 + uv::Buffer remBufs[2] = {uv::Buffer{&m_buf0[2], 1}, uv::Buffer{m_buf1}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameBufBoundary) { + // stop at end of buf 0 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(3)); + // Write should get called for all of buf 1 + uv::Buffer remBufs[1] = {uv::Buffer{m_buf1}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf2) { + // stop partway through buf 1 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(4)); + // Write should get called for remainder of buf 1 + uv::Buffer remBufs[1] = {uv::Buffer{&m_buf1[1], 1}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialFrameBoundary) { + // stop at end of buf 1 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(5)); + // Write should get called for all of next frame + uv::Buffer remBufs[1] = {uv::Buffer{m_buf2}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf3) { + // stop partway through buf 2 + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(6)); + // Write should get called for remainder of buf 2 + uv::Buffer remBufs[1] = {uv::Buffer{&m_buf2[1], 3}}; + EXPECT_CALL(stream, + DoWrite(wpi::SpanEq(std::span(remBufs)), _)); + ASSERT_EQ(req->Continue(stream, req), 1); +} + +class WebSocketTrySendTest : public ::testing::Test { + public: + ::testing::StrictMock stream; + std::shared_ptr req; + static const std::array m_buf0data; + static const std::array m_buf1data; + static const std::array m_buf2data; + static const std::array m_buf3data; + static const std::array m_bufs; + static const std::array m_frame0data; + static const std::array m_frame0bufs; + static const std::array m_frame1bufs; + static const std::array m_frame2bufs; + static const std::array m_frames; + static const std::array, 3> m_serialized; + static const std::array m_frameHeaders; + + int makeReqCalled = 0; + int callbackCalled = 0; + void CheckTrySendFrames(std::span expectCbBufs, + std::span expectRet, + int expectErr = 0); +}; + +const std::array WebSocketTrySendTest::m_buf0data{1, 2, 3}; +const std::array WebSocketTrySendTest::m_buf1data{4, 5}; +const std::array WebSocketTrySendTest::m_buf2data{6, 7, 8, 9}; +const std::array WebSocketTrySendTest::m_buf3data{10, 11, 12, 13}; +const std::array WebSocketTrySendTest::m_bufs{ + uv::Buffer{m_buf0data}, uv::Buffer{m_buf1data}, uv::Buffer{m_buf2data}, + uv::Buffer{m_buf3data}}; +const std::array WebSocketTrySendTest::m_frame0data{1, 2, 3, 4, 5}; +const std::array WebSocketTrySendTest::m_frame0bufs{m_bufs[0], + m_bufs[1]}; +const std::array WebSocketTrySendTest::m_frame1bufs{m_bufs[2]}; +const std::array WebSocketTrySendTest::m_frame2bufs{m_bufs[3]}; +const std::array WebSocketTrySendTest::m_frames{ + WebSocket::Frame{WebSocket::Frame::kBinaryFragment, m_frame0bufs}, + WebSocket::Frame{WebSocket::Frame::kBinary, m_frame1bufs}, + WebSocket::Frame{WebSocket::Frame::kText, m_frame2bufs}, +}; +const std::array, 3> WebSocketTrySendTest::m_serialized{ + WebSocketTest::BuildMessage(m_frames[0].opcode, false, false, m_frame0data), + WebSocketTest::BuildMessage(m_frames[1].opcode, true, false, m_buf2data), + WebSocketTest::BuildMessage(m_frames[2].opcode, true, false, m_buf3data), +}; +const std::array WebSocketTrySendTest::m_frameHeaders{ + uv::Buffer{m_serialized[0].data(), + m_serialized[0].size() - m_frame0data.size()}, + uv::Buffer{m_serialized[1].data(), + m_serialized[1].size() - m_buf2data.size()}, + uv::Buffer{m_serialized[2].data(), + m_serialized[2].size() - m_buf3data.size()}, +}; + +void WebSocketTrySendTest::CheckTrySendFrames( + std::span expectCbBufs, + std::span expectRet, int expectErr) { + ASSERT_THAT( + TrySendFrames( + true, stream, m_frames, + [&](std::function, uv::Error)>&& cb) { + ++makeReqCalled; + req = std::make_shared(std::move(cb)); + return req; + }, + [&](auto bufs, auto err) { + ++callbackCalled; + ASSERT_THAT(bufs, + SpanEq(std::span(expectCbBufs))); + ASSERT_EQ(err.code(), expectErr); + }), + SpanEq(expectRet)); +} + +TEST_F(WebSocketTrySendTest, ServerComplete) { + // if trywrite sends everything + EXPECT_CALL(stream, TryWrite(_)) + .WillOnce(Return(m_serialized[0].size() + m_serialized[1].size() + + m_serialized[2].size())); + // return nothing, and call callback immediately + CheckTrySendFrames(m_bufs, {}); + ASSERT_EQ(makeReqCalled, 0); + ASSERT_EQ(callbackCalled, 1); +} + +TEST_F(WebSocketTrySendTest, ServerNoProgress) { + // if trywrite sends nothing + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(0)); + // we should get all the frames back (the callback may be called with an empty + // set of buffers) + CheckTrySendFrames({}, m_frames); + ASSERT_EQ(makeReqCalled, 0); + ASSERT_THAT(callbackCalled, AnyOf(0, 1)); +} + +TEST_F(WebSocketTrySendTest, ServerError) { + // if TryWrite returns -1, the error is passed along + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(-1)); + CheckTrySendFrames(m_bufs, m_frames, -1); + ASSERT_EQ(makeReqCalled, 0); + ASSERT_EQ(callbackCalled, 1); +} + +TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf0) { + // stop partway through buf 0 (first buf of frame 0) + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 2)); + // Write should get called for remainder of buf 0 and all of buf 1 + // buf 2 should get put into continuation because frame 0 is a fragment + // return will be frame 2 only + std::array remBufs{std::span{m_buf0data}.subspan(2), + m_bufs[1]}; + std::array contBufs{m_frameHeaders[1], m_bufs[2]}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_THAT(req->m_frames.m_bufs, SpanEq(contBufs)); + ASSERT_EQ(callbackCalled, 0); +} + +TEST_F(WebSocketTrySendTest, ServerPartialMidFrameBufBoundary) { + // stop at end of buf 0 (first buf of frame 0) + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 3)); + // Write should get called for all of buf 1 + // buf 2 should get put into continuation because frame 0 is a fragment + // return will be frame 2 only + std::array remBufs{m_bufs[1]}; + std::array contBufs{m_frameHeaders[1], m_bufs[2]}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_THAT(req->m_frames.m_bufs, SpanEq(contBufs)); + ASSERT_EQ(callbackCalled, 0); +} + +TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf1) { + // stop partway through buf 1 (second buf of frame 0) + EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 4)); + // Write should get called for remainder of buf 1 + // buf 2 should get put into continuation because frame 0 is a fragment + // return will be frame 2 only + std::array remBufs{std::span{m_buf1data}.subspan(1)}; + std::array contBufs{m_frameHeaders[1], m_bufs[2]}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_THAT(req->m_frames.m_bufs, SpanEq(contBufs)); + ASSERT_EQ(callbackCalled, 0); +} + +TEST_F(WebSocketTrySendTest, ServerPartialFrameBoundary) { + // stop at end of buf 1 (end of frame 0) + EXPECT_CALL(stream, TryWrite(_)) + .WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 5)); + // Write should get called for all of buf 2 because frame 0 is a fragment + // no continuation + // return will be frame 2 only + std::array remBufs{m_bufs[2]}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_TRUE(req->m_frames.m_bufs.empty()); + ASSERT_EQ(callbackCalled, 0); +} + +TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf2) { + // stop partway through buf 2 (frame 1 buf) + EXPECT_CALL(stream, TryWrite(_)) + .WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 6)); + // Write should get called for remainder of buf 2; no continuation + // return will be frame 2 only + std::array remBufs{std::span{m_buf2data}.subspan(1)}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_TRUE(req->m_frames.m_bufs.empty()); + ASSERT_EQ(callbackCalled, 0); +} + +TEST_F(WebSocketTrySendTest, ServerFrameBoundary) { + // stop at end of buf 2 (end of frame 1) + EXPECT_CALL(stream, TryWrite(_)) + .WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 9)); + // call callback immediately for bufs 0-2 and return frame 2 + CheckTrySendFrames(std::span{m_bufs}.subspan(0, 3), + std::span{m_frames}.subspan(2)); + ASSERT_EQ(makeReqCalled, 0); + ASSERT_EQ(callbackCalled, 1); +} + +TEST_F(WebSocketTrySendTest, ServerPartialLastFrame) { + // stop partway through buf 3 + EXPECT_CALL(stream, TryWrite(_)) + .WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + + m_frameHeaders[2].len + 10)); + // Write called for remainder of buf 3; no continuation + std::array remBufs{std::span{m_buf3data}.subspan(1)}; + EXPECT_CALL(stream, DoWrite(wpi::SpanEq(remBufs), _)); + CheckTrySendFrames({}, {}); + ASSERT_EQ(makeReqCalled, 1); + ASSERT_TRUE(req->m_frames.m_bufs.empty()); + ASSERT_EQ(callbackCalled, 0); +} + +} // namespace wpi::detail diff --git a/wpinet/src/test/native/cpp/WebSocketServerTest.cpp b/wpinet/src/test/native/cpp/WebSocketServerTest.cpp index 18dd50bef6..3d74ca2dab 100644 --- a/wpinet/src/test/native/cpp/WebSocketServerTest.cpp +++ b/wpinet/src/test/native/cpp/WebSocketServerTest.cpp @@ -471,6 +471,49 @@ TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) { ASSERT_EQ(gotCallback, 3); } +// Control frames can happen in the middle of a fragmented message +TEST_F(WebSocketServerTest, ReceiveFragmentWithControl) { + int gotCallback = 0; + int gotPongCallback = 0; + + std::vector data(4, 0x03); + std::vector data2(4, 0x04); + std::vector data3(4, 0x05); + std::vector data4(4, 0x06); + std::vector combData{data}; + combData.insert(combData.end(), data2.begin(), data2.end()); + combData.insert(combData.end(), data4.begin(), data4.end()); + + setupWebSocket = [&] { + ws->binary.connect([&](auto inData, bool fin) { + ASSERT_TRUE(gotPongCallback); + ++gotCallback; + ws->Terminate(); + ASSERT_TRUE(fin); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(combData, recvData); + }); + ws->pong.connect([&](auto inData) { + ASSERT_FALSE(gotCallback); + ++gotPongCallback; + }); + }; + + auto message = BuildMessage(0x02, false, true, data); + auto message2 = BuildMessage(0x00, false, true, data2); + auto message3 = BuildMessage(0x0a, true, true, data3); + auto message4 = BuildMessage(0x00, true, true, data4); + resp.headersComplete.connect([&](bool) { + clientPipe->Write({{message}, {message2}, {message3}, {message4}}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); + ASSERT_EQ(gotPongCallback, 1); +} + // // Maximum message size is limited. //