mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-29 02:21:44 +00:00
327 lines
9.1 KiB
C++
327 lines
9.1 KiB
C++
// Copyright (c) FIRST and other WPILib contributors.
|
|
// Open Source Software; you can modify and/or share it under the terms of
|
|
// the WPILib BSD license file in the root directory of this project.
|
|
|
|
// clang-format off
|
|
#include "wpi/net/WebSocket.hpp"
|
|
// clang-format on
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <catch2/generators/catch_generators.hpp>
|
|
|
|
#include "WebSocketTest.hpp"
|
|
#include "wpi/net/HttpParser.hpp"
|
|
#include "wpi/net/raw_uv_ostream.hpp"
|
|
#include "wpi/util/Base64.hpp"
|
|
#include "wpi/util/SmallString.hpp"
|
|
#include "wpi/util/StringExtras.hpp"
|
|
#include "wpi/util/sha1.hpp"
|
|
|
|
namespace wpi::net {
|
|
|
|
class WebSocketClientTest : public WebSocketTest {
|
|
public:
|
|
WebSocketClientTest() {
|
|
// Bare bones server
|
|
req.header.connect([this](std::string_view name, std::string_view value) {
|
|
// save key (required for valid response)
|
|
if (wpi::util::equals_lower(name, "sec-websocket-key")) {
|
|
clientKey = value;
|
|
}
|
|
});
|
|
req.headersComplete.connect([this](bool) {
|
|
// send response
|
|
wpi::util::SmallVector<uv::Buffer, 4> bufs;
|
|
raw_uv_ostream os{bufs, 4096};
|
|
os << "HTTP/1.1 101 Switching Protocols\r\n";
|
|
os << "Upgrade: websocket\r\n";
|
|
os << "Connection: Upgrade\r\n";
|
|
|
|
// accept hash
|
|
wpi::util::SHA1 hash;
|
|
hash.Update(clientKey);
|
|
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
|
if (mockBadAccept) {
|
|
hash.Update("1");
|
|
}
|
|
wpi::util::SmallString<64> hashBuf;
|
|
wpi::util::SmallString<64> acceptBuf;
|
|
os << "Sec-WebSocket-Accept: "
|
|
<< wpi::util::Base64Encode(hash.RawFinal(hashBuf), acceptBuf)
|
|
<< "\r\n";
|
|
|
|
if (!mockProtocol.empty()) {
|
|
os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
|
|
}
|
|
|
|
os << "\r\n";
|
|
|
|
conn->Write(bufs, [](auto bufs, uv::Error) {
|
|
for (auto& buf : bufs) {
|
|
buf.Deallocate();
|
|
}
|
|
});
|
|
|
|
serverHeadersDone = true;
|
|
if (connected) {
|
|
connected();
|
|
}
|
|
});
|
|
|
|
serverPipe->Listen([this] {
|
|
conn = serverPipe->Accept();
|
|
conn->StartRead();
|
|
conn->data.connect([this](uv::Buffer& buf, size_t size) {
|
|
std::string_view data{buf.base, size};
|
|
if (!serverHeadersDone) {
|
|
data = req.Execute(data);
|
|
if (req.HasError()) {
|
|
Finish();
|
|
}
|
|
INFO(http_errno_name(req.GetError()));
|
|
REQUIRE(req.GetError() == HPE_OK);
|
|
if (data.empty()) {
|
|
return;
|
|
}
|
|
}
|
|
wireData.insert(wireData.end(), data.begin(), data.end());
|
|
});
|
|
conn->end.connect([this] { Finish(); });
|
|
});
|
|
}
|
|
|
|
bool mockBadAccept = false;
|
|
std::vector<uint8_t> wireData;
|
|
std::shared_ptr<uv::Pipe> conn;
|
|
HttpParser req{HttpParser::Type::REQUEST};
|
|
wpi::util::SmallString<64> clientKey;
|
|
std::string mockProtocol;
|
|
bool serverHeadersDone = false;
|
|
std::function<void()> connected;
|
|
};
|
|
|
|
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest Open",
|
|
"[websocket][client][handshake]") {
|
|
int gotOpen = 0;
|
|
|
|
clientPipe->Connect(pipeName, [&] {
|
|
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
|
ws->closed.connect([&](uint16_t code, std::string_view reason) {
|
|
Finish();
|
|
if (code != 1005 && code != 1006) {
|
|
FAIL("Code: " << code << " Reason: " << reason);
|
|
}
|
|
});
|
|
ws->open.connect([&](std::string_view protocol) {
|
|
++gotOpen;
|
|
Finish();
|
|
REQUIRE(protocol.empty());
|
|
});
|
|
});
|
|
|
|
loop->Run();
|
|
REQUIRE(gotOpen == 1);
|
|
}
|
|
|
|
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest BadAccept",
|
|
"[websocket][client][handshake][protocol]") {
|
|
int gotClosed = 0;
|
|
|
|
mockBadAccept = true;
|
|
|
|
clientPipe->Connect(pipeName, [&] {
|
|
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
|
ws->closed.connect([&](uint16_t code, std::string_view msg) {
|
|
Finish();
|
|
++gotClosed;
|
|
INFO("Message: " << msg);
|
|
REQUIRE(code == 1002);
|
|
});
|
|
ws->open.connect([&](std::string_view protocol) {
|
|
Finish();
|
|
FAIL("Got open");
|
|
});
|
|
});
|
|
|
|
loop->Run();
|
|
REQUIRE(gotClosed == 1);
|
|
}
|
|
|
|
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolGood",
|
|
"[websocket][client][protocol]") {
|
|
int gotOpen = 0;
|
|
|
|
mockProtocol = "myProtocol";
|
|
|
|
clientPipe->Connect(pipeName, [&] {
|
|
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
|
|
{"myProtocol", "myProtocol2"});
|
|
ws->closed.connect([&](uint16_t code, std::string_view msg) {
|
|
Finish();
|
|
if (code != 1005 && code != 1006) {
|
|
FAIL("Code: " << code << "Message: " << msg);
|
|
}
|
|
});
|
|
ws->open.connect([&](std::string_view protocol) {
|
|
++gotOpen;
|
|
Finish();
|
|
REQUIRE(protocol == "myProtocol");
|
|
});
|
|
});
|
|
|
|
loop->Run();
|
|
REQUIRE(gotOpen == 1);
|
|
}
|
|
|
|
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolRespNotReq",
|
|
"[websocket][client][protocol]") {
|
|
int gotClosed = 0;
|
|
|
|
mockProtocol = "myProtocol";
|
|
|
|
clientPipe->Connect(pipeName, [&] {
|
|
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
|
ws->closed.connect([&](uint16_t code, std::string_view msg) {
|
|
Finish();
|
|
++gotClosed;
|
|
INFO("Message: " << msg);
|
|
REQUIRE(code == 1003);
|
|
});
|
|
ws->open.connect([&](std::string_view protocol) {
|
|
Finish();
|
|
FAIL("Got open");
|
|
});
|
|
});
|
|
|
|
loop->Run();
|
|
REQUIRE(gotClosed == 1);
|
|
}
|
|
|
|
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolReqNotResp",
|
|
"[websocket][client][protocol]") {
|
|
int gotClosed = 0;
|
|
|
|
clientPipe->Connect(pipeName, [&] {
|
|
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
|
|
{{"myProtocol"}});
|
|
ws->closed.connect([&](uint16_t code, std::string_view msg) {
|
|
Finish();
|
|
++gotClosed;
|
|
INFO("Message: " << msg);
|
|
REQUIRE(code == 1002);
|
|
});
|
|
ws->open.connect([&](std::string_view protocol) {
|
|
Finish();
|
|
FAIL("Got open");
|
|
});
|
|
});
|
|
|
|
loop->Run();
|
|
REQUIRE(gotClosed == 1);
|
|
}
|
|
|
|
//
|
|
// Send and receive data. Most of these cases are tested in
|
|
// WebSocketServerTest, so only spot check differences like masking.
|
|
//
|
|
|
|
class WebSocketClientDataTest : public WebSocketClientTest {
|
|
public:
|
|
WebSocketClientDataTest() {
|
|
clientPipe->Connect(pipeName, [&] {
|
|
ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
|
if (setupWebSocket) {
|
|
setupWebSocket();
|
|
}
|
|
});
|
|
}
|
|
|
|
std::function<void()> setupWebSocket;
|
|
std::shared_ptr<WebSocket> ws;
|
|
};
|
|
|
|
TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest SendBinary",
|
|
"[websocket][client][data]") {
|
|
int gotCallback = 0;
|
|
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
|
|
size_t{126}, size_t{65535}, size_t{65536}),
|
|
0x03u);
|
|
setupWebSocket = [&] {
|
|
ws->open.connect([&](std::string_view) {
|
|
ws->SendBinary({{data}}, [&](auto bufs, uv::Error) {
|
|
++gotCallback;
|
|
ws->Terminate();
|
|
REQUIRE_FALSE(bufs.empty());
|
|
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
|
|
});
|
|
});
|
|
};
|
|
|
|
loop->Run();
|
|
|
|
auto expectData = BuildMessage(0x02, true, true, data);
|
|
AdjustMasking(wireData);
|
|
REQUIRE(wireData == expectData);
|
|
REQUIRE(gotCallback == 1);
|
|
}
|
|
|
|
TEST_CASE_METHOD(WebSocketClientDataTest,
|
|
"WebSocketClientDataTest ReceiveBinary",
|
|
"[websocket][client][data]") {
|
|
int gotCallback = 0;
|
|
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
|
|
size_t{126}, size_t{65535}, size_t{65536}),
|
|
0x03u);
|
|
setupWebSocket = [&] {
|
|
ws->binary.connect([&](auto inData, bool fin) {
|
|
++gotCallback;
|
|
ws->Terminate();
|
|
REQUIRE(fin);
|
|
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
|
REQUIRE(data == recvData);
|
|
});
|
|
};
|
|
auto message = BuildMessage(0x02, true, false, data);
|
|
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
|
|
|
|
loop->Run();
|
|
|
|
REQUIRE(gotCallback == 1);
|
|
}
|
|
|
|
//
|
|
// The client must close the connection if a masked frame is received.
|
|
//
|
|
|
|
TEST_CASE_METHOD(WebSocketClientDataTest,
|
|
"WebSocketClientDataTest ReceiveMasked",
|
|
"[websocket][client][data][protocol]") {
|
|
int gotCallback = 0;
|
|
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
|
|
size_t{126}, size_t{65535}, size_t{65536}),
|
|
' ');
|
|
setupWebSocket = [&] {
|
|
ws->text.connect([&](std::string_view, bool) {
|
|
ws->Terminate();
|
|
FAIL("Should not have gotten masked message");
|
|
});
|
|
ws->closed.connect([&](uint16_t code, std::string_view reason) {
|
|
++gotCallback;
|
|
INFO("reason: " << reason);
|
|
REQUIRE(code == 1002);
|
|
});
|
|
};
|
|
auto message = BuildMessage(0x01, true, true, data);
|
|
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
|
|
|
|
loop->Run();
|
|
|
|
REQUIRE(gotCallback == 1);
|
|
}
|
|
|
|
} // namespace wpi::net
|