wpiutil: Add WebSocket implementation (#1186)

This is a RFC 6455 compliant implementation with both client and server support.
This commit is contained in:
Peter Johnson
2018-08-24 20:54:23 -07:00
committed by GitHub
parent d6d5321828
commit c8482cd6d2
9 changed files with 2833 additions and 0 deletions

View File

@@ -0,0 +1,565 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
#include "wpi/WebSocket.h"
#include <random>
#include "wpi/Base64.h"
#include "wpi/HttpParser.h"
#include "wpi/SmallString.h"
#include "wpi/SmallVector.h"
#include "wpi/raw_uv_ostream.h"
#include "wpi/sha1.h"
#include "wpi/uv/Stream.h"
using namespace wpi;
namespace {
class WebSocketWriteReq : public uv::WriteReq {
public:
explicit WebSocketWriteReq(
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
finish.connect([=](uv::Error err) {
MutableArrayRef<uv::Buffer> bufs{m_bufs};
for (auto&& buf : bufs.slice(0, m_startUser)) buf.Deallocate();
callback(bufs.slice(m_startUser), err);
});
}
SmallVector<uv::Buffer, 4> m_bufs;
size_t m_startUser;
};
} // namespace
class WebSocket::ClientHandshakeData {
public:
ClientHandshakeData() {
// key is a random nonce
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
char nonce[16]; // the nonce sent to the server
for (char& v : nonce) v = static_cast<char>(dist(gen));
raw_svector_ostream os(key);
Base64Encode(os, StringRef{nonce, 16});
}
~ClientHandshakeData() {
if (auto t = timer.lock()) {
t->Stop();
t->Close();
}
}
SmallString<64> key; // the key sent to the server
SmallVector<std::string, 2> protocols; // valid protocols
HttpParser parser{HttpParser::kResponse}; // server response parser
bool hasUpgrade = false;
bool hasConnection = false;
bool hasAccept = false;
bool hasProtocol = false;
std::weak_ptr<uv::Timer> timer;
};
static StringRef AcceptHash(StringRef key, SmallVectorImpl<char>& buf) {
SHA1 hash;
hash.Update(key);
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
SmallString<64> hashBuf;
return Base64Encode(hash.Final(hashBuf), buf);
}
WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
: m_stream{stream}, m_server{server} {
// Connect closed and error signals to ourselves
m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
m_stream.error.connect([this](uv::Error err) {
Terminate(1006, "stream error: " + Twine(err.name()));
});
// Start reading
m_stream.StopRead(); // we may have been reading
m_stream.StartRead();
m_stream.data.connect(
[this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
m_stream.end.connect(
[this]() { Terminate(1006, "remote end closed connection"); });
}
WebSocket::~WebSocket() {}
std::shared_ptr<WebSocket> WebSocket::CreateClient(
uv::Stream& stream, const Twine& uri, const Twine& host,
ArrayRef<StringRef> protocols, const ClientOptions& options) {
auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
stream.SetData(ws);
ws->StartClient(uri, host, protocols, options);
return ws;
}
std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
StringRef key,
StringRef version,
StringRef protocol) {
auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
stream.SetData(ws);
ws->StartServer(key, version, protocol);
return ws;
}
void WebSocket::Close(uint16_t code, const Twine& reason) {
SendClose(code, reason);
if (m_state != FAILED && m_state != CLOSED) m_state = CLOSING;
}
void WebSocket::Fail(uint16_t code, const Twine& reason) {
if (m_state == FAILED || m_state == CLOSED) return;
SendClose(code, reason);
SetClosed(code, reason, true);
Shutdown();
}
void WebSocket::Terminate(uint16_t code, const Twine& reason) {
if (m_state == FAILED || m_state == CLOSED) return;
SetClosed(code, reason);
Shutdown();
}
void WebSocket::StartClient(const Twine& uri, const Twine& host,
ArrayRef<StringRef> protocols,
const ClientOptions& options) {
// Create client handshake data
m_clientHandshake = std::make_unique<ClientHandshakeData>();
// Build client request
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
os << "GET " << uri << " HTTP/1.1\r\n";
os << "Host: " << host << "\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
os << "Sec-WebSocket-Version: 13\r\n";
// protocols (if provided)
if (!protocols.empty()) {
os << "Sec-WebSocket-Protocol: ";
bool first = true;
for (auto protocol : protocols) {
if (!first)
os << ", ";
else
first = false;
os << protocol;
// also save for later checking against server response
m_clientHandshake->protocols.emplace_back(protocol);
}
os << "\r\n";
}
// other headers
for (auto&& header : options.extraHeaders)
os << header.first << ": " << header.second << "\r\n";
// finish headers
os << "\r\n";
// Send client request
m_stream.Write(bufs, [](auto bufs, uv::Error) {
for (auto& buf : bufs) buf.Deallocate();
});
// Set up client response handling
m_clientHandshake->parser.status.connect([this](StringRef status) {
unsigned int code = m_clientHandshake->parser.GetStatusCode();
if (code != 101) Terminate(code, status);
});
m_clientHandshake->parser.header.connect(
[this](StringRef name, StringRef value) {
value = value.trim();
if (name.equals_lower("upgrade")) {
if (!value.equals_lower("websocket"))
return Terminate(1002, "invalid upgrade response value");
m_clientHandshake->hasUpgrade = true;
} else if (name.equals_lower("connection")) {
if (!value.equals_lower("upgrade"))
return Terminate(1002, "invalid connection response value");
m_clientHandshake->hasConnection = true;
} else if (name.equals_lower("sec-websocket-accept")) {
// Check against expected response
SmallString<64> acceptBuf;
if (!value.equals(AcceptHash(m_clientHandshake->key, acceptBuf)))
return Terminate(1002, "invalid accept key");
m_clientHandshake->hasAccept = true;
} else if (name.equals_lower("sec-websocket-extensions")) {
// No extensions are supported
if (!value.empty()) return Terminate(1010, "unsupported extension");
} else if (name.equals_lower("sec-websocket-protocol")) {
// Make sure it was one of the provided protocols
bool match = false;
for (auto&& protocol : m_clientHandshake->protocols) {
if (value.equals_lower(protocol)) {
match = true;
break;
}
}
if (!match) return Terminate(1003, "unsupported protocol");
m_clientHandshake->hasProtocol = true;
m_protocol = value;
}
});
m_clientHandshake->parser.headersComplete.connect([this](bool) {
if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
!m_clientHandshake->hasAccept ||
(!m_clientHandshake->hasProtocol &&
!m_clientHandshake->protocols.empty())) {
return Terminate(1002, "invalid response");
}
if (m_state == CONNECTING) {
m_state = OPEN;
open(m_protocol);
}
});
// Start handshake timer if a timeout was specified
if (options.handshakeTimeout != uv::Timer::Time::max()) {
auto timer = uv::Timer::Create(m_stream.GetLoopRef());
timer->timeout.connect(
[this]() { Terminate(1006, "connection timed out"); });
timer->Start(options.handshakeTimeout);
m_clientHandshake->timer = timer;
}
}
void WebSocket::StartServer(StringRef key, StringRef version,
StringRef protocol) {
m_protocol = protocol;
// Build server response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
// Handle unsupported version
if (version != "13") {
os << "HTTP/1.1 426 Upgrade Required\r\n";
os << "Upgrade: WebSocket\r\n";
os << "Sec-WebSocket-Version: 13\r\n\r\n";
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
for (auto& buf : bufs) buf.Deallocate();
// XXX: Should we support sending a new handshake on the same connection?
// XXX: "this->" is required by GCC 5.5 (bug)
this->Terminate(1003, "unsupported protocol version");
});
return;
}
os << "HTTP/1.1 101 Switching Protocols\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
// accept hash
SmallString<64> acceptBuf;
os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
if (!protocol.empty()) os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
// end headers
os << "\r\n";
// Send server response
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
for (auto& buf : bufs) buf.Deallocate();
if (m_state == CONNECTING) {
m_state = OPEN;
open(m_protocol);
}
});
}
void WebSocket::SendClose(uint16_t code, const Twine& reason) {
SmallVector<uv::Buffer, 4> bufs;
if (code != 1005) {
raw_uv_ostream os{bufs, 4096};
os << ArrayRef<uint8_t>{static_cast<uint8_t>((code >> 8) & 0xff),
static_cast<uint8_t>(code & 0xff)};
reason.print(os);
}
Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
for (auto&& buf : bufs) buf.Deallocate();
});
}
void WebSocket::SetClosed(uint16_t code, const Twine& reason, bool failed) {
if (m_state == FAILED || m_state == CLOSED) return;
m_state = failed ? FAILED : CLOSED;
SmallString<64> reasonBuf;
closed(code, reason.toStringRef(reasonBuf));
}
void WebSocket::Shutdown() {
m_stream.Shutdown([this] { m_stream.Close(); });
}
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
// ignore incoming data if we're failed or closed
if (m_state == FAILED || m_state == CLOSED) return;
StringRef data{buf.base, size};
// Handle connecting state (mainly on client)
if (m_state == CONNECTING) {
if (m_clientHandshake) {
data = m_clientHandshake->parser.Execute(data);
// check for parser failure
if (m_clientHandshake->parser.HasError())
return Terminate(1003, "invalid response");
if (m_state != OPEN) return; // not done with handshake yet
// we're done with the handshake, so release its memory
m_clientHandshake.reset();
// fall through to process additional data after handshake
} else {
return Terminate(1003, "got data on server before response");
}
}
// Message processing
while (!data.empty()) {
if (m_frameSize == UINT64_MAX) {
// Need at least two bytes to determine header length
if (m_header.size() < 2u) {
size_t toCopy = std::min(2u - m_header.size(), data.size());
m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
data = data.drop_front(toCopy);
if (m_header.size() < 2u) return; // need more data
// Validate RSV bits are zero
if ((m_header[0] & 0x70) != 0) return Fail(1002, "nonzero RSV");
}
// Once we have first two bytes, we can calculate the header size
if (m_headerSize == 0) {
m_headerSize = 2;
uint8_t len = m_header[1] & kLenMask;
if (len == 126)
m_headerSize += 2;
else if (len == 127)
m_headerSize += 8;
bool masking = (m_header[1] & kFlagMasking) != 0;
if (masking) m_headerSize += 4; // masking key
// On server side, incoming messages MUST be masked
// On client side, incoming messages MUST NOT be masked
if (m_server && !masking) return Fail(1002, "client data not masked");
if (!m_server && masking) return Fail(1002, "server data masked");
}
// Need to complete header to calculate message size
if (m_header.size() < m_headerSize) {
size_t toCopy = std::min(m_headerSize - m_header.size(), data.size());
m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
data = data.drop_front(toCopy);
if (m_header.size() < m_headerSize) return; // need more data
}
if (m_header.size() >= m_headerSize) {
// get payload length
uint8_t len = m_header[1] & kLenMask;
if (len == 126)
m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
static_cast<uint16_t>(m_header[3]);
else if (len == 127)
m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
(static_cast<uint64_t>(m_header[3]) << 48) |
(static_cast<uint64_t>(m_header[4]) << 40) |
(static_cast<uint64_t>(m_header[5]) << 32) |
(static_cast<uint64_t>(m_header[6]) << 24) |
(static_cast<uint64_t>(m_header[7]) << 16) |
(static_cast<uint64_t>(m_header[8]) << 8) |
static_cast<uint64_t>(m_header[9]);
else
m_frameSize = len;
// limit maximum size
if ((m_payload.size() + m_frameSize) > m_maxMessageSize)
return Fail(1009, "message too large");
}
}
if (m_frameSize != UINT64_MAX) {
size_t need = m_frameStart + m_frameSize - m_payload.size();
size_t toCopy = std::min(need, data.size());
m_payload.append(data.bytes_begin(), data.bytes_begin() + toCopy);
data = data.drop_front(toCopy);
need -= toCopy;
if (need == 0) {
// We have a complete frame
// If the message had masking, unmask it
if ((m_header[1] & kFlagMasking) != 0) {
uint8_t key[4] = {
m_header[m_headerSize - 4], m_header[m_headerSize - 3],
m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
int n = 0;
for (uint8_t& ch :
MutableArrayRef<uint8_t>{m_payload}.slice(m_frameStart)) {
ch ^= key[n++];
if (n >= 4) n = 0;
}
}
// Handle message
bool fin = (m_header[0] & kFlagFin) != 0;
uint8_t opcode = m_header[0] & kOpMask;
switch (opcode) {
case kOpCont:
switch (m_fragmentOpcode) {
case kOpText:
if (!m_combineFragments || fin)
text(StringRef{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()},
fin);
break;
case kOpBinary:
if (!m_combineFragments || fin) binary(m_payload, fin);
break;
default:
// no preceding message?
return Fail(1002, "invalid continuation message");
}
if (fin) m_fragmentOpcode = 0;
break;
case kOpText:
if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
if (!m_combineFragments || fin)
text(StringRef{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()},
fin);
if (!fin) m_fragmentOpcode = opcode;
break;
case kOpBinary:
if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
if (!m_combineFragments || fin) binary(m_payload, fin);
if (!fin) m_fragmentOpcode = opcode;
break;
case kOpClose: {
uint16_t code;
StringRef reason;
if (!fin) {
code = 1002;
reason = "cannot fragment control frames";
} else if (m_payload.size() < 2) {
code = 1005;
} else {
code = (static_cast<uint16_t>(m_payload[0]) << 8) |
static_cast<uint16_t>(m_payload[1]);
reason = StringRef{reinterpret_cast<char*>(m_payload.data()),
m_payload.size()}
.drop_front(2);
}
// Echo the close if we didn't previously send it
if (m_state != CLOSING) SendClose(code, reason);
SetClosed(code, reason);
// If we're the server, shutdown the connection.
if (m_server) Shutdown();
break;
}
case kOpPing:
if (!fin) return Fail(1002, "cannot fragment control frames");
ping(m_payload);
break;
case kOpPong:
if (!fin) return Fail(1002, "cannot fragment control frames");
pong(m_payload);
break;
default:
return Fail(1002, "invalid message opcode");
}
// Prepare for next message
m_header.clear();
m_headerSize = 0;
if (!m_combineFragments || fin) m_payload.clear();
m_frameStart = m_payload.size();
m_frameSize = UINT64_MAX;
}
}
}
}
void WebSocket::Send(
uint8_t opcode, ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
// If we're not open, emit an error and don't send the data
if (m_state != OPEN) {
int err;
if (m_state == CONNECTING)
err = UV_EAGAIN;
else
err = UV_ESHUTDOWN;
SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
callback(bufs, uv::Error{err});
return;
}
auto req = std::make_shared<WebSocketWriteReq>(callback);
raw_uv_ostream os{req->m_bufs, 4096};
// opcode (includes FIN bit)
os << static_cast<unsigned char>(opcode);
// payload length
uint64_t size = 0;
for (auto&& buf : data) size += buf.len;
if (size < 126) {
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
} else if (size <= 0xffff) {
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 8) & 0xff),
static_cast<uint8_t>(size & 0xff)};
} else {
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 56) & 0xff),
static_cast<uint8_t>((size >> 48) & 0xff),
static_cast<uint8_t>((size >> 40) & 0xff),
static_cast<uint8_t>((size >> 32) & 0xff),
static_cast<uint8_t>((size >> 24) & 0xff),
static_cast<uint8_t>((size >> 16) & 0xff),
static_cast<uint8_t>((size >> 8) & 0xff),
static_cast<uint8_t>(size & 0xff)};
}
// clients need to mask the input data
if (!m_server) {
// generate masking key
static std::random_device rd;
static std::default_random_engine gen{rd()};
std::uniform_int_distribution<unsigned int> dist(0, 255);
uint8_t key[4];
for (uint8_t& v : key) v = dist(gen);
os << ArrayRef<uint8_t>{key, 4};
// copy and mask data
int n = 0;
for (auto&& buf : data) {
for (auto&& ch : buf.data()) {
os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
if (n >= 4) n = 0;
}
}
req->m_startUser = req->m_bufs.size();
req->m_bufs.append(data.begin(), data.end());
// don't send the user bufs as we copied their data
m_stream.Write(ArrayRef<uv::Buffer>{req->m_bufs}.slice(0, req->m_startUser),
req);
} else {
// servers can just send the buffers directly without masking
req->m_startUser = req->m_bufs.size();
req->m_bufs.append(data.begin(), data.end());
m_stream.Write(req->m_bufs, req);
}
}

View File

@@ -0,0 +1,142 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
#include "wpi/WebSocketServer.h"
#include "wpi/raw_uv_ostream.h"
#include "wpi/uv/Buffer.h"
#include "wpi/uv/Stream.h"
using namespace wpi;
WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) {
req.header.connect([this](StringRef name, StringRef value) {
if (name.equals_lower("host")) {
m_gotHost = true;
} else if (name.equals_lower("upgrade")) {
if (value.equals_lower("websocket")) m_websocket = true;
} else if (name.equals_lower("sec-websocket-key")) {
m_key = value;
} else if (name.equals_lower("sec-websocket-version")) {
m_version = value;
} else if (name.equals_lower("sec-websocket-protocol")) {
// Protocols are comma delimited, repeated headers add to list
SmallVector<StringRef, 2> protocols;
value.split(protocols, ",", -1, false);
for (auto protocol : protocols) {
protocol = protocol.trim();
if (!protocol.empty()) m_protocols.emplace_back(protocol);
}
}
});
req.headersComplete.connect([&req, this](bool) {
if (req.IsUpgrade() && IsUpgrade()) upgrade();
});
}
std::pair<bool, StringRef> WebSocketServerHelper::MatchProtocol(
ArrayRef<StringRef> protocols) {
if (protocols.empty() && m_protocols.empty())
return std::make_pair(true, StringRef{});
for (auto protocol : protocols) {
for (auto&& clientProto : m_protocols) {
if (protocol == clientProto) return std::make_pair(true, protocol);
}
}
return std::make_pair(false, StringRef{});
}
WebSocketServer::WebSocketServer(uv::Stream& stream,
ArrayRef<StringRef> protocols,
const ServerOptions& options,
const private_init&)
: m_stream{stream},
m_helper{m_req},
m_protocols{protocols.begin(), protocols.end()},
m_options{options} {
// Header handling
m_req.header.connect([this](StringRef name, StringRef value) {
if (name.equals_lower("host")) {
if (m_options.checkHost) {
if (!m_options.checkHost(value)) Abort(401, "Unrecognized Host");
}
}
});
m_req.url.connect([this](StringRef name) {
if (m_options.checkUrl) {
if (!m_options.checkUrl(name)) Abort(404, "Not Found");
}
});
m_req.headersComplete.connect([this](bool) {
// We only accept websocket connections
if (!m_helper.IsUpgrade() || !m_req.IsUpgrade())
Abort(426, "Upgrade Required");
});
// Handle upgrade event
m_helper.upgrade.connect([this] {
if (m_aborted) return;
// Negotiate sub-protocol
SmallVector<StringRef, 2> protocols{m_protocols.begin(), m_protocols.end()};
StringRef protocol = m_helper.MatchProtocol(protocols).second;
// Disconnect our header reader
m_headerConn.disconnect();
// Accepting the stream may destroy this (as it replaces the stream user
// data), so grab a shared pointer first.
auto self = shared_from_this();
// Accept the upgrade
auto ws = m_helper.Accept(m_stream, protocol);
// Connect the websocket open event to our connected event.
ws->open.connect_extended([ self, s = ws.get() ](auto conn, StringRef) {
self->connected(self->m_req.GetUrl(), *s);
conn.disconnect(); // one-shot
});
});
// Set up stream
stream.StartRead();
m_headerConn =
stream.data.connect_connection([this](uv::Buffer& buf, size_t size) {
if (m_aborted) return;
m_req.Execute(StringRef{buf.base, size});
if (m_req.HasError()) Abort(400, "Bad Request");
});
stream.error.connect([this](uv::Error) { m_stream.Close(); });
stream.end.connect([this] { m_stream.Close(); });
}
std::shared_ptr<WebSocketServer> WebSocketServer::Create(
uv::Stream& stream, ArrayRef<StringRef> protocols,
const ServerOptions& options) {
auto server = std::make_shared<WebSocketServer>(stream, protocols, options,
private_init{});
stream.SetData(server);
return server;
}
void WebSocketServer::Abort(uint16_t code, StringRef reason) {
if (m_aborted) return;
m_aborted = true;
// Build response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 1024};
// Handle unsupported version
os << "HTTP/1.1 " << code << ' ' << reason << "\r\n";
if (code == 426) os << "Upgrade: WebSocket\r\n";
os << "\r\n";
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
for (auto& buf : bufs) buf.Deallocate();
m_stream.Shutdown([this] { m_stream.Close(); });
});
}

View File

@@ -0,0 +1,378 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
#ifndef WPIUTIL_WPI_WEBSOCKET_H_
#define WPIUTIL_WPI_WEBSOCKET_H_
#include <stdint.h>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include "wpi/ArrayRef.h"
#include "wpi/Signal.h"
#include "wpi/SmallVector.h"
#include "wpi/StringRef.h"
#include "wpi/Twine.h"
#include "wpi/uv/Buffer.h"
#include "wpi/uv/Error.h"
#include "wpi/uv/Timer.h"
namespace wpi {
namespace uv {
class Stream;
} // namespace uv
/**
* RFC 6455 compliant WebSocket client and server implementation.
*/
class WebSocket : public std::enable_shared_from_this<WebSocket> {
struct private_init {};
static constexpr uint8_t kOpCont = 0x00;
static constexpr uint8_t kOpText = 0x01;
static constexpr uint8_t kOpBinary = 0x02;
static constexpr uint8_t kOpClose = 0x08;
static constexpr uint8_t kOpPing = 0x09;
static constexpr uint8_t kOpPong = 0x0A;
static constexpr uint8_t kOpMask = 0x0F;
static constexpr uint8_t kFlagFin = 0x80;
static constexpr uint8_t kFlagMasking = 0x80;
static constexpr uint8_t kLenMask = 0x7f;
public:
WebSocket(uv::Stream& stream, bool server, const private_init&);
WebSocket(const WebSocket&) = delete;
WebSocket(WebSocket&&) = delete;
WebSocket& operator=(const WebSocket&) = delete;
WebSocket& operator=(WebSocket&&) = delete;
~WebSocket();
/**
* Connection states.
*/
enum State {
/** The connection is not yet open. */
CONNECTING = 0,
/** The connection is open and ready to communicate. */
OPEN,
/** The connection is in the process of closing. */
CLOSING,
/** The connection failed. */
FAILED,
/** The connection is closed. */
CLOSED
};
/**
* Client connection options.
*/
struct ClientOptions {
ClientOptions() : handshakeTimeout{uv::Timer::Time::max()} {}
/** Timeout for the handshake request. */
uv::Timer::Time handshakeTimeout;
/** Additional headers to include in handshake. */
ArrayRef<std::pair<StringRef, StringRef>> extraHeaders;
};
/**
* Starts a client connection by performing the initial client handshake.
* An open event is emitted when the handshake completes.
* This sets the stream user data to the websocket.
* @param stream Connection stream
* @param uri The Request-URI to send
* @param host The host or host:port to send
* @param protocols The list of subprotocols
* @param options Handshake options
*/
static std::shared_ptr<WebSocket> CreateClient(
uv::Stream& stream, const Twine& uri, const Twine& host,
ArrayRef<StringRef> protocols = ArrayRef<StringRef>{},
const ClientOptions& options = ClientOptions{});
/**
* Starts a server connection by performing the initial server side handshake.
* This should be called after the HTTP headers have been received.
* An open event is emitted when the handshake completes.
* This sets the stream user data to the websocket.
* @param stream Connection stream
* @param key The value of the Sec-WebSocket-Key header field in the client
* request
* @param version The value of the Sec-WebSocket-Version header field in the
* client request
* @param protocol The subprotocol to send to the client (in the
* Sec-WebSocket-Protocol header field).
*/
static std::shared_ptr<WebSocket> CreateServer(
uv::Stream& stream, StringRef key, StringRef version,
StringRef protocol = StringRef{});
/**
* Get connection state.
*/
State GetState() const { return m_state; }
/**
* Return if the connection is open. Messages can only be sent on open
* connections.
*/
bool IsOpen() const { return m_state == OPEN; }
/**
* Get the underlying stream.
*/
uv::Stream& GetStream() const { return m_stream; }
/**
* Get the selected sub-protocol. Only valid in or after the open() event.
*/
StringRef GetProtocol() const { return m_protocol; }
/**
* Set the maximum message size. Default is 128 KB. If configured to combine
* fragments this maximum applies to the entire message (all combined
* fragments).
* @param size Maximum message size in bytes
*/
void SetMaxMessageSize(size_t size) { m_maxMessageSize = size; }
/**
* Set whether or not fragmented frames should be combined. Default is to
* combine. If fragmented frames are combined, the text and binary callbacks
* will always have the second parameter (fin) set to true.
* @param combine True if fragmented frames should be combined.
*/
void SetCombineFragments(bool combine) { m_combineFragments = combine; }
/**
* Initiate a closing handshake.
* @param code A numeric status code (defaults to 1005, no status code)
* @param reason A human-readable string explaining why the connection is
* closing (optional).
*/
void Close(uint16_t code = 1005, const Twine& reason = Twine{});
/**
* Send a text message.
* @param data UTF-8 encoded data to send
* @param callback Callback which is invoked when the write completes.
*/
void SendText(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kFlagFin | kOpText, data, callback);
}
/**
* Send a binary message.
* @param data Data to send
* @param callback Callback which is invoked when the write completes.
*/
void SendBinary(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kFlagFin | kOpBinary, data, callback);
}
/**
* Send a text message fragment. This must be followed by one or more
* SendFragment() calls, where the last one has fin=True, to complete the
* message.
* @param data UTF-8 encoded data to send
* @param callback Callback which is invoked when the write completes.
*/
void SendTextFragment(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kOpText, data, callback);
}
/**
* Send a text message fragment. This must be followed by one or more
* SendFragment() calls, where the last one has fin=True, to complete the
* message.
* @param data Data to send
* @param callback Callback which is invoked when the write completes.
*/
void SendBinaryFragment(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kOpBinary, data, callback);
}
/**
* Send a continuation frame. This is used to send additional parts of a
* message started with SendTextFragment() or SendBinaryFragment().
* @param data Data to send
* @param fin Set to true if this is the final fragment of the message
* @param callback Callback which is invoked when the write completes.
*/
void SendFragment(
ArrayRef<uv::Buffer> data, bool fin,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kOpCont | (fin ? kFlagFin : 0), data, callback);
}
/**
* Send a ping frame with no data.
* @param callback Optional callback which is invoked when the ping frame
* write completes.
*/
void SendPing(std::function<void(uv::Error)> callback = nullptr) {
SendPing(ArrayRef<uv::Buffer>{}, [callback](auto bufs, uv::Error err) {
if (callback) callback(err);
});
}
/**
* Send a ping frame.
* @param data Data to send in the ping frame
* @param callback Callback which is invoked when the ping frame
* write completes.
*/
void SendPing(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kFlagFin | kOpPing, data, callback);
}
/**
* Send a pong frame with no data.
* @param callback Optional callback which is invoked when the pong frame
* write completes.
*/
void SendPong(std::function<void(uv::Error)> callback = nullptr) {
SendPong(ArrayRef<uv::Buffer>{}, [callback](auto bufs, uv::Error err) {
if (callback) callback(err);
});
}
/**
* Send a pong frame.
* @param data Data to send in the pong frame
* @param callback Callback which is invoked when the pong frame
* write completes.
*/
void SendPong(
ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
Send(kFlagFin | kOpPong, data, callback);
}
/**
* Fail the connection.
*/
void Fail(uint16_t code = 1002, const Twine& reason = "protocol error");
/**
* Forcibly close the connection.
*/
void Terminate(uint16_t code = 1006, const Twine& reason = "terminated");
/**
* Gets user-defined data.
* @return User-defined data if any, nullptr otherwise.
*/
template <typename T = void>
std::shared_ptr<T> GetData() const {
return std::static_pointer_cast<T>(m_data);
}
/**
* Sets user-defined data.
* @param data User-defined arbitrary data.
*/
void SetData(std::shared_ptr<void> data) { m_data = std::move(data); }
/**
* Open event. Emitted when the connection is open and ready to communicate.
* The parameter is the selected subprotocol.
*/
sig::Signal<StringRef> open;
/**
* Close event. Emitted when the connection is closed. The first parameter
* is a numeric value indicating the status code explaining why the connection
* has been closed. The second parameter is a human-readable string
* explaining the reason why the connection has been closed.
*/
sig::Signal<uint16_t, StringRef> closed;
/**
* Text message event. Emitted when a text message is received.
* The first parameter is the data, the second parameter is true if the
* data is the last fragment of the message.
*/
sig::Signal<StringRef, bool> text;
/**
* Binary message event. Emitted when a binary message is received.
* The first parameter is the data, the second parameter is true if the
* data is the last fragment of the message.
*/
sig::Signal<ArrayRef<uint8_t>, bool> binary;
/**
* Ping event. Emitted when a ping message is received.
*/
sig::Signal<ArrayRef<uint8_t>> ping;
/**
* Pong event. Emitted when a pong message is received.
*/
sig::Signal<ArrayRef<uint8_t>> pong;
private:
// user data
std::shared_ptr<void> m_data;
// constructor parameters
uv::Stream& m_stream;
bool m_server;
// subprotocol, set via constructor (server) or handshake (client)
std::string m_protocol;
// user-settable configuration
size_t m_maxMessageSize = 128 * 1024;
bool m_combineFragments = true;
// operating state
State m_state = CONNECTING;
// incoming message buffers/state
SmallVector<uint8_t, 14> m_header;
size_t m_headerSize = 0;
SmallVector<uint8_t, 1024> m_payload;
size_t m_frameStart = 0;
uint64_t m_frameSize = UINT64_MAX;
uint8_t m_fragmentOpcode = 0;
// temporary data used only during client handshake
class ClientHandshakeData;
std::unique_ptr<ClientHandshakeData> m_clientHandshake;
void StartClient(const Twine& uri, const Twine& host,
ArrayRef<StringRef> protocols, const ClientOptions& options);
void StartServer(StringRef key, StringRef version, StringRef protocol);
void SendClose(uint16_t code, const Twine& reason);
void SetClosed(uint16_t code, const Twine& reason, bool failed = false);
void Shutdown();
void HandleIncoming(uv::Buffer& buf, size_t size);
void Send(
uint8_t opcode, ArrayRef<uv::Buffer> data,
std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback);
};
} // namespace wpi
#endif // WPIUTIL_WPI_WEBSOCKET_H_

View File

@@ -0,0 +1,147 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2018 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
#ifndef WPIUTIL_WPI_WEBSOCKETSERVER_H_
#define WPIUTIL_WPI_WEBSOCKETSERVER_H_
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include "wpi/ArrayRef.h"
#include "wpi/HttpParser.h"
#include "wpi/Signal.h"
#include "wpi/SmallString.h"
#include "wpi/SmallVector.h"
#include "wpi/StringRef.h"
#include "wpi/WebSocket.h"
namespace wpi {
namespace uv {
class Stream;
} // namespace uv
/**
* WebSocket HTTP server helper. Handles websocket-specific headers. User
* must provide the HttpParser.
*/
class WebSocketServerHelper {
public:
/**
* Constructor.
* @param req HttpParser for request
*/
explicit WebSocketServerHelper(HttpParser& req);
/**
* Get whether or not this was a websocket upgrade.
* Only valid during and after the upgrade event.
*/
bool IsWebsocket() const { return m_websocket; }
/**
* Try to find a match to the list of sub-protocols provided by the client.
* The list is priority ordered, so the first match wins.
* Only valid during and after the upgrade event.
* @param protocols Acceptable protocols
* @return Pair; first item is true if a match was made, false if not.
* Second item is the matched protocol if a match was made, otherwise
* is empty.
*/
std::pair<bool, StringRef> MatchProtocol(ArrayRef<StringRef> protocols);
/**
* Accept the upgrade. Disconnect other readers (such as the HttpParser
* reader) before calling this. See also WebSocket::CreateServer().
* @param stream Connection stream
* @param protocol The subprotocol to send to the client
*/
std::shared_ptr<WebSocket> Accept(uv::Stream& stream,
StringRef protocol = StringRef{}) {
return WebSocket::CreateServer(stream, m_key, m_version, protocol);
}
bool IsUpgrade() const { return m_gotHost && m_websocket; }
/**
* Upgrade event. Call Accept() to accept the upgrade.
*/
sig::Signal<> upgrade;
private:
bool m_gotHost = false;
bool m_websocket = false;
SmallVector<std::string, 2> m_protocols;
SmallString<64> m_key;
SmallString<16> m_version;
};
/**
* Dedicated WebSocket server.
*/
class WebSocketServer : public std::enable_shared_from_this<WebSocketServer> {
struct private_init {};
public:
/**
* Server options.
*/
struct ServerOptions {
/**
* Checker for URL. Return true if URL should be accepted. By default all
* URLs are accepted.
*/
std::function<bool(StringRef)> checkUrl;
/**
* Checker for Host header. Return true if Host should be accepted. By
* default all hosts are accepted.
*/
std::function<bool(StringRef)> checkHost;
};
/**
* Private constructor.
*/
WebSocketServer(uv::Stream& stream, ArrayRef<StringRef> protocols,
const ServerOptions& options, const private_init&);
/**
* Starts a dedicated WebSocket server on the provided connection. The
* connection should be an accepted client stream.
* This also sets the stream user data to the socket server.
* A connected event is emitted when the connection is opened.
* @param stream Connection stream
* @param protocols Acceptable subprotocols
* @param options Handshake options
*/
static std::shared_ptr<WebSocketServer> Create(
uv::Stream& stream, ArrayRef<StringRef> protocols = ArrayRef<StringRef>{},
const ServerOptions& options = ServerOptions{});
/**
* Connected event. First parameter is the URL, second is the websocket.
*/
sig::Signal<StringRef, WebSocket&> connected;
private:
uv::Stream& m_stream;
HttpParser m_req{HttpParser::kRequest};
WebSocketServerHelper m_helper;
SmallVector<std::string, 2> m_protocols;
ServerOptions m_options;
bool m_aborted = false;
sig::Connection m_headerConn;
void Abort(uint16_t code, StringRef reason);
};
} // namespace wpi
#endif // WPIUTIL_WPI_WEBSOCKETSERVER_H_