diff --git a/wpiutil/src/main/native/cpp/WebSocket.cpp b/wpiutil/src/main/native/cpp/WebSocket.cpp new file mode 100644 index 0000000000..8d0a65cf14 --- /dev/null +++ b/wpiutil/src/main/native/cpp/WebSocket.cpp @@ -0,0 +1,565 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocket.h" + +#include + +#include "wpi/Base64.h" +#include "wpi/HttpParser.h" +#include "wpi/SmallString.h" +#include "wpi/SmallVector.h" +#include "wpi/raw_uv_ostream.h" +#include "wpi/sha1.h" +#include "wpi/uv/Stream.h" + +using namespace wpi; + +namespace { +class WebSocketWriteReq : public uv::WriteReq { + public: + explicit WebSocketWriteReq( + std::function, uv::Error)> callback) { + finish.connect([=](uv::Error err) { + MutableArrayRef bufs{m_bufs}; + for (auto&& buf : bufs.slice(0, m_startUser)) buf.Deallocate(); + callback(bufs.slice(m_startUser), err); + }); + } + + SmallVector m_bufs; + size_t m_startUser; +}; +} // namespace + +class WebSocket::ClientHandshakeData { + public: + ClientHandshakeData() { + // key is a random nonce + static std::random_device rd; + static std::default_random_engine gen{rd()}; + std::uniform_int_distribution dist(0, 255); + char nonce[16]; // the nonce sent to the server + for (char& v : nonce) v = static_cast(dist(gen)); + raw_svector_ostream os(key); + Base64Encode(os, StringRef{nonce, 16}); + } + ~ClientHandshakeData() { + if (auto t = timer.lock()) { + t->Stop(); + t->Close(); + } + } + + SmallString<64> key; // the key sent to the server + SmallVector protocols; // valid protocols + HttpParser parser{HttpParser::kResponse}; // server response parser + bool hasUpgrade = false; + bool hasConnection = false; + bool hasAccept = false; + bool hasProtocol = false; + + std::weak_ptr timer; +}; + +static StringRef AcceptHash(StringRef key, SmallVectorImpl& buf) { + SHA1 hash; + hash.Update(key); + hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + SmallString<64> hashBuf; + return Base64Encode(hash.Final(hashBuf), buf); +} + +WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&) + : m_stream{stream}, m_server{server} { + // Connect closed and error signals to ourselves + m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); }); + m_stream.error.connect([this](uv::Error err) { + Terminate(1006, "stream error: " + Twine(err.name())); + }); + + // Start reading + m_stream.StopRead(); // we may have been reading + m_stream.StartRead(); + m_stream.data.connect( + [this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); }); + m_stream.end.connect( + [this]() { Terminate(1006, "remote end closed connection"); }); +} + +WebSocket::~WebSocket() {} + +std::shared_ptr WebSocket::CreateClient( + uv::Stream& stream, const Twine& uri, const Twine& host, + ArrayRef protocols, const ClientOptions& options) { + auto ws = std::make_shared(stream, false, private_init{}); + stream.SetData(ws); + ws->StartClient(uri, host, protocols, options); + return ws; +} + +std::shared_ptr WebSocket::CreateServer(uv::Stream& stream, + StringRef key, + StringRef version, + StringRef protocol) { + auto ws = std::make_shared(stream, true, private_init{}); + stream.SetData(ws); + ws->StartServer(key, version, protocol); + return ws; +} + +void WebSocket::Close(uint16_t code, const Twine& reason) { + SendClose(code, reason); + if (m_state != FAILED && m_state != CLOSED) m_state = CLOSING; +} + +void WebSocket::Fail(uint16_t code, const Twine& reason) { + if (m_state == FAILED || m_state == CLOSED) return; + SendClose(code, reason); + SetClosed(code, reason, true); + Shutdown(); +} + +void WebSocket::Terminate(uint16_t code, const Twine& reason) { + if (m_state == FAILED || m_state == CLOSED) return; + SetClosed(code, reason); + Shutdown(); +} + +void WebSocket::StartClient(const Twine& uri, const Twine& host, + ArrayRef protocols, + const ClientOptions& options) { + // Create client handshake data + m_clientHandshake = std::make_unique(); + + // Build client request + SmallVector bufs; + raw_uv_ostream os{bufs, 4096}; + + os << "GET " << uri << " HTTP/1.1\r\n"; + os << "Host: " << host << "\r\n"; + os << "Upgrade: websocket\r\n"; + os << "Connection: Upgrade\r\n"; + os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n"; + os << "Sec-WebSocket-Version: 13\r\n"; + + // protocols (if provided) + if (!protocols.empty()) { + os << "Sec-WebSocket-Protocol: "; + bool first = true; + for (auto protocol : protocols) { + if (!first) + os << ", "; + else + first = false; + os << protocol; + // also save for later checking against server response + m_clientHandshake->protocols.emplace_back(protocol); + } + os << "\r\n"; + } + + // other headers + for (auto&& header : options.extraHeaders) + os << header.first << ": " << header.second << "\r\n"; + + // finish headers + os << "\r\n"; + + // Send client request + m_stream.Write(bufs, [](auto bufs, uv::Error) { + for (auto& buf : bufs) buf.Deallocate(); + }); + + // Set up client response handling + m_clientHandshake->parser.status.connect([this](StringRef status) { + unsigned int code = m_clientHandshake->parser.GetStatusCode(); + if (code != 101) Terminate(code, status); + }); + m_clientHandshake->parser.header.connect( + [this](StringRef name, StringRef value) { + value = value.trim(); + if (name.equals_lower("upgrade")) { + if (!value.equals_lower("websocket")) + return Terminate(1002, "invalid upgrade response value"); + m_clientHandshake->hasUpgrade = true; + } else if (name.equals_lower("connection")) { + if (!value.equals_lower("upgrade")) + return Terminate(1002, "invalid connection response value"); + m_clientHandshake->hasConnection = true; + } else if (name.equals_lower("sec-websocket-accept")) { + // Check against expected response + SmallString<64> acceptBuf; + if (!value.equals(AcceptHash(m_clientHandshake->key, acceptBuf))) + return Terminate(1002, "invalid accept key"); + m_clientHandshake->hasAccept = true; + } else if (name.equals_lower("sec-websocket-extensions")) { + // No extensions are supported + if (!value.empty()) return Terminate(1010, "unsupported extension"); + } else if (name.equals_lower("sec-websocket-protocol")) { + // Make sure it was one of the provided protocols + bool match = false; + for (auto&& protocol : m_clientHandshake->protocols) { + if (value.equals_lower(protocol)) { + match = true; + break; + } + } + if (!match) return Terminate(1003, "unsupported protocol"); + m_clientHandshake->hasProtocol = true; + m_protocol = value; + } + }); + m_clientHandshake->parser.headersComplete.connect([this](bool) { + if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection || + !m_clientHandshake->hasAccept || + (!m_clientHandshake->hasProtocol && + !m_clientHandshake->protocols.empty())) { + return Terminate(1002, "invalid response"); + } + if (m_state == CONNECTING) { + m_state = OPEN; + open(m_protocol); + } + }); + + // Start handshake timer if a timeout was specified + if (options.handshakeTimeout != uv::Timer::Time::max()) { + auto timer = uv::Timer::Create(m_stream.GetLoopRef()); + timer->timeout.connect( + [this]() { Terminate(1006, "connection timed out"); }); + timer->Start(options.handshakeTimeout); + m_clientHandshake->timer = timer; + } +} + +void WebSocket::StartServer(StringRef key, StringRef version, + StringRef protocol) { + m_protocol = protocol; + + // Build server response + SmallVector bufs; + raw_uv_ostream os{bufs, 4096}; + + // Handle unsupported version + if (version != "13") { + os << "HTTP/1.1 426 Upgrade Required\r\n"; + os << "Upgrade: WebSocket\r\n"; + os << "Sec-WebSocket-Version: 13\r\n\r\n"; + m_stream.Write(bufs, [this](auto bufs, uv::Error) { + for (auto& buf : bufs) buf.Deallocate(); + // XXX: Should we support sending a new handshake on the same connection? + // XXX: "this->" is required by GCC 5.5 (bug) + this->Terminate(1003, "unsupported protocol version"); + }); + return; + } + + os << "HTTP/1.1 101 Switching Protocols\r\n"; + os << "Upgrade: websocket\r\n"; + os << "Connection: Upgrade\r\n"; + + // accept hash + SmallString<64> acceptBuf; + os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n"; + + if (!protocol.empty()) os << "Sec-WebSocket-Protocol: " << protocol << "\r\n"; + + // end headers + os << "\r\n"; + + // Send server response + m_stream.Write(bufs, [this](auto bufs, uv::Error) { + for (auto& buf : bufs) buf.Deallocate(); + if (m_state == CONNECTING) { + m_state = OPEN; + open(m_protocol); + } + }); +} + +void WebSocket::SendClose(uint16_t code, const Twine& reason) { + SmallVector bufs; + if (code != 1005) { + raw_uv_ostream os{bufs, 4096}; + os << ArrayRef{static_cast((code >> 8) & 0xff), + static_cast(code & 0xff)}; + reason.print(os); + } + Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) { + for (auto&& buf : bufs) buf.Deallocate(); + }); +} + +void WebSocket::SetClosed(uint16_t code, const Twine& reason, bool failed) { + if (m_state == FAILED || m_state == CLOSED) return; + m_state = failed ? FAILED : CLOSED; + SmallString<64> reasonBuf; + closed(code, reason.toStringRef(reasonBuf)); +} + +void WebSocket::Shutdown() { + m_stream.Shutdown([this] { m_stream.Close(); }); +} + +void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) { + // ignore incoming data if we're failed or closed + if (m_state == FAILED || m_state == CLOSED) return; + + StringRef data{buf.base, size}; + + // Handle connecting state (mainly on client) + if (m_state == CONNECTING) { + if (m_clientHandshake) { + data = m_clientHandshake->parser.Execute(data); + // check for parser failure + if (m_clientHandshake->parser.HasError()) + return Terminate(1003, "invalid response"); + if (m_state != OPEN) return; // not done with handshake yet + + // we're done with the handshake, so release its memory + m_clientHandshake.reset(); + + // fall through to process additional data after handshake + } else { + return Terminate(1003, "got data on server before response"); + } + } + + // Message processing + while (!data.empty()) { + if (m_frameSize == UINT64_MAX) { + // Need at least two bytes to determine header length + if (m_header.size() < 2u) { + size_t toCopy = std::min(2u - m_header.size(), data.size()); + m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy); + data = data.drop_front(toCopy); + if (m_header.size() < 2u) return; // need more data + + // Validate RSV bits are zero + if ((m_header[0] & 0x70) != 0) return Fail(1002, "nonzero RSV"); + } + + // Once we have first two bytes, we can calculate the header size + if (m_headerSize == 0) { + m_headerSize = 2; + uint8_t len = m_header[1] & kLenMask; + if (len == 126) + m_headerSize += 2; + else if (len == 127) + m_headerSize += 8; + bool masking = (m_header[1] & kFlagMasking) != 0; + if (masking) m_headerSize += 4; // masking key + // On server side, incoming messages MUST be masked + // On client side, incoming messages MUST NOT be masked + if (m_server && !masking) return Fail(1002, "client data not masked"); + if (!m_server && masking) return Fail(1002, "server data masked"); + } + + // Need to complete header to calculate message size + if (m_header.size() < m_headerSize) { + size_t toCopy = std::min(m_headerSize - m_header.size(), data.size()); + m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy); + data = data.drop_front(toCopy); + if (m_header.size() < m_headerSize) return; // need more data + } + + if (m_header.size() >= m_headerSize) { + // get payload length + uint8_t len = m_header[1] & kLenMask; + if (len == 126) + m_frameSize = (static_cast(m_header[2]) << 8) | + static_cast(m_header[3]); + else if (len == 127) + m_frameSize = (static_cast(m_header[2]) << 56) | + (static_cast(m_header[3]) << 48) | + (static_cast(m_header[4]) << 40) | + (static_cast(m_header[5]) << 32) | + (static_cast(m_header[6]) << 24) | + (static_cast(m_header[7]) << 16) | + (static_cast(m_header[8]) << 8) | + static_cast(m_header[9]); + else + m_frameSize = len; + + // limit maximum size + if ((m_payload.size() + m_frameSize) > m_maxMessageSize) + return Fail(1009, "message too large"); + } + } + + if (m_frameSize != UINT64_MAX) { + size_t need = m_frameStart + m_frameSize - m_payload.size(); + size_t toCopy = std::min(need, data.size()); + m_payload.append(data.bytes_begin(), data.bytes_begin() + toCopy); + data = data.drop_front(toCopy); + need -= toCopy; + if (need == 0) { + // We have a complete frame + // If the message had masking, unmask it + if ((m_header[1] & kFlagMasking) != 0) { + uint8_t key[4] = { + m_header[m_headerSize - 4], m_header[m_headerSize - 3], + m_header[m_headerSize - 2], m_header[m_headerSize - 1]}; + int n = 0; + for (uint8_t& ch : + MutableArrayRef{m_payload}.slice(m_frameStart)) { + ch ^= key[n++]; + if (n >= 4) n = 0; + } + } + + // Handle message + bool fin = (m_header[0] & kFlagFin) != 0; + uint8_t opcode = m_header[0] & kOpMask; + switch (opcode) { + case kOpCont: + switch (m_fragmentOpcode) { + case kOpText: + if (!m_combineFragments || fin) + text(StringRef{reinterpret_cast(m_payload.data()), + m_payload.size()}, + fin); + break; + case kOpBinary: + if (!m_combineFragments || fin) binary(m_payload, fin); + break; + default: + // no preceding message? + return Fail(1002, "invalid continuation message"); + } + if (fin) m_fragmentOpcode = 0; + break; + case kOpText: + if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment"); + if (!m_combineFragments || fin) + text(StringRef{reinterpret_cast(m_payload.data()), + m_payload.size()}, + fin); + if (!fin) m_fragmentOpcode = opcode; + break; + case kOpBinary: + if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment"); + if (!m_combineFragments || fin) binary(m_payload, fin); + if (!fin) m_fragmentOpcode = opcode; + break; + case kOpClose: { + uint16_t code; + StringRef reason; + if (!fin) { + code = 1002; + reason = "cannot fragment control frames"; + } else if (m_payload.size() < 2) { + code = 1005; + } else { + code = (static_cast(m_payload[0]) << 8) | + static_cast(m_payload[1]); + reason = StringRef{reinterpret_cast(m_payload.data()), + m_payload.size()} + .drop_front(2); + } + // Echo the close if we didn't previously send it + if (m_state != CLOSING) SendClose(code, reason); + SetClosed(code, reason); + // If we're the server, shutdown the connection. + if (m_server) Shutdown(); + break; + } + case kOpPing: + if (!fin) return Fail(1002, "cannot fragment control frames"); + ping(m_payload); + break; + case kOpPong: + if (!fin) return Fail(1002, "cannot fragment control frames"); + pong(m_payload); + break; + default: + return Fail(1002, "invalid message opcode"); + } + + // Prepare for next message + m_header.clear(); + m_headerSize = 0; + if (!m_combineFragments || fin) m_payload.clear(); + m_frameStart = m_payload.size(); + m_frameSize = UINT64_MAX; + } + } + } +} + +void WebSocket::Send( + uint8_t opcode, ArrayRef data, + std::function, uv::Error)> callback) { + // If we're not open, emit an error and don't send the data + if (m_state != OPEN) { + int err; + if (m_state == CONNECTING) + err = UV_EAGAIN; + else + err = UV_ESHUTDOWN; + SmallVector bufs{data.begin(), data.end()}; + callback(bufs, uv::Error{err}); + return; + } + + auto req = std::make_shared(callback); + raw_uv_ostream os{req->m_bufs, 4096}; + + // opcode (includes FIN bit) + os << static_cast(opcode); + + // payload length + uint64_t size = 0; + for (auto&& buf : data) size += buf.len; + if (size < 126) { + os << static_cast((m_server ? 0x00 : kFlagMasking) | size); + } else if (size <= 0xffff) { + os << static_cast((m_server ? 0x00 : kFlagMasking) | 126); + os << ArrayRef{static_cast((size >> 8) & 0xff), + static_cast(size & 0xff)}; + } else { + os << static_cast((m_server ? 0x00 : kFlagMasking) | 127); + os << ArrayRef{static_cast((size >> 56) & 0xff), + static_cast((size >> 48) & 0xff), + static_cast((size >> 40) & 0xff), + static_cast((size >> 32) & 0xff), + static_cast((size >> 24) & 0xff), + static_cast((size >> 16) & 0xff), + static_cast((size >> 8) & 0xff), + static_cast(size & 0xff)}; + } + + // clients need to mask the input data + if (!m_server) { + // generate masking key + static std::random_device rd; + static std::default_random_engine gen{rd()}; + std::uniform_int_distribution dist(0, 255); + uint8_t key[4]; + for (uint8_t& v : key) v = dist(gen); + os << ArrayRef{key, 4}; + // copy and mask data + int n = 0; + for (auto&& buf : data) { + for (auto&& ch : buf.data()) { + os << static_cast(static_cast(ch) ^ key[n++]); + if (n >= 4) n = 0; + } + } + req->m_startUser = req->m_bufs.size(); + req->m_bufs.append(data.begin(), data.end()); + // don't send the user bufs as we copied their data + m_stream.Write(ArrayRef{req->m_bufs}.slice(0, req->m_startUser), + req); + } else { + // servers can just send the buffers directly without masking + req->m_startUser = req->m_bufs.size(); + req->m_bufs.append(data.begin(), data.end()); + m_stream.Write(req->m_bufs, req); + } +} diff --git a/wpiutil/src/main/native/cpp/WebSocketServer.cpp b/wpiutil/src/main/native/cpp/WebSocketServer.cpp new file mode 100644 index 0000000000..2bd8e38d48 --- /dev/null +++ b/wpiutil/src/main/native/cpp/WebSocketServer.cpp @@ -0,0 +1,142 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocketServer.h" + +#include "wpi/raw_uv_ostream.h" +#include "wpi/uv/Buffer.h" +#include "wpi/uv/Stream.h" + +using namespace wpi; + +WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) { + req.header.connect([this](StringRef name, StringRef value) { + if (name.equals_lower("host")) { + m_gotHost = true; + } else if (name.equals_lower("upgrade")) { + if (value.equals_lower("websocket")) m_websocket = true; + } else if (name.equals_lower("sec-websocket-key")) { + m_key = value; + } else if (name.equals_lower("sec-websocket-version")) { + m_version = value; + } else if (name.equals_lower("sec-websocket-protocol")) { + // Protocols are comma delimited, repeated headers add to list + SmallVector protocols; + value.split(protocols, ",", -1, false); + for (auto protocol : protocols) { + protocol = protocol.trim(); + if (!protocol.empty()) m_protocols.emplace_back(protocol); + } + } + }); + req.headersComplete.connect([&req, this](bool) { + if (req.IsUpgrade() && IsUpgrade()) upgrade(); + }); +} + +std::pair WebSocketServerHelper::MatchProtocol( + ArrayRef protocols) { + if (protocols.empty() && m_protocols.empty()) + return std::make_pair(true, StringRef{}); + for (auto protocol : protocols) { + for (auto&& clientProto : m_protocols) { + if (protocol == clientProto) return std::make_pair(true, protocol); + } + } + return std::make_pair(false, StringRef{}); +} + +WebSocketServer::WebSocketServer(uv::Stream& stream, + ArrayRef protocols, + const ServerOptions& options, + const private_init&) + : m_stream{stream}, + m_helper{m_req}, + m_protocols{protocols.begin(), protocols.end()}, + m_options{options} { + // Header handling + m_req.header.connect([this](StringRef name, StringRef value) { + if (name.equals_lower("host")) { + if (m_options.checkHost) { + if (!m_options.checkHost(value)) Abort(401, "Unrecognized Host"); + } + } + }); + m_req.url.connect([this](StringRef name) { + if (m_options.checkUrl) { + if (!m_options.checkUrl(name)) Abort(404, "Not Found"); + } + }); + m_req.headersComplete.connect([this](bool) { + // We only accept websocket connections + if (!m_helper.IsUpgrade() || !m_req.IsUpgrade()) + Abort(426, "Upgrade Required"); + }); + + // Handle upgrade event + m_helper.upgrade.connect([this] { + if (m_aborted) return; + + // Negotiate sub-protocol + SmallVector protocols{m_protocols.begin(), m_protocols.end()}; + StringRef protocol = m_helper.MatchProtocol(protocols).second; + + // Disconnect our header reader + m_headerConn.disconnect(); + + // Accepting the stream may destroy this (as it replaces the stream user + // data), so grab a shared pointer first. + auto self = shared_from_this(); + + // Accept the upgrade + auto ws = m_helper.Accept(m_stream, protocol); + + // Connect the websocket open event to our connected event. + ws->open.connect_extended([ self, s = ws.get() ](auto conn, StringRef) { + self->connected(self->m_req.GetUrl(), *s); + conn.disconnect(); // one-shot + }); + }); + + // Set up stream + stream.StartRead(); + m_headerConn = + stream.data.connect_connection([this](uv::Buffer& buf, size_t size) { + if (m_aborted) return; + m_req.Execute(StringRef{buf.base, size}); + if (m_req.HasError()) Abort(400, "Bad Request"); + }); + stream.error.connect([this](uv::Error) { m_stream.Close(); }); + stream.end.connect([this] { m_stream.Close(); }); +} + +std::shared_ptr WebSocketServer::Create( + uv::Stream& stream, ArrayRef protocols, + const ServerOptions& options) { + auto server = std::make_shared(stream, protocols, options, + private_init{}); + stream.SetData(server); + return server; +} + +void WebSocketServer::Abort(uint16_t code, StringRef reason) { + if (m_aborted) return; + m_aborted = true; + + // Build response + SmallVector bufs; + raw_uv_ostream os{bufs, 1024}; + + // Handle unsupported version + os << "HTTP/1.1 " << code << ' ' << reason << "\r\n"; + if (code == 426) os << "Upgrade: WebSocket\r\n"; + os << "\r\n"; + m_stream.Write(bufs, [this](auto bufs, uv::Error) { + for (auto& buf : bufs) buf.Deallocate(); + m_stream.Shutdown([this] { m_stream.Close(); }); + }); +} diff --git a/wpiutil/src/main/native/include/wpi/WebSocket.h b/wpiutil/src/main/native/include/wpi/WebSocket.h new file mode 100644 index 0000000000..418c134292 --- /dev/null +++ b/wpiutil/src/main/native/include/wpi/WebSocket.h @@ -0,0 +1,378 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#ifndef WPIUTIL_WPI_WEBSOCKET_H_ +#define WPIUTIL_WPI_WEBSOCKET_H_ + +#include + +#include +#include +#include +#include + +#include "wpi/ArrayRef.h" +#include "wpi/Signal.h" +#include "wpi/SmallVector.h" +#include "wpi/StringRef.h" +#include "wpi/Twine.h" +#include "wpi/uv/Buffer.h" +#include "wpi/uv/Error.h" +#include "wpi/uv/Timer.h" + +namespace wpi { + +namespace uv { +class Stream; +} // namespace uv + +/** + * RFC 6455 compliant WebSocket client and server implementation. + */ +class WebSocket : public std::enable_shared_from_this { + struct private_init {}; + + static constexpr uint8_t kOpCont = 0x00; + static constexpr uint8_t kOpText = 0x01; + static constexpr uint8_t kOpBinary = 0x02; + static constexpr uint8_t kOpClose = 0x08; + static constexpr uint8_t kOpPing = 0x09; + static constexpr uint8_t kOpPong = 0x0A; + static constexpr uint8_t kOpMask = 0x0F; + static constexpr uint8_t kFlagFin = 0x80; + static constexpr uint8_t kFlagMasking = 0x80; + static constexpr uint8_t kLenMask = 0x7f; + + public: + WebSocket(uv::Stream& stream, bool server, const private_init&); + WebSocket(const WebSocket&) = delete; + WebSocket(WebSocket&&) = delete; + WebSocket& operator=(const WebSocket&) = delete; + WebSocket& operator=(WebSocket&&) = delete; + ~WebSocket(); + + /** + * Connection states. + */ + enum State { + /** The connection is not yet open. */ + CONNECTING = 0, + /** The connection is open and ready to communicate. */ + OPEN, + /** The connection is in the process of closing. */ + CLOSING, + /** The connection failed. */ + FAILED, + /** The connection is closed. */ + CLOSED + }; + + /** + * Client connection options. + */ + struct ClientOptions { + ClientOptions() : handshakeTimeout{uv::Timer::Time::max()} {} + + /** Timeout for the handshake request. */ + uv::Timer::Time handshakeTimeout; + + /** Additional headers to include in handshake. */ + ArrayRef> extraHeaders; + }; + + /** + * Starts a client connection by performing the initial client handshake. + * An open event is emitted when the handshake completes. + * This sets the stream user data to the websocket. + * @param stream Connection stream + * @param uri The Request-URI to send + * @param host The host or host:port to send + * @param protocols The list of subprotocols + * @param options Handshake options + */ + static std::shared_ptr CreateClient( + uv::Stream& stream, const Twine& uri, const Twine& host, + ArrayRef protocols = ArrayRef{}, + const ClientOptions& options = ClientOptions{}); + + /** + * Starts a server connection by performing the initial server side handshake. + * This should be called after the HTTP headers have been received. + * An open event is emitted when the handshake completes. + * This sets the stream user data to the websocket. + * @param stream Connection stream + * @param key The value of the Sec-WebSocket-Key header field in the client + * request + * @param version The value of the Sec-WebSocket-Version header field in the + * client request + * @param protocol The subprotocol to send to the client (in the + * Sec-WebSocket-Protocol header field). + */ + static std::shared_ptr CreateServer( + uv::Stream& stream, StringRef key, StringRef version, + StringRef protocol = StringRef{}); + + /** + * Get connection state. + */ + State GetState() const { return m_state; } + + /** + * Return if the connection is open. Messages can only be sent on open + * connections. + */ + bool IsOpen() const { return m_state == OPEN; } + + /** + * Get the underlying stream. + */ + uv::Stream& GetStream() const { return m_stream; } + + /** + * Get the selected sub-protocol. Only valid in or after the open() event. + */ + StringRef GetProtocol() const { return m_protocol; } + + /** + * Set the maximum message size. Default is 128 KB. If configured to combine + * fragments this maximum applies to the entire message (all combined + * fragments). + * @param size Maximum message size in bytes + */ + void SetMaxMessageSize(size_t size) { m_maxMessageSize = size; } + + /** + * Set whether or not fragmented frames should be combined. Default is to + * combine. If fragmented frames are combined, the text and binary callbacks + * will always have the second parameter (fin) set to true. + * @param combine True if fragmented frames should be combined. + */ + void SetCombineFragments(bool combine) { m_combineFragments = combine; } + + /** + * Initiate a closing handshake. + * @param code A numeric status code (defaults to 1005, no status code) + * @param reason A human-readable string explaining why the connection is + * closing (optional). + */ + void Close(uint16_t code = 1005, const Twine& reason = Twine{}); + + /** + * Send a text message. + * @param data UTF-8 encoded data to send + * @param callback Callback which is invoked when the write completes. + */ + void SendText( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kFlagFin | kOpText, data, callback); + } + + /** + * Send a binary message. + * @param data Data to send + * @param callback Callback which is invoked when the write completes. + */ + void SendBinary( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kFlagFin | kOpBinary, data, callback); + } + + /** + * Send a text message fragment. This must be followed by one or more + * SendFragment() calls, where the last one has fin=True, to complete the + * message. + * @param data UTF-8 encoded data to send + * @param callback Callback which is invoked when the write completes. + */ + void SendTextFragment( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kOpText, data, callback); + } + + /** + * Send a text message fragment. This must be followed by one or more + * SendFragment() calls, where the last one has fin=True, to complete the + * message. + * @param data Data to send + * @param callback Callback which is invoked when the write completes. + */ + void SendBinaryFragment( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kOpBinary, data, callback); + } + + /** + * Send a continuation frame. This is used to send additional parts of a + * message started with SendTextFragment() or SendBinaryFragment(). + * @param data Data to send + * @param fin Set to true if this is the final fragment of the message + * @param callback Callback which is invoked when the write completes. + */ + void SendFragment( + ArrayRef data, bool fin, + std::function, uv::Error)> callback) { + Send(kOpCont | (fin ? kFlagFin : 0), data, callback); + } + + /** + * Send a ping frame with no data. + * @param callback Optional callback which is invoked when the ping frame + * write completes. + */ + void SendPing(std::function callback = nullptr) { + SendPing(ArrayRef{}, [callback](auto bufs, uv::Error err) { + if (callback) callback(err); + }); + } + + /** + * Send a ping frame. + * @param data Data to send in the ping frame + * @param callback Callback which is invoked when the ping frame + * write completes. + */ + void SendPing( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kFlagFin | kOpPing, data, callback); + } + + /** + * Send a pong frame with no data. + * @param callback Optional callback which is invoked when the pong frame + * write completes. + */ + void SendPong(std::function callback = nullptr) { + SendPong(ArrayRef{}, [callback](auto bufs, uv::Error err) { + if (callback) callback(err); + }); + } + + /** + * Send a pong frame. + * @param data Data to send in the pong frame + * @param callback Callback which is invoked when the pong frame + * write completes. + */ + void SendPong( + ArrayRef data, + std::function, uv::Error)> callback) { + Send(kFlagFin | kOpPong, data, callback); + } + + /** + * Fail the connection. + */ + void Fail(uint16_t code = 1002, const Twine& reason = "protocol error"); + + /** + * Forcibly close the connection. + */ + void Terminate(uint16_t code = 1006, const Twine& reason = "terminated"); + + /** + * Gets user-defined data. + * @return User-defined data if any, nullptr otherwise. + */ + template + std::shared_ptr GetData() const { + return std::static_pointer_cast(m_data); + } + + /** + * Sets user-defined data. + * @param data User-defined arbitrary data. + */ + void SetData(std::shared_ptr data) { m_data = std::move(data); } + + /** + * Open event. Emitted when the connection is open and ready to communicate. + * The parameter is the selected subprotocol. + */ + sig::Signal open; + + /** + * Close event. Emitted when the connection is closed. The first parameter + * is a numeric value indicating the status code explaining why the connection + * has been closed. The second parameter is a human-readable string + * explaining the reason why the connection has been closed. + */ + sig::Signal closed; + + /** + * Text message event. Emitted when a text message is received. + * The first parameter is the data, the second parameter is true if the + * data is the last fragment of the message. + */ + sig::Signal text; + + /** + * Binary message event. Emitted when a binary message is received. + * The first parameter is the data, the second parameter is true if the + * data is the last fragment of the message. + */ + sig::Signal, bool> binary; + + /** + * Ping event. Emitted when a ping message is received. + */ + sig::Signal> ping; + + /** + * Pong event. Emitted when a pong message is received. + */ + sig::Signal> pong; + + private: + // user data + std::shared_ptr m_data; + + // constructor parameters + uv::Stream& m_stream; + bool m_server; + + // subprotocol, set via constructor (server) or handshake (client) + std::string m_protocol; + + // user-settable configuration + size_t m_maxMessageSize = 128 * 1024; + bool m_combineFragments = true; + + // operating state + State m_state = CONNECTING; + + // incoming message buffers/state + SmallVector m_header; + size_t m_headerSize = 0; + SmallVector m_payload; + size_t m_frameStart = 0; + uint64_t m_frameSize = UINT64_MAX; + uint8_t m_fragmentOpcode = 0; + + // temporary data used only during client handshake + class ClientHandshakeData; + std::unique_ptr m_clientHandshake; + + void StartClient(const Twine& uri, const Twine& host, + ArrayRef protocols, const ClientOptions& options); + void StartServer(StringRef key, StringRef version, StringRef protocol); + void SendClose(uint16_t code, const Twine& reason); + void SetClosed(uint16_t code, const Twine& reason, bool failed = false); + void Shutdown(); + void HandleIncoming(uv::Buffer& buf, size_t size); + void Send( + uint8_t opcode, ArrayRef data, + std::function, uv::Error)> callback); +}; + +} // namespace wpi + +#endif // WPIUTIL_WPI_WEBSOCKET_H_ diff --git a/wpiutil/src/main/native/include/wpi/WebSocketServer.h b/wpiutil/src/main/native/include/wpi/WebSocketServer.h new file mode 100644 index 0000000000..bad3f68d8f --- /dev/null +++ b/wpiutil/src/main/native/include/wpi/WebSocketServer.h @@ -0,0 +1,147 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#ifndef WPIUTIL_WPI_WEBSOCKETSERVER_H_ +#define WPIUTIL_WPI_WEBSOCKETSERVER_H_ + +#include +#include +#include +#include + +#include "wpi/ArrayRef.h" +#include "wpi/HttpParser.h" +#include "wpi/Signal.h" +#include "wpi/SmallString.h" +#include "wpi/SmallVector.h" +#include "wpi/StringRef.h" +#include "wpi/WebSocket.h" + +namespace wpi { + +namespace uv { +class Stream; +} // namespace uv + +/** + * WebSocket HTTP server helper. Handles websocket-specific headers. User + * must provide the HttpParser. + */ +class WebSocketServerHelper { + public: + /** + * Constructor. + * @param req HttpParser for request + */ + explicit WebSocketServerHelper(HttpParser& req); + + /** + * Get whether or not this was a websocket upgrade. + * Only valid during and after the upgrade event. + */ + bool IsWebsocket() const { return m_websocket; } + + /** + * Try to find a match to the list of sub-protocols provided by the client. + * The list is priority ordered, so the first match wins. + * Only valid during and after the upgrade event. + * @param protocols Acceptable protocols + * @return Pair; first item is true if a match was made, false if not. + * Second item is the matched protocol if a match was made, otherwise + * is empty. + */ + std::pair MatchProtocol(ArrayRef protocols); + + /** + * Accept the upgrade. Disconnect other readers (such as the HttpParser + * reader) before calling this. See also WebSocket::CreateServer(). + * @param stream Connection stream + * @param protocol The subprotocol to send to the client + */ + std::shared_ptr Accept(uv::Stream& stream, + StringRef protocol = StringRef{}) { + return WebSocket::CreateServer(stream, m_key, m_version, protocol); + } + + bool IsUpgrade() const { return m_gotHost && m_websocket; } + + /** + * Upgrade event. Call Accept() to accept the upgrade. + */ + sig::Signal<> upgrade; + + private: + bool m_gotHost = false; + bool m_websocket = false; + SmallVector m_protocols; + SmallString<64> m_key; + SmallString<16> m_version; +}; + +/** + * Dedicated WebSocket server. + */ +class WebSocketServer : public std::enable_shared_from_this { + struct private_init {}; + + public: + /** + * Server options. + */ + struct ServerOptions { + /** + * Checker for URL. Return true if URL should be accepted. By default all + * URLs are accepted. + */ + std::function checkUrl; + + /** + * Checker for Host header. Return true if Host should be accepted. By + * default all hosts are accepted. + */ + std::function checkHost; + }; + + /** + * Private constructor. + */ + WebSocketServer(uv::Stream& stream, ArrayRef protocols, + const ServerOptions& options, const private_init&); + + /** + * Starts a dedicated WebSocket server on the provided connection. The + * connection should be an accepted client stream. + * This also sets the stream user data to the socket server. + * A connected event is emitted when the connection is opened. + * @param stream Connection stream + * @param protocols Acceptable subprotocols + * @param options Handshake options + */ + static std::shared_ptr Create( + uv::Stream& stream, ArrayRef protocols = ArrayRef{}, + const ServerOptions& options = ServerOptions{}); + + /** + * Connected event. First parameter is the URL, second is the websocket. + */ + sig::Signal connected; + + private: + uv::Stream& m_stream; + HttpParser m_req{HttpParser::kRequest}; + WebSocketServerHelper m_helper; + SmallVector m_protocols; + ServerOptions m_options; + bool m_aborted = false; + sig::Connection m_headerConn; + + void Abort(uint16_t code, StringRef reason); +}; + +} // namespace wpi + +#endif // WPIUTIL_WPI_WEBSOCKETSERVER_H_ diff --git a/wpiutil/src/test/native/cpp/WebSocketClientTest.cpp b/wpiutil/src/test/native/cpp/WebSocketClientTest.cpp new file mode 100644 index 0000000000..87e698a151 --- /dev/null +++ b/wpiutil/src/test/native/cpp/WebSocketClientTest.cpp @@ -0,0 +1,299 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocket.h" // NOLINT(build/include_order) + +#include "WebSocketTest.h" +#include "wpi/Base64.h" +#include "wpi/HttpParser.h" +#include "wpi/SmallString.h" +#include "wpi/raw_uv_ostream.h" +#include "wpi/sha1.h" + +namespace wpi { + +class WebSocketClientTest : public WebSocketTest { + public: + WebSocketClientTest() { + // Bare bones server + req.header.connect([this](StringRef name, StringRef value) { + // save key (required for valid response) + if (name.equals_lower("sec-websocket-key")) clientKey = value; + }); + req.headersComplete.connect([this](bool) { + // send response + SmallVector bufs; + raw_uv_ostream os{bufs, 4096}; + os << "HTTP/1.1 101 Switching Protocols\r\n"; + os << "Upgrade: websocket\r\n"; + os << "Connection: Upgrade\r\n"; + + // accept hash + SHA1 hash; + hash.Update(clientKey); + hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + if (mockBadAccept) hash.Update("1"); + SmallString<64> hashBuf; + SmallString<64> acceptBuf; + os << "Sec-WebSocket-Accept: " + << Base64Encode(hash.Final(hashBuf), acceptBuf) << "\r\n"; + + if (!mockProtocol.empty()) + os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n"; + + os << "\r\n"; + + conn->Write(bufs, [](auto bufs, uv::Error) { + for (auto& buf : bufs) buf.Deallocate(); + }); + + serverHeadersDone = true; + if (connected) connected(); + }); + + serverPipe->Listen([this] { + conn = serverPipe->Accept(); + conn->StartRead(); + conn->data.connect([this](uv::Buffer& buf, size_t size) { + StringRef data{buf.base, size}; + if (!serverHeadersDone) { + data = req.Execute(data); + if (req.HasError()) Finish(); + ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); + if (data.empty()) return; + } + wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end()); + }); + conn->end.connect([this] { Finish(); }); + }); + } + + bool mockBadAccept = false; + std::vector wireData; + std::shared_ptr conn; + HttpParser req{HttpParser::kRequest}; + SmallString<64> clientKey; + std::string mockProtocol; + bool serverHeadersDone = false; + std::function connected; +}; + +TEST_F(WebSocketClientTest, Open) { + int gotOpen = 0; + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef reason) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << " Reason: " << reason; + }); + ws->open.connect([&](StringRef protocol) { + ++gotOpen; + Finish(); + ASSERT_TRUE(protocol.empty()); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotOpen, 1); +} + +TEST_F(WebSocketClientTest, BadAccept) { + int gotClosed = 0; + + mockBadAccept = true; + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef msg) { + Finish(); + ++gotClosed; + ASSERT_EQ(code, 1002) << "Message: " << msg; + }); + ws->open.connect([&](StringRef protocol) { + Finish(); + FAIL() << "Got open"; + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketClientTest, ProtocolGood) { + int gotOpen = 0; + + mockProtocol = "myProtocol"; + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient( + *clientPipe, "/test", pipeName, + ArrayRef{"myProtocol", "myProtocol2"}); + ws->closed.connect([&](uint16_t code, StringRef msg) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << "Message: " << msg; + }); + ws->open.connect([&](StringRef protocol) { + ++gotOpen; + Finish(); + ASSERT_EQ(protocol, "myProtocol"); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotOpen, 1); +} + +TEST_F(WebSocketClientTest, ProtocolRespNotReq) { + int gotClosed = 0; + + mockProtocol = "myProtocol"; + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef msg) { + Finish(); + ++gotClosed; + ASSERT_EQ(code, 1003) << "Message: " << msg; + }); + ws->open.connect([&](StringRef protocol) { + Finish(); + FAIL() << "Got open"; + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketClientTest, ProtocolReqNotResp) { + int gotClosed = 0; + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, + StringRef{"myProtocol"}); + ws->closed.connect([&](uint16_t code, StringRef msg) { + Finish(); + ++gotClosed; + ASSERT_EQ(code, 1002) << "Message: " << msg; + }); + ws->open.connect([&](StringRef protocol) { + Finish(); + FAIL() << "Got open"; + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotClosed, 1); +} + +// +// Send and receive data. Most of these cases are tested in +// WebSocketServerTest, so only spot check differences like masking. +// + +class WebSocketClientDataTest : public WebSocketClientTest, + public ::testing::WithParamInterface { + public: + WebSocketClientDataTest() { + clientPipe->Connect(pipeName, [&] { + ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + if (setupWebSocket) setupWebSocket(); + }); + } + + std::function setupWebSocket; + std::shared_ptr ws; +}; + +INSTANTIATE_TEST_CASE_P(WebSocketClientDataTests, WebSocketClientDataTest, + ::testing::Values(0, 1, 125, 126, 65535, 65536), ); + +TEST_P(WebSocketClientDataTest, SendBinary) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { + ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) { + ++gotCallback; + ws->Terminate(); + ASSERT_FALSE(bufs.empty()); + ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); + }); + }); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x02, true, true, data); + AdjustMasking(wireData); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketClientDataTest, ReceiveBinary) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->binary.connect([&](ArrayRef inData, bool fin) { + ++gotCallback; + ws->Terminate(); + ASSERT_TRUE(fin); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(data, recvData); + }); + }; + auto message = BuildMessage(0x02, true, false, data); + connected = [&] { + conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }; + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// +// The client must close the connection if a masked frame is received. +// + +TEST_P(WebSocketClientDataTest, ReceiveMasked) { + int gotCallback = 0; + std::vector data(GetParam(), ' '); + setupWebSocket = [&] { + ws->text.connect([&](StringRef, bool) { + ws->Terminate(); + FAIL() << "Should not have gotten masked message"; + }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, true, true, data); + connected = [&] { + conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }; + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +} // namespace wpi diff --git a/wpiutil/src/test/native/cpp/WebSocketIntegrationTest.cpp b/wpiutil/src/test/native/cpp/WebSocketIntegrationTest.cpp new file mode 100644 index 0000000000..f51e8fcb64 --- /dev/null +++ b/wpiutil/src/test/native/cpp/WebSocketIntegrationTest.cpp @@ -0,0 +1,148 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocketServer.h" // NOLINT(build/include_order) + +#include "WebSocketTest.h" +#include "wpi/HttpParser.h" +#include "wpi/SmallString.h" + +namespace wpi { + +class WebSocketIntegrationTest : public WebSocketTest {}; + +TEST_F(WebSocketIntegrationTest, Open) { + int gotServerOpen = 0; + int gotClientOpen = 0; + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto server = WebSocketServer::Create(*conn); + server->connected.connect([&](StringRef url, WebSocket&) { + ++gotServerOpen; + ASSERT_EQ(url, "/test"); + }); + }); + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef reason) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << " Reason: " << reason; + }); + ws->open.connect([&, s = ws.get() ](StringRef) { + ++gotClientOpen; + s->Close(); + }); + }); + + loop->Run(); + + ASSERT_EQ(gotServerOpen, 1); + ASSERT_EQ(gotClientOpen, 1); +} + +TEST_F(WebSocketIntegrationTest, Protocol) { + int gotServerOpen = 0; + int gotClientOpen = 0; + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto server = WebSocketServer::Create(*conn, {"proto1", "proto2"}); + server->connected.connect([&](StringRef, WebSocket& ws) { + ++gotServerOpen; + ASSERT_EQ(ws.GetProtocol(), "proto1"); + }); + }); + + clientPipe->Connect(pipeName, [&] { + auto ws = + WebSocket::CreateClient(*clientPipe, "/test", pipeName, {"proto1"}); + ws->closed.connect([&](uint16_t code, StringRef reason) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << " Reason: " << reason; + }); + ws->open.connect([&, s = ws.get() ](StringRef protocol) { + ++gotClientOpen; + s->Close(); + ASSERT_EQ(protocol, "proto1"); + }); + }); + + loop->Run(); + + ASSERT_EQ(gotServerOpen, 1); + ASSERT_EQ(gotClientOpen, 1); +} + +TEST_F(WebSocketIntegrationTest, ServerSendBinary) { + int gotData = 0; + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto server = WebSocketServer::Create(*conn); + server->connected.connect([&](StringRef, WebSocket& ws) { + ws.SendBinary(uv::Buffer{"\x03\x04", 2}, [&](auto, uv::Error) {}); + ws.Close(); + }); + }); + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef reason) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << " Reason: " << reason; + }); + ws->binary.connect([&](ArrayRef data, bool) { + ++gotData; + std::vector recvData{data.begin(), data.end()}; + std::vector expectData{0x03, 0x04}; + ASSERT_EQ(recvData, expectData); + }); + }); + + loop->Run(); + + ASSERT_EQ(gotData, 1); +} + +TEST_F(WebSocketIntegrationTest, ClientSendText) { + int gotData = 0; + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto server = WebSocketServer::Create(*conn); + server->connected.connect([&](StringRef, WebSocket& ws) { + ws.text.connect([&](StringRef data, bool) { + ++gotData; + ASSERT_EQ(data, "hello"); + }); + }); + }); + + clientPipe->Connect(pipeName, [&] { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + ws->closed.connect([&](uint16_t code, StringRef reason) { + Finish(); + if (code != 1005 && code != 1006) + FAIL() << "Code: " << code << " Reason: " << reason; + }); + ws->open.connect([&, s = ws.get() ](StringRef) { + s->SendText(uv::Buffer{"hello"}, [&](auto, uv::Error) {}); + s->Close(); + }); + }); + + loop->Run(); + + ASSERT_EQ(gotData, 1); +} + +} // namespace wpi diff --git a/wpiutil/src/test/native/cpp/WebSocketServerTest.cpp b/wpiutil/src/test/native/cpp/WebSocketServerTest.cpp new file mode 100644 index 0000000000..7d723e3831 --- /dev/null +++ b/wpiutil/src/test/native/cpp/WebSocketServerTest.cpp @@ -0,0 +1,736 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocket.h" // NOLINT(build/include_order) + +#include "WebSocketTest.h" +#include "wpi/Base64.h" +#include "wpi/HttpParser.h" +#include "wpi/SmallString.h" +#include "wpi/raw_uv_ostream.h" +#include "wpi/sha1.h" + +namespace wpi { + +class WebSocketServerTest : public WebSocketTest { + public: + WebSocketServerTest() { + resp.headersComplete.connect([this](bool) { headersDone = true; }); + + serverPipe->Listen([this]() { + auto conn = serverPipe->Accept(); + ws = WebSocket::CreateServer(*conn, "foo", "13"); + if (setupWebSocket) setupWebSocket(); + }); + clientPipe->Connect(pipeName, [this]() { + clientPipe->StartRead(); + clientPipe->data.connect([this](uv::Buffer& buf, size_t size) { + StringRef data{buf.base, size}; + if (!headersDone) { + data = resp.Execute(data); + if (resp.HasError()) Finish(); + ASSERT_EQ(resp.GetError(), HPE_OK) + << http_errno_name(resp.GetError()); + if (data.empty()) return; + } + wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end()); + if (handleData) handleData(data); + }); + clientPipe->end.connect([this]() { Finish(); }); + }); + } + + std::function setupWebSocket; + std::function handleData; + std::vector wireData; + std::shared_ptr ws; + HttpParser resp{HttpParser::kResponse}; + bool headersDone = false; +}; + +// +// Terminate closes the endpoint but doesn't send a close frame. +// + +TEST_F(WebSocketServerTest, Terminate) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Terminate(); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1006) << "reason: " << reason; + }); + }; + + loop->Run(); + + ASSERT_TRUE(wireData.empty()); // terminate doesn't send data + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, TerminateCode) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Terminate(1000); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000) << "reason: " << reason; + }); + }; + + loop->Run(); + + ASSERT_TRUE(wireData.empty()); // terminate doesn't send data + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, TerminateReason) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Terminate(1000, "reason"); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000); + ASSERT_EQ(reason, "reason"); + }); + }; + + loop->Run(); + + ASSERT_TRUE(wireData.empty()); // terminate doesn't send data + ASSERT_EQ(gotClosed, 1); +} + +// +// Close() sends a close frame. +// + +TEST_F(WebSocketServerTest, CloseBasic) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Close(); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1005) << "reason: " << reason; + }); + }; + // need to respond with close for server to finish shutdown + auto message = BuildMessage(0x08, true, true, {}); + handleData = [&](StringRef) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x08, true, false, {}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, CloseCode) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Close(1000); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000) << "reason: " << reason; + }); + }; + // need to respond with close for server to finish shutdown + auto message = BuildMessage(0x08, true, true, {0x03u, 0xe8u}); + handleData = [&](StringRef) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x08, true, false, {0x03u, 0xe8u}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, CloseReason) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { ws->Close(1000, "hangup"); }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000); + ASSERT_EQ(reason, "hangup"); + }); + }; + // need to respond with close for server to finish shutdown + auto message = BuildMessage(0x08, true, true, + {0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'}); + handleData = [&](StringRef) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x08, true, false, + {0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +// +// Receiving a close frame results in closure and echoing the close frame. +// + +TEST_F(WebSocketServerTest, ReceiveCloseBasic) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1005) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x08, true, true, {}); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + // the endpoint should echo the message + auto expectData = BuildMessage(0x08, true, false, {}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, ReceiveCloseCode) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x08, true, true, {0x03u, 0xe8u}); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + // the endpoint should echo the message + auto expectData = BuildMessage(0x08, true, false, {0x03u, 0xe8u}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketServerTest, ReceiveCloseReason) { + int gotClosed = 0; + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotClosed; + ASSERT_EQ(code, 1000); + ASSERT_EQ(reason, "hangup"); + }); + }; + auto message = BuildMessage(0x08, true, true, + {0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'}); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + // the endpoint should echo the message + auto expectData = BuildMessage(0x08, true, false, + {0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'}); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotClosed, 1); +} + +// +// If an unknown opcode is received, the receiving endpoint MUST _Fail the +// WebSocket Connection_. +// + +class WebSocketServerBadOpcodeTest + : public WebSocketServerTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P(WebSocketServerBadOpcodeTests, + WebSocketServerBadOpcodeTest, + ::testing::Values(3, 4, 5, 6, 7, 0xb, 0xc, 0xd, 0xe, + 0xf), ); + +TEST_P(WebSocketServerBadOpcodeTest, Receive) { + int gotCallback = 0; + std::vector data(4, 0x03); + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(GetParam(), true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// +// Control frames themselves MUST NOT be fragmented. +// + +class WebSocketServerControlFrameTest + : public WebSocketServerTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P(WebSocketServerControlFrameTests, + WebSocketServerControlFrameTest, + ::testing::Values(0x8, 0x9, 0xa), ); + +TEST_P(WebSocketServerControlFrameTest, ReceiveFragment) { + int gotCallback = 0; + std::vector data(4, 0x03); + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(GetParam(), false, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// +// A fragmented message consists of a single frame with the FIN bit +// clear and an opcode other than 0, followed by zero or more frames +// with the FIN bit clear and the opcode set to 0, and terminated by +// a single frame with the FIN bit set and an opcode of 0. +// + +// No previous message +TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFrame) { + int gotCallback = 0; + std::vector data(4, 0x03); + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x00, false, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// No previous message with FIN=1. +TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFragment) { + int gotCallback = 0; + std::vector data(4, 0x03); + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, true, true, {}); // FIN=1 + auto message2 = BuildMessage(0x00, false, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write({uv::Buffer(message), uv::Buffer(message2)}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// Incomplete fragment +TEST_F(WebSocketServerTest, ReceiveFragmentInvalidIncomplete) { + int gotCallback = 0; + setupWebSocket = [&] { + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, false, true, {}); + auto message2 = BuildMessage(0x00, false, true, {}); + auto message3 = BuildMessage(0x01, true, true, {}); + resp.headersComplete.connect([&](bool) { + clientPipe->Write( + {uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// Normally fragments are combined into a single callback +TEST_F(WebSocketServerTest, ReceiveFragment) { + int gotCallback = 0; + + std::vector data(4, 0x03); + std::vector data2(4, 0x04); + std::vector data3(4, 0x05); + std::vector combData{data}; + combData.insert(combData.end(), data2.begin(), data2.end()); + combData.insert(combData.end(), data3.begin(), data3.end()); + + setupWebSocket = [&] { + ws->binary.connect([&](ArrayRef inData, bool fin) { + ++gotCallback; + ws->Terminate(); + ASSERT_TRUE(fin); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(combData, recvData); + }); + }; + + auto message = BuildMessage(0x02, false, true, data); + auto message2 = BuildMessage(0x00, false, true, data2); + auto message3 = BuildMessage(0x00, true, true, data3); + resp.headersComplete.connect([&](bool) { + clientPipe->Write( + {uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// But can be configured for multiple callbacks +TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) { + int gotCallback = 0; + + std::vector data(4, 0x03); + std::vector data2(4, 0x04); + std::vector data3(4, 0x05); + std::vector combData{data}; + combData.insert(combData.end(), data2.begin(), data2.end()); + combData.insert(combData.end(), data3.begin(), data3.end()); + + setupWebSocket = [&] { + ws->SetCombineFragments(false); + ws->binary.connect([&](ArrayRef inData, bool fin) { + std::vector recvData{inData.begin(), inData.end()}; + switch (++gotCallback) { + case 1: + ASSERT_FALSE(fin); + ASSERT_EQ(data, recvData); + break; + case 2: + ASSERT_FALSE(fin); + ASSERT_EQ(data2, recvData); + break; + case 3: + ws->Terminate(); + ASSERT_TRUE(fin); + ASSERT_EQ(data3, recvData); + break; + default: + FAIL() << "too many callbacks"; + break; + } + }); + }; + + auto message = BuildMessage(0x02, false, true, data); + auto message2 = BuildMessage(0x00, false, true, data2); + auto message3 = BuildMessage(0x00, true, true, data3); + resp.headersComplete.connect([&](bool) { + clientPipe->Write( + {uv::Buffer(message), uv::Buffer(message2), uv::Buffer(message3)}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 3); +} + +// +// Maximum message size is limited. +// + +// Single message +TEST_F(WebSocketServerTest, ReceiveTooLarge) { + int gotCallback = 0; + std::vector data(2048, 0x03u); + setupWebSocket = [&] { + ws->SetMaxMessageSize(1024); + ws->binary.connect([&](auto, bool) { + ws->Terminate(); + FAIL() << "Should not have gotten unmasked message"; + }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1009) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// Applied across fragments if combining +TEST_F(WebSocketServerTest, ReceiveTooLargeFragmented) { + int gotCallback = 0; + std::vector data(768, 0x03u); + setupWebSocket = [&] { + ws->SetMaxMessageSize(1024); + ws->binary.connect([&](auto, bool) { + ws->Terminate(); + FAIL() << "Should not have gotten unmasked message"; + }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1009) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, false, true, data); + auto message2 = BuildMessage(0x00, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write({uv::Buffer(message), uv::Buffer(message2)}, + [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// +// Send and receive data. +// + +class WebSocketServerDataTest : public WebSocketServerTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P(WebSocketServerDataTests, WebSocketServerDataTest, + ::testing::Values(0, 1, 125, 126, 65535, 65536), ); + +TEST_P(WebSocketServerDataTest, SendText) { + int gotCallback = 0; + std::vector data(GetParam(), ' '); + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { + ws->SendText(uv::Buffer(data), [&](auto bufs, uv::Error) { + ++gotCallback; + ws->Terminate(); + ASSERT_FALSE(bufs.empty()); + ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); + }); + }); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x01, true, false, data); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, SendBinary) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { + ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) { + ++gotCallback; + ws->Terminate(); + ASSERT_FALSE(bufs.empty()); + ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); + }); + }); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x02, true, false, data); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, SendPing) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { + ws->SendPing(uv::Buffer(data), [&](auto bufs, uv::Error) { + ++gotCallback; + ws->Terminate(); + ASSERT_FALSE(bufs.empty()); + ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); + }); + }); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x09, true, false, data); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, SendPong) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->open.connect([&](StringRef) { + ws->SendPong(uv::Buffer(data), [&](auto bufs, uv::Error) { + ++gotCallback; + ws->Terminate(); + ASSERT_FALSE(bufs.empty()); + ASSERT_EQ(bufs[0].base, reinterpret_cast(data.data())); + }); + }); + }; + + loop->Run(); + + auto expectData = BuildMessage(0x0a, true, false, data); + ASSERT_EQ(wireData, expectData); + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, ReceiveText) { + int gotCallback = 0; + std::vector data(GetParam(), ' '); + setupWebSocket = [&] { + ws->text.connect([&](StringRef inData, bool fin) { + ++gotCallback; + ws->Terminate(); + ASSERT_TRUE(fin); + std::vector recvData; + recvData.insert(recvData.end(), inData.bytes_begin(), inData.bytes_end()); + ASSERT_EQ(data, recvData); + }); + }; + auto message = BuildMessage(0x01, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, ReceiveBinary) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->binary.connect([&](ArrayRef inData, bool fin) { + ++gotCallback; + ws->Terminate(); + ASSERT_TRUE(fin); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(data, recvData); + }); + }; + auto message = BuildMessage(0x02, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, ReceivePing) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->ping.connect([&](ArrayRef inData) { + ++gotCallback; + ws->Terminate(); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(data, recvData); + }); + }; + auto message = BuildMessage(0x09, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +TEST_P(WebSocketServerDataTest, ReceivePong) { + int gotCallback = 0; + std::vector data(GetParam(), 0x03u); + setupWebSocket = [&] { + ws->pong.connect([&](ArrayRef inData) { + ++gotCallback; + ws->Terminate(); + std::vector recvData{inData.begin(), inData.end()}; + ASSERT_EQ(data, recvData); + }); + }; + auto message = BuildMessage(0x0a, true, true, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +// +// The server must close the connection if an unmasked frame is received. +// + +TEST_P(WebSocketServerDataTest, ReceiveUnmasked) { + int gotCallback = 0; + std::vector data(GetParam(), ' '); + setupWebSocket = [&] { + ws->text.connect([&](StringRef, bool) { + ws->Terminate(); + FAIL() << "Should not have gotten unmasked message"; + }); + ws->closed.connect([&](uint16_t code, StringRef reason) { + ++gotCallback; + ASSERT_EQ(code, 1002) << "reason: " << reason; + }); + }; + auto message = BuildMessage(0x01, true, false, data); + resp.headersComplete.connect([&](bool) { + clientPipe->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {}); + }); + + loop->Run(); + + ASSERT_EQ(gotCallback, 1); +} + +} // namespace wpi diff --git a/wpiutil/src/test/native/cpp/WebSocketTest.cpp b/wpiutil/src/test/native/cpp/WebSocketTest.cpp new file mode 100644 index 0000000000..1062017b49 --- /dev/null +++ b/wpiutil/src/test/native/cpp/WebSocketTest.cpp @@ -0,0 +1,345 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "wpi/WebSocket.h" // NOLINT(build/include_order) + +#include "WebSocketTest.h" + +#include "wpi/HttpParser.h" + +namespace wpi { + +#ifdef _WIN32 +const char* WebSocketTest::pipeName = "\\\\.\\pipe\\websocket-unit-test"; +#else +const char* WebSocketTest::pipeName = "/tmp/websocket-unit-test"; +#endif +const uint8_t WebSocketTest::testMask[4] = {0x11, 0x22, 0x33, 0x44}; + +void WebSocketTest::SetUpTestCase() { +#ifndef _WIN32 + unlink(pipeName); +#endif +} + +std::vector WebSocketTest::BuildHeader(uint8_t opcode, bool fin, + bool masking, uint64_t len) { + std::vector data; + data.push_back(opcode | (fin ? 0x80u : 0x00u)); + if (len < 126) { + data.push_back(len | (masking ? 0x80 : 0x00u)); + } else if (len < 65536) { + data.push_back(126u | (masking ? 0x80 : 0x00u)); + data.push_back(len >> 8); + data.push_back(len & 0xff); + } else { + data.push_back(127u | (masking ? 0x80u : 0x00u)); + for (int i = 56; i >= 0; i -= 8) data.push_back((len >> i) & 0xff); + } + if (masking) data.insert(data.end(), &testMask[0], &testMask[4]); + return data; +} + +std::vector WebSocketTest::BuildMessage(uint8_t opcode, bool fin, + bool masking, + ArrayRef data) { + auto finalData = BuildHeader(opcode, fin, masking, data.size()); + size_t headerSize = finalData.size(); + finalData.insert(finalData.end(), data.begin(), data.end()); + if (masking) { + uint8_t mask[4] = {finalData[headerSize - 4], finalData[headerSize - 3], + finalData[headerSize - 2], finalData[headerSize - 1]}; + int n = 0; + for (size_t i = headerSize, end = finalData.size(); i < end; ++i) { + finalData[i] ^= mask[n++]; + if (n >= 4) n = 0; + } + } + return finalData; +} + +// If the message is masked, changes the mask to match the mask set by +// BuildHeader() by unmasking and remasking. +void WebSocketTest::AdjustMasking(MutableArrayRef message) { + if (message.size() < 2) return; + if ((message[1] & 0x80) == 0) return; // not masked + size_t maskPos; + uint8_t len = message[1] & 0x7f; + if (len == 126) + maskPos = 4; + else if (len == 127) + maskPos = 10; + else + maskPos = 2; + uint8_t mask[4] = {message[maskPos], message[maskPos + 1], + message[maskPos + 2], message[maskPos + 3]}; + message[maskPos] = testMask[0]; + message[maskPos + 1] = testMask[1]; + message[maskPos + 2] = testMask[2]; + message[maskPos + 3] = testMask[3]; + int n = 0; + for (auto& ch : message.slice(maskPos + 4)) { + ch ^= mask[n] ^ testMask[n]; + if (++n >= 4) n = 0; + } +} + +TEST_F(WebSocketTest, CreateClientBasic) { + int gotHost = 0; + int gotUpgrade = 0; + int gotConnection = 0; + int gotKey = 0; + int gotVersion = 0; + + HttpParser req{HttpParser::kRequest}; + req.url.connect([](StringRef url) { ASSERT_EQ(url, "/test"); }); + req.header.connect([&](StringRef name, StringRef value) { + if (name.equals_lower("host")) { + ASSERT_EQ(value, pipeName); + ++gotHost; + } else if (name.equals_lower("upgrade")) { + ASSERT_EQ(value, "websocket"); + ++gotUpgrade; + } else if (name.equals_lower("connection")) { + ASSERT_EQ(value, "Upgrade"); + ++gotConnection; + } else if (name.equals_lower("sec-websocket-key")) { + ++gotKey; + } else if (name.equals_lower("sec-websocket-version")) { + ASSERT_EQ(value, "13"); + ++gotVersion; + } else { + FAIL() << "unexpected header " << name.str(); + } + }); + req.headersComplete.connect([&](bool) { Finish(); }); + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + conn->StartRead(); + conn->data.connect([&](uv::Buffer& buf, size_t size) { + req.Execute(StringRef{buf.base, size}); + if (req.HasError()) Finish(); + ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); + }); + }); + clientPipe->Connect(pipeName, [&]() { + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotHost, 1); + ASSERT_EQ(gotUpgrade, 1); + ASSERT_EQ(gotConnection, 1); + ASSERT_EQ(gotKey, 1); + ASSERT_EQ(gotVersion, 1); +} + +TEST_F(WebSocketTest, CreateClientExtraHeaders) { + int gotExtra1 = 0; + int gotExtra2 = 0; + HttpParser req{HttpParser::kRequest}; + req.header.connect([&](StringRef name, StringRef value) { + if (name.equals("Extra1")) { + ASSERT_EQ(value, "Data1"); + ++gotExtra1; + } else if (name.equals("Extra2")) { + ASSERT_EQ(value, "Data2"); + ++gotExtra2; + } + }); + req.headersComplete.connect([&](bool) { Finish(); }); + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + conn->StartRead(); + conn->data.connect([&](uv::Buffer& buf, size_t size) { + req.Execute(StringRef{buf.base, size}); + if (req.HasError()) Finish(); + ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); + }); + }); + clientPipe->Connect(pipeName, [&]() { + WebSocket::ClientOptions options; + SmallVector, 4> extraHeaders; + extraHeaders.emplace_back("Extra1", "Data1"); + extraHeaders.emplace_back("Extra2", "Data2"); + options.extraHeaders = extraHeaders; + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, + ArrayRef{}, options); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotExtra1, 1); + ASSERT_EQ(gotExtra2, 1); +} + +TEST_F(WebSocketTest, CreateClientTimeout) { + int gotClosed = 0; + serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); }); + clientPipe->Connect(pipeName, [&]() { + WebSocket::ClientOptions options; + options.handshakeTimeout = uv::Timer::Time{100}; + auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName, + ArrayRef{}, options); + ws->closed.connect([&](uint16_t code, StringRef) { + Finish(); + ++gotClosed; + ASSERT_EQ(code, 1006); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotClosed, 1); +} + +TEST_F(WebSocketTest, CreateServerBasic) { + int gotStatus = 0; + int gotUpgrade = 0; + int gotConnection = 0; + int gotAccept = 0; + int gotOpen = 0; + + HttpParser resp{HttpParser::kResponse}; + resp.status.connect([&](StringRef status) { + ++gotStatus; + ASSERT_EQ(resp.GetStatusCode(), 101u) << "status: " << status; + }); + resp.header.connect([&](StringRef name, StringRef value) { + if (name.equals_lower("upgrade")) { + ASSERT_EQ(value, "websocket"); + ++gotUpgrade; + } else if (name.equals_lower("connection")) { + ASSERT_EQ(value, "Upgrade"); + ++gotConnection; + } else if (name.equals_lower("sec-websocket-accept")) { + ++gotAccept; + } else { + FAIL() << "unexpected header " << name.str(); + } + }); + resp.headersComplete.connect([&](bool) { Finish(); }); + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto ws = WebSocket::CreateServer(*conn, "foo", "13"); + ws->open.connect([&](StringRef protocol) { + ++gotOpen; + ASSERT_TRUE(protocol.empty()); + }); + }); + clientPipe->Connect(pipeName, [&] { + clientPipe->StartRead(); + clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { + resp.Execute(StringRef{buf.base, size}); + if (resp.HasError()) Finish(); + ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotStatus, 1); + ASSERT_EQ(gotUpgrade, 1); + ASSERT_EQ(gotConnection, 1); + ASSERT_EQ(gotAccept, 1); + ASSERT_EQ(gotOpen, 1); +} + +TEST_F(WebSocketTest, CreateServerProtocol) { + int gotProtocol = 0; + int gotOpen = 0; + + HttpParser resp{HttpParser::kResponse}; + resp.header.connect([&](StringRef name, StringRef value) { + if (name.equals_lower("sec-websocket-protocol")) { + ++gotProtocol; + ASSERT_EQ(value, "myProtocol"); + } + }); + resp.headersComplete.connect([&](bool) { Finish(); }); + + serverPipe->Listen([&]() { + auto conn = serverPipe->Accept(); + auto ws = WebSocket::CreateServer(*conn, "foo", "13", "myProtocol"); + ws->open.connect([&](StringRef protocol) { + ++gotOpen; + ASSERT_EQ(protocol, "myProtocol"); + }); + }); + clientPipe->Connect(pipeName, [&] { + clientPipe->StartRead(); + clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { + resp.Execute(StringRef{buf.base, size}); + if (resp.HasError()) Finish(); + ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotProtocol, 1); + ASSERT_EQ(gotOpen, 1); +} + +TEST_F(WebSocketTest, CreateServerBadVersion) { + int gotStatus = 0; + int gotVersion = 0; + int gotUpgrade = 0; + + HttpParser resp{HttpParser::kResponse}; + resp.status.connect([&](StringRef status) { + ++gotStatus; + ASSERT_EQ(resp.GetStatusCode(), 426u) << "status: " << status; + }); + resp.header.connect([&](StringRef name, StringRef value) { + if (name.equals_lower("sec-websocket-version")) { + ++gotVersion; + ASSERT_EQ(value, "13"); + } else if (name.equals_lower("upgrade")) { + ++gotUpgrade; + ASSERT_EQ(value, "WebSocket"); + } else { + FAIL() << "unexpected header " << name.str(); + } + }); + resp.headersComplete.connect([&](bool) { Finish(); }); + + serverPipe->Listen([&] { + auto conn = serverPipe->Accept(); + auto ws = WebSocket::CreateServer(*conn, "foo", "14"); + ws->open.connect([&](StringRef) { + Finish(); + FAIL(); + }); + }); + clientPipe->Connect(pipeName, [&] { + clientPipe->StartRead(); + clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { + resp.Execute(StringRef{buf.base, size}); + if (resp.HasError()) Finish(); + ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); + }); + }); + + loop->Run(); + + if (HasFatalFailure()) return; + ASSERT_EQ(gotStatus, 1); + ASSERT_EQ(gotVersion, 1); + ASSERT_EQ(gotUpgrade, 1); +} + +} // namespace wpi diff --git a/wpiutil/src/test/native/cpp/WebSocketTest.h b/wpiutil/src/test/native/cpp/WebSocketTest.h new file mode 100644 index 0000000000..8b40440ec5 --- /dev/null +++ b/wpiutil/src/test/native/cpp/WebSocketTest.h @@ -0,0 +1,73 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2018 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#pragma once + +#include +#include +#include + +#include "gtest/gtest.h" +#include "wpi/ArrayRef.h" +#include "wpi/uv/Loop.h" +#include "wpi/uv/Pipe.h" +#include "wpi/uv/Timer.h" + +namespace wpi { + +class WebSocketTest : public ::testing::Test { + public: + static const char* pipeName; + + static void SetUpTestCase(); + + WebSocketTest() { + loop = uv::Loop::Create(); + clientPipe = uv::Pipe::Create(loop); + serverPipe = uv::Pipe::Create(loop); + + serverPipe->Bind(pipeName); + +#if 0 + auto debugTimer = uv::Timer::Create(loop); + debugTimer->timeout.connect([this] { + std::printf("Active handles:\n"); + uv_print_active_handles(loop->GetRaw(), stdout); + }); + debugTimer->Start(uv::Timer::Time{100}, uv::Timer::Time{100}); + debugTimer->Unreference(); +#endif + + auto failTimer = uv::Timer::Create(loop); + failTimer->timeout.connect([this] { + loop->Stop(); + FAIL() << "loop failed to terminate"; + }); + failTimer->Start(uv::Timer::Time{1000}); + failTimer->Unreference(); + } + + ~WebSocketTest() { Finish(); } + + void Finish() { + loop->Walk([](uv::Handle& it) { it.Close(); }); + } + + static std::vector BuildHeader(uint8_t opcode, bool fin, + bool masking, uint64_t len); + static std::vector BuildMessage(uint8_t opcode, bool fin, + bool masking, + ArrayRef data); + static void AdjustMasking(MutableArrayRef message); + static const uint8_t testMask[4]; + + std::shared_ptr loop; + std::shared_ptr clientPipe; + std::shared_ptr serverPipe; +}; + +} // namespace wpi