[ntcore] Fix write_impl (#5847)

The previous fix didn't handle all cases correctly. Instead, add a new
function to raw_ostream (SetNumBytesInBuffer) to allow always using the
full buffer size, and revamp write_impl to more cleanly handle all
cases.
This commit is contained in:
Peter Johnson
2023-10-30 08:23:33 -07:00
committed by GitHub
parent 1713386869
commit 07e13d60a2
4 changed files with 63 additions and 26 deletions

View File

@@ -4,6 +4,7 @@
#include "WebSocketConnection.h"
#include <algorithm>
#include <span>
#include <wpi/Endian.h>
@@ -30,7 +31,8 @@ class WebSocketConnection::Stream final : public wpi::raw_ostream {
public:
explicit Stream(WebSocketConnection& conn) : m_conn{conn} {
auto& buf = conn.m_bufs.back();
SetBuffer(buf.base + buf.len, kAllocSize - buf.len);
SetBuffer(buf.base, kAllocSize);
SetNumBytesInBuffer(buf.len);
}
~Stream() final {
@@ -48,36 +50,40 @@ class WebSocketConnection::Stream final : public wpi::raw_ostream {
};
void WebSocketConnection::Stream::write_impl(const char* data, size_t len) {
if (len >= kAllocSize) {
// only called by raw_ostream::write() when the buffer is empty and a large
// thing is being written; called with a length that's a multiple of the
// alloc size
assert((len % kAllocSize) == 0);
assert(m_conn.m_bufs.back().len == 0);
while (len > 0) {
auto& buf = m_conn.m_bufs.back();
std::memcpy(buf.base, data, kAllocSize);
buf.len = kAllocSize;
m_conn.m_framePos += kAllocSize;
m_conn.m_written += kAllocSize;
data += kAllocSize;
len -= kAllocSize;
if (data == m_conn.m_bufs.back().base) {
// flush_nonempty() case
m_conn.m_bufs.back().len = len;
if (!m_disableAlloc) {
m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin;
m_conn.StartFrame(wpi::WebSocket::Frame::kFragment);
SetBuffer(m_conn.m_bufs.back().base, kAllocSize);
}
return;
}
bool updateBuffer = false;
while (len > 0) {
auto& buf = m_conn.m_bufs.back();
assert(buf.len <= kAllocSize);
size_t amt = (std::min)(static_cast<int>(kAllocSize - buf.len),
static_cast<int>(len));
if (amt > 0) {
std::memcpy(buf.base + buf.len, data, amt);
buf.len += amt;
m_conn.m_framePos += amt;
m_conn.m_written += amt;
data += amt;
len -= amt;
}
if (buf.len >= kAllocSize && (len > 0 || !m_disableAlloc)) {
// fragment the current frame and start a new one
m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin;
m_conn.StartFrame(wpi::WebSocket::Frame::kFragment);
updateBuffer = true;
}
SetBuffer(m_conn.m_bufs.back().base, kAllocSize);
[[unlikely]] return;
}
auto& buf = m_conn.m_bufs.back();
buf.len += len;
m_conn.m_framePos += len;
m_conn.m_written += len;
if (!m_disableAlloc && buf.len >= kAllocSize) {
// fragment the current frame and start a new one
[[unlikely]] m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin;
m_conn.StartFrame(wpi::WebSocket::Frame::kFragment);
if (updateBuffer) {
SetBuffer(m_conn.m_bufs.back().base, kAllocSize);
}
}
@@ -132,7 +138,7 @@ void WebSocketConnection::StartFrame(uint8_t opcode) {
void WebSocketConnection::FinishText() {
assert(!m_bufs.empty());
auto& buf = m_bufs.back();
assert(buf.len < kAllocSize + 1); // safe because we alloc one more byte
assert(buf.len < (kAllocSize + 1)); // safe because we alloc one more byte
buf.base[buf.len++] = ']';
}