2020-12-26 14:12:05 -08:00
|
|
|
// Copyright (c) FIRST and other WPILib contributors.
|
|
|
|
|
// Open Source Software; you can modify and/or share it under the terms of
|
|
|
|
|
// the WPILib BSD license file in the root directory of this project.
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
#include "wpi/WebSocket.h"
|
|
|
|
|
|
|
|
|
|
#include <random>
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
#include "fmt/format.h"
|
2018-08-24 20:54:23 -07:00
|
|
|
#include "wpi/Base64.h"
|
|
|
|
|
#include "wpi/HttpParser.h"
|
|
|
|
|
#include "wpi/SmallString.h"
|
|
|
|
|
#include "wpi/SmallVector.h"
|
2021-06-06 16:13:58 -07:00
|
|
|
#include "wpi/StringExtras.h"
|
2018-08-24 20:54:23 -07:00
|
|
|
#include "wpi/raw_uv_ostream.h"
|
|
|
|
|
#include "wpi/sha1.h"
|
|
|
|
|
#include "wpi/uv/Stream.h"
|
|
|
|
|
|
|
|
|
|
using namespace wpi;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class WebSocketWriteReq : public uv::WriteReq {
|
|
|
|
|
public:
|
|
|
|
|
explicit WebSocketWriteReq(
|
2021-06-06 19:51:14 -07:00
|
|
|
std::function<void(span<uv::Buffer>, uv::Error)> callback) {
|
2018-08-24 20:54:23 -07:00
|
|
|
finish.connect([=](uv::Error err) {
|
2021-06-06 19:51:14 -07:00
|
|
|
span<uv::Buffer> bufs{m_bufs};
|
|
|
|
|
for (auto&& buf : bufs.subspan(0, m_startUser)) {
|
2020-12-28 12:58:06 -08:00
|
|
|
buf.Deallocate();
|
|
|
|
|
}
|
2021-06-06 19:51:14 -07:00
|
|
|
callback(bufs.subspan(m_startUser), err);
|
2018-08-24 20:54:23 -07:00
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<uv::Buffer, 4> m_bufs;
|
|
|
|
|
size_t m_startUser;
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
class WebSocket::ClientHandshakeData {
|
|
|
|
|
public:
|
|
|
|
|
ClientHandshakeData() {
|
|
|
|
|
// key is a random nonce
|
|
|
|
|
static std::random_device rd;
|
|
|
|
|
static std::default_random_engine gen{rd()};
|
|
|
|
|
std::uniform_int_distribution<unsigned int> dist(0, 255);
|
|
|
|
|
char nonce[16]; // the nonce sent to the server
|
2020-12-28 12:58:06 -08:00
|
|
|
for (char& v : nonce) {
|
|
|
|
|
v = static_cast<char>(dist(gen));
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
raw_svector_ostream os(key);
|
2021-06-06 16:13:58 -07:00
|
|
|
Base64Encode(os, {nonce, 16});
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
~ClientHandshakeData() {
|
|
|
|
|
if (auto t = timer.lock()) {
|
|
|
|
|
t->Stop();
|
|
|
|
|
t->Close();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallString<64> key; // the key sent to the server
|
|
|
|
|
SmallVector<std::string, 2> protocols; // valid protocols
|
|
|
|
|
HttpParser parser{HttpParser::kResponse}; // server response parser
|
|
|
|
|
bool hasUpgrade = false;
|
|
|
|
|
bool hasConnection = false;
|
|
|
|
|
bool hasAccept = false;
|
|
|
|
|
bool hasProtocol = false;
|
|
|
|
|
|
|
|
|
|
std::weak_ptr<uv::Timer> timer;
|
|
|
|
|
};
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
static std::string_view AcceptHash(std::string_view key,
|
|
|
|
|
SmallVectorImpl<char>& buf) {
|
2018-08-24 20:54:23 -07:00
|
|
|
SHA1 hash;
|
|
|
|
|
hash.Update(key);
|
|
|
|
|
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
|
|
|
|
SmallString<64> hashBuf;
|
2018-11-18 18:27:06 -08:00
|
|
|
return Base64Encode(hash.RawFinal(hashBuf), buf);
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
|
|
|
|
|
: m_stream{stream}, m_server{server} {
|
|
|
|
|
// Connect closed and error signals to ourselves
|
|
|
|
|
m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
|
|
|
|
|
m_stream.error.connect([this](uv::Error err) {
|
2021-06-06 16:13:58 -07:00
|
|
|
Terminate(1006, fmt::format("stream error: {}", err.name()));
|
2018-08-24 20:54:23 -07:00
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Start reading
|
|
|
|
|
m_stream.StopRead(); // we may have been reading
|
|
|
|
|
m_stream.StartRead();
|
|
|
|
|
m_stream.data.connect(
|
|
|
|
|
[this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
|
|
|
|
|
m_stream.end.connect(
|
|
|
|
|
[this]() { Terminate(1006, "remote end closed connection"); });
|
|
|
|
|
}
|
|
|
|
|
|
2020-12-28 00:37:33 -08:00
|
|
|
WebSocket::~WebSocket() = default;
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
std::shared_ptr<WebSocket> WebSocket::CreateClient(
|
2021-06-06 16:13:58 -07:00
|
|
|
uv::Stream& stream, std::string_view uri, std::string_view host,
|
2021-06-06 19:51:14 -07:00
|
|
|
span<const std::string_view> protocols, const ClientOptions& options) {
|
2018-08-24 20:54:23 -07:00
|
|
|
auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
|
|
|
|
|
stream.SetData(ws);
|
|
|
|
|
ws->StartClient(uri, host, protocols, options);
|
|
|
|
|
return ws;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
|
2021-06-06 16:13:58 -07:00
|
|
|
std::string_view key,
|
|
|
|
|
std::string_view version,
|
|
|
|
|
std::string_view protocol) {
|
2018-08-24 20:54:23 -07:00
|
|
|
auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
|
|
|
|
|
stream.SetData(ws);
|
|
|
|
|
ws->StartServer(key, version, protocol);
|
|
|
|
|
return ws;
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::Close(uint16_t code, std::string_view reason) {
|
2018-08-24 20:54:23 -07:00
|
|
|
SendClose(code, reason);
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state != FAILED && m_state != CLOSED) {
|
|
|
|
|
m_state = CLOSING;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::Fail(uint16_t code, std::string_view reason) {
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state == FAILED || m_state == CLOSED) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
SendClose(code, reason);
|
|
|
|
|
SetClosed(code, reason, true);
|
|
|
|
|
Shutdown();
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::Terminate(uint16_t code, std::string_view reason) {
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state == FAILED || m_state == CLOSED) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
SetClosed(code, reason);
|
|
|
|
|
Shutdown();
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::StartClient(std::string_view uri, std::string_view host,
|
2021-06-06 19:51:14 -07:00
|
|
|
span<const std::string_view> protocols,
|
2018-08-24 20:54:23 -07:00
|
|
|
const ClientOptions& options) {
|
|
|
|
|
// Create client handshake data
|
|
|
|
|
m_clientHandshake = std::make_unique<ClientHandshakeData>();
|
|
|
|
|
|
|
|
|
|
// Build client request
|
|
|
|
|
SmallVector<uv::Buffer, 4> bufs;
|
|
|
|
|
raw_uv_ostream os{bufs, 4096};
|
|
|
|
|
|
|
|
|
|
os << "GET " << uri << " HTTP/1.1\r\n";
|
|
|
|
|
os << "Host: " << host << "\r\n";
|
|
|
|
|
os << "Upgrade: websocket\r\n";
|
|
|
|
|
os << "Connection: Upgrade\r\n";
|
|
|
|
|
os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
|
|
|
|
|
os << "Sec-WebSocket-Version: 13\r\n";
|
|
|
|
|
|
|
|
|
|
// protocols (if provided)
|
|
|
|
|
if (!protocols.empty()) {
|
|
|
|
|
os << "Sec-WebSocket-Protocol: ";
|
|
|
|
|
bool first = true;
|
|
|
|
|
for (auto protocol : protocols) {
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!first) {
|
2018-08-24 20:54:23 -07:00
|
|
|
os << ", ";
|
2020-12-28 12:58:06 -08:00
|
|
|
} else {
|
2018-08-24 20:54:23 -07:00
|
|
|
first = false;
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
os << protocol;
|
|
|
|
|
// also save for later checking against server response
|
|
|
|
|
m_clientHandshake->protocols.emplace_back(protocol);
|
|
|
|
|
}
|
|
|
|
|
os << "\r\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// other headers
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto&& header : options.extraHeaders) {
|
2018-08-24 20:54:23 -07:00
|
|
|
os << header.first << ": " << header.second << "\r\n";
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// finish headers
|
|
|
|
|
os << "\r\n";
|
|
|
|
|
|
|
|
|
|
// Send client request
|
|
|
|
|
m_stream.Write(bufs, [](auto bufs, uv::Error) {
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto& buf : bufs) {
|
|
|
|
|
buf.Deallocate();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Set up client response handling
|
2021-06-06 16:13:58 -07:00
|
|
|
m_clientHandshake->parser.status.connect([this](std::string_view status) {
|
2018-08-24 20:54:23 -07:00
|
|
|
unsigned int code = m_clientHandshake->parser.GetStatusCode();
|
2020-12-28 12:58:06 -08:00
|
|
|
if (code != 101) {
|
|
|
|
|
Terminate(code, status);
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
});
|
|
|
|
|
m_clientHandshake->parser.header.connect(
|
2021-06-06 16:13:58 -07:00
|
|
|
[this](std::string_view name, std::string_view value) {
|
|
|
|
|
value = trim(value);
|
|
|
|
|
if (equals_lower(name, "upgrade")) {
|
|
|
|
|
if (!equals_lower(value, "websocket")) {
|
2018-08-24 20:54:23 -07:00
|
|
|
return Terminate(1002, "invalid upgrade response value");
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_clientHandshake->hasUpgrade = true;
|
2021-06-06 16:13:58 -07:00
|
|
|
} else if (equals_lower(name, "connection")) {
|
|
|
|
|
if (!equals_lower(value, "upgrade")) {
|
2018-08-24 20:54:23 -07:00
|
|
|
return Terminate(1002, "invalid connection response value");
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_clientHandshake->hasConnection = true;
|
2021-06-06 16:13:58 -07:00
|
|
|
} else if (equals_lower(name, "sec-websocket-accept")) {
|
2018-08-24 20:54:23 -07:00
|
|
|
// Check against expected response
|
|
|
|
|
SmallString<64> acceptBuf;
|
2021-06-06 16:13:58 -07:00
|
|
|
if (!equals(value, AcceptHash(m_clientHandshake->key, acceptBuf))) {
|
2018-08-24 20:54:23 -07:00
|
|
|
return Terminate(1002, "invalid accept key");
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_clientHandshake->hasAccept = true;
|
2021-06-06 16:13:58 -07:00
|
|
|
} else if (equals_lower(name, "sec-websocket-extensions")) {
|
2018-08-24 20:54:23 -07:00
|
|
|
// No extensions are supported
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!value.empty()) {
|
|
|
|
|
return Terminate(1010, "unsupported extension");
|
|
|
|
|
}
|
2021-06-06 16:13:58 -07:00
|
|
|
} else if (equals_lower(name, "sec-websocket-protocol")) {
|
2018-08-24 20:54:23 -07:00
|
|
|
// Make sure it was one of the provided protocols
|
|
|
|
|
bool match = false;
|
|
|
|
|
for (auto&& protocol : m_clientHandshake->protocols) {
|
2021-06-06 16:13:58 -07:00
|
|
|
if (equals_lower(value, protocol)) {
|
2018-08-24 20:54:23 -07:00
|
|
|
match = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!match) {
|
|
|
|
|
return Terminate(1003, "unsupported protocol");
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_clientHandshake->hasProtocol = true;
|
|
|
|
|
m_protocol = value;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
m_clientHandshake->parser.headersComplete.connect([this](bool) {
|
|
|
|
|
if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
|
|
|
|
|
!m_clientHandshake->hasAccept ||
|
|
|
|
|
(!m_clientHandshake->hasProtocol &&
|
|
|
|
|
!m_clientHandshake->protocols.empty())) {
|
|
|
|
|
return Terminate(1002, "invalid response");
|
|
|
|
|
}
|
|
|
|
|
if (m_state == CONNECTING) {
|
|
|
|
|
m_state = OPEN;
|
|
|
|
|
open(m_protocol);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Start handshake timer if a timeout was specified
|
2019-05-31 13:43:32 -07:00
|
|
|
if (options.handshakeTimeout != (uv::Timer::Time::max)()) {
|
2018-08-24 20:54:23 -07:00
|
|
|
auto timer = uv::Timer::Create(m_stream.GetLoopRef());
|
|
|
|
|
timer->timeout.connect(
|
|
|
|
|
[this]() { Terminate(1006, "connection timed out"); });
|
|
|
|
|
timer->Start(options.handshakeTimeout);
|
|
|
|
|
m_clientHandshake->timer = timer;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::StartServer(std::string_view key, std::string_view version,
|
|
|
|
|
std::string_view protocol) {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_protocol = protocol;
|
|
|
|
|
|
|
|
|
|
// Build server response
|
|
|
|
|
SmallVector<uv::Buffer, 4> bufs;
|
|
|
|
|
raw_uv_ostream os{bufs, 4096};
|
|
|
|
|
|
|
|
|
|
// Handle unsupported version
|
|
|
|
|
if (version != "13") {
|
|
|
|
|
os << "HTTP/1.1 426 Upgrade Required\r\n";
|
|
|
|
|
os << "Upgrade: WebSocket\r\n";
|
|
|
|
|
os << "Sec-WebSocket-Version: 13\r\n\r\n";
|
|
|
|
|
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto& buf : bufs) {
|
|
|
|
|
buf.Deallocate();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
// XXX: Should we support sending a new handshake on the same connection?
|
|
|
|
|
// XXX: "this->" is required by GCC 5.5 (bug)
|
|
|
|
|
this->Terminate(1003, "unsupported protocol version");
|
|
|
|
|
});
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
os << "HTTP/1.1 101 Switching Protocols\r\n";
|
|
|
|
|
os << "Upgrade: websocket\r\n";
|
|
|
|
|
os << "Connection: Upgrade\r\n";
|
|
|
|
|
|
|
|
|
|
// accept hash
|
|
|
|
|
SmallString<64> acceptBuf;
|
|
|
|
|
os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
|
|
|
|
|
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!protocol.empty()) {
|
|
|
|
|
os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// end headers
|
|
|
|
|
os << "\r\n";
|
|
|
|
|
|
|
|
|
|
// Send server response
|
|
|
|
|
m_stream.Write(bufs, [this](auto bufs, uv::Error) {
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto& buf : bufs) {
|
|
|
|
|
buf.Deallocate();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
if (m_state == CONNECTING) {
|
|
|
|
|
m_state = OPEN;
|
|
|
|
|
open(m_protocol);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::SendClose(uint16_t code, std::string_view reason) {
|
2018-08-24 20:54:23 -07:00
|
|
|
SmallVector<uv::Buffer, 4> bufs;
|
|
|
|
|
if (code != 1005) {
|
|
|
|
|
raw_uv_ostream os{bufs, 4096};
|
2019-06-29 23:54:02 -07:00
|
|
|
const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
|
|
|
|
|
static_cast<uint8_t>(code & 0xff)};
|
2021-06-06 19:51:14 -07:00
|
|
|
os << span{codeMsb};
|
2021-06-06 16:13:58 -07:00
|
|
|
os << reason;
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto&& buf : bufs) {
|
|
|
|
|
buf.Deallocate();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
void WebSocket::SetClosed(uint16_t code, std::string_view reason, bool failed) {
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state == FAILED || m_state == CLOSED) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_state = failed ? FAILED : CLOSED;
|
2021-06-06 16:13:58 -07:00
|
|
|
closed(code, reason);
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void WebSocket::Shutdown() {
|
|
|
|
|
m_stream.Shutdown([this] { m_stream.Close(); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
|
|
|
|
|
// ignore incoming data if we're failed or closed
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state == FAILED || m_state == CLOSED) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
2021-06-06 16:13:58 -07:00
|
|
|
std::string_view data{buf.base, size};
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// Handle connecting state (mainly on client)
|
|
|
|
|
if (m_state == CONNECTING) {
|
|
|
|
|
if (m_clientHandshake) {
|
|
|
|
|
data = m_clientHandshake->parser.Execute(data);
|
|
|
|
|
// check for parser failure
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_clientHandshake->parser.HasError()) {
|
2018-08-24 20:54:23 -07:00
|
|
|
return Terminate(1003, "invalid response");
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
|
|
|
|
if (m_state != OPEN) {
|
|
|
|
|
return; // not done with handshake yet
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// we're done with the handshake, so release its memory
|
|
|
|
|
m_clientHandshake.reset();
|
|
|
|
|
|
|
|
|
|
// fall through to process additional data after handshake
|
|
|
|
|
} else {
|
|
|
|
|
return Terminate(1003, "got data on server before response");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Message processing
|
|
|
|
|
while (!data.empty()) {
|
|
|
|
|
if (m_frameSize == UINT64_MAX) {
|
|
|
|
|
// Need at least two bytes to determine header length
|
|
|
|
|
if (m_header.size() < 2u) {
|
2019-05-31 13:43:32 -07:00
|
|
|
size_t toCopy = (std::min)(2u - m_header.size(), data.size());
|
2021-06-06 16:13:58 -07:00
|
|
|
m_header.append(data.data(), data.data() + toCopy);
|
|
|
|
|
data.remove_prefix(toCopy);
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_header.size() < 2u) {
|
|
|
|
|
return; // need more data
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// Validate RSV bits are zero
|
2020-12-28 12:58:06 -08:00
|
|
|
if ((m_header[0] & 0x70) != 0) {
|
|
|
|
|
return Fail(1002, "nonzero RSV");
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Once we have first two bytes, we can calculate the header size
|
|
|
|
|
if (m_headerSize == 0) {
|
|
|
|
|
m_headerSize = 2;
|
|
|
|
|
uint8_t len = m_header[1] & kLenMask;
|
2020-12-28 12:58:06 -08:00
|
|
|
if (len == 126) {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_headerSize += 2;
|
2020-12-28 12:58:06 -08:00
|
|
|
} else if (len == 127) {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_headerSize += 8;
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
bool masking = (m_header[1] & kFlagMasking) != 0;
|
2020-12-28 12:58:06 -08:00
|
|
|
if (masking) {
|
|
|
|
|
m_headerSize += 4; // masking key
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
// On server side, incoming messages MUST be masked
|
|
|
|
|
// On client side, incoming messages MUST NOT be masked
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_server && !masking) {
|
|
|
|
|
return Fail(1002, "client data not masked");
|
|
|
|
|
}
|
|
|
|
|
if (!m_server && masking) {
|
|
|
|
|
return Fail(1002, "server data masked");
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Need to complete header to calculate message size
|
|
|
|
|
if (m_header.size() < m_headerSize) {
|
2019-05-31 13:43:32 -07:00
|
|
|
size_t toCopy = (std::min)(m_headerSize - m_header.size(), data.size());
|
2021-06-06 16:13:58 -07:00
|
|
|
m_header.append(data.data(), data.data() + toCopy);
|
|
|
|
|
data.remove_prefix(toCopy);
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_header.size() < m_headerSize) {
|
|
|
|
|
return; // need more data
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (m_header.size() >= m_headerSize) {
|
|
|
|
|
// get payload length
|
|
|
|
|
uint8_t len = m_header[1] & kLenMask;
|
2020-12-28 12:58:06 -08:00
|
|
|
if (len == 126) {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
|
|
|
|
|
static_cast<uint16_t>(m_header[3]);
|
2020-12-28 12:58:06 -08:00
|
|
|
} else if (len == 127) {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[3]) << 48) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[4]) << 40) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[5]) << 32) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[6]) << 24) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[7]) << 16) |
|
|
|
|
|
(static_cast<uint64_t>(m_header[8]) << 8) |
|
|
|
|
|
static_cast<uint64_t>(m_header[9]);
|
2020-12-28 12:58:06 -08:00
|
|
|
} else {
|
2018-08-24 20:54:23 -07:00
|
|
|
m_frameSize = len;
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
|
|
|
|
|
// limit maximum size
|
2020-12-28 12:58:06 -08:00
|
|
|
if ((m_payload.size() + m_frameSize) > m_maxMessageSize) {
|
2018-08-24 20:54:23 -07:00
|
|
|
return Fail(1009, "message too large");
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (m_frameSize != UINT64_MAX) {
|
|
|
|
|
size_t need = m_frameStart + m_frameSize - m_payload.size();
|
2019-05-31 13:43:32 -07:00
|
|
|
size_t toCopy = (std::min)(need, data.size());
|
2021-06-06 16:13:58 -07:00
|
|
|
m_payload.append(data.data(), data.data() + toCopy);
|
|
|
|
|
data.remove_prefix(toCopy);
|
2018-08-24 20:54:23 -07:00
|
|
|
need -= toCopy;
|
|
|
|
|
if (need == 0) {
|
|
|
|
|
// We have a complete frame
|
|
|
|
|
// If the message had masking, unmask it
|
|
|
|
|
if ((m_header[1] & kFlagMasking) != 0) {
|
|
|
|
|
uint8_t key[4] = {
|
|
|
|
|
m_header[m_headerSize - 4], m_header[m_headerSize - 3],
|
|
|
|
|
m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
|
|
|
|
|
int n = 0;
|
2021-06-06 19:51:14 -07:00
|
|
|
for (uint8_t& ch : span{m_payload}.subspan(m_frameStart)) {
|
2018-08-24 20:54:23 -07:00
|
|
|
ch ^= key[n++];
|
2020-12-28 12:58:06 -08:00
|
|
|
if (n >= 4) {
|
|
|
|
|
n = 0;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle message
|
|
|
|
|
bool fin = (m_header[0] & kFlagFin) != 0;
|
|
|
|
|
uint8_t opcode = m_header[0] & kOpMask;
|
|
|
|
|
switch (opcode) {
|
|
|
|
|
case kOpCont:
|
|
|
|
|
switch (m_fragmentOpcode) {
|
|
|
|
|
case kOpText:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!m_combineFragments || fin) {
|
2021-06-06 16:13:58 -07:00
|
|
|
text(std::string_view{reinterpret_cast<char*>(
|
|
|
|
|
m_payload.data()),
|
|
|
|
|
m_payload.size()},
|
2018-08-24 20:54:23 -07:00
|
|
|
fin);
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
case kOpBinary:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!m_combineFragments || fin) {
|
|
|
|
|
binary(m_payload, fin);
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
// no preceding message?
|
|
|
|
|
return Fail(1002, "invalid continuation message");
|
|
|
|
|
}
|
2020-12-28 12:58:06 -08:00
|
|
|
if (fin) {
|
|
|
|
|
m_fragmentOpcode = 0;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
case kOpText:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_fragmentOpcode != 0) {
|
|
|
|
|
return Fail(1002, "incomplete fragment");
|
|
|
|
|
}
|
|
|
|
|
if (!m_combineFragments || fin) {
|
2021-06-06 16:13:58 -07:00
|
|
|
text(std::string_view{reinterpret_cast<char*>(m_payload.data()),
|
|
|
|
|
m_payload.size()},
|
2018-08-24 20:54:23 -07:00
|
|
|
fin);
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
|
|
|
|
if (!fin) {
|
|
|
|
|
m_fragmentOpcode = opcode;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
case kOpBinary:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_fragmentOpcode != 0) {
|
|
|
|
|
return Fail(1002, "incomplete fragment");
|
|
|
|
|
}
|
|
|
|
|
if (!m_combineFragments || fin) {
|
|
|
|
|
binary(m_payload, fin);
|
|
|
|
|
}
|
|
|
|
|
if (!fin) {
|
|
|
|
|
m_fragmentOpcode = opcode;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
case kOpClose: {
|
|
|
|
|
uint16_t code;
|
2021-06-06 16:13:58 -07:00
|
|
|
std::string_view reason;
|
2018-08-24 20:54:23 -07:00
|
|
|
if (!fin) {
|
|
|
|
|
code = 1002;
|
|
|
|
|
reason = "cannot fragment control frames";
|
|
|
|
|
} else if (m_payload.size() < 2) {
|
|
|
|
|
code = 1005;
|
|
|
|
|
} else {
|
|
|
|
|
code = (static_cast<uint16_t>(m_payload[0]) << 8) |
|
|
|
|
|
static_cast<uint16_t>(m_payload[1]);
|
2021-06-06 16:13:58 -07:00
|
|
|
reason = drop_front(
|
|
|
|
|
{reinterpret_cast<char*>(m_payload.data()), m_payload.size()},
|
|
|
|
|
2);
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
// Echo the close if we didn't previously send it
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state != CLOSING) {
|
|
|
|
|
SendClose(code, reason);
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
SetClosed(code, reason);
|
|
|
|
|
// If we're the server, shutdown the connection.
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_server) {
|
|
|
|
|
Shutdown();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case kOpPing:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!fin) {
|
|
|
|
|
return Fail(1002, "cannot fragment control frames");
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
ping(m_payload);
|
|
|
|
|
break;
|
|
|
|
|
case kOpPong:
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!fin) {
|
|
|
|
|
return Fail(1002, "cannot fragment control frames");
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
pong(m_payload);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
return Fail(1002, "invalid message opcode");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Prepare for next message
|
|
|
|
|
m_header.clear();
|
|
|
|
|
m_headerSize = 0;
|
2020-12-28 12:58:06 -08:00
|
|
|
if (!m_combineFragments || fin) {
|
|
|
|
|
m_payload.clear();
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
m_frameStart = m_payload.size();
|
|
|
|
|
m_frameSize = UINT64_MAX;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void WebSocket::Send(
|
2021-06-06 19:51:14 -07:00
|
|
|
uint8_t opcode, span<const uv::Buffer> data,
|
|
|
|
|
std::function<void(span<uv::Buffer>, uv::Error)> callback) {
|
2018-08-24 20:54:23 -07:00
|
|
|
// If we're not open, emit an error and don't send the data
|
|
|
|
|
if (m_state != OPEN) {
|
|
|
|
|
int err;
|
2020-12-28 12:58:06 -08:00
|
|
|
if (m_state == CONNECTING) {
|
2018-08-24 20:54:23 -07:00
|
|
|
err = UV_EAGAIN;
|
2020-12-28 12:58:06 -08:00
|
|
|
} else {
|
2018-08-24 20:54:23 -07:00
|
|
|
err = UV_ESHUTDOWN;
|
2020-12-28 12:58:06 -08:00
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
|
|
|
|
|
callback(bufs, uv::Error{err});
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto req = std::make_shared<WebSocketWriteReq>(callback);
|
|
|
|
|
raw_uv_ostream os{req->m_bufs, 4096};
|
|
|
|
|
|
|
|
|
|
// opcode (includes FIN bit)
|
|
|
|
|
os << static_cast<unsigned char>(opcode);
|
|
|
|
|
|
|
|
|
|
// payload length
|
|
|
|
|
uint64_t size = 0;
|
2020-12-28 12:58:06 -08:00
|
|
|
for (auto&& buf : data) {
|
|
|
|
|
size += buf.len;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
if (size < 126) {
|
|
|
|
|
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
|
|
|
|
|
} else if (size <= 0xffff) {
|
|
|
|
|
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
|
2019-06-29 23:54:02 -07:00
|
|
|
const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 8) & 0xff),
|
|
|
|
|
static_cast<uint8_t>(size & 0xff)};
|
2021-06-06 19:51:14 -07:00
|
|
|
os << span{sizeMsb};
|
2018-08-24 20:54:23 -07:00
|
|
|
} else {
|
|
|
|
|
os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
|
2019-06-29 23:54:02 -07:00
|
|
|
const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 56) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 48) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 40) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 32) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 24) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 16) & 0xff),
|
|
|
|
|
static_cast<uint8_t>((size >> 8) & 0xff),
|
|
|
|
|
static_cast<uint8_t>(size & 0xff)};
|
2021-06-06 19:51:14 -07:00
|
|
|
os << span{sizeMsb};
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// clients need to mask the input data
|
|
|
|
|
if (!m_server) {
|
|
|
|
|
// generate masking key
|
|
|
|
|
static std::random_device rd;
|
|
|
|
|
static std::default_random_engine gen{rd()};
|
|
|
|
|
std::uniform_int_distribution<unsigned int> dist(0, 255);
|
|
|
|
|
uint8_t key[4];
|
2020-12-28 12:58:06 -08:00
|
|
|
for (uint8_t& v : key) {
|
|
|
|
|
v = dist(gen);
|
|
|
|
|
}
|
2021-06-06 19:51:14 -07:00
|
|
|
os << span<const uint8_t>{key, 4};
|
2018-08-24 20:54:23 -07:00
|
|
|
// copy and mask data
|
|
|
|
|
int n = 0;
|
|
|
|
|
for (auto&& buf : data) {
|
|
|
|
|
for (auto&& ch : buf.data()) {
|
|
|
|
|
os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
|
2020-12-28 12:58:06 -08:00
|
|
|
if (n >= 4) {
|
|
|
|
|
n = 0;
|
|
|
|
|
}
|
2018-08-24 20:54:23 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
req->m_startUser = req->m_bufs.size();
|
|
|
|
|
req->m_bufs.append(data.begin(), data.end());
|
|
|
|
|
// don't send the user bufs as we copied their data
|
2021-06-06 19:51:14 -07:00
|
|
|
m_stream.Write(span{req->m_bufs}.subspan(0, req->m_startUser), req);
|
2018-08-24 20:54:23 -07:00
|
|
|
} else {
|
|
|
|
|
// servers can just send the buffers directly without masking
|
|
|
|
|
req->m_startUser = req->m_bufs.size();
|
|
|
|
|
req->m_bufs.append(data.begin(), data.end());
|
|
|
|
|
m_stream.Write(req->m_bufs, req);
|
|
|
|
|
}
|
|
|
|
|
}
|