// Copyright (c) FIRST and other WPILib contributors. // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. // clang-format off #include "wpi/net/WebSocket.hpp" // clang-format on #include #include #include #include #include #include "WebSocketTest.hpp" #include "wpi/net/HttpParser.hpp" #include "wpi/net/raw_uv_ostream.hpp" #include "wpi/util/Base64.hpp" #include "wpi/util/SmallString.hpp" #include "wpi/util/StringExtras.hpp" #include "wpi/util/sha1.hpp" namespace wpi::net { class WebSocketClientTest : public WebSocketTest { public: WebSocketClientTest() { // Bare bones server req.header.connect([this](std::string_view name, std::string_view value) { // save key (required for valid response) if (wpi::util::equals_lower(name, "sec-websocket-key")) { clientKey = value; } }); req.headersComplete.connect([this](bool) { // send response wpi::util::SmallVector bufs; raw_uv_ostream os{bufs, 4096}; os << "HTTP/1.1 101 Switching Protocols\r\n"; os << "Upgrade: websocket\r\n"; os << "Connection: Upgrade\r\n"; // accept hash wpi::util::SHA1 hash; hash.Update(clientKey); hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); if (mockBadAccept) { hash.Update("1"); } wpi::util::SmallString<64> hashBuf; wpi::util::SmallString<64> acceptBuf; os << "Sec-WebSocket-Accept: " << wpi::util::Base64Encode(hash.RawFinal(hashBuf), acceptBuf) << "\r\n"; if (!mockProtocol.empty()) { os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n"; } os << "\r\n"; conn->Write(bufs, [](auto bufs, uv::Error) { for (auto& buf : bufs) { buf.Deallocate(); } }); serverHeadersDone = true; if (connected) { connected(); } }); serverPipe->Listen([this] { conn = serverPipe->Accept(); conn->StartRead(); conn->data.connect([this](uv::Buffer& buf, size_t size) { std::string_view data{buf.base, size}; if (!serverHeadersDone) { data = req.Execute(data); if (req.HasError()) { Finish(); } INFO(http_errno_name(req.GetError())); REQUIRE(req.GetError() == HPE_OK); if (data.empty()) { return; } } wireData.insert(wireData.end(), data.begin(), data.end()); }); conn->end.connect([this] { Finish(); }); }); } bool mockBadAccept = false; std::vector wireData; std::shared_ptr conn; HttpParser req{HttpParser::Type::REQUEST}; wpi::util::SmallString<64> clientKey; std::string mockProtocol; bool serverHeadersDone = false; std::function connected; }; TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest Open", "[websocket][client][handshake]") { int gotOpen = 0; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, std::string_view reason) { Finish(); if (code != 1005 && code != 1006) { FAIL("Code: " << code << " Reason: " << reason); } }); ws->open.connect([&](std::string_view protocol) { ++gotOpen; Finish(); REQUIRE(protocol.empty()); }); }); loop->Run(); REQUIRE(gotOpen == 1); } TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest BadAccept", "[websocket][client][handshake][protocol]") { int gotClosed = 0; mockBadAccept = true; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, std::string_view msg) { Finish(); ++gotClosed; INFO("Message: " << msg); REQUIRE(code == 1002); }); ws->open.connect([&](std::string_view protocol) { Finish(); FAIL("Got open"); }); }); loop->Run(); REQUIRE(gotClosed == 1); } TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolGood", "[websocket][client][protocol]") { int gotOpen = 0; mockProtocol = "myProtocol"; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, {"myProtocol", "myProtocol2"}); ws->closed.connect([&](uint16_t code, std::string_view msg) { Finish(); if (code != 1005 && code != 1006) { FAIL("Code: " << code << "Message: " << msg); } }); ws->open.connect([&](std::string_view protocol) { ++gotOpen; Finish(); REQUIRE(protocol == "myProtocol"); }); }); loop->Run(); REQUIRE(gotOpen == 1); } TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolRespNotReq", "[websocket][client][protocol]") { int gotClosed = 0; mockProtocol = "myProtocol"; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, std::string_view msg) { Finish(); ++gotClosed; INFO("Message: " << msg); REQUIRE(code == 1003); }); ws->open.connect([&](std::string_view protocol) { Finish(); FAIL("Got open"); }); }); loop->Run(); REQUIRE(gotClosed == 1); } TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolReqNotResp", "[websocket][client][protocol]") { int gotClosed = 0; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, {{"myProtocol"}}); ws->closed.connect([&](uint16_t code, std::string_view msg) { Finish(); ++gotClosed; INFO("Message: " << msg); REQUIRE(code == 1002); }); ws->open.connect([&](std::string_view protocol) { Finish(); FAIL("Got open"); }); }); loop->Run(); REQUIRE(gotClosed == 1); } // // Send and receive data. Most of these cases are tested in // WebSocketServerTest, so only spot check differences like masking. // class WebSocketClientDataTest : public WebSocketClientTest { public: WebSocketClientDataTest() { clientPipe->Connect(pipeName, [&] { ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); if (setupWebSocket) { setupWebSocket(); } }); } std::function setupWebSocket; std::shared_ptr ws; }; TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest SendBinary", "[websocket][client][data]") { int gotCallback = 0; std::vector data(GENERATE(size_t{0}, size_t{1}, size_t{125}, size_t{126}, size_t{65535}, size_t{65536}), 0x03u); setupWebSocket = [&] { ws->open.connect([&](std::string_view) { ws->SendBinary({{data}}, [&](auto bufs, uv::Error) { ++gotCallback; ws->Terminate(); REQUIRE_FALSE(bufs.empty()); REQUIRE(bufs[0].base == reinterpret_cast(data.data())); }); }); }; loop->Run(); auto expectData = BuildMessage(0x02, true, true, data); AdjustMasking(wireData); REQUIRE(wireData == expectData); REQUIRE(gotCallback == 1); } TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest ReceiveBinary", "[websocket][client][data]") { int gotCallback = 0; std::vector data(GENERATE(size_t{0}, size_t{1}, size_t{125}, size_t{126}, size_t{65535}, size_t{65536}), 0x03u); setupWebSocket = [&] { ws->binary.connect([&](auto inData, bool fin) { ++gotCallback; ws->Terminate(); REQUIRE(fin); std::vector recvData{inData.begin(), inData.end()}; REQUIRE(data == recvData); }); }; auto message = BuildMessage(0x02, true, false, data); connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); }; loop->Run(); REQUIRE(gotCallback == 1); } // // The client must close the connection if a masked frame is received. // TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest ReceiveMasked", "[websocket][client][data][protocol]") { int gotCallback = 0; std::vector data(GENERATE(size_t{0}, size_t{1}, size_t{125}, size_t{126}, size_t{65535}, size_t{65536}), ' '); setupWebSocket = [&] { ws->text.connect([&](std::string_view, bool) { ws->Terminate(); FAIL("Should not have gotten masked message"); }); ws->closed.connect([&](uint16_t code, std::string_view reason) { ++gotCallback; INFO("reason: " << reason); REQUIRE(code == 1002); }); }; auto message = BuildMessage(0x01, true, true, data); connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); }; loop->Run(); REQUIRE(gotCallback == 1); } } // namespace wpi::net