// Copyright (c) FIRST and other WPILib contributors. // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. #include "wpi/WebSocket.h" // NOLINT(build/include_order) #include "WebSocketTest.h" #include "wpi/Base64.h" #include "wpi/HttpParser.h" #include "wpi/SmallString.h" #include "wpi/raw_uv_ostream.h" #include "wpi/sha1.h" namespace wpi { class WebSocketClientTest : public WebSocketTest { public: WebSocketClientTest() { // Bare bones server req.header.connect([this](StringRef name, StringRef value) { // save key (required for valid response) if (name.equals_lower("sec-websocket-key")) clientKey = value; }); req.headersComplete.connect([this](bool) { // send response SmallVector bufs; raw_uv_ostream os{bufs, 4096}; os << "HTTP/1.1 101 Switching Protocols\r\n"; os << "Upgrade: websocket\r\n"; os << "Connection: Upgrade\r\n"; // accept hash SHA1 hash; hash.Update(clientKey); hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); if (mockBadAccept) hash.Update("1"); SmallString<64> hashBuf; SmallString<64> acceptBuf; os << "Sec-WebSocket-Accept: " << Base64Encode(hash.RawFinal(hashBuf), acceptBuf) << "\r\n"; if (!mockProtocol.empty()) os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n"; os << "\r\n"; conn->Write(bufs, [](auto bufs, uv::Error) { for (auto& buf : bufs) buf.Deallocate(); }); serverHeadersDone = true; if (connected) connected(); }); serverPipe->Listen([this] { conn = serverPipe->Accept(); conn->StartRead(); conn->data.connect([this](uv::Buffer& buf, size_t size) { StringRef data{buf.base, size}; if (!serverHeadersDone) { data = req.Execute(data); if (req.HasError()) Finish(); ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); if (data.empty()) return; } wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end()); }); conn->end.connect([this] { Finish(); }); }); } bool mockBadAccept = false; std::vector wireData; std::shared_ptr conn; HttpParser req{HttpParser::kRequest}; SmallString<64> clientKey; std::string mockProtocol; bool serverHeadersDone = false; std::function connected; }; TEST_F(WebSocketClientTest, Open) { int gotOpen = 0; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, StringRef reason) { Finish(); if (code != 1005 && code != 1006) FAIL() << "Code: " << code << " Reason: " << reason; }); ws->open.connect([&](StringRef protocol) { ++gotOpen; Finish(); ASSERT_TRUE(protocol.empty()); }); }); loop->Run(); if (HasFatalFailure()) return; ASSERT_EQ(gotOpen, 1); } TEST_F(WebSocketClientTest, BadAccept) { int gotClosed = 0; mockBadAccept = true; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, StringRef msg) { Finish(); ++gotClosed; ASSERT_EQ(code, 1002) << "Message: " << msg; }); ws->open.connect([&](StringRef protocol) { Finish(); FAIL() << "Got open"; }); }); loop->Run(); if (HasFatalFailure()) return; ASSERT_EQ(gotClosed, 1); } TEST_F(WebSocketClientTest, ProtocolGood) { int gotOpen = 0; mockProtocol = "myProtocol"; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, {"myProtocol", "myProtocol2"}); ws->closed.connect([&](uint16_t code, StringRef msg) { Finish(); if (code != 1005 && code != 1006) FAIL() << "Code: " << code << "Message: " << msg; }); ws->open.connect([&](StringRef protocol) { ++gotOpen; Finish(); ASSERT_EQ(protocol, "myProtocol"); }); }); loop->Run(); if (HasFatalFailure()) return; ASSERT_EQ(gotOpen, 1); } TEST_F(WebSocketClientTest, ProtocolRespNotReq) { int gotClosed = 0; mockProtocol = "myProtocol"; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); ws->closed.connect([&](uint16_t code, StringRef msg) { Finish(); ++gotClosed; ASSERT_EQ(code, 1003) << "Message: " << msg; }); ws->open.connect([&](StringRef protocol) { Finish(); FAIL() << "Got open"; }); }); loop->Run(); if (HasFatalFailure()) return; ASSERT_EQ(gotClosed, 1); } TEST_F(WebSocketClientTest, ProtocolReqNotResp) { int gotClosed = 0; clientPipe->Connect(pipeName, [&] { auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, StringRef{"myProtocol"}); ws->closed.connect([&](uint16_t code, StringRef msg) { Finish(); ++gotClosed; ASSERT_EQ(code, 1002) << "Message: " << msg; }); ws->open.connect([&](StringRef protocol) { Finish(); FAIL() << "Got open"; }); }); loop->Run(); if (HasFatalFailure()) return; ASSERT_EQ(gotClosed, 1); } // // Send and receive data. Most of these cases are tested in // WebSocketServerTest, so only spot check differences like masking. // class WebSocketClientDataTest : public WebSocketClientTest, public ::testing::WithParamInterface { public: WebSocketClientDataTest() { clientPipe->Connect(pipeName, [&] { ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); if (setupWebSocket) setupWebSocket(); }); } std::function setupWebSocket; std::shared_ptr ws; }; INSTANTIATE_TEST_SUITE_P(WebSocketClientDataTests, WebSocketClientDataTest, ::testing::Values(0, 1, 125, 126, 65535, 65536)); TEST_P(WebSocketClientDataTest, SendBinary) { int gotCallback = 0; std::vector data(GetParam(), 0x03u); setupWebSocket = [&] { ws->open.connect([&](StringRef) { ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) { ++gotCallback; ws->Terminate(); ASSERT_FALSE(bufs.empty()); ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); }); }); }; loop->Run(); auto expectData = BuildMessage(0x02, true, true, data); AdjustMasking(wireData); ASSERT_EQ(wireData, expectData); ASSERT_EQ(gotCallback, 1); } TEST_P(WebSocketClientDataTest, ReceiveBinary) { int gotCallback = 0; std::vector data(GetParam(), 0x03u); setupWebSocket = [&] { ws->binary.connect([&](ArrayRef inData, bool fin) { ++gotCallback; ws->Terminate(); ASSERT_TRUE(fin); std::vector recvData{inData.begin(), inData.end()}; ASSERT_EQ(data, recvData); }); }; auto message = BuildMessage(0x02, true, false, data); connected = [&] { conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); }; loop->Run(); ASSERT_EQ(gotCallback, 1); } // // The client must close the connection if a masked frame is received. // TEST_P(WebSocketClientDataTest, ReceiveMasked) { int gotCallback = 0; std::vector data(GetParam(), ' '); setupWebSocket = [&] { ws->text.connect([&](StringRef, bool) { ws->Terminate(); FAIL() << "Should not have gotten masked message"; }); ws->closed.connect([&](uint16_t code, StringRef reason) { ++gotCallback; ASSERT_EQ(code, 1002) << "reason: " << reason; }); }; auto message = BuildMessage(0x01, true, true, data); connected = [&] { conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); }; loop->Run(); ASSERT_EQ(gotCallback, 1); } } // namespace wpi