mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-07-05 03:21:42 +00:00
wpiutil: Add WebSocket implementation (#1186)
This is a RFC 6455 compliant implementation with both client and server support.
This commit is contained in:
565
wpiutil/src/main/native/cpp/WebSocket.cpp
Normal file
565
wpiutil/src/main/native/cpp/WebSocket.cpp
Normal file
@@ -0,0 +1,565 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocket.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "wpi/Base64.h"
|
||||
#include "wpi/HttpParser.h"
|
||||
#include "wpi/SmallString.h"
|
||||
#include "wpi/SmallVector.h"
|
||||
#include "wpi/raw_uv_ostream.h"
|
||||
#include "wpi/sha1.h"
|
||||
#include "wpi/uv/Stream.h"
|
||||
|
||||
using namespace wpi;
|
||||
|
||||
namespace {
|
||||
class WebSocketWriteReq : public uv::WriteReq {
|
||||
public:
|
||||
explicit WebSocketWriteReq(
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
finish.connect([=](uv::Error err) {
|
||||
MutableArrayRef<uv::Buffer> bufs{m_bufs};
|
||||
for (auto&& buf : bufs.slice(0, m_startUser)) buf.Deallocate();
|
||||
callback(bufs.slice(m_startUser), err);
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<uv::Buffer, 4> m_bufs;
|
||||
size_t m_startUser;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
class WebSocket::ClientHandshakeData {
|
||||
public:
|
||||
ClientHandshakeData() {
|
||||
// key is a random nonce
|
||||
static std::random_device rd;
|
||||
static std::default_random_engine gen{rd()};
|
||||
std::uniform_int_distribution<unsigned int> dist(0, 255);
|
||||
char nonce[16]; // the nonce sent to the server
|
||||
for (char& v : nonce) v = static_cast<char>(dist(gen));
|
||||
raw_svector_ostream os(key);
|
||||
Base64Encode(os, StringRef{nonce, 16});
|
||||
}
|
||||
~ClientHandshakeData() {
|
||||
if (auto t = timer.lock()) {
|
||||
t->Stop();
|
||||
t->Close();
|
||||
}
|
||||
}
|
||||
|
||||
SmallString<64> key; // the key sent to the server
|
||||
SmallVector<std::string, 2> protocols; // valid protocols
|
||||
HttpParser parser{HttpParser::kResponse}; // server response parser
|
||||
bool hasUpgrade = false;
|
||||
bool hasConnection = false;
|
||||
bool hasAccept = false;
|
||||
bool hasProtocol = false;
|
||||
|
||||
std::weak_ptr<uv::Timer> timer;
|
||||
};
|
||||
|
||||
static StringRef AcceptHash(StringRef key, SmallVectorImpl<char>& buf) {
|
||||
SHA1 hash;
|
||||
hash.Update(key);
|
||||
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
SmallString<64> hashBuf;
|
||||
return Base64Encode(hash.Final(hashBuf), buf);
|
||||
}
|
||||
|
||||
WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
|
||||
: m_stream{stream}, m_server{server} {
|
||||
// Connect closed and error signals to ourselves
|
||||
m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
|
||||
m_stream.error.connect([this](uv::Error err) {
|
||||
Terminate(1006, "stream error: " + Twine(err.name()));
|
||||
});
|
||||
|
||||
// Start reading
|
||||
m_stream.StopRead(); // we may have been reading
|
||||
m_stream.StartRead();
|
||||
m_stream.data.connect(
|
||||
[this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
|
||||
m_stream.end.connect(
|
||||
[this]() { Terminate(1006, "remote end closed connection"); });
|
||||
}
|
||||
|
||||
WebSocket::~WebSocket() {}
|
||||
|
||||
std::shared_ptr<WebSocket> WebSocket::CreateClient(
|
||||
uv::Stream& stream, const Twine& uri, const Twine& host,
|
||||
ArrayRef<StringRef> protocols, const ClientOptions& options) {
|
||||
auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
|
||||
stream.SetData(ws);
|
||||
ws->StartClient(uri, host, protocols, options);
|
||||
return ws;
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
|
||||
StringRef key,
|
||||
StringRef version,
|
||||
StringRef protocol) {
|
||||
auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
|
||||
stream.SetData(ws);
|
||||
ws->StartServer(key, version, protocol);
|
||||
return ws;
|
||||
}
|
||||
|
||||
void WebSocket::Close(uint16_t code, const Twine& reason) {
|
||||
SendClose(code, reason);
|
||||
if (m_state != FAILED && m_state != CLOSED) m_state = CLOSING;
|
||||
}
|
||||
|
||||
void WebSocket::Fail(uint16_t code, const Twine& reason) {
|
||||
if (m_state == FAILED || m_state == CLOSED) return;
|
||||
SendClose(code, reason);
|
||||
SetClosed(code, reason, true);
|
||||
Shutdown();
|
||||
}
|
||||
|
||||
void WebSocket::Terminate(uint16_t code, const Twine& reason) {
|
||||
if (m_state == FAILED || m_state == CLOSED) return;
|
||||
SetClosed(code, reason);
|
||||
Shutdown();
|
||||
}
|
||||
|
||||
void WebSocket::StartClient(const Twine& uri, const Twine& host,
|
||||
ArrayRef<StringRef> protocols,
|
||||
const ClientOptions& options) {
|
||||
// Create client handshake data
|
||||
m_clientHandshake = std::make_unique<ClientHandshakeData>();
|
||||
|
||||
// Build client request
|
||||
SmallVector<uv::Buffer, 4> bufs;
|
||||
raw_uv_ostream os{bufs, 4096};
|
||||
|
||||
os << "GET " << uri << " HTTP/1.1\r\n";
|
||||
os << "Host: " << host << "\r\n";
|
||||
os << "Upgrade: websocket\r\n";
|
||||
os << "Connection: Upgrade\r\n";
|
||||
os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
|
||||
os << "Sec-WebSocket-Version: 13\r\n";
|
||||
|
||||
// protocols (if provided)
|
||||
if (!protocols.empty()) {
|
||||
os << "Sec-WebSocket-Protocol: ";
|
||||
bool first = true;
|
||||
for (auto protocol : protocols) {
|
||||
if (!first)
|
||||
os << ", ";
|
||||
else
|
||||
first = false;
|
||||
os << protocol;
|
||||
// also save for later checking against server response
|
||||
m_clientHandshake->protocols.emplace_back(protocol);
|
||||
}
|
||||
os << "\r\n";
|
||||
}
|
||||
|
||||
// other headers
|
||||
for (auto&& header : options.extraHeaders)
|
||||
os << header.first << ": " << header.second << "\r\n";
|
||||
|
||||
// finish headers
|
||||
os << "\r\n";
|
||||
|
||||
// Send client request
|
||||
m_stream.Write(bufs, [](auto bufs, uv::Error) {
|
||||
for (auto& buf : bufs) buf.Deallocate();
|
||||
});
|
||||
|
||||
// Set up client response handling
|
||||
m_clientHandshake->parser.status.connect([this](StringRef status) {
|
||||
unsigned int code = m_clientHandshake->parser.GetStatusCode();
|
||||
if (code != 101) Terminate(code, status);
|
||||
});
|
||||
m_clientHandshake->parser.header.connect(
|
||||
[this](StringRef name, StringRef value) {
|
||||
value = value.trim();
|
||||
if (name.equals_lower("upgrade")) {
|
||||
if (!value.equals_lower("websocket"))
|
||||
return Terminate(1002, "invalid upgrade response value");
|
||||
m_clientHandshake->hasUpgrade = true;
|
||||
} else if (name.equals_lower("connection")) {
|
||||
if (!value.equals_lower("upgrade"))
|
||||
return Terminate(1002, "invalid connection response value");
|
||||
m_clientHandshake->hasConnection = true;
|
||||
} else if (name.equals_lower("sec-websocket-accept")) {
|
||||
// Check against expected response
|
||||
SmallString<64> acceptBuf;
|
||||
if (!value.equals(AcceptHash(m_clientHandshake->key, acceptBuf)))
|
||||
return Terminate(1002, "invalid accept key");
|
||||
m_clientHandshake->hasAccept = true;
|
||||
} else if (name.equals_lower("sec-websocket-extensions")) {
|
||||
// No extensions are supported
|
||||
if (!value.empty()) return Terminate(1010, "unsupported extension");
|
||||
} else if (name.equals_lower("sec-websocket-protocol")) {
|
||||
// Make sure it was one of the provided protocols
|
||||
bool match = false;
|
||||
for (auto&& protocol : m_clientHandshake->protocols) {
|
||||
if (value.equals_lower(protocol)) {
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) return Terminate(1003, "unsupported protocol");
|
||||
m_clientHandshake->hasProtocol = true;
|
||||
m_protocol = value;
|
||||
}
|
||||
});
|
||||
m_clientHandshake->parser.headersComplete.connect([this](bool) {
|
||||
if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
|
||||
!m_clientHandshake->hasAccept ||
|
||||
(!m_clientHandshake->hasProtocol &&
|
||||
!m_clientHandshake->protocols.empty())) {
|
||||
return Terminate(1002, "invalid response");
|
||||
}
|
||||
if (m_state == CONNECTING) {
|
||||
m_state = OPEN;
|
||||
open(m_protocol);
|
||||
}
|
||||
});
|
||||
|
||||
// Start handshake timer if a timeout was specified
|
||||
if (options.handshakeTimeout != uv::Timer::Time::max()) {
|
||||
auto timer = uv::Timer::Create(m_stream.GetLoopRef());
|
||||
timer->timeout.connect(
|
||||
[this]() { Terminate(1006, "connection timed out"); });
|
||||
timer->Start(options.handshakeTimeout);
|
||||
m_clientHandshake->timer = timer;
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocket::StartServer(StringRef key, StringRef version,
|
||||
StringRef protocol) {
|
||||
m_protocol = protocol;
|
||||
|
||||
// Build server response
|
||||
SmallVector<uv::Buffer, 4> bufs;
|
||||
raw_uv_ostream os{bufs, 4096};
|
||||
|
||||
// Handle unsupported version
|
||||
if (version != "13") {
|
||||
os << "HTTP/1.1 426 Upgrade Required\r\n";
|
||||
os << "Upgrade: WebSocket\r\n";
|
||||
os << "Sec-WebSocket-Version: 13\r\n\r\n";
|
||||
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
|
||||
for (auto& buf : bufs) buf.Deallocate();
|
||||
// XXX: Should we support sending a new handshake on the same connection?
|
||||
// XXX: "this->" is required by GCC 5.5 (bug)
|
||||
this->Terminate(1003, "unsupported protocol version");
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
os << "HTTP/1.1 101 Switching Protocols\r\n";
|
||||
os << "Upgrade: websocket\r\n";
|
||||
os << "Connection: Upgrade\r\n";
|
||||
|
||||
// accept hash
|
||||
SmallString<64> acceptBuf;
|
||||
os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
|
||||
|
||||
if (!protocol.empty()) os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
|
||||
|
||||
// end headers
|
||||
os << "\r\n";
|
||||
|
||||
// Send server response
|
||||
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
|
||||
for (auto& buf : bufs) buf.Deallocate();
|
||||
if (m_state == CONNECTING) {
|
||||
m_state = OPEN;
|
||||
open(m_protocol);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void WebSocket::SendClose(uint16_t code, const Twine& reason) {
|
||||
SmallVector<uv::Buffer, 4> bufs;
|
||||
if (code != 1005) {
|
||||
raw_uv_ostream os{bufs, 4096};
|
||||
os << ArrayRef<uint8_t>{static_cast<uint8_t>((code >> 8) & 0xff),
|
||||
static_cast<uint8_t>(code & 0xff)};
|
||||
reason.print(os);
|
||||
}
|
||||
Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
|
||||
for (auto&& buf : bufs) buf.Deallocate();
|
||||
});
|
||||
}
|
||||
|
||||
void WebSocket::SetClosed(uint16_t code, const Twine& reason, bool failed) {
|
||||
if (m_state == FAILED || m_state == CLOSED) return;
|
||||
m_state = failed ? FAILED : CLOSED;
|
||||
SmallString<64> reasonBuf;
|
||||
closed(code, reason.toStringRef(reasonBuf));
|
||||
}
|
||||
|
||||
void WebSocket::Shutdown() {
|
||||
m_stream.Shutdown([this] { m_stream.Close(); });
|
||||
}
|
||||
|
||||
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
|
||||
// ignore incoming data if we're failed or closed
|
||||
if (m_state == FAILED || m_state == CLOSED) return;
|
||||
|
||||
StringRef data{buf.base, size};
|
||||
|
||||
// Handle connecting state (mainly on client)
|
||||
if (m_state == CONNECTING) {
|
||||
if (m_clientHandshake) {
|
||||
data = m_clientHandshake->parser.Execute(data);
|
||||
// check for parser failure
|
||||
if (m_clientHandshake->parser.HasError())
|
||||
return Terminate(1003, "invalid response");
|
||||
if (m_state != OPEN) return; // not done with handshake yet
|
||||
|
||||
// we're done with the handshake, so release its memory
|
||||
m_clientHandshake.reset();
|
||||
|
||||
// fall through to process additional data after handshake
|
||||
} else {
|
||||
return Terminate(1003, "got data on server before response");
|
||||
}
|
||||
}
|
||||
|
||||
// Message processing
|
||||
while (!data.empty()) {
|
||||
if (m_frameSize == UINT64_MAX) {
|
||||
// Need at least two bytes to determine header length
|
||||
if (m_header.size() < 2u) {
|
||||
size_t toCopy = std::min(2u - m_header.size(), data.size());
|
||||
m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
|
||||
data = data.drop_front(toCopy);
|
||||
if (m_header.size() < 2u) return; // need more data
|
||||
|
||||
// Validate RSV bits are zero
|
||||
if ((m_header[0] & 0x70) != 0) return Fail(1002, "nonzero RSV");
|
||||
}
|
||||
|
||||
// Once we have first two bytes, we can calculate the header size
|
||||
if (m_headerSize == 0) {
|
||||
m_headerSize = 2;
|
||||
uint8_t len = m_header[1] & kLenMask;
|
||||
if (len == 126)
|
||||
m_headerSize += 2;
|
||||
else if (len == 127)
|
||||
m_headerSize += 8;
|
||||
bool masking = (m_header[1] & kFlagMasking) != 0;
|
||||
if (masking) m_headerSize += 4; // masking key
|
||||
// On server side, incoming messages MUST be masked
|
||||
// On client side, incoming messages MUST NOT be masked
|
||||
if (m_server && !masking) return Fail(1002, "client data not masked");
|
||||
if (!m_server && masking) return Fail(1002, "server data masked");
|
||||
}
|
||||
|
||||
// Need to complete header to calculate message size
|
||||
if (m_header.size() < m_headerSize) {
|
||||
size_t toCopy = std::min(m_headerSize - m_header.size(), data.size());
|
||||
m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
|
||||
data = data.drop_front(toCopy);
|
||||
if (m_header.size() < m_headerSize) return; // need more data
|
||||
}
|
||||
|
||||
if (m_header.size() >= m_headerSize) {
|
||||
// get payload length
|
||||
uint8_t len = m_header[1] & kLenMask;
|
||||
if (len == 126)
|
||||
m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
|
||||
static_cast<uint16_t>(m_header[3]);
|
||||
else if (len == 127)
|
||||
m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
|
||||
(static_cast<uint64_t>(m_header[3]) << 48) |
|
||||
(static_cast<uint64_t>(m_header[4]) << 40) |
|
||||
(static_cast<uint64_t>(m_header[5]) << 32) |
|
||||
(static_cast<uint64_t>(m_header[6]) << 24) |
|
||||
(static_cast<uint64_t>(m_header[7]) << 16) |
|
||||
(static_cast<uint64_t>(m_header[8]) << 8) |
|
||||
static_cast<uint64_t>(m_header[9]);
|
||||
else
|
||||
m_frameSize = len;
|
||||
|
||||
// limit maximum size
|
||||
if ((m_payload.size() + m_frameSize) > m_maxMessageSize)
|
||||
return Fail(1009, "message too large");
|
||||
}
|
||||
}
|
||||
|
||||
if (m_frameSize != UINT64_MAX) {
|
||||
size_t need = m_frameStart + m_frameSize - m_payload.size();
|
||||
size_t toCopy = std::min(need, data.size());
|
||||
m_payload.append(data.bytes_begin(), data.bytes_begin() + toCopy);
|
||||
data = data.drop_front(toCopy);
|
||||
need -= toCopy;
|
||||
if (need == 0) {
|
||||
// We have a complete frame
|
||||
// If the message had masking, unmask it
|
||||
if ((m_header[1] & kFlagMasking) != 0) {
|
||||
uint8_t key[4] = {
|
||||
m_header[m_headerSize - 4], m_header[m_headerSize - 3],
|
||||
m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
|
||||
int n = 0;
|
||||
for (uint8_t& ch :
|
||||
MutableArrayRef<uint8_t>{m_payload}.slice(m_frameStart)) {
|
||||
ch ^= key[n++];
|
||||
if (n >= 4) n = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message
|
||||
bool fin = (m_header[0] & kFlagFin) != 0;
|
||||
uint8_t opcode = m_header[0] & kOpMask;
|
||||
switch (opcode) {
|
||||
case kOpCont:
|
||||
switch (m_fragmentOpcode) {
|
||||
case kOpText:
|
||||
if (!m_combineFragments || fin)
|
||||
text(StringRef{reinterpret_cast<char*>(m_payload.data()),
|
||||
m_payload.size()},
|
||||
fin);
|
||||
break;
|
||||
case kOpBinary:
|
||||
if (!m_combineFragments || fin) binary(m_payload, fin);
|
||||
break;
|
||||
default:
|
||||
// no preceding message?
|
||||
return Fail(1002, "invalid continuation message");
|
||||
}
|
||||
if (fin) m_fragmentOpcode = 0;
|
||||
break;
|
||||
case kOpText:
|
||||
if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
|
||||
if (!m_combineFragments || fin)
|
||||
text(StringRef{reinterpret_cast<char*>(m_payload.data()),
|
||||
m_payload.size()},
|
||||
fin);
|
||||
if (!fin) m_fragmentOpcode = opcode;
|
||||
break;
|
||||
case kOpBinary:
|
||||
if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
|
||||
if (!m_combineFragments || fin) binary(m_payload, fin);
|
||||
if (!fin) m_fragmentOpcode = opcode;
|
||||
break;
|
||||
case kOpClose: {
|
||||
uint16_t code;
|
||||
StringRef reason;
|
||||
if (!fin) {
|
||||
code = 1002;
|
||||
reason = "cannot fragment control frames";
|
||||
} else if (m_payload.size() < 2) {
|
||||
code = 1005;
|
||||
} else {
|
||||
code = (static_cast<uint16_t>(m_payload[0]) << 8) |
|
||||
static_cast<uint16_t>(m_payload[1]);
|
||||
reason = StringRef{reinterpret_cast<char*>(m_payload.data()),
|
||||
m_payload.size()}
|
||||
.drop_front(2);
|
||||
}
|
||||
// Echo the close if we didn't previously send it
|
||||
if (m_state != CLOSING) SendClose(code, reason);
|
||||
SetClosed(code, reason);
|
||||
// If we're the server, shutdown the connection.
|
||||
if (m_server) Shutdown();
|
||||
break;
|
||||
}
|
||||
case kOpPing:
|
||||
if (!fin) return Fail(1002, "cannot fragment control frames");
|
||||
ping(m_payload);
|
||||
break;
|
||||
case kOpPong:
|
||||
if (!fin) return Fail(1002, "cannot fragment control frames");
|
||||
pong(m_payload);
|
||||
break;
|
||||
default:
|
||||
return Fail(1002, "invalid message opcode");
|
||||
}
|
||||
|
||||
// Prepare for next message
|
||||
m_header.clear();
|
||||
m_headerSize = 0;
|
||||
if (!m_combineFragments || fin) m_payload.clear();
|
||||
m_frameStart = m_payload.size();
|
||||
m_frameSize = UINT64_MAX;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocket::Send(
|
||||
uint8_t opcode, ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
// If we're not open, emit an error and don't send the data
|
||||
if (m_state != OPEN) {
|
||||
int err;
|
||||
if (m_state == CONNECTING)
|
||||
err = UV_EAGAIN;
|
||||
else
|
||||
err = UV_ESHUTDOWN;
|
||||
SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
|
||||
callback(bufs, uv::Error{err});
|
||||
return;
|
||||
}
|
||||
|
||||
auto req = std::make_shared<WebSocketWriteReq>(callback);
|
||||
raw_uv_ostream os{req->m_bufs, 4096};
|
||||
|
||||
// opcode (includes FIN bit)
|
||||
os << static_cast<unsigned char>(opcode);
|
||||
|
||||
// payload length
|
||||
uint64_t size = 0;
|
||||
for (auto&& buf : data) size += buf.len;
|
||||
if (size < 126) {
|
||||
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
|
||||
} else if (size <= 0xffff) {
|
||||
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
|
||||
os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 8) & 0xff),
|
||||
static_cast<uint8_t>(size & 0xff)};
|
||||
} else {
|
||||
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
|
||||
os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 56) & 0xff),
|
||||
static_cast<uint8_t>((size >> 48) & 0xff),
|
||||
static_cast<uint8_t>((size >> 40) & 0xff),
|
||||
static_cast<uint8_t>((size >> 32) & 0xff),
|
||||
static_cast<uint8_t>((size >> 24) & 0xff),
|
||||
static_cast<uint8_t>((size >> 16) & 0xff),
|
||||
static_cast<uint8_t>((size >> 8) & 0xff),
|
||||
static_cast<uint8_t>(size & 0xff)};
|
||||
}
|
||||
|
||||
// clients need to mask the input data
|
||||
if (!m_server) {
|
||||
// generate masking key
|
||||
static std::random_device rd;
|
||||
static std::default_random_engine gen{rd()};
|
||||
std::uniform_int_distribution<unsigned int> dist(0, 255);
|
||||
uint8_t key[4];
|
||||
for (uint8_t& v : key) v = dist(gen);
|
||||
os << ArrayRef<uint8_t>{key, 4};
|
||||
// copy and mask data
|
||||
int n = 0;
|
||||
for (auto&& buf : data) {
|
||||
for (auto&& ch : buf.data()) {
|
||||
os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
|
||||
if (n >= 4) n = 0;
|
||||
}
|
||||
}
|
||||
req->m_startUser = req->m_bufs.size();
|
||||
req->m_bufs.append(data.begin(), data.end());
|
||||
// don't send the user bufs as we copied their data
|
||||
m_stream.Write(ArrayRef<uv::Buffer>{req->m_bufs}.slice(0, req->m_startUser),
|
||||
req);
|
||||
} else {
|
||||
// servers can just send the buffers directly without masking
|
||||
req->m_startUser = req->m_bufs.size();
|
||||
req->m_bufs.append(data.begin(), data.end());
|
||||
m_stream.Write(req->m_bufs, req);
|
||||
}
|
||||
}
|
||||
142
wpiutil/src/main/native/cpp/WebSocketServer.cpp
Normal file
142
wpiutil/src/main/native/cpp/WebSocketServer.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocketServer.h"
|
||||
|
||||
#include "wpi/raw_uv_ostream.h"
|
||||
#include "wpi/uv/Buffer.h"
|
||||
#include "wpi/uv/Stream.h"
|
||||
|
||||
using namespace wpi;
|
||||
|
||||
WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) {
|
||||
req.header.connect([this](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("host")) {
|
||||
m_gotHost = true;
|
||||
} else if (name.equals_lower("upgrade")) {
|
||||
if (value.equals_lower("websocket")) m_websocket = true;
|
||||
} else if (name.equals_lower("sec-websocket-key")) {
|
||||
m_key = value;
|
||||
} else if (name.equals_lower("sec-websocket-version")) {
|
||||
m_version = value;
|
||||
} else if (name.equals_lower("sec-websocket-protocol")) {
|
||||
// Protocols are comma delimited, repeated headers add to list
|
||||
SmallVector<StringRef, 2> protocols;
|
||||
value.split(protocols, ",", -1, false);
|
||||
for (auto protocol : protocols) {
|
||||
protocol = protocol.trim();
|
||||
if (!protocol.empty()) m_protocols.emplace_back(protocol);
|
||||
}
|
||||
}
|
||||
});
|
||||
req.headersComplete.connect([&req, this](bool) {
|
||||
if (req.IsUpgrade() && IsUpgrade()) upgrade();
|
||||
});
|
||||
}
|
||||
|
||||
std::pair<bool, StringRef> WebSocketServerHelper::MatchProtocol(
|
||||
ArrayRef<StringRef> protocols) {
|
||||
if (protocols.empty() && m_protocols.empty())
|
||||
return std::make_pair(true, StringRef{});
|
||||
for (auto protocol : protocols) {
|
||||
for (auto&& clientProto : m_protocols) {
|
||||
if (protocol == clientProto) return std::make_pair(true, protocol);
|
||||
}
|
||||
}
|
||||
return std::make_pair(false, StringRef{});
|
||||
}
|
||||
|
||||
WebSocketServer::WebSocketServer(uv::Stream& stream,
|
||||
ArrayRef<StringRef> protocols,
|
||||
const ServerOptions& options,
|
||||
const private_init&)
|
||||
: m_stream{stream},
|
||||
m_helper{m_req},
|
||||
m_protocols{protocols.begin(), protocols.end()},
|
||||
m_options{options} {
|
||||
// Header handling
|
||||
m_req.header.connect([this](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("host")) {
|
||||
if (m_options.checkHost) {
|
||||
if (!m_options.checkHost(value)) Abort(401, "Unrecognized Host");
|
||||
}
|
||||
}
|
||||
});
|
||||
m_req.url.connect([this](StringRef name) {
|
||||
if (m_options.checkUrl) {
|
||||
if (!m_options.checkUrl(name)) Abort(404, "Not Found");
|
||||
}
|
||||
});
|
||||
m_req.headersComplete.connect([this](bool) {
|
||||
// We only accept websocket connections
|
||||
if (!m_helper.IsUpgrade() || !m_req.IsUpgrade())
|
||||
Abort(426, "Upgrade Required");
|
||||
});
|
||||
|
||||
// Handle upgrade event
|
||||
m_helper.upgrade.connect([this] {
|
||||
if (m_aborted) return;
|
||||
|
||||
// Negotiate sub-protocol
|
||||
SmallVector<StringRef, 2> protocols{m_protocols.begin(), m_protocols.end()};
|
||||
StringRef protocol = m_helper.MatchProtocol(protocols).second;
|
||||
|
||||
// Disconnect our header reader
|
||||
m_headerConn.disconnect();
|
||||
|
||||
// Accepting the stream may destroy this (as it replaces the stream user
|
||||
// data), so grab a shared pointer first.
|
||||
auto self = shared_from_this();
|
||||
|
||||
// Accept the upgrade
|
||||
auto ws = m_helper.Accept(m_stream, protocol);
|
||||
|
||||
// Connect the websocket open event to our connected event.
|
||||
ws->open.connect_extended([ self, s = ws.get() ](auto conn, StringRef) {
|
||||
self->connected(self->m_req.GetUrl(), *s);
|
||||
conn.disconnect(); // one-shot
|
||||
});
|
||||
});
|
||||
|
||||
// Set up stream
|
||||
stream.StartRead();
|
||||
m_headerConn =
|
||||
stream.data.connect_connection([this](uv::Buffer& buf, size_t size) {
|
||||
if (m_aborted) return;
|
||||
m_req.Execute(StringRef{buf.base, size});
|
||||
if (m_req.HasError()) Abort(400, "Bad Request");
|
||||
});
|
||||
stream.error.connect([this](uv::Error) { m_stream.Close(); });
|
||||
stream.end.connect([this] { m_stream.Close(); });
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketServer> WebSocketServer::Create(
|
||||
uv::Stream& stream, ArrayRef<StringRef> protocols,
|
||||
const ServerOptions& options) {
|
||||
auto server = std::make_shared<WebSocketServer>(stream, protocols, options,
|
||||
private_init{});
|
||||
stream.SetData(server);
|
||||
return server;
|
||||
}
|
||||
|
||||
void WebSocketServer::Abort(uint16_t code, StringRef reason) {
|
||||
if (m_aborted) return;
|
||||
m_aborted = true;
|
||||
|
||||
// Build response
|
||||
SmallVector<uv::Buffer, 4> bufs;
|
||||
raw_uv_ostream os{bufs, 1024};
|
||||
|
||||
// Handle unsupported version
|
||||
os << "HTTP/1.1 " << code << ' ' << reason << "\r\n";
|
||||
if (code == 426) os << "Upgrade: WebSocket\r\n";
|
||||
os << "\r\n";
|
||||
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
|
||||
for (auto& buf : bufs) buf.Deallocate();
|
||||
m_stream.Shutdown([this] { m_stream.Close(); });
|
||||
});
|
||||
}
|
||||
378
wpiutil/src/main/native/include/wpi/WebSocket.h
Normal file
378
wpiutil/src/main/native/include/wpi/WebSocket.h
Normal file
@@ -0,0 +1,378 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#ifndef WPIUTIL_WPI_WEBSOCKET_H_
|
||||
#define WPIUTIL_WPI_WEBSOCKET_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "wpi/ArrayRef.h"
|
||||
#include "wpi/Signal.h"
|
||||
#include "wpi/SmallVector.h"
|
||||
#include "wpi/StringRef.h"
|
||||
#include "wpi/Twine.h"
|
||||
#include "wpi/uv/Buffer.h"
|
||||
#include "wpi/uv/Error.h"
|
||||
#include "wpi/uv/Timer.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
namespace uv {
|
||||
class Stream;
|
||||
} // namespace uv
|
||||
|
||||
/**
|
||||
* RFC 6455 compliant WebSocket client and server implementation.
|
||||
*/
|
||||
class WebSocket : public std::enable_shared_from_this<WebSocket> {
|
||||
struct private_init {};
|
||||
|
||||
static constexpr uint8_t kOpCont = 0x00;
|
||||
static constexpr uint8_t kOpText = 0x01;
|
||||
static constexpr uint8_t kOpBinary = 0x02;
|
||||
static constexpr uint8_t kOpClose = 0x08;
|
||||
static constexpr uint8_t kOpPing = 0x09;
|
||||
static constexpr uint8_t kOpPong = 0x0A;
|
||||
static constexpr uint8_t kOpMask = 0x0F;
|
||||
static constexpr uint8_t kFlagFin = 0x80;
|
||||
static constexpr uint8_t kFlagMasking = 0x80;
|
||||
static constexpr uint8_t kLenMask = 0x7f;
|
||||
|
||||
public:
|
||||
WebSocket(uv::Stream& stream, bool server, const private_init&);
|
||||
WebSocket(const WebSocket&) = delete;
|
||||
WebSocket(WebSocket&&) = delete;
|
||||
WebSocket& operator=(const WebSocket&) = delete;
|
||||
WebSocket& operator=(WebSocket&&) = delete;
|
||||
~WebSocket();
|
||||
|
||||
/**
|
||||
* Connection states.
|
||||
*/
|
||||
enum State {
|
||||
/** The connection is not yet open. */
|
||||
CONNECTING = 0,
|
||||
/** The connection is open and ready to communicate. */
|
||||
OPEN,
|
||||
/** The connection is in the process of closing. */
|
||||
CLOSING,
|
||||
/** The connection failed. */
|
||||
FAILED,
|
||||
/** The connection is closed. */
|
||||
CLOSED
|
||||
};
|
||||
|
||||
/**
|
||||
* Client connection options.
|
||||
*/
|
||||
struct ClientOptions {
|
||||
ClientOptions() : handshakeTimeout{uv::Timer::Time::max()} {}
|
||||
|
||||
/** Timeout for the handshake request. */
|
||||
uv::Timer::Time handshakeTimeout;
|
||||
|
||||
/** Additional headers to include in handshake. */
|
||||
ArrayRef<std::pair<StringRef, StringRef>> extraHeaders;
|
||||
};
|
||||
|
||||
/**
|
||||
* Starts a client connection by performing the initial client handshake.
|
||||
* An open event is emitted when the handshake completes.
|
||||
* This sets the stream user data to the websocket.
|
||||
* @param stream Connection stream
|
||||
* @param uri The Request-URI to send
|
||||
* @param host The host or host:port to send
|
||||
* @param protocols The list of subprotocols
|
||||
* @param options Handshake options
|
||||
*/
|
||||
static std::shared_ptr<WebSocket> CreateClient(
|
||||
uv::Stream& stream, const Twine& uri, const Twine& host,
|
||||
ArrayRef<StringRef> protocols = ArrayRef<StringRef>{},
|
||||
const ClientOptions& options = ClientOptions{});
|
||||
|
||||
/**
|
||||
* Starts a server connection by performing the initial server side handshake.
|
||||
* This should be called after the HTTP headers have been received.
|
||||
* An open event is emitted when the handshake completes.
|
||||
* This sets the stream user data to the websocket.
|
||||
* @param stream Connection stream
|
||||
* @param key The value of the Sec-WebSocket-Key header field in the client
|
||||
* request
|
||||
* @param version The value of the Sec-WebSocket-Version header field in the
|
||||
* client request
|
||||
* @param protocol The subprotocol to send to the client (in the
|
||||
* Sec-WebSocket-Protocol header field).
|
||||
*/
|
||||
static std::shared_ptr<WebSocket> CreateServer(
|
||||
uv::Stream& stream, StringRef key, StringRef version,
|
||||
StringRef protocol = StringRef{});
|
||||
|
||||
/**
|
||||
* Get connection state.
|
||||
*/
|
||||
State GetState() const { return m_state; }
|
||||
|
||||
/**
|
||||
* Return if the connection is open. Messages can only be sent on open
|
||||
* connections.
|
||||
*/
|
||||
bool IsOpen() const { return m_state == OPEN; }
|
||||
|
||||
/**
|
||||
* Get the underlying stream.
|
||||
*/
|
||||
uv::Stream& GetStream() const { return m_stream; }
|
||||
|
||||
/**
|
||||
* Get the selected sub-protocol. Only valid in or after the open() event.
|
||||
*/
|
||||
StringRef GetProtocol() const { return m_protocol; }
|
||||
|
||||
/**
|
||||
* Set the maximum message size. Default is 128 KB. If configured to combine
|
||||
* fragments this maximum applies to the entire message (all combined
|
||||
* fragments).
|
||||
* @param size Maximum message size in bytes
|
||||
*/
|
||||
void SetMaxMessageSize(size_t size) { m_maxMessageSize = size; }
|
||||
|
||||
/**
|
||||
* Set whether or not fragmented frames should be combined. Default is to
|
||||
* combine. If fragmented frames are combined, the text and binary callbacks
|
||||
* will always have the second parameter (fin) set to true.
|
||||
* @param combine True if fragmented frames should be combined.
|
||||
*/
|
||||
void SetCombineFragments(bool combine) { m_combineFragments = combine; }
|
||||
|
||||
/**
|
||||
* Initiate a closing handshake.
|
||||
* @param code A numeric status code (defaults to 1005, no status code)
|
||||
* @param reason A human-readable string explaining why the connection is
|
||||
* closing (optional).
|
||||
*/
|
||||
void Close(uint16_t code = 1005, const Twine& reason = Twine{});
|
||||
|
||||
/**
|
||||
* Send a text message.
|
||||
* @param data UTF-8 encoded data to send
|
||||
* @param callback Callback which is invoked when the write completes.
|
||||
*/
|
||||
void SendText(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kFlagFin | kOpText, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a binary message.
|
||||
* @param data Data to send
|
||||
* @param callback Callback which is invoked when the write completes.
|
||||
*/
|
||||
void SendBinary(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kFlagFin | kOpBinary, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a text message fragment. This must be followed by one or more
|
||||
* SendFragment() calls, where the last one has fin=True, to complete the
|
||||
* message.
|
||||
* @param data UTF-8 encoded data to send
|
||||
* @param callback Callback which is invoked when the write completes.
|
||||
*/
|
||||
void SendTextFragment(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kOpText, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a text message fragment. This must be followed by one or more
|
||||
* SendFragment() calls, where the last one has fin=True, to complete the
|
||||
* message.
|
||||
* @param data Data to send
|
||||
* @param callback Callback which is invoked when the write completes.
|
||||
*/
|
||||
void SendBinaryFragment(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kOpBinary, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a continuation frame. This is used to send additional parts of a
|
||||
* message started with SendTextFragment() or SendBinaryFragment().
|
||||
* @param data Data to send
|
||||
* @param fin Set to true if this is the final fragment of the message
|
||||
* @param callback Callback which is invoked when the write completes.
|
||||
*/
|
||||
void SendFragment(
|
||||
ArrayRef<uv::Buffer> data, bool fin,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kOpCont | (fin ? kFlagFin : 0), data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a ping frame with no data.
|
||||
* @param callback Optional callback which is invoked when the ping frame
|
||||
* write completes.
|
||||
*/
|
||||
void SendPing(std::function<void(uv::Error)> callback = nullptr) {
|
||||
SendPing(ArrayRef<uv::Buffer>{}, [callback](auto bufs, uv::Error err) {
|
||||
if (callback) callback(err);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a ping frame.
|
||||
* @param data Data to send in the ping frame
|
||||
* @param callback Callback which is invoked when the ping frame
|
||||
* write completes.
|
||||
*/
|
||||
void SendPing(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kFlagFin | kOpPing, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a pong frame with no data.
|
||||
* @param callback Optional callback which is invoked when the pong frame
|
||||
* write completes.
|
||||
*/
|
||||
void SendPong(std::function<void(uv::Error)> callback = nullptr) {
|
||||
SendPong(ArrayRef<uv::Buffer>{}, [callback](auto bufs, uv::Error err) {
|
||||
if (callback) callback(err);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a pong frame.
|
||||
* @param data Data to send in the pong frame
|
||||
* @param callback Callback which is invoked when the pong frame
|
||||
* write completes.
|
||||
*/
|
||||
void SendPong(
|
||||
ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
|
||||
Send(kFlagFin | kOpPong, data, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fail the connection.
|
||||
*/
|
||||
void Fail(uint16_t code = 1002, const Twine& reason = "protocol error");
|
||||
|
||||
/**
|
||||
* Forcibly close the connection.
|
||||
*/
|
||||
void Terminate(uint16_t code = 1006, const Twine& reason = "terminated");
|
||||
|
||||
/**
|
||||
* Gets user-defined data.
|
||||
* @return User-defined data if any, nullptr otherwise.
|
||||
*/
|
||||
template <typename T = void>
|
||||
std::shared_ptr<T> GetData() const {
|
||||
return std::static_pointer_cast<T>(m_data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets user-defined data.
|
||||
* @param data User-defined arbitrary data.
|
||||
*/
|
||||
void SetData(std::shared_ptr<void> data) { m_data = std::move(data); }
|
||||
|
||||
/**
|
||||
* Open event. Emitted when the connection is open and ready to communicate.
|
||||
* The parameter is the selected subprotocol.
|
||||
*/
|
||||
sig::Signal<StringRef> open;
|
||||
|
||||
/**
|
||||
* Close event. Emitted when the connection is closed. The first parameter
|
||||
* is a numeric value indicating the status code explaining why the connection
|
||||
* has been closed. The second parameter is a human-readable string
|
||||
* explaining the reason why the connection has been closed.
|
||||
*/
|
||||
sig::Signal<uint16_t, StringRef> closed;
|
||||
|
||||
/**
|
||||
* Text message event. Emitted when a text message is received.
|
||||
* The first parameter is the data, the second parameter is true if the
|
||||
* data is the last fragment of the message.
|
||||
*/
|
||||
sig::Signal<StringRef, bool> text;
|
||||
|
||||
/**
|
||||
* Binary message event. Emitted when a binary message is received.
|
||||
* The first parameter is the data, the second parameter is true if the
|
||||
* data is the last fragment of the message.
|
||||
*/
|
||||
sig::Signal<ArrayRef<uint8_t>, bool> binary;
|
||||
|
||||
/**
|
||||
* Ping event. Emitted when a ping message is received.
|
||||
*/
|
||||
sig::Signal<ArrayRef<uint8_t>> ping;
|
||||
|
||||
/**
|
||||
* Pong event. Emitted when a pong message is received.
|
||||
*/
|
||||
sig::Signal<ArrayRef<uint8_t>> pong;
|
||||
|
||||
private:
|
||||
// user data
|
||||
std::shared_ptr<void> m_data;
|
||||
|
||||
// constructor parameters
|
||||
uv::Stream& m_stream;
|
||||
bool m_server;
|
||||
|
||||
// subprotocol, set via constructor (server) or handshake (client)
|
||||
std::string m_protocol;
|
||||
|
||||
// user-settable configuration
|
||||
size_t m_maxMessageSize = 128 * 1024;
|
||||
bool m_combineFragments = true;
|
||||
|
||||
// operating state
|
||||
State m_state = CONNECTING;
|
||||
|
||||
// incoming message buffers/state
|
||||
SmallVector<uint8_t, 14> m_header;
|
||||
size_t m_headerSize = 0;
|
||||
SmallVector<uint8_t, 1024> m_payload;
|
||||
size_t m_frameStart = 0;
|
||||
uint64_t m_frameSize = UINT64_MAX;
|
||||
uint8_t m_fragmentOpcode = 0;
|
||||
|
||||
// temporary data used only during client handshake
|
||||
class ClientHandshakeData;
|
||||
std::unique_ptr<ClientHandshakeData> m_clientHandshake;
|
||||
|
||||
void StartClient(const Twine& uri, const Twine& host,
|
||||
ArrayRef<StringRef> protocols, const ClientOptions& options);
|
||||
void StartServer(StringRef key, StringRef version, StringRef protocol);
|
||||
void SendClose(uint16_t code, const Twine& reason);
|
||||
void SetClosed(uint16_t code, const Twine& reason, bool failed = false);
|
||||
void Shutdown();
|
||||
void HandleIncoming(uv::Buffer& buf, size_t size);
|
||||
void Send(
|
||||
uint8_t opcode, ArrayRef<uv::Buffer> data,
|
||||
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback);
|
||||
};
|
||||
|
||||
} // namespace wpi
|
||||
|
||||
#endif // WPIUTIL_WPI_WEBSOCKET_H_
|
||||
147
wpiutil/src/main/native/include/wpi/WebSocketServer.h
Normal file
147
wpiutil/src/main/native/include/wpi/WebSocketServer.h
Normal file
@@ -0,0 +1,147 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#ifndef WPIUTIL_WPI_WEBSOCKETSERVER_H_
|
||||
#define WPIUTIL_WPI_WEBSOCKETSERVER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "wpi/ArrayRef.h"
|
||||
#include "wpi/HttpParser.h"
|
||||
#include "wpi/Signal.h"
|
||||
#include "wpi/SmallString.h"
|
||||
#include "wpi/SmallVector.h"
|
||||
#include "wpi/StringRef.h"
|
||||
#include "wpi/WebSocket.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
namespace uv {
|
||||
class Stream;
|
||||
} // namespace uv
|
||||
|
||||
/**
|
||||
* WebSocket HTTP server helper. Handles websocket-specific headers. User
|
||||
* must provide the HttpParser.
|
||||
*/
|
||||
class WebSocketServerHelper {
|
||||
public:
|
||||
/**
|
||||
* Constructor.
|
||||
* @param req HttpParser for request
|
||||
*/
|
||||
explicit WebSocketServerHelper(HttpParser& req);
|
||||
|
||||
/**
|
||||
* Get whether or not this was a websocket upgrade.
|
||||
* Only valid during and after the upgrade event.
|
||||
*/
|
||||
bool IsWebsocket() const { return m_websocket; }
|
||||
|
||||
/**
|
||||
* Try to find a match to the list of sub-protocols provided by the client.
|
||||
* The list is priority ordered, so the first match wins.
|
||||
* Only valid during and after the upgrade event.
|
||||
* @param protocols Acceptable protocols
|
||||
* @return Pair; first item is true if a match was made, false if not.
|
||||
* Second item is the matched protocol if a match was made, otherwise
|
||||
* is empty.
|
||||
*/
|
||||
std::pair<bool, StringRef> MatchProtocol(ArrayRef<StringRef> protocols);
|
||||
|
||||
/**
|
||||
* Accept the upgrade. Disconnect other readers (such as the HttpParser
|
||||
* reader) before calling this. See also WebSocket::CreateServer().
|
||||
* @param stream Connection stream
|
||||
* @param protocol The subprotocol to send to the client
|
||||
*/
|
||||
std::shared_ptr<WebSocket> Accept(uv::Stream& stream,
|
||||
StringRef protocol = StringRef{}) {
|
||||
return WebSocket::CreateServer(stream, m_key, m_version, protocol);
|
||||
}
|
||||
|
||||
bool IsUpgrade() const { return m_gotHost && m_websocket; }
|
||||
|
||||
/**
|
||||
* Upgrade event. Call Accept() to accept the upgrade.
|
||||
*/
|
||||
sig::Signal<> upgrade;
|
||||
|
||||
private:
|
||||
bool m_gotHost = false;
|
||||
bool m_websocket = false;
|
||||
SmallVector<std::string, 2> m_protocols;
|
||||
SmallString<64> m_key;
|
||||
SmallString<16> m_version;
|
||||
};
|
||||
|
||||
/**
|
||||
* Dedicated WebSocket server.
|
||||
*/
|
||||
class WebSocketServer : public std::enable_shared_from_this<WebSocketServer> {
|
||||
struct private_init {};
|
||||
|
||||
public:
|
||||
/**
|
||||
* Server options.
|
||||
*/
|
||||
struct ServerOptions {
|
||||
/**
|
||||
* Checker for URL. Return true if URL should be accepted. By default all
|
||||
* URLs are accepted.
|
||||
*/
|
||||
std::function<bool(StringRef)> checkUrl;
|
||||
|
||||
/**
|
||||
* Checker for Host header. Return true if Host should be accepted. By
|
||||
* default all hosts are accepted.
|
||||
*/
|
||||
std::function<bool(StringRef)> checkHost;
|
||||
};
|
||||
|
||||
/**
|
||||
* Private constructor.
|
||||
*/
|
||||
WebSocketServer(uv::Stream& stream, ArrayRef<StringRef> protocols,
|
||||
const ServerOptions& options, const private_init&);
|
||||
|
||||
/**
|
||||
* Starts a dedicated WebSocket server on the provided connection. The
|
||||
* connection should be an accepted client stream.
|
||||
* This also sets the stream user data to the socket server.
|
||||
* A connected event is emitted when the connection is opened.
|
||||
* @param stream Connection stream
|
||||
* @param protocols Acceptable subprotocols
|
||||
* @param options Handshake options
|
||||
*/
|
||||
static std::shared_ptr<WebSocketServer> Create(
|
||||
uv::Stream& stream, ArrayRef<StringRef> protocols = ArrayRef<StringRef>{},
|
||||
const ServerOptions& options = ServerOptions{});
|
||||
|
||||
/**
|
||||
* Connected event. First parameter is the URL, second is the websocket.
|
||||
*/
|
||||
sig::Signal<StringRef, WebSocket&> connected;
|
||||
|
||||
private:
|
||||
uv::Stream& m_stream;
|
||||
HttpParser m_req{HttpParser::kRequest};
|
||||
WebSocketServerHelper m_helper;
|
||||
SmallVector<std::string, 2> m_protocols;
|
||||
ServerOptions m_options;
|
||||
bool m_aborted = false;
|
||||
sig::Connection m_headerConn;
|
||||
|
||||
void Abort(uint16_t code, StringRef reason);
|
||||
};
|
||||
|
||||
} // namespace wpi
|
||||
|
||||
#endif // WPIUTIL_WPI_WEBSOCKETSERVER_H_
|
||||
299
wpiutil/src/test/native/cpp/WebSocketClientTest.cpp
Normal file
299
wpiutil/src/test/native/cpp/WebSocketClientTest.cpp
Normal file
@@ -0,0 +1,299 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocket.h" // NOLINT(build/include_order)
|
||||
|
||||
#include "WebSocketTest.h"
|
||||
#include "wpi/Base64.h"
|
||||
#include "wpi/HttpParser.h"
|
||||
#include "wpi/SmallString.h"
|
||||
#include "wpi/raw_uv_ostream.h"
|
||||
#include "wpi/sha1.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
class WebSocketClientTest : public WebSocketTest {
|
||||
public:
|
||||
WebSocketClientTest() {
|
||||
// Bare bones server
|
||||
req.header.connect([this](StringRef name, StringRef value) {
|
||||
// save key (required for valid response)
|
||||
if (name.equals_lower("sec-websocket-key")) clientKey = value;
|
||||
});
|
||||
req.headersComplete.connect([this](bool) {
|
||||
// send response
|
||||
SmallVector<uv::Buffer, 4> bufs;
|
||||
raw_uv_ostream os{bufs, 4096};
|
||||
os << "HTTP/1.1 101 Switching Protocols\r\n";
|
||||
os << "Upgrade: websocket\r\n";
|
||||
os << "Connection: Upgrade\r\n";
|
||||
|
||||
// accept hash
|
||||
SHA1 hash;
|
||||
hash.Update(clientKey);
|
||||
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
if (mockBadAccept) hash.Update("1");
|
||||
SmallString<64> hashBuf;
|
||||
SmallString<64> acceptBuf;
|
||||
os << "Sec-WebSocket-Accept: "
|
||||
<< Base64Encode(hash.Final(hashBuf), acceptBuf) << "\r\n";
|
||||
|
||||
if (!mockProtocol.empty())
|
||||
os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
|
||||
|
||||
os << "\r\n";
|
||||
|
||||
conn->Write(bufs, [](auto bufs, uv::Error) {
|
||||
for (auto& buf : bufs) buf.Deallocate();
|
||||
});
|
||||
|
||||
serverHeadersDone = true;
|
||||
if (connected) connected();
|
||||
});
|
||||
|
||||
serverPipe->Listen([this] {
|
||||
conn = serverPipe->Accept();
|
||||
conn->StartRead();
|
||||
conn->data.connect([this](uv::Buffer& buf, size_t size) {
|
||||
StringRef data{buf.base, size};
|
||||
if (!serverHeadersDone) {
|
||||
data = req.Execute(data);
|
||||
if (req.HasError()) Finish();
|
||||
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
|
||||
if (data.empty()) return;
|
||||
}
|
||||
wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end());
|
||||
});
|
||||
conn->end.connect([this] { Finish(); });
|
||||
});
|
||||
}
|
||||
|
||||
bool mockBadAccept = false;
|
||||
std::vector<uint8_t> wireData;
|
||||
std::shared_ptr<uv::Pipe> conn;
|
||||
HttpParser req{HttpParser::kRequest};
|
||||
SmallString<64> clientKey;
|
||||
std::string mockProtocol;
|
||||
bool serverHeadersDone = false;
|
||||
std::function<void()> connected;
|
||||
};
|
||||
|
||||
TEST_F(WebSocketClientTest, Open) {
|
||||
int gotOpen = 0;
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << " Reason: " << reason;
|
||||
});
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
++gotOpen;
|
||||
Finish();
|
||||
ASSERT_TRUE(protocol.empty());
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketClientTest, BadAccept) {
|
||||
int gotClosed = 0;
|
||||
|
||||
mockBadAccept = true;
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef msg) {
|
||||
Finish();
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1002) << "Message: " << msg;
|
||||
});
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
Finish();
|
||||
FAIL() << "Got open";
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketClientTest, ProtocolGood) {
|
||||
int gotOpen = 0;
|
||||
|
||||
mockProtocol = "myProtocol";
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(
|
||||
*clientPipe, "/test", pipeName,
|
||||
ArrayRef<StringRef>{"myProtocol", "myProtocol2"});
|
||||
ws->closed.connect([&](uint16_t code, StringRef msg) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << "Message: " << msg;
|
||||
});
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
++gotOpen;
|
||||
Finish();
|
||||
ASSERT_EQ(protocol, "myProtocol");
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
|
||||
int gotClosed = 0;
|
||||
|
||||
mockProtocol = "myProtocol";
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef msg) {
|
||||
Finish();
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1003) << "Message: " << msg;
|
||||
});
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
Finish();
|
||||
FAIL() << "Got open";
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
|
||||
int gotClosed = 0;
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
|
||||
StringRef{"myProtocol"});
|
||||
ws->closed.connect([&](uint16_t code, StringRef msg) {
|
||||
Finish();
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1002) << "Message: " << msg;
|
||||
});
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
Finish();
|
||||
FAIL() << "Got open";
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// Send and receive data. Most of these cases are tested in
|
||||
// WebSocketServerTest, so only spot check differences like masking.
|
||||
//
|
||||
|
||||
class WebSocketClientDataTest : public WebSocketClientTest,
|
||||
public ::testing::WithParamInterface<size_t> {
|
||||
public:
|
||||
WebSocketClientDataTest() {
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
if (setupWebSocket) setupWebSocket();
|
||||
});
|
||||
}
|
||||
|
||||
std::function<void()> setupWebSocket;
|
||||
std::shared_ptr<WebSocket> ws;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WebSocketClientDataTests, WebSocketClientDataTest,
|
||||
::testing::Values(0, 1, 125, 126, 65535, 65536), );
|
||||
|
||||
TEST_P(WebSocketClientDataTest, SendBinary) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) {
|
||||
ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_FALSE(bufs.empty());
|
||||
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x02, true, true, data);
|
||||
AdjustMasking(wireData);
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketClientDataTest, ReceiveBinary) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_TRUE(fin);
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
ASSERT_EQ(data, recvData);
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x02, true, false, data);
|
||||
connected = [&] {
|
||||
conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// The client must close the connection if a masked frame is received.
|
||||
//
|
||||
|
||||
TEST_P(WebSocketClientDataTest, ReceiveMasked) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), ' ');
|
||||
setupWebSocket = [&] {
|
||||
ws->text.connect([&](StringRef, bool) {
|
||||
ws->Terminate();
|
||||
FAIL() << "Should not have gotten masked message";
|
||||
});
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, true, true, data);
|
||||
connected = [&] {
|
||||
conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
} // namespace wpi
|
||||
148
wpiutil/src/test/native/cpp/WebSocketIntegrationTest.cpp
Normal file
148
wpiutil/src/test/native/cpp/WebSocketIntegrationTest.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocketServer.h" // NOLINT(build/include_order)
|
||||
|
||||
#include "WebSocketTest.h"
|
||||
#include "wpi/HttpParser.h"
|
||||
#include "wpi/SmallString.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
class WebSocketIntegrationTest : public WebSocketTest {};
|
||||
|
||||
TEST_F(WebSocketIntegrationTest, Open) {
|
||||
int gotServerOpen = 0;
|
||||
int gotClientOpen = 0;
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto server = WebSocketServer::Create(*conn);
|
||||
server->connected.connect([&](StringRef url, WebSocket&) {
|
||||
++gotServerOpen;
|
||||
ASSERT_EQ(url, "/test");
|
||||
});
|
||||
});
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << " Reason: " << reason;
|
||||
});
|
||||
ws->open.connect([&, s = ws.get() ](StringRef) {
|
||||
++gotClientOpen;
|
||||
s->Close();
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotServerOpen, 1);
|
||||
ASSERT_EQ(gotClientOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketIntegrationTest, Protocol) {
|
||||
int gotServerOpen = 0;
|
||||
int gotClientOpen = 0;
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto server = WebSocketServer::Create(*conn, {"proto1", "proto2"});
|
||||
server->connected.connect([&](StringRef, WebSocket& ws) {
|
||||
++gotServerOpen;
|
||||
ASSERT_EQ(ws.GetProtocol(), "proto1");
|
||||
});
|
||||
});
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws =
|
||||
WebSocket::CreateClient(*clientPipe, "/test", pipeName, {"proto1"});
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << " Reason: " << reason;
|
||||
});
|
||||
ws->open.connect([&, s = ws.get() ](StringRef protocol) {
|
||||
++gotClientOpen;
|
||||
s->Close();
|
||||
ASSERT_EQ(protocol, "proto1");
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotServerOpen, 1);
|
||||
ASSERT_EQ(gotClientOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketIntegrationTest, ServerSendBinary) {
|
||||
int gotData = 0;
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto server = WebSocketServer::Create(*conn);
|
||||
server->connected.connect([&](StringRef, WebSocket& ws) {
|
||||
ws.SendBinary(uv::Buffer{"\x03\x04", 2}, [&](auto, uv::Error) {});
|
||||
ws.Close();
|
||||
});
|
||||
});
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << " Reason: " << reason;
|
||||
});
|
||||
ws->binary.connect([&](ArrayRef<uint8_t> data, bool) {
|
||||
++gotData;
|
||||
std::vector<uint8_t> recvData{data.begin(), data.end()};
|
||||
std::vector<uint8_t> expectData{0x03, 0x04};
|
||||
ASSERT_EQ(recvData, expectData);
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotData, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketIntegrationTest, ClientSendText) {
|
||||
int gotData = 0;
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto server = WebSocketServer::Create(*conn);
|
||||
server->connected.connect([&](StringRef, WebSocket& ws) {
|
||||
ws.text.connect([&](StringRef data, bool) {
|
||||
++gotData;
|
||||
ASSERT_EQ(data, "hello");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
Finish();
|
||||
if (code != 1005 && code != 1006)
|
||||
FAIL() << "Code: " << code << " Reason: " << reason;
|
||||
});
|
||||
ws->open.connect([&, s = ws.get() ](StringRef) {
|
||||
s->SendText(uv::Buffer{"hello"}, [&](auto, uv::Error) {});
|
||||
s->Close();
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotData, 1);
|
||||
}
|
||||
|
||||
} // namespace wpi
|
||||
736
wpiutil/src/test/native/cpp/WebSocketServerTest.cpp
Normal file
736
wpiutil/src/test/native/cpp/WebSocketServerTest.cpp
Normal file
@@ -0,0 +1,736 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocket.h" // NOLINT(build/include_order)
|
||||
|
||||
#include "WebSocketTest.h"
|
||||
#include "wpi/Base64.h"
|
||||
#include "wpi/HttpParser.h"
|
||||
#include "wpi/SmallString.h"
|
||||
#include "wpi/raw_uv_ostream.h"
|
||||
#include "wpi/sha1.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
class WebSocketServerTest : public WebSocketTest {
|
||||
public:
|
||||
WebSocketServerTest() {
|
||||
resp.headersComplete.connect([this](bool) { headersDone = true; });
|
||||
|
||||
serverPipe->Listen([this]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
ws = WebSocket::CreateServer(*conn, "foo", "13");
|
||||
if (setupWebSocket) setupWebSocket();
|
||||
});
|
||||
clientPipe->Connect(pipeName, [this]() {
|
||||
clientPipe->StartRead();
|
||||
clientPipe->data.connect([this](uv::Buffer& buf, size_t size) {
|
||||
StringRef data{buf.base, size};
|
||||
if (!headersDone) {
|
||||
data = resp.Execute(data);
|
||||
if (resp.HasError()) Finish();
|
||||
ASSERT_EQ(resp.GetError(), HPE_OK)
|
||||
<< http_errno_name(resp.GetError());
|
||||
if (data.empty()) return;
|
||||
}
|
||||
wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end());
|
||||
if (handleData) handleData(data);
|
||||
});
|
||||
clientPipe->end.connect([this]() { Finish(); });
|
||||
});
|
||||
}
|
||||
|
||||
std::function<void()> setupWebSocket;
|
||||
std::function<void(StringRef)> handleData;
|
||||
std::vector<uint8_t> wireData;
|
||||
std::shared_ptr<WebSocket> ws;
|
||||
HttpParser resp{HttpParser::kResponse};
|
||||
bool headersDone = false;
|
||||
};
|
||||
|
||||
//
|
||||
// Terminate closes the endpoint but doesn't send a close frame.
|
||||
//
|
||||
|
||||
TEST_F(WebSocketServerTest, Terminate) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Terminate(); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1006) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, TerminateCode) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Terminate(1000); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, TerminateReason) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Terminate(1000, "reason"); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000);
|
||||
ASSERT_EQ(reason, "reason");
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// Close() sends a close frame.
|
||||
//
|
||||
|
||||
TEST_F(WebSocketServerTest, CloseBasic) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Close(); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1005) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
// need to respond with close for server to finish shutdown
|
||||
auto message = BuildMessage(0x08, true, true, {});
|
||||
handleData = [&](StringRef) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x08, true, false, {});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, CloseCode) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Close(1000); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
// need to respond with close for server to finish shutdown
|
||||
auto message = BuildMessage(0x08, true, true, {0x03u, 0xe8u});
|
||||
handleData = [&](StringRef) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x08, true, false, {0x03u, 0xe8u});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, CloseReason) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) { ws->Close(1000, "hangup"); });
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000);
|
||||
ASSERT_EQ(reason, "hangup");
|
||||
});
|
||||
};
|
||||
// need to respond with close for server to finish shutdown
|
||||
auto message = BuildMessage(0x08, true, true,
|
||||
{0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'});
|
||||
handleData = [&](StringRef) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x08, true, false,
|
||||
{0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// Receiving a close frame results in closure and echoing the close frame.
|
||||
//
|
||||
|
||||
TEST_F(WebSocketServerTest, ReceiveCloseBasic) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1005) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x08, true, true, {});
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
// the endpoint should echo the message
|
||||
auto expectData = BuildMessage(0x08, true, false, {});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, ReceiveCloseCode) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x08, true, true, {0x03u, 0xe8u});
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
// the endpoint should echo the message
|
||||
auto expectData = BuildMessage(0x08, true, false, {0x03u, 0xe8u});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, ReceiveCloseReason) {
|
||||
int gotClosed = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1000);
|
||||
ASSERT_EQ(reason, "hangup");
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x08, true, true,
|
||||
{0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'});
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
// the endpoint should echo the message
|
||||
auto expectData = BuildMessage(0x08, true, false,
|
||||
{0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'});
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// If an unknown opcode is received, the receiving endpoint MUST _Fail the
|
||||
// WebSocket Connection_.
|
||||
//
|
||||
|
||||
class WebSocketServerBadOpcodeTest
|
||||
: public WebSocketServerTest,
|
||||
public ::testing::WithParamInterface<uint8_t> {};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WebSocketServerBadOpcodeTests,
|
||||
WebSocketServerBadOpcodeTest,
|
||||
::testing::Values(3, 4, 5, 6, 7, 0xb, 0xc, 0xd, 0xe,
|
||||
0xf), );
|
||||
|
||||
TEST_P(WebSocketServerBadOpcodeTest, Receive) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(GetParam(), true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// Control frames themselves MUST NOT be fragmented.
|
||||
//
|
||||
|
||||
class WebSocketServerControlFrameTest
|
||||
: public WebSocketServerTest,
|
||||
public ::testing::WithParamInterface<uint8_t> {};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WebSocketServerControlFrameTests,
|
||||
WebSocketServerControlFrameTest,
|
||||
::testing::Values(0x8, 0x9, 0xa), );
|
||||
|
||||
TEST_P(WebSocketServerControlFrameTest, ReceiveFragment) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(GetParam(), false, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// A fragmented message consists of a single frame with the FIN bit
|
||||
// clear and an opcode other than 0, followed by zero or more frames
|
||||
// with the FIN bit clear and the opcode set to 0, and terminated by
|
||||
// a single frame with the FIN bit set and an opcode of 0.
|
||||
//
|
||||
|
||||
// No previous message
|
||||
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFrame) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x00, false, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
// No previous message with FIN=1.
|
||||
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFragment) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, true, true, {}); // FIN=1
|
||||
auto message2 = BuildMessage(0x00, false, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write({uv::Buffer(message), uv::Buffer(message2)},
|
||||
[&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
// Incomplete fragment
|
||||
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidIncomplete) {
|
||||
int gotCallback = 0;
|
||||
setupWebSocket = [&] {
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, false, true, {});
|
||||
auto message2 = BuildMessage(0x00, false, true, {});
|
||||
auto message3 = BuildMessage(0x01, true, true, {});
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(
|
||||
{uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)},
|
||||
[&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
// Normally fragments are combined into a single callback
|
||||
TEST_F(WebSocketServerTest, ReceiveFragment) {
|
||||
int gotCallback = 0;
|
||||
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
std::vector<uint8_t> data2(4, 0x04);
|
||||
std::vector<uint8_t> data3(4, 0x05);
|
||||
std::vector<uint8_t> combData{data};
|
||||
combData.insert(combData.end(), data2.begin(), data2.end());
|
||||
combData.insert(combData.end(), data3.begin(), data3.end());
|
||||
|
||||
setupWebSocket = [&] {
|
||||
ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_TRUE(fin);
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
ASSERT_EQ(combData, recvData);
|
||||
});
|
||||
};
|
||||
|
||||
auto message = BuildMessage(0x02, false, true, data);
|
||||
auto message2 = BuildMessage(0x00, false, true, data2);
|
||||
auto message3 = BuildMessage(0x00, true, true, data3);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(
|
||||
{uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)},
|
||||
[&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
// But can be configured for multiple callbacks
|
||||
TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) {
|
||||
int gotCallback = 0;
|
||||
|
||||
std::vector<uint8_t> data(4, 0x03);
|
||||
std::vector<uint8_t> data2(4, 0x04);
|
||||
std::vector<uint8_t> data3(4, 0x05);
|
||||
std::vector<uint8_t> combData{data};
|
||||
combData.insert(combData.end(), data2.begin(), data2.end());
|
||||
combData.insert(combData.end(), data3.begin(), data3.end());
|
||||
|
||||
setupWebSocket = [&] {
|
||||
ws->SetCombineFragments(false);
|
||||
ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
switch (++gotCallback) {
|
||||
case 1:
|
||||
ASSERT_FALSE(fin);
|
||||
ASSERT_EQ(data, recvData);
|
||||
break;
|
||||
case 2:
|
||||
ASSERT_FALSE(fin);
|
||||
ASSERT_EQ(data2, recvData);
|
||||
break;
|
||||
case 3:
|
||||
ws->Terminate();
|
||||
ASSERT_TRUE(fin);
|
||||
ASSERT_EQ(data3, recvData);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "too many callbacks";
|
||||
break;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
auto message = BuildMessage(0x02, false, true, data);
|
||||
auto message2 = BuildMessage(0x00, false, true, data2);
|
||||
auto message3 = BuildMessage(0x00, true, true, data3);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(
|
||||
{uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)},
|
||||
[&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 3);
|
||||
}
|
||||
|
||||
//
|
||||
// Maximum message size is limited.
|
||||
//
|
||||
|
||||
// Single message
|
||||
TEST_F(WebSocketServerTest, ReceiveTooLarge) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(2048, 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->SetMaxMessageSize(1024);
|
||||
ws->binary.connect([&](auto, bool) {
|
||||
ws->Terminate();
|
||||
FAIL() << "Should not have gotten unmasked message";
|
||||
});
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1009) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
// Applied across fragments if combining
|
||||
TEST_F(WebSocketServerTest, ReceiveTooLargeFragmented) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(768, 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->SetMaxMessageSize(1024);
|
||||
ws->binary.connect([&](auto, bool) {
|
||||
ws->Terminate();
|
||||
FAIL() << "Should not have gotten unmasked message";
|
||||
});
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1009) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, false, true, data);
|
||||
auto message2 = BuildMessage(0x00, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write({uv::Buffer(message), uv::Buffer(message2)},
|
||||
[&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// Send and receive data.
|
||||
//
|
||||
|
||||
class WebSocketServerDataTest : public WebSocketServerTest,
|
||||
public ::testing::WithParamInterface<size_t> {};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WebSocketServerDataTests, WebSocketServerDataTest,
|
||||
::testing::Values(0, 1, 125, 126, 65535, 65536), );
|
||||
|
||||
TEST_P(WebSocketServerDataTest, SendText) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), ' ');
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) {
|
||||
ws->SendText(uv::Buffer(data), [&](auto bufs, uv::Error) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_FALSE(bufs.empty());
|
||||
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x01, true, false, data);
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, SendBinary) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) {
|
||||
ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_FALSE(bufs.empty());
|
||||
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x02, true, false, data);
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, SendPing) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) {
|
||||
ws->SendPing(uv::Buffer(data), [&](auto bufs, uv::Error) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_FALSE(bufs.empty());
|
||||
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x09, true, false, data);
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, SendPong) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->open.connect([&](StringRef) {
|
||||
ws->SendPong(uv::Buffer(data), [&](auto bufs, uv::Error) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_FALSE(bufs.empty());
|
||||
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
loop->Run();
|
||||
|
||||
auto expectData = BuildMessage(0x0a, true, false, data);
|
||||
ASSERT_EQ(wireData, expectData);
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, ReceiveText) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), ' ');
|
||||
setupWebSocket = [&] {
|
||||
ws->text.connect([&](StringRef inData, bool fin) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_TRUE(fin);
|
||||
std::vector<uint8_t> recvData;
|
||||
recvData.insert(recvData.end(), inData.bytes_begin(), inData.bytes_end());
|
||||
ASSERT_EQ(data, recvData);
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, ReceiveBinary) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
ASSERT_TRUE(fin);
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
ASSERT_EQ(data, recvData);
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x02, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, ReceivePing) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->ping.connect([&](ArrayRef<uint8_t> inData) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
ASSERT_EQ(data, recvData);
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x09, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
TEST_P(WebSocketServerDataTest, ReceivePong) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), 0x03u);
|
||||
setupWebSocket = [&] {
|
||||
ws->pong.connect([&](ArrayRef<uint8_t> inData) {
|
||||
++gotCallback;
|
||||
ws->Terminate();
|
||||
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
|
||||
ASSERT_EQ(data, recvData);
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x0a, true, true, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
//
|
||||
// The server must close the connection if an unmasked frame is received.
|
||||
//
|
||||
|
||||
TEST_P(WebSocketServerDataTest, ReceiveUnmasked) {
|
||||
int gotCallback = 0;
|
||||
std::vector<uint8_t> data(GetParam(), ' ');
|
||||
setupWebSocket = [&] {
|
||||
ws->text.connect([&](StringRef, bool) {
|
||||
ws->Terminate();
|
||||
FAIL() << "Should not have gotten unmasked message";
|
||||
});
|
||||
ws->closed.connect([&](uint16_t code, StringRef reason) {
|
||||
++gotCallback;
|
||||
ASSERT_EQ(code, 1002) << "reason: " << reason;
|
||||
});
|
||||
};
|
||||
auto message = BuildMessage(0x01, true, false, data);
|
||||
resp.headersComplete.connect([&](bool) {
|
||||
clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
ASSERT_EQ(gotCallback, 1);
|
||||
}
|
||||
|
||||
} // namespace wpi
|
||||
345
wpiutil/src/test/native/cpp/WebSocketTest.cpp
Normal file
345
wpiutil/src/test/native/cpp/WebSocketTest.cpp
Normal file
@@ -0,0 +1,345 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#include "wpi/WebSocket.h" // NOLINT(build/include_order)
|
||||
|
||||
#include "WebSocketTest.h"
|
||||
|
||||
#include "wpi/HttpParser.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
#ifdef _WIN32
|
||||
const char* WebSocketTest::pipeName = "\\\\.\\pipe\\websocket-unit-test";
|
||||
#else
|
||||
const char* WebSocketTest::pipeName = "/tmp/websocket-unit-test";
|
||||
#endif
|
||||
const uint8_t WebSocketTest::testMask[4] = {0x11, 0x22, 0x33, 0x44};
|
||||
|
||||
void WebSocketTest::SetUpTestCase() {
|
||||
#ifndef _WIN32
|
||||
unlink(pipeName);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<uint8_t> WebSocketTest::BuildHeader(uint8_t opcode, bool fin,
|
||||
bool masking, uint64_t len) {
|
||||
std::vector<uint8_t> data;
|
||||
data.push_back(opcode | (fin ? 0x80u : 0x00u));
|
||||
if (len < 126) {
|
||||
data.push_back(len | (masking ? 0x80 : 0x00u));
|
||||
} else if (len < 65536) {
|
||||
data.push_back(126u | (masking ? 0x80 : 0x00u));
|
||||
data.push_back(len >> 8);
|
||||
data.push_back(len & 0xff);
|
||||
} else {
|
||||
data.push_back(127u | (masking ? 0x80u : 0x00u));
|
||||
for (int i = 56; i >= 0; i -= 8) data.push_back((len >> i) & 0xff);
|
||||
}
|
||||
if (masking) data.insert(data.end(), &testMask[0], &testMask[4]);
|
||||
return data;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> WebSocketTest::BuildMessage(uint8_t opcode, bool fin,
|
||||
bool masking,
|
||||
ArrayRef<uint8_t> data) {
|
||||
auto finalData = BuildHeader(opcode, fin, masking, data.size());
|
||||
size_t headerSize = finalData.size();
|
||||
finalData.insert(finalData.end(), data.begin(), data.end());
|
||||
if (masking) {
|
||||
uint8_t mask[4] = {finalData[headerSize - 4], finalData[headerSize - 3],
|
||||
finalData[headerSize - 2], finalData[headerSize - 1]};
|
||||
int n = 0;
|
||||
for (size_t i = headerSize, end = finalData.size(); i < end; ++i) {
|
||||
finalData[i] ^= mask[n++];
|
||||
if (n >= 4) n = 0;
|
||||
}
|
||||
}
|
||||
return finalData;
|
||||
}
|
||||
|
||||
// If the message is masked, changes the mask to match the mask set by
|
||||
// BuildHeader() by unmasking and remasking.
|
||||
void WebSocketTest::AdjustMasking(MutableArrayRef<uint8_t> message) {
|
||||
if (message.size() < 2) return;
|
||||
if ((message[1] & 0x80) == 0) return; // not masked
|
||||
size_t maskPos;
|
||||
uint8_t len = message[1] & 0x7f;
|
||||
if (len == 126)
|
||||
maskPos = 4;
|
||||
else if (len == 127)
|
||||
maskPos = 10;
|
||||
else
|
||||
maskPos = 2;
|
||||
uint8_t mask[4] = {message[maskPos], message[maskPos + 1],
|
||||
message[maskPos + 2], message[maskPos + 3]};
|
||||
message[maskPos] = testMask[0];
|
||||
message[maskPos + 1] = testMask[1];
|
||||
message[maskPos + 2] = testMask[2];
|
||||
message[maskPos + 3] = testMask[3];
|
||||
int n = 0;
|
||||
for (auto& ch : message.slice(maskPos + 4)) {
|
||||
ch ^= mask[n] ^ testMask[n];
|
||||
if (++n >= 4) n = 0;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateClientBasic) {
|
||||
int gotHost = 0;
|
||||
int gotUpgrade = 0;
|
||||
int gotConnection = 0;
|
||||
int gotKey = 0;
|
||||
int gotVersion = 0;
|
||||
|
||||
HttpParser req{HttpParser::kRequest};
|
||||
req.url.connect([](StringRef url) { ASSERT_EQ(url, "/test"); });
|
||||
req.header.connect([&](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("host")) {
|
||||
ASSERT_EQ(value, pipeName);
|
||||
++gotHost;
|
||||
} else if (name.equals_lower("upgrade")) {
|
||||
ASSERT_EQ(value, "websocket");
|
||||
++gotUpgrade;
|
||||
} else if (name.equals_lower("connection")) {
|
||||
ASSERT_EQ(value, "Upgrade");
|
||||
++gotConnection;
|
||||
} else if (name.equals_lower("sec-websocket-key")) {
|
||||
++gotKey;
|
||||
} else if (name.equals_lower("sec-websocket-version")) {
|
||||
ASSERT_EQ(value, "13");
|
||||
++gotVersion;
|
||||
} else {
|
||||
FAIL() << "unexpected header " << name.str();
|
||||
}
|
||||
});
|
||||
req.headersComplete.connect([&](bool) { Finish(); });
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
conn->StartRead();
|
||||
conn->data.connect([&](uv::Buffer& buf, size_t size) {
|
||||
req.Execute(StringRef{buf.base, size});
|
||||
if (req.HasError()) Finish();
|
||||
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
|
||||
});
|
||||
});
|
||||
clientPipe->Connect(pipeName, [&]() {
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotHost, 1);
|
||||
ASSERT_EQ(gotUpgrade, 1);
|
||||
ASSERT_EQ(gotConnection, 1);
|
||||
ASSERT_EQ(gotKey, 1);
|
||||
ASSERT_EQ(gotVersion, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateClientExtraHeaders) {
|
||||
int gotExtra1 = 0;
|
||||
int gotExtra2 = 0;
|
||||
HttpParser req{HttpParser::kRequest};
|
||||
req.header.connect([&](StringRef name, StringRef value) {
|
||||
if (name.equals("Extra1")) {
|
||||
ASSERT_EQ(value, "Data1");
|
||||
++gotExtra1;
|
||||
} else if (name.equals("Extra2")) {
|
||||
ASSERT_EQ(value, "Data2");
|
||||
++gotExtra2;
|
||||
}
|
||||
});
|
||||
req.headersComplete.connect([&](bool) { Finish(); });
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
conn->StartRead();
|
||||
conn->data.connect([&](uv::Buffer& buf, size_t size) {
|
||||
req.Execute(StringRef{buf.base, size});
|
||||
if (req.HasError()) Finish();
|
||||
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
|
||||
});
|
||||
});
|
||||
clientPipe->Connect(pipeName, [&]() {
|
||||
WebSocket::ClientOptions options;
|
||||
SmallVector<std::pair<StringRef, StringRef>, 4> extraHeaders;
|
||||
extraHeaders.emplace_back("Extra1", "Data1");
|
||||
extraHeaders.emplace_back("Extra2", "Data2");
|
||||
options.extraHeaders = extraHeaders;
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
|
||||
ArrayRef<StringRef>{}, options);
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotExtra1, 1);
|
||||
ASSERT_EQ(gotExtra2, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateClientTimeout) {
|
||||
int gotClosed = 0;
|
||||
serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); });
|
||||
clientPipe->Connect(pipeName, [&]() {
|
||||
WebSocket::ClientOptions options;
|
||||
options.handshakeTimeout = uv::Timer::Time{100};
|
||||
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
|
||||
ArrayRef<StringRef>{}, options);
|
||||
ws->closed.connect([&](uint16_t code, StringRef) {
|
||||
Finish();
|
||||
++gotClosed;
|
||||
ASSERT_EQ(code, 1006);
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotClosed, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateServerBasic) {
|
||||
int gotStatus = 0;
|
||||
int gotUpgrade = 0;
|
||||
int gotConnection = 0;
|
||||
int gotAccept = 0;
|
||||
int gotOpen = 0;
|
||||
|
||||
HttpParser resp{HttpParser::kResponse};
|
||||
resp.status.connect([&](StringRef status) {
|
||||
++gotStatus;
|
||||
ASSERT_EQ(resp.GetStatusCode(), 101u) << "status: " << status;
|
||||
});
|
||||
resp.header.connect([&](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("upgrade")) {
|
||||
ASSERT_EQ(value, "websocket");
|
||||
++gotUpgrade;
|
||||
} else if (name.equals_lower("connection")) {
|
||||
ASSERT_EQ(value, "Upgrade");
|
||||
++gotConnection;
|
||||
} else if (name.equals_lower("sec-websocket-accept")) {
|
||||
++gotAccept;
|
||||
} else {
|
||||
FAIL() << "unexpected header " << name.str();
|
||||
}
|
||||
});
|
||||
resp.headersComplete.connect([&](bool) { Finish(); });
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto ws = WebSocket::CreateServer(*conn, "foo", "13");
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
++gotOpen;
|
||||
ASSERT_TRUE(protocol.empty());
|
||||
});
|
||||
});
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
clientPipe->StartRead();
|
||||
clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
|
||||
resp.Execute(StringRef{buf.base, size});
|
||||
if (resp.HasError()) Finish();
|
||||
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotStatus, 1);
|
||||
ASSERT_EQ(gotUpgrade, 1);
|
||||
ASSERT_EQ(gotConnection, 1);
|
||||
ASSERT_EQ(gotAccept, 1);
|
||||
ASSERT_EQ(gotOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateServerProtocol) {
|
||||
int gotProtocol = 0;
|
||||
int gotOpen = 0;
|
||||
|
||||
HttpParser resp{HttpParser::kResponse};
|
||||
resp.header.connect([&](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("sec-websocket-protocol")) {
|
||||
++gotProtocol;
|
||||
ASSERT_EQ(value, "myProtocol");
|
||||
}
|
||||
});
|
||||
resp.headersComplete.connect([&](bool) { Finish(); });
|
||||
|
||||
serverPipe->Listen([&]() {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto ws = WebSocket::CreateServer(*conn, "foo", "13", "myProtocol");
|
||||
ws->open.connect([&](StringRef protocol) {
|
||||
++gotOpen;
|
||||
ASSERT_EQ(protocol, "myProtocol");
|
||||
});
|
||||
});
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
clientPipe->StartRead();
|
||||
clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
|
||||
resp.Execute(StringRef{buf.base, size});
|
||||
if (resp.HasError()) Finish();
|
||||
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotProtocol, 1);
|
||||
ASSERT_EQ(gotOpen, 1);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTest, CreateServerBadVersion) {
|
||||
int gotStatus = 0;
|
||||
int gotVersion = 0;
|
||||
int gotUpgrade = 0;
|
||||
|
||||
HttpParser resp{HttpParser::kResponse};
|
||||
resp.status.connect([&](StringRef status) {
|
||||
++gotStatus;
|
||||
ASSERT_EQ(resp.GetStatusCode(), 426u) << "status: " << status;
|
||||
});
|
||||
resp.header.connect([&](StringRef name, StringRef value) {
|
||||
if (name.equals_lower("sec-websocket-version")) {
|
||||
++gotVersion;
|
||||
ASSERT_EQ(value, "13");
|
||||
} else if (name.equals_lower("upgrade")) {
|
||||
++gotUpgrade;
|
||||
ASSERT_EQ(value, "WebSocket");
|
||||
} else {
|
||||
FAIL() << "unexpected header " << name.str();
|
||||
}
|
||||
});
|
||||
resp.headersComplete.connect([&](bool) { Finish(); });
|
||||
|
||||
serverPipe->Listen([&] {
|
||||
auto conn = serverPipe->Accept();
|
||||
auto ws = WebSocket::CreateServer(*conn, "foo", "14");
|
||||
ws->open.connect([&](StringRef) {
|
||||
Finish();
|
||||
FAIL();
|
||||
});
|
||||
});
|
||||
clientPipe->Connect(pipeName, [&] {
|
||||
clientPipe->StartRead();
|
||||
clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
|
||||
resp.Execute(StringRef{buf.base, size});
|
||||
if (resp.HasError()) Finish();
|
||||
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
|
||||
});
|
||||
});
|
||||
|
||||
loop->Run();
|
||||
|
||||
if (HasFatalFailure()) return;
|
||||
ASSERT_EQ(gotStatus, 1);
|
||||
ASSERT_EQ(gotVersion, 1);
|
||||
ASSERT_EQ(gotUpgrade, 1);
|
||||
}
|
||||
|
||||
} // namespace wpi
|
||||
73
wpiutil/src/test/native/cpp/WebSocketTest.h
Normal file
73
wpiutil/src/test/native/cpp/WebSocketTest.h
Normal file
@@ -0,0 +1,73 @@
|
||||
/*----------------------------------------------------------------------------*/
|
||||
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
|
||||
/* Open Source Software - may be modified and shared by FRC teams. The code */
|
||||
/* must be accompanied by the FIRST BSD license file in the root directory of */
|
||||
/* the project. */
|
||||
/*----------------------------------------------------------------------------*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "wpi/ArrayRef.h"
|
||||
#include "wpi/uv/Loop.h"
|
||||
#include "wpi/uv/Pipe.h"
|
||||
#include "wpi/uv/Timer.h"
|
||||
|
||||
namespace wpi {
|
||||
|
||||
class WebSocketTest : public ::testing::Test {
|
||||
public:
|
||||
static const char* pipeName;
|
||||
|
||||
static void SetUpTestCase();
|
||||
|
||||
WebSocketTest() {
|
||||
loop = uv::Loop::Create();
|
||||
clientPipe = uv::Pipe::Create(loop);
|
||||
serverPipe = uv::Pipe::Create(loop);
|
||||
|
||||
serverPipe->Bind(pipeName);
|
||||
|
||||
#if 0
|
||||
auto debugTimer = uv::Timer::Create(loop);
|
||||
debugTimer->timeout.connect([this] {
|
||||
std::printf("Active handles:\n");
|
||||
uv_print_active_handles(loop->GetRaw(), stdout);
|
||||
});
|
||||
debugTimer->Start(uv::Timer::Time{100}, uv::Timer::Time{100});
|
||||
debugTimer->Unreference();
|
||||
#endif
|
||||
|
||||
auto failTimer = uv::Timer::Create(loop);
|
||||
failTimer->timeout.connect([this] {
|
||||
loop->Stop();
|
||||
FAIL() << "loop failed to terminate";
|
||||
});
|
||||
failTimer->Start(uv::Timer::Time{1000});
|
||||
failTimer->Unreference();
|
||||
}
|
||||
|
||||
~WebSocketTest() { Finish(); }
|
||||
|
||||
void Finish() {
|
||||
loop->Walk([](uv::Handle& it) { it.Close(); });
|
||||
}
|
||||
|
||||
static std::vector<uint8_t> BuildHeader(uint8_t opcode, bool fin,
|
||||
bool masking, uint64_t len);
|
||||
static std::vector<uint8_t> BuildMessage(uint8_t opcode, bool fin,
|
||||
bool masking,
|
||||
ArrayRef<uint8_t> data);
|
||||
static void AdjustMasking(MutableArrayRef<uint8_t> message);
|
||||
static const uint8_t testMask[4];
|
||||
|
||||
std::shared_ptr<uv::Loop> loop;
|
||||
std::shared_ptr<uv::Pipe> clientPipe;
|
||||
std::shared_ptr<uv::Pipe> serverPipe;
|
||||
};
|
||||
|
||||
} // namespace wpi
|
||||
Reference in New Issue
Block a user