Files
allwpilib/wpinet/src/test/native/cpp/WebSocketClientTest.cpp

327 lines
9.1 KiB
C++
Raw Normal View History

// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
// clang-format off
#include "wpi/net/WebSocket.hpp"
// clang-format on
2024-09-20 17:43:39 -07:00
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <catch2/generators/catch_generators.hpp>
2025-11-07 19:57:55 -05:00
#include "WebSocketTest.hpp"
#include "wpi/net/HttpParser.hpp"
#include "wpi/net/raw_uv_ostream.hpp"
2025-11-07 19:56:21 -05:00
#include "wpi/util/Base64.hpp"
#include "wpi/util/SmallString.hpp"
#include "wpi/util/StringExtras.hpp"
#include "wpi/util/sha1.hpp"
2025-11-07 20:00:05 -05:00
namespace wpi::net {
class WebSocketClientTest : public WebSocketTest {
public:
WebSocketClientTest() {
// Bare bones server
req.header.connect([this](std::string_view name, std::string_view value) {
// save key (required for valid response)
2025-11-07 20:00:05 -05:00
if (wpi::util::equals_lower(name, "sec-websocket-key")) {
clientKey = value;
}
});
req.headersComplete.connect([this](bool) {
// send response
2025-11-07 20:00:05 -05:00
wpi::util::SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
os << "HTTP/1.1 101 Switching Protocols\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
// accept hash
2025-11-07 20:00:05 -05:00
wpi::util::SHA1 hash;
hash.Update(clientKey);
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
if (mockBadAccept) {
hash.Update("1");
}
2025-11-07 20:00:05 -05:00
wpi::util::SmallString<64> hashBuf;
wpi::util::SmallString<64> acceptBuf;
os << "Sec-WebSocket-Accept: "
2025-11-07 20:01:58 -05:00
<< wpi::util::Base64Encode(hash.RawFinal(hashBuf), acceptBuf)
<< "\r\n";
if (!mockProtocol.empty()) {
os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
}
os << "\r\n";
conn->Write(bufs, [](auto bufs, uv::Error) {
for (auto& buf : bufs) {
buf.Deallocate();
}
});
serverHeadersDone = true;
if (connected) {
connected();
}
});
serverPipe->Listen([this] {
conn = serverPipe->Accept();
conn->StartRead();
conn->data.connect([this](uv::Buffer& buf, size_t size) {
std::string_view data{buf.base, size};
if (!serverHeadersDone) {
data = req.Execute(data);
if (req.HasError()) {
Finish();
}
INFO(http_errno_name(req.GetError()));
REQUIRE(req.GetError() == HPE_OK);
if (data.empty()) {
return;
}
}
wireData.insert(wireData.end(), data.begin(), data.end());
});
conn->end.connect([this] { Finish(); });
});
}
bool mockBadAccept = false;
std::vector<uint8_t> wireData;
std::shared_ptr<uv::Pipe> conn;
HttpParser req{HttpParser::Type::REQUEST};
2025-11-07 20:00:05 -05:00
wpi::util::SmallString<64> clientKey;
std::string mockProtocol;
bool serverHeadersDone = false;
std::function<void()> connected;
};
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest Open",
"[websocket][client][handshake]") {
int gotOpen = 0;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
REQUIRE(protocol.empty());
});
});
loop->Run();
REQUIRE(gotOpen == 1);
}
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest BadAccept",
"[websocket][client][handshake][protocol]") {
int gotClosed = 0;
mockBadAccept = true;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
INFO("Message: " << msg);
REQUIRE(code == 1002);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL("Got open");
});
});
loop->Run();
REQUIRE(gotClosed == 1);
}
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolGood",
"[websocket][client][protocol]") {
int gotOpen = 0;
mockProtocol = "myProtocol";
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
{"myProtocol", "myProtocol2"});
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
if (code != 1005 && code != 1006) {
FAIL("Code: " << code << "Message: " << msg);
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
REQUIRE(protocol == "myProtocol");
});
});
loop->Run();
REQUIRE(gotOpen == 1);
}
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolRespNotReq",
"[websocket][client][protocol]") {
int gotClosed = 0;
mockProtocol = "myProtocol";
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
INFO("Message: " << msg);
REQUIRE(code == 1003);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL("Got open");
});
});
loop->Run();
REQUIRE(gotClosed == 1);
}
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolReqNotResp",
"[websocket][client][protocol]") {
int gotClosed = 0;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
{{"myProtocol"}});
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
INFO("Message: " << msg);
REQUIRE(code == 1002);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL("Got open");
});
});
loop->Run();
REQUIRE(gotClosed == 1);
}
//
// Send and receive data. Most of these cases are tested in
// WebSocketServerTest, so only spot check differences like masking.
//
class WebSocketClientDataTest : public WebSocketClientTest {
public:
WebSocketClientDataTest() {
clientPipe->Connect(pipeName, [&] {
ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
if (setupWebSocket) {
setupWebSocket();
}
});
}
std::function<void()> setupWebSocket;
std::shared_ptr<WebSocket> ws;
};
TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest SendBinary",
"[websocket][client][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendBinary({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
loop->Run();
auto expectData = BuildMessage(0x02, true, true, data);
AdjustMasking(wireData);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_CASE_METHOD(WebSocketClientDataTest,
"WebSocketClientDataTest ReceiveBinary",
"[websocket][client][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->binary.connect([&](auto inData, bool fin) {
++gotCallback;
ws->Terminate();
REQUIRE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x02, true, false, data);
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
loop->Run();
REQUIRE(gotCallback == 1);
}
//
// The client must close the connection if a masked frame is received.
//
TEST_CASE_METHOD(WebSocketClientDataTest,
"WebSocketClientDataTest ReceiveMasked",
"[websocket][client][data][protocol]") {
int gotCallback = 0;
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
' ');
setupWebSocket = [&] {
ws->text.connect([&](std::string_view, bool) {
ws->Terminate();
FAIL("Should not have gotten masked message");
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x01, true, true, data);
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
loop->Run();
REQUIRE(gotCallback == 1);
}
2025-11-07 20:00:05 -05:00
} // namespace wpi::net