[wpinet] Translate unit tests to catch2 (#8954)

This commit is contained in:
Peter Johnson
2026-06-19 21:49:50 -05:00
committed by GitHub
parent cfab47e871
commit 396b553069
27 changed files with 952 additions and 776 deletions

View File

@@ -13,3 +13,23 @@ macro(wpilib_add_test name srcdir)
endif()
add_test(NAME ${name} COMMAND ${name}_test)
endmacro()
macro(wpilib_add_test_catch2 name)
set(wpilib_catch2_test_src)
foreach(src ${ARGN})
if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/${src}")
file(GLOB_RECURSE wpilib_catch2_dir_src ${src}/*.cpp)
list(APPEND wpilib_catch2_test_src ${wpilib_catch2_dir_src})
else()
list(APPEND wpilib_catch2_test_src ${src})
endif()
endforeach()
add_executable(${name}_test ${wpilib_catch2_test_src})
set_property(TARGET ${name}_test PROPERTY FOLDER "tests")
wpilib_target_warnings(${name}_test)
target_link_libraries(${name}_test catch2)
catch_discover_tests(${name}_test)
if(MSVC)
target_compile_options(${name}_test PRIVATE /wd4101 /wd4251 /utf-8)
endif()
endmacro()

View File

@@ -28,7 +28,8 @@ ext {
staticGtestConfigs = [:]
}
staticGtestConfigs["${nativeName}Test"] = []
def nativeTestSuiteName = project.findProperty('nativeTestSuiteName') ?: "${nativeName}Test"
staticGtestConfigs[nativeTestSuiteName] = []
apply from: "${rootDir}/shared/googletest.gradle"
apply from: "${rootDir}/shared/catch2.gradle"
@@ -194,7 +195,7 @@ model {
}
}
testSuites {
"${nativeName}Test"(GoogleTestTestSuiteSpec) {
"${nativeTestSuiteName}"(GoogleTestTestSuiteSpec) {
for(NativeComponentSpec c : $.components) {
if (c.name == nativeName) {
testing c
@@ -216,6 +217,11 @@ model {
}
}
}
if (nativeTestSuiteName.contains('Catch2')) {
binaries.all {
lib project: ':thirdparty:catch2', library: 'catch2', linkage: 'static'
}
}
}
}
binaries {

View File

@@ -226,7 +226,7 @@ cc_test(
tags = ["no-asan"],
deps = [
":wpinet",
"//thirdparty/googletest",
"//thirdparty/catch2",
"//wpiutil:wpiutil-testlib",
],
)

View File

@@ -187,7 +187,7 @@ set_property(TARGET netconsoleServer PROPERTY FOLDER "examples")
set_property(TARGET netconsoleTee PROPERTY FOLDER "examples")
if(WITH_TESTS)
wpilib_add_test(wpinet src/test/native/cpp)
wpilib_add_test_catch2(wpinet src/test/native/cpp)
target_include_directories(wpinet_test PRIVATE src/test/native/include src/main/native/cpp)
target_link_libraries(wpinet_test wpinet ${LIBUTIL} googletest wpiutil_testlib)
target_link_libraries(wpinet_test wpinet ${LIBUTIL} wpiutil_testlib)
endif()

View File

@@ -4,6 +4,7 @@ ext {
nativeName = 'wpinet'
devMain = 'org.wpilib.net.DevMain'
nativeTestSuiteName = "${nativeName}Catch2Test"
splitSetup = {
it.sources {

View File

@@ -4,203 +4,203 @@
#include "wpi/net/HttpParser.hpp"
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net {
TEST(HttpParserTest, UrlMethodHeadersComplete) {
TEST_CASE("HttpParserTest UrlMethodHeadersComplete", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.url.connect([&](std::string_view path) {
ASSERT_EQ(path, "/foo/bar");
ASSERT_EQ(p.GetUrl(), "/foo/bar");
REQUIRE(path == "/foo/bar");
REQUIRE(p.GetUrl() == "/foo/bar");
++callbacks;
});
p.Execute("GET /foo");
p.Execute("/bar");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute(" HTTP/1.1\r\n\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_EQ(p.GetUrl(), "/foo/bar");
ASSERT_EQ(p.GetMethod(), HTTP_GET);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE(p.GetUrl() == "/foo/bar");
REQUIRE(p.GetMethod() == HTTP_GET);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, UrlMethodHeader) {
TEST_CASE("HttpParserTest UrlMethodHeader", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.url.connect([&](std::string_view path) {
ASSERT_EQ(path, "/foo/bar");
ASSERT_EQ(p.GetUrl(), "/foo/bar");
REQUIRE(path == "/foo/bar");
REQUIRE(p.GetUrl() == "/foo/bar");
++callbacks;
});
p.Execute("GET /foo");
p.Execute("/bar");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute(" HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("F");
ASSERT_EQ(callbacks, 1);
ASSERT_EQ(p.GetUrl(), "/foo/bar");
ASSERT_EQ(p.GetMethod(), HTTP_GET);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE(p.GetUrl() == "/foo/bar");
REQUIRE(p.GetMethod() == HTTP_GET);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, StatusHeadersComplete) {
TEST_CASE("HttpParserTest StatusHeadersComplete", "[http][parser]") {
HttpParser p{HttpParser::Type::RESPONSE};
int callbacks = 0;
p.status.connect([&](std::string_view status) {
ASSERT_EQ(status, "OK");
ASSERT_EQ(p.GetStatusCode(), 200u);
REQUIRE(status == "OK");
REQUIRE(p.GetStatusCode() == 200u);
++callbacks;
});
p.Execute("HTTP/1.1 200");
p.Execute(" OK");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_EQ(p.GetStatusCode(), 200u);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE(p.GetStatusCode() == 200u);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, StatusHeader) {
TEST_CASE("HttpParserTest StatusHeader", "[http][parser]") {
HttpParser p{HttpParser::Type::RESPONSE};
int callbacks = 0;
p.status.connect([&](std::string_view status) {
ASSERT_EQ(status, "OK");
ASSERT_EQ(p.GetStatusCode(), 200u);
REQUIRE(status == "OK");
REQUIRE(p.GetStatusCode() == 200u);
++callbacks;
});
p.Execute("HTTP/1.1 200");
p.Execute(" OK\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("F");
ASSERT_EQ(callbacks, 1);
ASSERT_EQ(p.GetStatusCode(), 200u);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE(p.GetStatusCode() == 200u);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeaderFieldComplete) {
TEST_CASE("HttpParserTest HeaderFieldComplete", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.header.connect([&](std::string_view name, std::string_view value) {
ASSERT_EQ(name, "Foo");
ASSERT_EQ(value, "Bar");
REQUIRE(name == "Foo");
REQUIRE(value == "Bar");
++callbacks;
});
p.Execute("GET / HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Fo");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("o: ");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Bar");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeaderFieldNext) {
TEST_CASE("HttpParserTest HeaderFieldNext", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.header.connect([&](std::string_view name, std::string_view value) {
ASSERT_EQ(name, "Foo");
ASSERT_EQ(value, "Bar");
REQUIRE(name == "Foo");
REQUIRE(value == "Bar");
++callbacks;
});
p.Execute("GET / HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Fo");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("o: ");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Bar");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("F");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeadersComplete) {
TEST_CASE("HttpParserTest HeadersComplete", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.headersComplete.connect([&](bool keepAlive) {
ASSERT_EQ(keepAlive, false);
REQUIRE(keepAlive == false);
++callbacks;
});
p.Execute("GET / HTTP/1.0\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeadersCompleteHTTP11) {
TEST_CASE("HttpParserTest HeadersCompleteHTTP11", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.headersComplete.connect([&](bool keepAlive) {
ASSERT_EQ(keepAlive, true);
REQUIRE(keepAlive == true);
++callbacks;
});
p.Execute("GET / HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeadersCompleteKeepAlive) {
TEST_CASE("HttpParserTest HeadersCompleteKeepAlive", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.headersComplete.connect([&](bool keepAlive) {
ASSERT_EQ(keepAlive, true);
REQUIRE(keepAlive == true);
++callbacks;
});
p.Execute("GET / HTTP/1.0\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Connection: Keep-Alive\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, HeadersCompleteUpgrade) {
TEST_CASE("HttpParserTest HeadersCompleteUpgrade", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.headersComplete.connect([&](bool) {
ASSERT_TRUE(p.IsUpgrade());
REQUIRE(p.IsUpgrade());
++callbacks;
});
p.Execute("GET / HTTP/1.0\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("Connection: Upgrade\r\n");
p.Execute("Upgrade: websocket\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 1);
REQUIRE_FALSE(p.HasError());
}
TEST(HttpParserTest, Reset) {
TEST_CASE("HttpParserTest Reset", "[http][parser]") {
HttpParser p{HttpParser::Type::REQUEST};
int callbacks = 0;
p.headersComplete.connect([&](bool) { ++callbacks; });
p.Execute("GET / HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 0);
REQUIRE(callbacks == 0);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 1);
REQUIRE(callbacks == 1);
p.Reset(HttpParser::Type::REQUEST);
p.Execute("GET / HTTP/1.1\r\n");
ASSERT_EQ(callbacks, 1);
REQUIRE(callbacks == 1);
p.Execute("\r\n");
ASSERT_EQ(callbacks, 2);
ASSERT_FALSE(p.HasError());
REQUIRE(callbacks == 2);
REQUIRE_FALSE(p.HasError());
}
} // namespace wpi::net

View File

@@ -4,99 +4,99 @@
#include "wpi/net/HttpUtil.hpp"
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net {
TEST(HttpMultipartScannerTest, ExecuteExact) {
TEST_CASE("HttpMultipartScannerTest ExecuteExact", "[http][multipart]") {
HttpMultipartScanner scanner("foo");
EXPECT_TRUE(scanner.Execute("abcdefg---\r\n--foo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
EXPECT_TRUE(scanner.GetSkipped().empty());
CHECK(scanner.Execute("abcdefg---\r\n--foo\r\n").empty());
CHECK(scanner.IsDone());
CHECK(scanner.GetSkipped().empty());
}
TEST(HttpMultipartScannerTest, ExecutePartial) {
TEST_CASE("HttpMultipartScannerTest ExecutePartial", "[http][multipart]") {
HttpMultipartScanner scanner("foo");
EXPECT_TRUE(scanner.Execute("abcdefg--").empty());
EXPECT_FALSE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("-\r\n").empty());
EXPECT_FALSE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("--foo\r").empty());
EXPECT_FALSE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("\n").empty());
EXPECT_TRUE(scanner.IsDone());
CHECK(scanner.Execute("abcdefg--").empty());
CHECK_FALSE(scanner.IsDone());
CHECK(scanner.Execute("-\r\n").empty());
CHECK_FALSE(scanner.IsDone());
CHECK(scanner.Execute("--foo\r").empty());
CHECK_FALSE(scanner.IsDone());
CHECK(scanner.Execute("\n").empty());
CHECK(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, ExecuteTrailing) {
TEST_CASE("HttpMultipartScannerTest ExecuteTrailing", "[http][multipart]") {
HttpMultipartScanner scanner("foo");
EXPECT_EQ(scanner.Execute("abcdefg---\r\n--foo\r\nxyz"), "xyz");
CHECK(scanner.Execute("abcdefg---\r\n--foo\r\nxyz") == "xyz");
}
TEST(HttpMultipartScannerTest, ExecutePadding) {
TEST_CASE("HttpMultipartScannerTest ExecutePadding", "[http][multipart]") {
HttpMultipartScanner scanner("foo");
EXPECT_EQ(scanner.Execute("abcdefg---\r\n--foo \r\nxyz"), "xyz");
EXPECT_TRUE(scanner.IsDone());
CHECK(scanner.Execute("abcdefg---\r\n--foo \r\nxyz") == "xyz");
CHECK(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, SaveSkipped) {
TEST_CASE("HttpMultipartScannerTest SaveSkipped", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
scanner.Execute("abcdefg---\r\n--foo\r\n");
EXPECT_EQ(scanner.GetSkipped(), "abcdefg---\r\n--foo\r\n");
CHECK(scanner.GetSkipped() == "abcdefg---\r\n--foo\r\n");
}
TEST(HttpMultipartScannerTest, Reset) {
TEST_CASE("HttpMultipartScannerTest Reset", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
scanner.Execute("abcdefg---\r\n--foo\r\n");
EXPECT_TRUE(scanner.IsDone());
EXPECT_EQ(scanner.GetSkipped(), "abcdefg---\r\n--foo\r\n");
CHECK(scanner.IsDone());
CHECK(scanner.GetSkipped() == "abcdefg---\r\n--foo\r\n");
scanner.Reset(true);
EXPECT_FALSE(scanner.IsDone());
CHECK_FALSE(scanner.IsDone());
scanner.SetBoundary("bar");
scanner.Execute("--foo\r\n--bar\r\n");
EXPECT_TRUE(scanner.IsDone());
EXPECT_EQ(scanner.GetSkipped(), "--foo\r\n--bar\r\n");
CHECK(scanner.IsDone());
CHECK(scanner.GetSkipped() == "--foo\r\n--bar\r\n");
}
TEST(HttpMultipartScannerTest, WithoutDashes) {
TEST_CASE("HttpMultipartScannerTest WithoutDashes", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
EXPECT_TRUE(scanner.Execute("--\r\nfoo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
CHECK(scanner.Execute("--\r\nfoo\r\n").empty());
CHECK(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, SeqDashesDashes) {
TEST_CASE("HttpMultipartScannerTest SeqDashesDashes", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
EXPECT_TRUE(scanner.Execute("\r\n--foo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("\r\n--foo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
CHECK(scanner.Execute("\r\n--foo\r\n").empty());
CHECK(scanner.IsDone());
CHECK(scanner.Execute("\r\n--foo\r\n").empty());
CHECK(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, SeqDashesNoDashes) {
TEST_CASE("HttpMultipartScannerTest SeqDashesNoDashes", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
EXPECT_TRUE(scanner.Execute("\r\n--foo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("\r\nfoo\r\n").empty());
EXPECT_FALSE(scanner.IsDone());
CHECK(scanner.Execute("\r\n--foo\r\n").empty());
CHECK(scanner.IsDone());
CHECK(scanner.Execute("\r\nfoo\r\n").empty());
CHECK_FALSE(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, SeqNoDashesDashes) {
TEST_CASE("HttpMultipartScannerTest SeqNoDashesDashes", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
EXPECT_TRUE(scanner.Execute("\r\nfoo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("\r\n--foo\r\n").empty());
EXPECT_FALSE(scanner.IsDone());
CHECK(scanner.Execute("\r\nfoo\r\n").empty());
CHECK(scanner.IsDone());
CHECK(scanner.Execute("\r\n--foo\r\n").empty());
CHECK_FALSE(scanner.IsDone());
}
TEST(HttpMultipartScannerTest, SeqNoDashesNoDashes) {
TEST_CASE("HttpMultipartScannerTest SeqNoDashesNoDashes", "[http][multipart]") {
HttpMultipartScanner scanner("foo", true);
EXPECT_TRUE(scanner.Execute("\r\nfoo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
EXPECT_TRUE(scanner.Execute("\r\nfoo\r\n").empty());
EXPECT_TRUE(scanner.IsDone());
CHECK(scanner.Execute("\r\nfoo\r\n").empty());
CHECK(scanner.IsDone());
CHECK(scanner.Execute("\r\nfoo\r\n").empty());
CHECK(scanner.IsDone());
}
} // namespace wpi::net

View File

@@ -6,7 +6,7 @@
#include <memory>
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net {

View File

@@ -8,12 +8,13 @@
#include <thread>
#include <vector>
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/MulticastServiceAnnouncer.h"
#include "wpi/net/MulticastServiceResolver.h"
TEST(MulticastServiceAnnouncerTest, EmptyText) {
TEST_CASE("MulticastServiceAnnouncerTest EmptyText",
"[multicast][service-discovery]") {
const std::string_view serviceName = "TestServiceNoText";
const std::string_view serviceType = "_wpinotxt._tcp";
const int port = std::rand();
@@ -37,14 +38,15 @@ TEST(MulticastServiceAnnouncerTest, EmptyText) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
ASSERT_GT(allData.size(), 0ul);
REQUIRE(allData.size() > 0ul);
resolver.Stop();
announcer.Stop();
}
}
TEST(MulticastServiceAnnouncerTest, SingleText) {
TEST_CASE("MulticastServiceAnnouncerTest SingleText",
"[multicast][service-discovery]") {
const std::string_view serviceName = "TestServiceSingle";
const std::string_view serviceType = "_wpitxt";
const int port = std::rand();

View File

@@ -11,6 +11,8 @@
#include <string>
#include <vector>
#include <catch2/generators/catch_generators.hpp>
#include "WebSocketTest.hpp"
#include "wpi/net/HttpParser.hpp"
#include "wpi/net/raw_uv_ostream.hpp"
@@ -80,7 +82,8 @@ class WebSocketClientTest : public WebSocketTest {
if (req.HasError()) {
Finish();
}
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
INFO(http_errno_name(req.GetError()));
REQUIRE(req.GetError() == HPE_OK);
if (data.empty()) {
return;
}
@@ -101,7 +104,8 @@ class WebSocketClientTest : public WebSocketTest {
std::function<void()> connected;
};
TEST_F(WebSocketClientTest, Open) {
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest Open",
"[websocket][client][handshake]") {
int gotOpen = 0;
clientPipe->Connect(pipeName, [&] {
@@ -109,25 +113,22 @@ TEST_F(WebSocketClientTest, Open) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
ASSERT_TRUE(protocol.empty());
REQUIRE(protocol.empty());
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotOpen, 1);
REQUIRE(gotOpen == 1);
}
TEST_F(WebSocketClientTest, BadAccept) {
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest BadAccept",
"[websocket][client][handshake][protocol]") {
int gotClosed = 0;
mockBadAccept = true;
@@ -137,23 +138,21 @@ TEST_F(WebSocketClientTest, BadAccept) {
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1002) << "Message: " << msg;
INFO("Message: " << msg);
REQUIRE(code == 1002);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
FAIL("Got open");
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketClientTest, ProtocolGood) {
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolGood",
"[websocket][client][protocol]") {
int gotOpen = 0;
mockProtocol = "myProtocol";
@@ -164,25 +163,22 @@ TEST_F(WebSocketClientTest, ProtocolGood) {
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << "Message: " << msg;
FAIL("Code: " << code << "Message: " << msg);
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
ASSERT_EQ(protocol, "myProtocol");
REQUIRE(protocol == "myProtocol");
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotOpen, 1);
REQUIRE(gotOpen == 1);
}
TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolRespNotReq",
"[websocket][client][protocol]") {
int gotClosed = 0;
mockProtocol = "myProtocol";
@@ -192,23 +188,21 @@ TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1003) << "Message: " << msg;
INFO("Message: " << msg);
REQUIRE(code == 1003);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
FAIL("Got open");
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
TEST_CASE_METHOD(WebSocketClientTest, "WebSocketClientTest ProtocolReqNotResp",
"[websocket][client][protocol]") {
int gotClosed = 0;
clientPipe->Connect(pipeName, [&] {
@@ -217,20 +211,17 @@ TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1002) << "Message: " << msg;
INFO("Message: " << msg);
REQUIRE(code == 1002);
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
FAIL("Got open");
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
REQUIRE(gotClosed == 1);
}
//
@@ -238,8 +229,7 @@ TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
// WebSocketServerTest, so only spot check differences like masking.
//
class WebSocketClientDataTest : public WebSocketClientTest,
public ::testing::WithParamInterface<size_t> {
class WebSocketClientDataTest : public WebSocketClientTest {
public:
WebSocketClientDataTest() {
clientPipe->Connect(pipeName, [&] {
@@ -254,19 +244,19 @@ class WebSocketClientDataTest : public WebSocketClientTest,
std::shared_ptr<WebSocket> ws;
};
INSTANTIATE_TEST_SUITE_P(WebSocketClientDataTests, WebSocketClientDataTest,
::testing::Values(0, 1, 125, 126, 65535, 65536));
TEST_P(WebSocketClientDataTest, SendBinary) {
TEST_CASE_METHOD(WebSocketClientDataTest, "WebSocketClientDataTest SendBinary",
"[websocket][client][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendBinary({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
@@ -275,20 +265,24 @@ TEST_P(WebSocketClientDataTest, SendBinary) {
auto expectData = BuildMessage(0x02, true, true, data);
AdjustMasking(wireData);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketClientDataTest, ReceiveBinary) {
TEST_CASE_METHOD(WebSocketClientDataTest,
"WebSocketClientDataTest ReceiveBinary",
"[websocket][client][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->binary.connect([&](auto inData, bool fin) {
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
REQUIRE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(data, recvData);
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x02, true, false, data);
@@ -296,24 +290,29 @@ TEST_P(WebSocketClientDataTest, ReceiveBinary) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
//
// The client must close the connection if a masked frame is received.
//
TEST_P(WebSocketClientDataTest, ReceiveMasked) {
TEST_CASE_METHOD(WebSocketClientDataTest,
"WebSocketClientDataTest ReceiveMasked",
"[websocket][client][data][protocol]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), ' ');
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
' ');
setupWebSocket = [&] {
ws->text.connect([&](std::string_view, bool) {
ws->Terminate();
FAIL() << "Should not have gotten masked message";
FAIL("Should not have gotten masked message");
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x01, true, true, data);
@@ -321,7 +320,7 @@ TEST_P(WebSocketClientDataTest, ReceiveMasked) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
} // namespace wpi::net

View File

@@ -15,7 +15,8 @@ namespace wpi::net {
class WebSocketIntegrationTest : public WebSocketTest {};
TEST_F(WebSocketIntegrationTest, Open) {
TEST_CASE_METHOD(WebSocketIntegrationTest, "WebSocketIntegrationTest Open",
"[websocket][integration][handshake]") {
int gotServerOpen = 0;
int gotClientOpen = 0;
@@ -24,7 +25,7 @@ TEST_F(WebSocketIntegrationTest, Open) {
auto server = WebSocketServer::Create(*conn);
server->connected.connect([&](std::string_view url, WebSocket&) {
++gotServerOpen;
ASSERT_EQ(url, "/test");
REQUIRE(url == "/test");
});
});
@@ -33,7 +34,7 @@ TEST_F(WebSocketIntegrationTest, Open) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->open.connect([&, s = ws.get()](std::string_view) {
@@ -44,11 +45,12 @@ TEST_F(WebSocketIntegrationTest, Open) {
loop->Run();
ASSERT_EQ(gotServerOpen, 1);
ASSERT_EQ(gotClientOpen, 1);
REQUIRE(gotServerOpen == 1);
REQUIRE(gotClientOpen == 1);
}
TEST_F(WebSocketIntegrationTest, Protocol) {
TEST_CASE_METHOD(WebSocketIntegrationTest, "WebSocketIntegrationTest Protocol",
"[websocket][integration][protocol]") {
int gotServerOpen = 0;
int gotClientOpen = 0;
@@ -57,7 +59,7 @@ TEST_F(WebSocketIntegrationTest, Protocol) {
auto server = WebSocketServer::Create(*conn, {"proto1", "proto2"});
server->connected.connect([&](std::string_view, WebSocket& ws) {
++gotServerOpen;
ASSERT_EQ(ws.GetProtocol(), "proto1");
REQUIRE(ws.GetProtocol() == "proto1");
});
});
@@ -67,23 +69,25 @@ TEST_F(WebSocketIntegrationTest, Protocol) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->open.connect([&, s = ws.get()](std::string_view protocol) {
++gotClientOpen;
s->Close();
ASSERT_EQ(protocol, "proto1");
REQUIRE(protocol == "proto1");
});
});
loop->Run();
ASSERT_EQ(gotServerOpen, 1);
ASSERT_EQ(gotClientOpen, 1);
REQUIRE(gotServerOpen == 1);
REQUIRE(gotClientOpen == 1);
}
TEST_F(WebSocketIntegrationTest, ServerSendBinary) {
TEST_CASE_METHOD(WebSocketIntegrationTest,
"WebSocketIntegrationTest ServerSendBinary",
"[websocket][integration][data]") {
int gotData = 0;
serverPipe->Listen([&]() {
@@ -100,23 +104,25 @@ TEST_F(WebSocketIntegrationTest, ServerSendBinary) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->binary.connect([&](auto data, bool) {
++gotData;
std::vector<uint8_t> recvData{data.begin(), data.end()};
std::vector<uint8_t> expectData{0x03, 0x04};
ASSERT_EQ(recvData, expectData);
REQUIRE(recvData == expectData);
});
});
loop->Run();
ASSERT_EQ(gotData, 1);
REQUIRE(gotData == 1);
}
TEST_F(WebSocketIntegrationTest, ClientSendText) {
TEST_CASE_METHOD(WebSocketIntegrationTest,
"WebSocketIntegrationTest ClientSendText",
"[websocket][integration][data]") {
int gotData = 0;
serverPipe->Listen([&]() {
@@ -125,7 +131,7 @@ TEST_F(WebSocketIntegrationTest, ClientSendText) {
server->connected.connect([&](std::string_view, WebSocket& ws) {
ws.text.connect([&](std::string_view data, bool) {
++gotData;
ASSERT_EQ(data, "hello");
REQUIRE(data == "hello");
});
});
});
@@ -135,7 +141,7 @@ TEST_F(WebSocketIntegrationTest, ClientSendText) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->open.connect([&, s = ws.get()](std::string_view) {
@@ -146,10 +152,12 @@ TEST_F(WebSocketIntegrationTest, ClientSendText) {
loop->Run();
ASSERT_EQ(gotData, 1);
REQUIRE(gotData == 1);
}
TEST_F(WebSocketIntegrationTest, ServerSendPing) {
TEST_CASE_METHOD(WebSocketIntegrationTest,
"WebSocketIntegrationTest ServerSendPing",
"[websocket][integration][control]") {
int gotPing = 0;
int gotPong = 0;
int gotData = 0;
@@ -166,7 +174,7 @@ TEST_F(WebSocketIntegrationTest, ServerSendPing) {
++gotPong;
std::vector<uint8_t> recvData{data.begin(), data.end()};
std::vector<uint8_t> expectData{0x03, 0x04};
ASSERT_EQ(recvData, expectData);
REQUIRE(recvData == expectData);
if (gotPong == 2) {
ws.Close();
}
@@ -179,26 +187,26 @@ TEST_F(WebSocketIntegrationTest, ServerSendPing) {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
FAIL("Code: " << code << " Reason: " << reason);
}
});
ws->ping.connect([&](auto data) {
++gotPing;
std::vector<uint8_t> recvData{data.begin(), data.end()};
std::vector<uint8_t> expectData{0x03, 0x04};
ASSERT_EQ(recvData, expectData);
REQUIRE(recvData == expectData);
});
ws->text.connect([&](std::string_view data, bool) {
++gotData;
ASSERT_EQ(data, "hello");
REQUIRE(data == "hello");
});
});
loop->Run();
ASSERT_EQ(gotPing, 2);
ASSERT_EQ(gotPong, 2);
ASSERT_EQ(gotData, 2);
REQUIRE(gotPing == 2);
REQUIRE(gotPong == 2);
REQUIRE(gotData == 2);
}
} // namespace wpi::net

View File

@@ -8,48 +8,53 @@
#include <array>
#include <functional>
#include <memory>
#include <ostream>
#include <span>
#include <utility>
#include <vector>
#include <gmock/gmock.h>
#include <catch2/catch_test_macros.hpp>
#include "WebSocketTest.hpp"
#include "wpi/net/uv/Buffer.hpp"
#include "wpi/util/SpanMatcher.hpp"
using ::testing::_;
using ::testing::AnyOf;
using ::testing::ElementsAre;
using ::testing::Field;
using ::testing::Pointee;
using ::testing::Return;
namespace wpi::net::uv {
inline bool operator==(const Buffer& lhs, const Buffer& rhs) {
return lhs.len == rhs.len &&
std::equal(lhs.base, lhs.base + lhs.len, rhs.base);
}
inline void PrintTo(const Buffer& buf, ::std::ostream* os) {
::wpi::util::PrintTo(buf.bytes(), os);
}
} // namespace wpi::net::uv
namespace wpi::net {
inline bool operator==(const WebSocket::Frame& lhs,
const WebSocket::Frame& rhs) {
return lhs.opcode == rhs.opcode &&
return lhs.opcode == rhs.opcode && lhs.data.size() == rhs.data.size() &&
std::equal(lhs.data.begin(), lhs.data.end(), rhs.data.begin());
}
inline void PrintTo(const WebSocket::Frame& frame, ::std::ostream* os) {
*os << frame.opcode << ": ";
::wpi::util::PrintTo(frame.data, os);
}
} // namespace wpi::net
namespace wpi::net::detail {
template <typename T>
bool SpanEquals(std::span<const T> lhs, std::span<const T> rhs) {
return lhs.size() == rhs.size() &&
std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
}
template <typename T>
bool SpanEquals(std::span<const T> lhs, const std::vector<T>& rhs) {
return SpanEquals(lhs, std::span<const T>{rhs});
}
template <typename T, size_t N>
bool SpanEquals(std::span<const T> lhs, const std::array<T, N>& rhs) {
return SpanEquals(lhs, std::span<const T>{rhs});
}
template <typename T, size_t N>
bool SpanEquals(std::span<const T> lhs, const T (&rhs)[N]) {
return SpanEquals(lhs, std::span<const T>{rhs});
}
class MockWebSocketWriteReq
: public std::enable_shared_from_this<MockWebSocketWriteReq>,
public detail::WebSocketWriteReqBase {
@@ -60,22 +65,58 @@ class MockWebSocketWriteReq
class MockStream {
public:
MOCK_METHOD(int, TryWrite, (std::span<const uv::Buffer>));
~MockStream() {
CHECK(m_tryWriteCalls == m_expectedTryWriteCalls);
CHECK(m_writeCalls == m_expectedWrites.size());
}
void ExpectTryWrite(int result) {
m_tryWriteResult = result;
m_expectedTryWriteCalls = 1;
m_checkTryWriteBufs = false;
}
template <typename Range>
void ExpectTryWrite(const Range& bufs, int result) {
ExpectTryWrite(result);
m_expectedTryWriteBufs.assign(bufs.begin(), bufs.end());
m_checkTryWriteBufs = true;
}
template <typename Range>
void ExpectWrite(const Range& bufs) {
m_expectedWrites.emplace_back(bufs.begin(), bufs.end());
}
int TryWrite(std::span<const uv::Buffer> bufs) {
REQUIRE(m_tryWriteCalls < m_expectedTryWriteCalls);
if (m_checkTryWriteBufs) {
REQUIRE(SpanEquals(bufs, m_expectedTryWriteBufs));
}
++m_tryWriteCalls;
return m_tryWriteResult;
}
void Write(std::span<const uv::Buffer> bufs,
const std::shared_ptr<MockWebSocketWriteReq>& req) {
// std::cout << "Write(";
// PrintTo(bufs, &std::cout);
// std::cout << ")\n";
DoWrite(bufs, req);
REQUIRE(m_writeCalls < m_expectedWrites.size());
REQUIRE(SpanEquals(bufs, m_expectedWrites[m_writeCalls]));
++m_writeCalls;
}
MOCK_METHOD(void, DoWrite,
(std::span<const uv::Buffer> bufs,
const std::shared_ptr<MockWebSocketWriteReq>& req));
wpi::util::Logger* GetLogger() const { return nullptr; }
private:
int m_tryWriteResult = 0;
size_t m_expectedTryWriteCalls = 0;
size_t m_tryWriteCalls = 0;
bool m_checkTryWriteBufs = false;
std::vector<uv::Buffer> m_expectedTryWriteBufs;
std::vector<std::vector<uv::Buffer>> m_expectedWrites;
size_t m_writeCalls = 0;
};
class WebSocketWriteReqTest : public ::testing::Test {
class WebSocketWriteReqTest {
public:
WebSocketWriteReqTest() {
req->m_frames.m_bufs.emplace_back(m_buf0);
@@ -87,7 +128,7 @@ class WebSocketWriteReqTest : public ::testing::Test {
std::shared_ptr<MockWebSocketWriteReq> req =
std::make_shared<MockWebSocketWriteReq>([](auto, auto) {});
::testing::StrictMock<MockStream> stream;
MockStream stream;
static const uint8_t m_buf0[3];
static const uint8_t m_buf1[2];
static const uint8_t m_buf2[4];
@@ -97,92 +138,97 @@ const uint8_t WebSocketWriteReqTest::m_buf0[3] = {1, 2, 3};
const uint8_t WebSocketWriteReqTest::m_buf1[2] = {4, 5};
const uint8_t WebSocketWriteReqTest::m_buf2[4] = {6, 7, 8, 9};
TEST_F(WebSocketWriteReqTest, ContinueDone) {
TEST_CASE_METHOD(WebSocketWriteReqTest, "WebSocketWriteReqTest ContinueDone",
"[websocket][serializer][write-request]") {
req->m_continueBufPos = 3;
ASSERT_EQ(req->Continue(stream, req), 0);
REQUIRE(req->Continue(stream, req) == 0);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWriteComplete) {
EXPECT_CALL(stream, TryWrite(wpi::util::SpanEq(req->m_frames.m_bufs)))
.WillOnce(Return(9));
ASSERT_EQ(req->Continue(stream, req), 0);
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWriteComplete",
"[websocket][serializer][write-request]") {
stream.ExpectTryWrite(req->m_frames.m_bufs, 9);
REQUIRE(req->Continue(stream, req) == 0);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWriteNoProgress) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWriteNoProgress",
"[websocket][serializer][write-request]") {
// if TryWrite returns 0
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(0));
stream.ExpectTryWrite(0);
// Write should get called for all of next frame - make forward progress
uv::Buffer remBufs[2] = {uv::Buffer{m_buf0}, uv::Buffer{m_buf1}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWriteError) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWriteError",
"[websocket][serializer][write-request]") {
// if TryWrite returns -1, the error is passed along
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(-1));
ASSERT_EQ(req->Continue(stream, req), -1);
stream.ExpectTryWrite(-1);
REQUIRE(req->Continue(stream, req) == -1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf1) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWritePartialMidFrameMidBuf1",
"[websocket][serializer][write-request]") {
// stop partway through buf 0
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(2));
stream.ExpectTryWrite(2);
// Write should get called for remainder of buf 0 and all of buf 1
uv::Buffer remBufs[2] = {uv::Buffer{&m_buf0[2], 1}, uv::Buffer{m_buf1}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameBufBoundary) {
TEST_CASE_METHOD(
WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWritePartialMidFrameBufBoundary",
"[websocket][serializer][write-request]") {
// stop at end of buf 0
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(3));
stream.ExpectTryWrite(3);
// Write should get called for all of buf 1
uv::Buffer remBufs[1] = {uv::Buffer{m_buf1}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf2) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWritePartialMidFrameMidBuf2",
"[websocket][serializer][write-request]") {
// stop partway through buf 1
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(4));
stream.ExpectTryWrite(4);
// Write should get called for remainder of buf 1
uv::Buffer remBufs[1] = {uv::Buffer{&m_buf1[1], 1}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialFrameBoundary) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWritePartialFrameBoundary",
"[websocket][serializer][write-request]") {
// stop at end of buf 1
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(5));
stream.ExpectTryWrite(5);
// Write should get called for all of next frame
uv::Buffer remBufs[1] = {uv::Buffer{m_buf2}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
TEST_F(WebSocketWriteReqTest, ContinueTryWritePartialMidFrameMidBuf3) {
TEST_CASE_METHOD(WebSocketWriteReqTest,
"WebSocketWriteReqTest ContinueTryWritePartialMidFrameMidBuf3",
"[websocket][serializer][write-request]") {
// stop partway through buf 2
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(6));
stream.ExpectTryWrite(6);
// Write should get called for remainder of buf 2
uv::Buffer remBufs[1] = {uv::Buffer{&m_buf2[1], 3}};
EXPECT_CALL(
stream,
DoWrite(wpi::util::SpanEq(std::span<const uv::Buffer>(remBufs)), _));
ASSERT_EQ(req->Continue(stream, req), 1);
stream.ExpectWrite(std::span<const uv::Buffer>(remBufs));
REQUIRE(req->Continue(stream, req) == 1);
}
class WebSocketTrySendTest : public ::testing::Test {
class WebSocketTrySendTest {
public:
::testing::StrictMock<MockStream> stream;
MockStream stream;
std::shared_ptr<MockWebSocketWriteReq> req;
static const std::array<uint8_t, 3> m_buf0data;
static const std::array<uint8_t, 2> m_buf1data;
@@ -238,55 +284,57 @@ const std::array<uv::Buffer, 3> WebSocketTrySendTest::m_frameHeaders{
void WebSocketTrySendTest::CheckTrySendFrames(
std::span<const uv::Buffer> expectCbBufs,
std::span<const WebSocket::Frame> expectRet, int expectErr) {
ASSERT_THAT(
TrySendFrames(
true, stream, m_frames,
[&](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
++makeReqCalled;
req = std::make_shared<MockWebSocketWriteReq>(std::move(cb));
return req;
},
[&](auto bufs, auto err) {
++callbackCalled;
ASSERT_THAT(bufs, wpi::util::SpanEq(
std::span<const uv::Buffer>(expectCbBufs)));
ASSERT_EQ(err.code(), expectErr);
}),
wpi::util::SpanEq(expectRet));
auto remaining = TrySendFrames(
true, stream, m_frames,
[&](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
++makeReqCalled;
req = std::make_shared<MockWebSocketWriteReq>(std::move(cb));
return req;
},
[&](auto bufs, auto err) {
++callbackCalled;
REQUIRE(SpanEquals(std::span<const uv::Buffer>{bufs}, expectCbBufs));
REQUIRE(err.code() == expectErr);
});
REQUIRE(SpanEquals(std::span<const WebSocket::Frame>{remaining}, expectRet));
}
TEST_F(WebSocketTrySendTest, ServerComplete) {
TEST_CASE_METHOD(WebSocketTrySendTest, "WebSocketTrySendTest ServerComplete",
"[websocket][serializer][send]") {
// if trywrite sends everything
EXPECT_CALL(stream, TryWrite(_))
.WillOnce(Return(m_serialized[0].size() + m_serialized[1].size() +
m_serialized[2].size()));
stream.ExpectTryWrite(m_serialized[0].size() + m_serialized[1].size() +
m_serialized[2].size());
// return nothing, and call callback immediately
CheckTrySendFrames(m_bufs, {});
ASSERT_EQ(makeReqCalled, 0);
ASSERT_EQ(callbackCalled, 1);
REQUIRE(makeReqCalled == 0);
REQUIRE(callbackCalled == 1);
}
TEST_F(WebSocketTrySendTest, ServerNoProgress) {
TEST_CASE_METHOD(WebSocketTrySendTest, "WebSocketTrySendTest ServerNoProgress",
"[websocket][serializer][send]") {
// if trywrite sends nothing
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(0));
stream.ExpectTryWrite(0);
// we should get all the frames back (the callback may be called with an empty
// set of buffers)
CheckTrySendFrames({}, m_frames);
ASSERT_EQ(makeReqCalled, 0);
ASSERT_THAT(callbackCalled, AnyOf(0, 1));
REQUIRE(makeReqCalled == 0);
REQUIRE((callbackCalled == 0 || callbackCalled == 1));
}
TEST_F(WebSocketTrySendTest, ServerError) {
TEST_CASE_METHOD(WebSocketTrySendTest, "WebSocketTrySendTest ServerError",
"[websocket][serializer][send]") {
// if TryWrite returns -1, the error is passed along
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(-1));
stream.ExpectTryWrite(-1);
CheckTrySendFrames(m_bufs, m_frames, -1);
ASSERT_EQ(makeReqCalled, 0);
ASSERT_EQ(callbackCalled, 1);
REQUIRE(makeReqCalled == 0);
REQUIRE(callbackCalled == 1);
}
TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf0) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialMidFrameMidBuf0",
"[websocket][serializer][send]") {
// stop partway through buf 0 (first buf of frame 0)
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 2));
stream.ExpectTryWrite(m_frameHeaders[0].len + 2);
// Write should get called for remainder of buf 0 and all of buf 1
// buf 2 should get put into continuation because frame 0 is a fragment
// return will be frame 2 only
@@ -294,101 +342,114 @@ TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf0) {
m_bufs[1]};
std::array<uv::Buffer, 2> contBufs{m_frameHeaders[1], m_bufs[2]};
std::array<int, 1> contFrameOffs{static_cast<int>(m_serialized[1].size())};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 1);
ASSERT_THAT(req->m_frames.m_bufs, wpi::util::SpanEq(contBufs));
ASSERT_EQ(req->m_continueBufPos, 0u);
ASSERT_EQ(req->m_continueFramePos, 0u);
ASSERT_THAT(req->m_continueFrameOffs, wpi::util::SpanEq(contFrameOffs));
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(
SpanEquals(std::span<const uv::Buffer>{req->m_frames.m_bufs}, contBufs));
REQUIRE(req->m_continueBufPos == 0u);
REQUIRE(req->m_continueFramePos == 0u);
REQUIRE(SpanEquals(std::span<const int>{req->m_continueFrameOffs},
contFrameOffs));
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, ServerPartialMidFrameBufBoundary) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialMidFrameBufBoundary",
"[websocket][serializer][send]") {
// stop at end of buf 0 (first buf of frame 0)
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 3));
stream.ExpectTryWrite(m_frameHeaders[0].len + 3);
// Write should get called for all of buf 1
// buf 2 should get put into continuation because frame 0 is a fragment
// return will be frame 2 only
std::array<uv::Buffer, 1> remBufs{m_bufs[1]};
std::array<uv::Buffer, 2> contBufs{m_frameHeaders[1], m_bufs[2]};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 1);
ASSERT_THAT(req->m_frames.m_bufs, wpi::util::SpanEq(contBufs));
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(
SpanEquals(std::span<const uv::Buffer>{req->m_frames.m_bufs}, contBufs));
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf1) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialMidFrameMidBuf1",
"[websocket][serializer][send]") {
// stop partway through buf 1 (second buf of frame 0)
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(m_frameHeaders[0].len + 4));
stream.ExpectTryWrite(m_frameHeaders[0].len + 4);
// Write should get called for remainder of buf 1
// buf 2 should get put into continuation because frame 0 is a fragment
// return will be frame 2 only
std::array<uv::Buffer, 1> remBufs{std::span{m_buf1data}.subspan(1)};
std::array<uv::Buffer, 2> contBufs{m_frameHeaders[1], m_bufs[2]};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 1);
ASSERT_THAT(req->m_frames.m_bufs, wpi::util::SpanEq(contBufs));
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(
SpanEquals(std::span<const uv::Buffer>{req->m_frames.m_bufs}, contBufs));
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, ServerPartialFrameBoundary) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialFrameBoundary",
"[websocket][serializer][send]") {
// stop at end of buf 1 (end of frame 0)
EXPECT_CALL(stream, TryWrite(_))
.WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 5));
stream.ExpectTryWrite(m_frameHeaders[0].len + m_frameHeaders[1].len + 5);
// Write should get called for all of buf 2 because frame 0 is a fragment
// no continuation
// return will be frame 2 only
std::array<uv::Buffer, 1> remBufs{m_bufs[2]};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 1);
ASSERT_TRUE(req->m_frames.m_bufs.empty());
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(req->m_frames.m_bufs.empty());
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, ServerPartialMidFrameMidBuf2) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialMidFrameMidBuf2",
"[websocket][serializer][send]") {
// stop partway through buf 2 (frame 1 buf)
EXPECT_CALL(stream, TryWrite(_))
.WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 6));
stream.ExpectTryWrite(m_frameHeaders[0].len + m_frameHeaders[1].len + 6);
// Write should get called for remainder of buf 2; no continuation
// return will be frame 2 only
std::array<uv::Buffer, 1> remBufs{std::span{m_buf2data}.subspan(1)};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 1);
ASSERT_TRUE(req->m_frames.m_bufs.empty());
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(req->m_frames.m_bufs.empty());
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, ServerFrameBoundary) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerFrameBoundary",
"[websocket][serializer][send]") {
// stop at end of buf 2 (end of frame 1)
EXPECT_CALL(stream, TryWrite(_))
.WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len + 9));
stream.ExpectTryWrite(m_frameHeaders[0].len + m_frameHeaders[1].len + 9);
// call callback immediately for bufs 0-2 and return frame 2
CheckTrySendFrames(std::span{m_bufs}.subspan(0, 3),
std::span{m_frames}.subspan(2));
ASSERT_EQ(makeReqCalled, 0);
ASSERT_EQ(callbackCalled, 1);
REQUIRE(makeReqCalled == 0);
REQUIRE(callbackCalled == 1);
}
TEST_F(WebSocketTrySendTest, ServerPartialLastFrame) {
TEST_CASE_METHOD(WebSocketTrySendTest,
"WebSocketTrySendTest ServerPartialLastFrame",
"[websocket][serializer][send]") {
// stop partway through buf 3
EXPECT_CALL(stream, TryWrite(_))
.WillOnce(Return(m_frameHeaders[0].len + m_frameHeaders[1].len +
m_frameHeaders[2].len + 10));
stream.ExpectTryWrite(m_frameHeaders[0].len + m_frameHeaders[1].len +
m_frameHeaders[2].len + 10);
// Write called for remainder of buf 3; no continuation
std::array<uv::Buffer, 1> remBufs{std::span{m_buf3data}.subspan(1)};
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
CheckTrySendFrames({}, {});
ASSERT_EQ(makeReqCalled, 1);
ASSERT_TRUE(req->m_frames.m_bufs.empty());
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(req->m_frames.m_bufs.empty());
REQUIRE(callbackCalled == 0);
}
TEST_F(WebSocketTrySendTest, Big) {
TEST_CASE_METHOD(WebSocketTrySendTest, "WebSocketTrySendTest Big",
"[websocket][serializer][send]") {
std::vector<uv::Buffer> bufs;
for (int i = 0; i < 100000;) {
i += 1430;
@@ -396,7 +457,7 @@ TEST_F(WebSocketTrySendTest, Big) {
uv::Buffer::Allocate(i < 100000 ? 1430 : (100000 - (i - 1430))));
}
WebSocket::Frame frame{WebSocket::OP_BINARY | WebSocket::FLAG_FIN, bufs};
EXPECT_CALL(stream, TryWrite(_)).WillOnce(Return(7681));
stream.ExpectTryWrite(7681);
// Write called for remainder of buffers
std::vector<uv::Buffer> remBufs;
@@ -406,24 +467,23 @@ TEST_F(WebSocketTrySendTest, Big) {
for (size_t i = 6; i < bufs.size(); ++i) {
remBufs.emplace_back(bufs[i]);
}
EXPECT_CALL(stream, DoWrite(wpi::util::SpanEq(remBufs), _));
stream.ExpectWrite(remBufs);
ASSERT_TRUE(
TrySendFrames(
true, stream, {{frame}},
[&](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
++makeReqCalled;
req = std::make_shared<MockWebSocketWriteReq>(std::move(cb));
return req;
},
[&](auto bufs, auto err) { ++callbackCalled; })
.empty());
REQUIRE(TrySendFrames(
true, stream, {{frame}},
[&](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
++makeReqCalled;
req = std::make_shared<MockWebSocketWriteReq>(std::move(cb));
return req;
},
[&](auto bufs, auto err) { ++callbackCalled; })
.empty());
for (auto& buf : bufs) {
buf.Deallocate();
}
ASSERT_EQ(makeReqCalled, 1);
ASSERT_TRUE(req->m_frames.m_bufs.empty());
ASSERT_EQ(callbackCalled, 0);
REQUIRE(makeReqCalled == 1);
REQUIRE(req->m_frames.m_bufs.empty());
REQUIRE(callbackCalled == 0);
}
} // namespace wpi::net::detail

View File

@@ -10,6 +10,8 @@
#include <memory>
#include <vector>
#include <catch2/generators/catch_generators.hpp>
#include "WebSocketTest.hpp"
#include "wpi/net/HttpParser.hpp"
#include "wpi/util/Base64.hpp"
@@ -39,8 +41,8 @@ class WebSocketServerTest : public WebSocketTest {
if (resp.HasError()) {
Finish();
}
ASSERT_EQ(resp.GetError(), HPE_OK)
<< http_errno_name(resp.GetError());
INFO(http_errno_name(resp.GetError()));
REQUIRE(resp.GetError() == HPE_OK);
if (data.empty()) {
return;
}
@@ -66,66 +68,73 @@ class WebSocketServerTest : public WebSocketTest {
// Terminate closes the endpoint but doesn't send a close frame.
//
TEST_F(WebSocketServerTest, Terminate) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest Terminate",
"[websocket][server][terminate]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Terminate(); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1006) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1006);
});
};
loop->Run();
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData.empty()); // terminate doesn't send data
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, TerminateCode) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest TerminateCode",
"[websocket][server][terminate]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Terminate(1000); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1000);
});
};
loop->Run();
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData.empty()); // terminate doesn't send data
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, TerminateReason) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest TerminateReason",
"[websocket][server][terminate]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Terminate(1000, "reason"); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000);
ASSERT_EQ(reason, "reason");
REQUIRE(code == 1000);
REQUIRE(reason == "reason");
});
};
loop->Run();
ASSERT_TRUE(wireData.empty()); // terminate doesn't send data
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData.empty()); // terminate doesn't send data
REQUIRE(gotClosed == 1);
}
//
// Close() sends a close frame.
//
TEST_F(WebSocketServerTest, CloseBasic) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest CloseBasic",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Close(); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1005) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1005);
});
};
// need to respond with close for server to finish shutdown
@@ -137,17 +146,19 @@ TEST_F(WebSocketServerTest, CloseBasic) {
loop->Run();
auto expectData = BuildMessage(0x08, true, false, {});
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, CloseCode) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest CloseCode",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Close(1000); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1000);
});
};
// need to respond with close for server to finish shutdown
@@ -160,18 +171,19 @@ TEST_F(WebSocketServerTest, CloseCode) {
loop->Run();
auto expectData = BuildMessage(0x08, true, false, contents);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, CloseReason) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest CloseReason",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) { ws->Close(1000, "hangup"); });
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000);
ASSERT_EQ(reason, "remote close: hangup");
REQUIRE(code == 1000);
REQUIRE(reason == "remote close: hangup");
});
};
// need to respond with close for server to finish shutdown
@@ -184,20 +196,22 @@ TEST_F(WebSocketServerTest, CloseReason) {
loop->Run();
auto expectData = BuildMessage(0x08, true, false, contents);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
//
// Receiving a close frame results in closure and echoing the close frame.
//
TEST_F(WebSocketServerTest, ReceiveCloseBasic) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest ReceiveCloseBasic",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1005) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1005);
});
};
auto message = BuildMessage(0x08, true, true, {});
@@ -209,16 +223,18 @@ TEST_F(WebSocketServerTest, ReceiveCloseBasic) {
// the endpoint should echo the message
auto expectData = BuildMessage(0x08, true, false, {});
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, ReceiveCloseCode) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest ReceiveCloseCode",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1000);
});
};
const uint8_t contents[] = {0x03u, 0xe8u};
@@ -231,17 +247,18 @@ TEST_F(WebSocketServerTest, ReceiveCloseCode) {
// the endpoint should echo the message
auto expectData = BuildMessage(0x08, true, false, contents);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketServerTest, ReceiveCloseReason) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest ReceiveCloseReason",
"[websocket][server][close]") {
int gotClosed = 0;
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotClosed;
ASSERT_EQ(code, 1000);
ASSERT_EQ(reason, "remote close: hangup");
REQUIRE(code == 1000);
REQUIRE(reason == "remote close: hangup");
});
};
const uint8_t contents[] = {0x03u, 0xe8u, 'h', 'a', 'n', 'g', 'u', 'p'};
@@ -254,8 +271,8 @@ TEST_F(WebSocketServerTest, ReceiveCloseReason) {
// the endpoint should echo the message
auto expectData = BuildMessage(0x08, true, false, contents);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotClosed, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotClosed == 1);
}
//
@@ -263,63 +280,59 @@ TEST_F(WebSocketServerTest, ReceiveCloseReason) {
// WebSocket Connection_.
//
class WebSocketServerBadOpcodeTest
: public WebSocketServerTest,
public ::testing::WithParamInterface<uint8_t> {};
class WebSocketServerBadOpcodeTest : public WebSocketServerTest {};
INSTANTIATE_TEST_SUITE_P(WebSocketServerBadOpcodeTests,
WebSocketServerBadOpcodeTest,
::testing::Values(3, 4, 5, 6, 7, 0xb, 0xc, 0xd, 0xe,
0xf));
TEST_P(WebSocketServerBadOpcodeTest, Receive) {
TEST_CASE_METHOD(WebSocketServerBadOpcodeTest,
"WebSocketServerBadOpcodeTest Receive",
"[websocket][server][protocol]") {
int gotCallback = 0;
auto opcode =
static_cast<uint8_t>(GENERATE(3, 4, 5, 6, 7, 0xb, 0xc, 0xd, 0xe, 0xf));
std::vector<uint8_t> data(4, 0x03);
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(GetParam(), true, true, data);
auto message = BuildMessage(opcode, true, true, data);
resp.headersComplete.connect([&](bool) {
clientPipe->Write({{message}}, [&](auto bufs, uv::Error) {});
});
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
//
// Control frames themselves MUST NOT be fragmented.
//
class WebSocketServerControlFrameTest
: public WebSocketServerTest,
public ::testing::WithParamInterface<uint8_t> {};
class WebSocketServerControlFrameTest : public WebSocketServerTest {};
INSTANTIATE_TEST_SUITE_P(WebSocketServerControlFrameTests,
WebSocketServerControlFrameTest,
::testing::Values(0x8, 0x9, 0xa));
TEST_P(WebSocketServerControlFrameTest, ReceiveFragment) {
TEST_CASE_METHOD(WebSocketServerControlFrameTest,
"WebSocketServerControlFrameTest ReceiveFragment",
"[websocket][server][control][fragment]") {
int gotCallback = 0;
auto opcode = static_cast<uint8_t>(GENERATE(0x8, 0x9, 0xa));
std::vector<uint8_t> data(4, 0x03);
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(GetParam(), false, true, data);
auto message = BuildMessage(opcode, false, true, data);
resp.headersComplete.connect([&](bool) {
clientPipe->Write({{message}}, [&](auto bufs, uv::Error) {});
});
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
//
@@ -330,13 +343,16 @@ TEST_P(WebSocketServerControlFrameTest, ReceiveFragment) {
//
// No previous message
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFrame) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveFragmentInvalidNoPrevFrame",
"[websocket][server][fragment][protocol]") {
int gotCallback = 0;
std::vector<uint8_t> data(4, 0x03);
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x00, false, true, data);
@@ -346,17 +362,20 @@ TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFrame) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
// No previous message with FIN=1.
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFragment) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveFragmentInvalidNoPrevFragment",
"[websocket][server][fragment][protocol]") {
int gotCallback = 0;
std::vector<uint8_t> data(4, 0x03);
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x01, true, true, {}); // FIN=1
@@ -367,16 +386,19 @@ TEST_F(WebSocketServerTest, ReceiveFragmentInvalidNoPrevFragment) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
// Incomplete fragment
TEST_F(WebSocketServerTest, ReceiveFragmentInvalidIncomplete) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveFragmentInvalidIncomplete",
"[websocket][server][fragment][protocol]") {
int gotCallback = 0;
setupWebSocket = [&] {
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x01, false, true, {});
@@ -389,11 +411,12 @@ TEST_F(WebSocketServerTest, ReceiveFragmentInvalidIncomplete) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
// Normally fragments are combined into a single callback
TEST_F(WebSocketServerTest, ReceiveFragment) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest ReceiveFragment",
"[websocket][server][fragment]") {
int gotCallback = 0;
std::vector<uint8_t> data(4, 0x03);
@@ -414,9 +437,9 @@ TEST_F(WebSocketServerTest, ReceiveFragment) {
ws->binary.connect([&](auto inData, bool fin) {
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
REQUIRE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(combData, recvData);
REQUIRE(combData == recvData);
});
};
@@ -430,11 +453,13 @@ TEST_F(WebSocketServerTest, ReceiveFragment) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
// But can be configured for multiple callbacks
TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveFragmentSeparate",
"[websocket][server][fragment]") {
int gotCallback = 0;
std::vector<uint8_t> data(4, 0x03);
@@ -450,20 +475,20 @@ TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) {
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
switch (++gotCallback) {
case 1:
ASSERT_FALSE(fin);
ASSERT_EQ(data, recvData);
REQUIRE_FALSE(fin);
REQUIRE(data == recvData);
break;
case 2:
ASSERT_FALSE(fin);
ASSERT_EQ(data2, recvData);
REQUIRE_FALSE(fin);
REQUIRE(data2 == recvData);
break;
case 3:
ws->Terminate();
ASSERT_TRUE(fin);
ASSERT_EQ(data3, recvData);
REQUIRE(fin);
REQUIRE(data3 == recvData);
break;
default:
FAIL() << "too many callbacks";
FAIL("too many callbacks");
break;
}
});
@@ -479,11 +504,13 @@ TEST_F(WebSocketServerTest, ReceiveFragmentSeparate) {
loop->Run();
ASSERT_EQ(gotCallback, 3);
REQUIRE(gotCallback == 3);
}
// Control frames can happen in the middle of a fragmented message
TEST_F(WebSocketServerTest, ReceiveFragmentWithControl) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveFragmentWithControl",
"[websocket][server][fragment][control]") {
int gotCallback = 0;
int gotPongCallback = 0;
@@ -504,15 +531,15 @@ TEST_F(WebSocketServerTest, ReceiveFragmentWithControl) {
setupWebSocket = [&] {
ws->binary.connect([&](auto inData, bool fin) {
ASSERT_TRUE(gotPongCallback);
REQUIRE(gotPongCallback);
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
REQUIRE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(combData, recvData);
REQUIRE(combData == recvData);
});
ws->pong.connect([&](auto inData) {
ASSERT_FALSE(gotCallback);
REQUIRE_FALSE(gotCallback);
++gotPongCallback;
});
};
@@ -528,8 +555,8 @@ TEST_F(WebSocketServerTest, ReceiveFragmentWithControl) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
ASSERT_EQ(gotPongCallback, 1);
REQUIRE(gotCallback == 1);
REQUIRE(gotPongCallback == 1);
}
//
@@ -537,18 +564,20 @@ TEST_F(WebSocketServerTest, ReceiveFragmentWithControl) {
//
// Single message
TEST_F(WebSocketServerTest, ReceiveTooLarge) {
TEST_CASE_METHOD(WebSocketServerTest, "WebSocketServerTest ReceiveTooLarge",
"[websocket][server][limits]") {
int gotCallback = 0;
std::vector<uint8_t> data(2048, 0x03u);
setupWebSocket = [&] {
ws->SetMaxMessageSize(1024);
ws->binary.connect([&](auto, bool) {
ws->Terminate();
FAIL() << "Should not have gotten unmasked message";
FAIL("Should not have gotten unmasked message");
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1009) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1009);
});
};
auto message = BuildMessage(0x01, true, true, data);
@@ -558,22 +587,25 @@ TEST_F(WebSocketServerTest, ReceiveTooLarge) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
// Applied across fragments if combining
TEST_F(WebSocketServerTest, ReceiveTooLargeFragmented) {
TEST_CASE_METHOD(WebSocketServerTest,
"WebSocketServerTest ReceiveTooLargeFragmented",
"[websocket][server][limits][fragment]") {
int gotCallback = 0;
std::vector<uint8_t> data(768, 0x03u);
setupWebSocket = [&] {
ws->SetMaxMessageSize(1024);
ws->binary.connect([&](auto, bool) {
ws->Terminate();
FAIL() << "Should not have gotten unmasked message";
FAIL("Should not have gotten unmasked message");
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1009) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1009);
});
};
auto message = BuildMessage(0x01, false, true, data);
@@ -584,29 +616,28 @@ TEST_F(WebSocketServerTest, ReceiveTooLargeFragmented) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
//
// Send and receive data.
//
class WebSocketServerDataTest : public WebSocketServerTest,
public ::testing::WithParamInterface<size_t> {};
class WebSocketServerDataTest : public WebSocketServerTest {};
INSTANTIATE_TEST_SUITE_P(WebSocketServerDataTests, WebSocketServerDataTest,
::testing::Values(0, 1, 125, 126, 65535, 65536));
TEST_P(WebSocketServerDataTest, SendText) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest SendText",
"[websocket][server][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), ' ');
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
' ');
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendText({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
@@ -614,20 +645,23 @@ TEST_P(WebSocketServerDataTest, SendText) {
loop->Run();
auto expectData = BuildMessage(0x01, true, false, data);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, SendBinary) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest SendBinary",
"[websocket][server][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendBinary({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
@@ -635,20 +669,23 @@ TEST_P(WebSocketServerDataTest, SendBinary) {
loop->Run();
auto expectData = BuildMessage(0x02, true, false, data);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, SendPing) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest SendPing",
"[websocket][server][control]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendPing({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
@@ -656,20 +693,23 @@ TEST_P(WebSocketServerDataTest, SendPing) {
loop->Run();
auto expectData = BuildMessage(0x09, true, false, data);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, SendPong) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest SendPong",
"[websocket][server][control]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendPong({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
REQUIRE_FALSE(bufs.empty());
REQUIRE(bufs[0].base == reinterpret_cast<const char*>(data.data()));
});
});
};
@@ -677,21 +717,24 @@ TEST_P(WebSocketServerDataTest, SendPong) {
loop->Run();
auto expectData = BuildMessage(0x0a, true, false, data);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
REQUIRE(wireData == expectData);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, ReceiveText) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest ReceiveText",
"[websocket][server][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), ' ');
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
' ');
setupWebSocket = [&] {
ws->text.connect([&](std::string_view inData, bool fin) {
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
REQUIRE(fin);
std::vector<uint8_t> recvData;
recvData.insert(recvData.end(), inData.begin(), inData.end());
ASSERT_EQ(data, recvData);
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x01, true, true, data);
@@ -701,19 +744,23 @@ TEST_P(WebSocketServerDataTest, ReceiveText) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, ReceiveBinary) {
TEST_CASE_METHOD(WebSocketServerDataTest,
"WebSocketServerDataTest ReceiveBinary",
"[websocket][server][data]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->binary.connect([&](auto inData, bool fin) {
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
REQUIRE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(data, recvData);
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x02, true, true, data);
@@ -723,18 +770,21 @@ TEST_P(WebSocketServerDataTest, ReceiveBinary) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, ReceivePing) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest ReceivePing",
"[websocket][server][control]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->ping.connect([&](auto inData) {
++gotCallback;
ws->Terminate();
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(data, recvData);
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x09, true, true, data);
@@ -744,18 +794,21 @@ TEST_P(WebSocketServerDataTest, ReceivePing) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
TEST_P(WebSocketServerDataTest, ReceivePong) {
TEST_CASE_METHOD(WebSocketServerDataTest, "WebSocketServerDataTest ReceivePong",
"[websocket][server][control]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
0x03u);
setupWebSocket = [&] {
ws->pong.connect([&](auto inData) {
++gotCallback;
ws->Terminate();
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(data, recvData);
REQUIRE(data == recvData);
});
};
auto message = BuildMessage(0x0a, true, true, data);
@@ -765,24 +818,29 @@ TEST_P(WebSocketServerDataTest, ReceivePong) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
//
// The server must close the connection if an unmasked frame is received.
//
TEST_P(WebSocketServerDataTest, ReceiveUnmasked) {
TEST_CASE_METHOD(WebSocketServerDataTest,
"WebSocketServerDataTest ReceiveUnmasked",
"[websocket][server][data][protocol]") {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), ' ');
std::vector<uint8_t> data(GENERATE(size_t{0}, size_t{1}, size_t{125},
size_t{126}, size_t{65535}, size_t{65536}),
' ');
setupWebSocket = [&] {
ws->text.connect([&](std::string_view, bool) {
ws->Terminate();
FAIL() << "Should not have gotten unmasked message";
FAIL("Should not have gotten unmasked message");
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
INFO("reason: " << reason);
REQUIRE(code == 1002);
});
};
auto message = BuildMessage(0x01, true, false, data);
@@ -792,7 +850,7 @@ TEST_P(WebSocketServerDataTest, ReceiveUnmasked) {
loop->Run();
ASSERT_EQ(gotCallback, 1);
REQUIRE(gotCallback == 1);
}
} // namespace wpi::net

View File

@@ -4,6 +4,10 @@
#include "wpi/net/WebSocket.hpp"
#ifndef _WIN32
#include <unistd.h>
#endif
#include <utility>
#include <vector>
@@ -20,7 +24,7 @@ const char* WebSocketTest::pipeName = "/tmp/websocket-unit-test";
#endif
const uint8_t WebSocketTest::testMask[4] = {0x11, 0x22, 0x33, 0x44};
void WebSocketTest::SetUpTestCase() {
void WebSocketTest::UnlinkPipe() {
#ifndef _WIN32
unlink(pipeName);
#endif
@@ -100,7 +104,8 @@ void WebSocketTest::AdjustMasking(std::span<uint8_t> message) {
}
}
TEST_F(WebSocketTest, CreateClientBasic) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateClientBasic",
"[websocket][client][handshake]") {
int gotHost = 0;
int gotUpgrade = 0;
int gotConnection = 0;
@@ -108,24 +113,24 @@ TEST_F(WebSocketTest, CreateClientBasic) {
int gotVersion = 0;
HttpParser req{HttpParser::Type::REQUEST};
req.url.connect([](std::string_view url) { ASSERT_EQ(url, "/test"); });
req.url.connect([](std::string_view url) { REQUIRE(url == "/test"); });
req.header.connect([&](std::string_view name, std::string_view value) {
if (wpi::util::equals_lower(name, "host")) {
ASSERT_EQ(value, pipeName);
REQUIRE(value == pipeName);
++gotHost;
} else if (wpi::util::equals_lower(name, "upgrade")) {
ASSERT_EQ(value, "websocket");
REQUIRE(value == "websocket");
++gotUpgrade;
} else if (wpi::util::equals_lower(name, "connection")) {
ASSERT_EQ(value, "Upgrade");
REQUIRE(value == "Upgrade");
++gotConnection;
} else if (wpi::util::equals_lower(name, "sec-websocket-key")) {
++gotKey;
} else if (wpi::util::equals_lower(name, "sec-websocket-version")) {
ASSERT_EQ(value, "13");
REQUIRE(value == "13");
++gotVersion;
} else {
FAIL() << "unexpected header " << name;
FAIL("unexpected header " << name);
}
});
req.headersComplete.connect([&](bool) { Finish(); });
@@ -138,7 +143,8 @@ TEST_F(WebSocketTest, CreateClientBasic) {
if (req.HasError()) {
Finish();
}
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
INFO(http_errno_name(req.GetError()));
REQUIRE(req.GetError() == HPE_OK);
});
});
clientPipe->Connect(pipeName, [&]() {
@@ -146,27 +152,24 @@ TEST_F(WebSocketTest, CreateClientBasic) {
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotHost, 1);
ASSERT_EQ(gotUpgrade, 1);
ASSERT_EQ(gotConnection, 1);
ASSERT_EQ(gotKey, 1);
ASSERT_EQ(gotVersion, 1);
REQUIRE(gotHost == 1);
REQUIRE(gotUpgrade == 1);
REQUIRE(gotConnection == 1);
REQUIRE(gotKey == 1);
REQUIRE(gotVersion == 1);
}
TEST_F(WebSocketTest, CreateClientExtraHeaders) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateClientExtraHeaders",
"[websocket][client][handshake]") {
int gotExtra1 = 0;
int gotExtra2 = 0;
HttpParser req{HttpParser::Type::REQUEST};
req.header.connect([&](std::string_view name, std::string_view value) {
if (wpi::util::equals(name, "Extra1")) {
ASSERT_EQ(value, "Data1");
REQUIRE(value == "Data1");
++gotExtra1;
} else if (wpi::util::equals(name, "Extra2")) {
ASSERT_EQ(value, "Data2");
REQUIRE(value == "Data2");
++gotExtra2;
}
});
@@ -180,7 +183,8 @@ TEST_F(WebSocketTest, CreateClientExtraHeaders) {
if (req.HasError()) {
Finish();
}
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
INFO(http_errno_name(req.GetError()));
REQUIRE(req.GetError() == HPE_OK);
});
});
clientPipe->Connect(pipeName, [&]() {
@@ -195,15 +199,12 @@ TEST_F(WebSocketTest, CreateClientExtraHeaders) {
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotExtra1, 1);
ASSERT_EQ(gotExtra2, 1);
REQUIRE(gotExtra1 == 1);
REQUIRE(gotExtra2 == 1);
}
TEST_F(WebSocketTest, CreateClientTimeout) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateClientTimeout",
"[websocket][client][handshake]") {
int gotClosed = 0;
serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); });
clientPipe->Connect(pipeName, [&]() {
@@ -214,19 +215,16 @@ TEST_F(WebSocketTest, CreateClientTimeout) {
ws->closed.connect([&](uint16_t code, std::string_view) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1006);
REQUIRE(code == 1006);
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
REQUIRE(gotClosed == 1);
}
TEST_F(WebSocketTest, CreateServerBasic) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateServerBasic",
"[websocket][server][handshake]") {
int gotStatus = 0;
int gotUpgrade = 0;
int gotConnection = 0;
@@ -236,20 +234,21 @@ TEST_F(WebSocketTest, CreateServerBasic) {
HttpParser resp{HttpParser::Type::RESPONSE};
resp.status.connect([&](std::string_view status) {
++gotStatus;
ASSERT_EQ(resp.GetStatusCode(), 101u) << "status: " << status;
INFO("status: " << status);
REQUIRE(resp.GetStatusCode() == 101u);
});
resp.header.connect([&](std::string_view name, std::string_view value) {
if (wpi::util::equals_lower(name, "upgrade")) {
ASSERT_EQ(value, "websocket");
REQUIRE(value == "websocket");
++gotUpgrade;
} else if (wpi::util::equals_lower(name, "connection")) {
ASSERT_EQ(value, "Upgrade");
REQUIRE(value == "Upgrade");
++gotConnection;
} else if (wpi::util::equals_lower(name, "sec-websocket-accept")) {
ASSERT_EQ(value, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
REQUIRE(value == "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
++gotAccept;
} else {
FAIL() << "unexpected header " << name;
FAIL("unexpected header " << name);
}
});
resp.headersComplete.connect([&](bool) { Finish(); });
@@ -259,7 +258,7 @@ TEST_F(WebSocketTest, CreateServerBasic) {
auto ws = WebSocket::CreateServer(*conn, "dGhlIHNhbXBsZSBub25jZQ==", "13");
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
ASSERT_TRUE(protocol.empty());
REQUIRE(protocol.empty());
});
});
clientPipe->Connect(pipeName, [&] {
@@ -269,23 +268,21 @@ TEST_F(WebSocketTest, CreateServerBasic) {
if (resp.HasError()) {
Finish();
}
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
INFO(http_errno_name(resp.GetError()));
REQUIRE(resp.GetError() == HPE_OK);
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotStatus, 1);
ASSERT_EQ(gotUpgrade, 1);
ASSERT_EQ(gotConnection, 1);
ASSERT_EQ(gotAccept, 1);
ASSERT_EQ(gotOpen, 1);
REQUIRE(gotStatus == 1);
REQUIRE(gotUpgrade == 1);
REQUIRE(gotConnection == 1);
REQUIRE(gotAccept == 1);
REQUIRE(gotOpen == 1);
}
TEST_F(WebSocketTest, CreateServerProtocol) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateServerProtocol",
"[websocket][server][handshake][protocol]") {
int gotProtocol = 0;
int gotOpen = 0;
@@ -293,7 +290,7 @@ TEST_F(WebSocketTest, CreateServerProtocol) {
resp.header.connect([&](std::string_view name, std::string_view value) {
if (wpi::util::equals_lower(name, "sec-websocket-protocol")) {
++gotProtocol;
ASSERT_EQ(value, "myProtocol");
REQUIRE(value == "myProtocol");
}
});
resp.headersComplete.connect([&](bool) { Finish(); });
@@ -303,7 +300,7 @@ TEST_F(WebSocketTest, CreateServerProtocol) {
auto ws = WebSocket::CreateServer(*conn, "foo", "13", "myProtocol");
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
ASSERT_EQ(protocol, "myProtocol");
REQUIRE(protocol == "myProtocol");
});
});
clientPipe->Connect(pipeName, [&] {
@@ -313,20 +310,18 @@ TEST_F(WebSocketTest, CreateServerProtocol) {
if (resp.HasError()) {
Finish();
}
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
INFO(http_errno_name(resp.GetError()));
REQUIRE(resp.GetError() == HPE_OK);
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotProtocol, 1);
ASSERT_EQ(gotOpen, 1);
REQUIRE(gotProtocol == 1);
REQUIRE(gotOpen == 1);
}
TEST_F(WebSocketTest, CreateServerBadVersion) {
TEST_CASE_METHOD(WebSocketTest, "WebSocketTest CreateServerBadVersion",
"[websocket][server][handshake][protocol]") {
int gotStatus = 0;
int gotVersion = 0;
int gotUpgrade = 0;
@@ -334,17 +329,18 @@ TEST_F(WebSocketTest, CreateServerBadVersion) {
HttpParser resp{HttpParser::Type::RESPONSE};
resp.status.connect([&](std::string_view status) {
++gotStatus;
ASSERT_EQ(resp.GetStatusCode(), 426u) << "status: " << status;
INFO("status: " << status);
REQUIRE(resp.GetStatusCode() == 426u);
});
resp.header.connect([&](std::string_view name, std::string_view value) {
if (wpi::util::equals_lower(name, "sec-websocket-version")) {
++gotVersion;
ASSERT_EQ(value, "13");
REQUIRE(value == "13");
} else if (wpi::util::equals_lower(name, "upgrade")) {
++gotUpgrade;
ASSERT_EQ(value, "WebSocket");
REQUIRE(value == "WebSocket");
} else {
FAIL() << "unexpected header " << name;
FAIL("unexpected header " << name);
}
});
resp.headersComplete.connect([&](bool) { Finish(); });
@@ -364,18 +360,15 @@ TEST_F(WebSocketTest, CreateServerBadVersion) {
if (resp.HasError()) {
Finish();
}
ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
INFO(http_errno_name(resp.GetError()));
REQUIRE(resp.GetError() == HPE_OK);
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotStatus, 1);
ASSERT_EQ(gotVersion, 1);
ASSERT_EQ(gotUpgrade, 1);
REQUIRE(gotStatus == 1);
REQUIRE(gotVersion == 1);
REQUIRE(gotUpgrade == 1);
}
} // namespace wpi::net

View File

@@ -9,7 +9,7 @@
#include <span>
#include <vector>
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Loop.hpp"
#include "wpi/net/uv/Pipe.hpp"
@@ -17,13 +17,15 @@
namespace wpi::net {
class WebSocketTest : public ::testing::Test {
class WebSocketTest {
public:
static const char* pipeName;
static void SetUpTestCase();
static void UnlinkPipe();
WebSocketTest() {
UnlinkPipe();
loop = uv::Loop::Create();
clientPipe = uv::Pipe::Create(loop);
serverPipe = uv::Pipe::Create(loop);
@@ -43,13 +45,13 @@ class WebSocketTest : public ::testing::Test {
auto failTimer = uv::Timer::Create(loop);
failTimer->timeout.connect([this] {
loop->Stop();
FAIL() << "loop failed to terminate";
FAIL("loop failed to terminate");
});
failTimer->Start(uv::Timer::Time{1000});
failTimer->Unreference();
}
~WebSocketTest() override { Finish(); }
~WebSocketTest() { Finish(); }
void Finish() {
loop->Walk([](uv::Handle& it) { it.Close(); });

View File

@@ -4,7 +4,7 @@
#include "wpi/net/WorkerThread.hpp"
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/EventLoopRunner.hpp"
#include "wpi/net/uv/Loop.hpp"
@@ -13,30 +13,33 @@
namespace wpi::net {
TEST(WorkerThreadTest, Future) {
TEST_CASE("WorkerThreadTest Future", "[worker][thread]") {
WorkerThread<int(bool)> worker;
wpi::util::future<int> f =
worker.QueueWork([](bool v) -> int { return v ? 1 : 2; }, true);
ASSERT_EQ(f.get(), 1);
REQUIRE(f.get() == 1);
}
TEST(WorkerThreadTest, FutureVoid) {
TEST_CASE("WorkerThreadTest FutureVoid", "[worker][thread]") {
int callbacks = 0;
bool v3_check = false;
WorkerThread<void(int)> worker;
wpi::util::future<void> f = worker.QueueWork(
[&](int v) {
++callbacks;
ASSERT_EQ(v, 3);
v3_check = v == 3;
},
3);
f.get();
ASSERT_EQ(callbacks, 1);
REQUIRE(callbacks == 1);
REQUIRE(v3_check);
}
TEST(WorkerThreadTest, Loop) {
TEST_CASE("WorkerThreadTest Loop", "[worker][thread]") {
wpi::util::mutex m;
wpi::util::condition_variable cv;
int callbacks = 0;
bool v2_check = false;
WorkerThread<int(bool)> worker;
EventLoopRunner runner;
@@ -46,18 +49,19 @@ TEST(WorkerThreadTest, Loop) {
std::scoped_lock lock{m};
++callbacks;
cv.notify_all();
ASSERT_EQ(v2, 1);
v2_check = v2 == 1;
},
true);
auto f = worker.QueueWork([&](bool) -> int { return 2; }, true);
ASSERT_EQ(f.get(), 2);
REQUIRE(f.get() == 2);
std::unique_lock lock{m};
cv.wait(lock, [&] { return callbacks == 1; });
ASSERT_EQ(callbacks, 1);
REQUIRE(callbacks == 1);
REQUIRE(v2_check);
}
TEST(WorkerThreadTest, LoopVoid) {
TEST_CASE("WorkerThreadTest LoopVoid", "[worker][thread]") {
wpi::util::mutex m;
wpi::util::condition_variable cv;
int callbacks = 0;
@@ -77,7 +81,7 @@ TEST(WorkerThreadTest, LoopVoid) {
std::unique_lock lock{m};
cv.wait(lock, [&] { return callbacks == 1; });
ASSERT_EQ(callbacks, 1);
REQUIRE(callbacks == 1);
}
} // namespace wpi::net

View File

@@ -4,21 +4,21 @@
#include "wpi/net/hostname.hpp"
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/util/SmallString.hpp"
#include "wpi/util/SmallVector.hpp"
namespace wpi::net {
TEST(HostNameTest, HostNameNotEmpty) {
ASSERT_NE(GetHostname(), "");
TEST_CASE("HostNameTest HostNameNotEmpty", "[hostname]") {
REQUIRE(GetHostname() != "");
}
TEST(HostNameTest, HostNameNotEmptySmallVector) {
TEST_CASE("HostNameTest HostNameNotEmptySmallVector", "[hostname]") {
wpi::util::SmallVector<char, 256> name;
ASSERT_NE(GetHostname(name), "");
REQUIRE(GetHostname(name) != "");
}
TEST(HostNameTest, HostNameEq) {
TEST_CASE("HostNameTest HostNameEq", "[hostname]") {
wpi::util::SmallVector<char, 256> nameBuf;
ASSERT_EQ(GetHostname(nameBuf), GetHostname());
REQUIRE(GetHostname(nameBuf) == GetHostname());
}
} // namespace wpi::net

View File

@@ -2,10 +2,8 @@
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
#include <catch2/catch_session.hpp>
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
return Catch::Session().run(argc, argv);
}

View File

@@ -4,64 +4,64 @@
#include "wpi/net/raw_uv_ostream.hpp"
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net {
TEST(RawUvOstreamTest, BasicWrite) {
TEST_CASE("RawUvOstreamTest BasicWrite", "[uv][ostream]") {
wpi::util::SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os(bufs, 1024);
os << "12";
os << "34";
ASSERT_EQ(bufs.size(), 1u);
ASSERT_EQ(bufs[0].len, 4u);
ASSERT_EQ(bufs[0].base[0], '1');
ASSERT_EQ(bufs[0].base[1], '2');
ASSERT_EQ(bufs[0].base[2], '3');
ASSERT_EQ(bufs[0].base[3], '4');
REQUIRE(bufs.size() == 1u);
REQUIRE(bufs[0].len == 4u);
REQUIRE(bufs[0].base[0] == '1');
REQUIRE(bufs[0].base[1] == '2');
REQUIRE(bufs[0].base[2] == '3');
REQUIRE(bufs[0].base[3] == '4');
for (auto& buf : bufs) {
buf.Deallocate();
}
}
TEST(RawUvOstreamTest, BoundaryWrite) {
TEST_CASE("RawUvOstreamTest BoundaryWrite", "[uv][ostream]") {
wpi::util::SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os(bufs, 4);
ASSERT_EQ(bufs.size(), 0u);
REQUIRE(bufs.size() == 0u);
os << "12";
ASSERT_EQ(bufs.size(), 1u);
REQUIRE(bufs.size() == 1u);
os << "34";
ASSERT_EQ(bufs.size(), 1u);
REQUIRE(bufs.size() == 1u);
os << "56";
ASSERT_EQ(bufs.size(), 2u);
REQUIRE(bufs.size() == 2u);
for (auto& buf : bufs) {
buf.Deallocate();
}
}
TEST(RawUvOstreamTest, LargeWrite) {
TEST_CASE("RawUvOstreamTest LargeWrite", "[uv][ostream]") {
wpi::util::SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os(bufs, 4);
os << "123456";
ASSERT_EQ(bufs.size(), 2u);
ASSERT_EQ(bufs[1].len, 2u);
ASSERT_EQ(bufs[1].base[0], '5');
REQUIRE(bufs.size() == 2u);
REQUIRE(bufs[1].len == 2u);
REQUIRE(bufs[1].base[0] == '5');
for (auto& buf : bufs) {
buf.Deallocate();
}
}
TEST(RawUvOstreamTest, PrevDataWrite) {
TEST_CASE("RawUvOstreamTest PrevDataWrite", "[uv][ostream]") {
wpi::util::SmallVector<uv::Buffer, 4> bufs;
bufs.emplace_back(uv::Buffer::Allocate(1024));
raw_uv_ostream os(bufs, 1024);
os << "1234";
ASSERT_EQ(bufs.size(), 2u);
ASSERT_EQ(bufs[0].len, 1024u);
ASSERT_EQ(bufs[1].len, 4u);
REQUIRE(bufs.size() == 2u);
REQUIRE(bufs[0].len == 1024u);
REQUIRE(bufs[1].len == 4u);
for (auto& buf : bufs) {
buf.Deallocate();

View File

@@ -6,21 +6,25 @@
#include "wpi/net/uv/AsyncFunction.hpp"
// clang-format on
#include <atomic>
#include <memory>
#include <thread>
#include <utility>
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Loop.hpp"
#include "wpi/net/uv/Prepare.hpp"
namespace wpi::net::uv {
TEST(UvAsyncFunctionTest, Basic) {
TEST_CASE("UvAsyncFunctionTest Basic", "[uv][async-function]") {
int prepare_cb_called = 0;
int async_cb_called[2] = {0, 0};
int close_cb_called = 0;
std::atomic_bool fail{false};
std::atomic_bool call0_check{false};
std::atomic_bool call1_check{false};
std::thread theThread;
@@ -28,9 +32,9 @@ TEST(UvAsyncFunctionTest, Basic) {
auto async = AsyncFunction<int(int)>::Create(loop);
auto prepare = Prepare::Create(loop);
loop->error.connect([](Error) { FAIL(); });
loop->error.connect([&](Error) { fail = true; });
prepare->error.connect([](Error) { FAIL(); });
prepare->error.connect([&](Error) { fail = true; });
prepare->prepare.connect([&] {
if (prepare_cb_called++) {
return;
@@ -38,13 +42,13 @@ TEST(UvAsyncFunctionTest, Basic) {
theThread = std::thread([&] {
auto call0 = async->Call(0);
auto call1 = async->Call(1);
ASSERT_EQ(call0.get(), 1);
ASSERT_EQ(call1.get(), 2);
call0_check = call0.get() == 1;
call1_check = call1.get() == 2;
});
});
prepare->Start();
async->error.connect([](Error) { FAIL(); });
async->error.connect([&](Error) { fail = true; });
async->closed.connect([&] { close_cb_called++; });
async->wakeup = [&](wpi::util::promise<int> out, int v) {
++async_cb_called[v];
@@ -57,18 +61,25 @@ TEST(UvAsyncFunctionTest, Basic) {
loop->Run();
ASSERT_EQ(async_cb_called[0], 1);
ASSERT_EQ(async_cb_called[1], 1);
ASSERT_EQ(close_cb_called, 1);
if (fail) {
FAIL();
}
REQUIRE(async_cb_called[0] == 1);
REQUIRE(async_cb_called[1] == 1);
REQUIRE(close_cb_called == 1);
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(call0_check);
REQUIRE(call1_check);
}
TEST(UvAsyncFunctionTest, Ref) {
TEST_CASE("UvAsyncFunctionTest Ref", "[uv][async-function]") {
int prepare_cb_called = 0;
int val = 0;
std::atomic_bool call_check{false};
std::thread theThread;
@@ -80,7 +91,8 @@ TEST(UvAsyncFunctionTest, Ref) {
if (prepare_cb_called++) {
return;
}
theThread = std::thread([&] { ASSERT_EQ(async->Call(1, val).get(), 2); });
theThread =
std::thread([&] { call_check = async->Call(1, val).get() == 2; });
});
prepare->Start();
@@ -93,15 +105,17 @@ TEST(UvAsyncFunctionTest, Ref) {
loop->Run();
ASSERT_EQ(val, 1);
REQUIRE(val == 1);
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(call_check);
}
TEST(UvAsyncFunctionTest, Movable) {
TEST_CASE("UvAsyncFunctionTest Movable", "[uv][async-function]") {
int prepare_cb_called = 0;
std::atomic_bool val2_check{false};
std::thread theThread;
@@ -117,8 +131,7 @@ TEST(UvAsyncFunctionTest, Movable) {
theThread = std::thread([&] {
auto val = std::make_unique<int>(1);
auto val2 = async->Call(std::move(val)).get();
ASSERT_NE(val2, nullptr);
ASSERT_EQ(*val2, 1);
val2_check = val2 != nullptr && *val2 == 1;
});
});
prepare->Start();
@@ -135,9 +148,10 @@ TEST(UvAsyncFunctionTest, Movable) {
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(val2_check);
}
TEST(UvAsyncFunctionTest, CallIgnoreResult) {
TEST_CASE("UvAsyncFunctionTest CallIgnoreResult", "[uv][async-function]") {
int prepare_cb_called = 0;
std::thread theThread;
@@ -169,7 +183,7 @@ TEST(UvAsyncFunctionTest, CallIgnoreResult) {
}
}
TEST(UvAsyncFunctionTest, VoidCall) {
TEST_CASE("UvAsyncFunctionTest VoidCall", "[uv][async-function]") {
int prepare_cb_called = 0;
std::thread theThread;
@@ -199,8 +213,9 @@ TEST(UvAsyncFunctionTest, VoidCall) {
}
}
TEST(UvAsyncFunctionTest, WaitFor) {
TEST_CASE("UvAsyncFunctionTest WaitFor", "[uv][async-function]") {
int prepare_cb_called = 0;
std::atomic_bool call_check{false};
std::thread theThread;
@@ -213,7 +228,7 @@ TEST(UvAsyncFunctionTest, WaitFor) {
return;
}
theThread = std::thread([&] {
ASSERT_FALSE(async->Call().wait_for(std::chrono::milliseconds(10)));
call_check = !async->Call().wait_for(std::chrono::milliseconds(10));
});
});
prepare->Start();
@@ -230,10 +245,12 @@ TEST(UvAsyncFunctionTest, WaitFor) {
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(call_check);
}
TEST(UvAsyncFunctionTest, VoidWaitFor) {
TEST_CASE("UvAsyncFunctionTest VoidWaitFor", "[uv][async-function]") {
int prepare_cb_called = 0;
std::atomic_bool call_check{false};
std::thread theThread;
@@ -246,7 +263,7 @@ TEST(UvAsyncFunctionTest, VoidWaitFor) {
return;
}
theThread = std::thread([&] {
ASSERT_FALSE(async->Call().wait_for(std::chrono::milliseconds(10)));
call_check = !async->Call().wait_for(std::chrono::milliseconds(10));
});
});
prepare->Start();
@@ -263,6 +280,7 @@ TEST(UvAsyncFunctionTest, VoidWaitFor) {
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(call_check);
}
} // namespace wpi::net::uv

View File

@@ -31,7 +31,7 @@
#include <functional>
#include <thread>
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Loop.hpp"
#include "wpi/net/uv/Prepare.hpp"
@@ -39,10 +39,11 @@
namespace wpi::net::uv {
TEST(UvAsyncTest, CallbackOnly) {
TEST_CASE("UvAsyncTest CallbackOnly", "[uv][async]") {
std::atomic_int async_cb_called{0};
int prepare_cb_called = 0;
int close_cb_called = 0;
std::atomic_bool fail{false};
wpi::util::mutex mutex;
mutex.lock();
@@ -53,9 +54,9 @@ TEST(UvAsyncTest, CallbackOnly) {
auto async = Async<>::Create(loop);
auto prepare = Prepare::Create(loop);
loop->error.connect([](Error) { FAIL(); });
loop->error.connect([&](Error) { fail = true; });
prepare->error.connect([](Error) { FAIL(); });
prepare->error.connect([&](Error) { fail = true; });
prepare->closed.connect([&] { close_cb_called++; });
prepare->prepare.connect([&] {
if (prepare_cb_called++) {
@@ -80,7 +81,7 @@ TEST(UvAsyncTest, CallbackOnly) {
});
prepare->Start();
async->error.connect([](Error) { FAIL(); });
async->error.connect([&](Error) { fail = true; });
async->closed.connect([&] { close_cb_called++; });
async->wakeup.connect([&] {
mutex.lock();
@@ -95,19 +96,26 @@ TEST(UvAsyncTest, CallbackOnly) {
loop->Run();
ASSERT_GT(prepare_cb_called, 0);
ASSERT_EQ(async_cb_called, 3);
ASSERT_EQ(close_cb_called, 2);
if (fail) {
FAIL();
}
REQUIRE(prepare_cb_called > 0);
REQUIRE(async_cb_called == 3);
REQUIRE(close_cb_called == 2);
if (theThread.joinable()) {
theThread.join();
}
}
TEST(UvAsyncTest, Data) {
TEST_CASE("UvAsyncTest Data", "[uv][async]") {
int prepare_cb_called = 0;
int async_cb_called[2] = {0, 0};
int close_cb_called = 0;
std::atomic_bool fail{false};
std::atomic_bool v0_check{false};
std::atomic_bool v1_check{false};
std::thread theThread;
@@ -115,20 +123,20 @@ TEST(UvAsyncTest, Data) {
auto async = Async<int, std::function<void(int)>>::Create(loop);
auto prepare = Prepare::Create(loop);
loop->error.connect([](Error) { FAIL(); });
loop->error.connect([&](Error) { fail = true; });
prepare->error.connect([](Error) { FAIL(); });
prepare->error.connect([&](Error) { fail = true; });
prepare->prepare.connect([&] {
if (prepare_cb_called++) {
return;
}
theThread = std::thread([&] {
async->Send(0, [&](int v) {
ASSERT_EQ(v, 0);
v0_check = v == 0;
++async_cb_called[0];
});
async->Send(1, [&](int v) {
ASSERT_EQ(v, 1);
v1_check = v == 1;
++async_cb_called[1];
async->Close();
prepare->Close();
@@ -137,22 +145,28 @@ TEST(UvAsyncTest, Data) {
});
prepare->Start();
async->error.connect([](Error) { FAIL(); });
async->error.connect([&](Error) { fail = true; });
async->closed.connect([&] { close_cb_called++; });
async->wakeup.connect([&](int v, std::function<void(int)> f) { f(v); });
loop->Run();
ASSERT_EQ(async_cb_called[0], 1);
ASSERT_EQ(async_cb_called[1], 1);
ASSERT_EQ(close_cb_called, 1);
if (fail) {
FAIL();
}
REQUIRE(async_cb_called[0] == 1);
REQUIRE(async_cb_called[1] == 1);
REQUIRE(close_cb_called == 1);
if (theThread.joinable()) {
theThread.join();
}
REQUIRE(v0_check);
REQUIRE(v1_check);
}
TEST(UvAsyncTest, DataRef) {
TEST_CASE("UvAsyncTest DataRef", "[uv][async]") {
int prepare_cb_called = 0;
int val = 0;
@@ -178,7 +192,7 @@ TEST(UvAsyncTest, DataRef) {
loop->Run();
ASSERT_EQ(val, 1);
REQUIRE(val == 1);
if (theThread.joinable()) {
theThread.join();

View File

@@ -6,45 +6,45 @@
#include "wpi/net/uv/Buffer.hpp"
// clang-format on
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net::uv {
TEST(UvSimpleBufferPoolTest, ConstructDefault) {
TEST_CASE("UvSimpleBufferPoolTest ConstructDefault", "[uv][buffer]") {
SimpleBufferPool<> pool;
auto buf1 = pool.Allocate();
ASSERT_EQ(buf1.len, 4096u); // NOLINT
REQUIRE(buf1.len == 4096u); // NOLINT
pool.Release({&buf1, 1});
}
TEST(UvSimpleBufferPoolTest, ConstructSize) {
TEST_CASE("UvSimpleBufferPoolTest ConstructSize", "[uv][buffer]") {
SimpleBufferPool<4> pool{8192};
auto buf1 = pool.Allocate();
ASSERT_EQ(buf1.len, 8192u); // NOLINT
REQUIRE(buf1.len == 8192u); // NOLINT
pool.Release({&buf1, 1});
}
TEST(UvSimpleBufferPoolTest, ReleaseReuse) {
TEST_CASE("UvSimpleBufferPoolTest ReleaseReuse", "[uv][buffer]") {
SimpleBufferPool<4> pool;
auto buf1 = pool.Allocate();
auto buf1copy = buf1;
auto origSize = buf1.len;
buf1.len = 8;
pool.Release({&buf1, 1});
ASSERT_EQ(buf1.base, nullptr);
REQUIRE(buf1.base == nullptr);
auto buf2 = pool.Allocate();
ASSERT_EQ(buf1copy.base, buf2.base);
ASSERT_EQ(buf2.len, origSize);
REQUIRE(buf1copy.base == buf2.base);
REQUIRE(buf2.len == origSize);
pool.Release({&buf2, 1});
}
TEST(UvSimpleBufferPoolTest, ClearRemaining) {
TEST_CASE("UvSimpleBufferPoolTest ClearRemaining", "[uv][buffer]") {
SimpleBufferPool<4> pool;
auto buf1 = pool.Allocate();
pool.Release({&buf1, 1});
ASSERT_EQ(pool.Remaining(), 1u);
REQUIRE(pool.Remaining() == 1u);
pool.Clear();
ASSERT_EQ(pool.Remaining(), 0u);
REQUIRE(pool.Remaining() == 0u);
}
} // namespace wpi::net::uv

View File

@@ -27,7 +27,7 @@
#include "wpi/net/uv/GetAddrInfo.hpp"
// clang-format on
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Loop.hpp"
@@ -35,37 +35,37 @@
namespace wpi::net::uv {
TEST(UvGetAddrInfoTest, BothNull) {
TEST_CASE("UvGetAddrInfoTest BothNull", "[uv][dns][addrinfo]") {
int fail_cb_called = 0;
auto loop = Loop::Create();
loop->error.connect([&](Error err) {
ASSERT_EQ(err.code(), UV_EINVAL);
REQUIRE(err.code() == UV_EINVAL);
fail_cb_called++;
});
GetAddrInfo(loop, [](const addrinfo&) { FAIL(); }, "");
loop->Run();
ASSERT_EQ(fail_cb_called, 1);
REQUIRE(fail_cb_called == 1);
}
TEST(UvGetAddrInfoTest, DISABLED_FailedLookup) {
TEST_CASE("UvGetAddrInfoTest FailedLookup", "[uv][dns][addrinfo][.]") {
int fail_cb_called = 0;
auto loop = Loop::Create();
loop->error.connect([&](Error err) {
ASSERT_EQ(fail_cb_called, 0);
ASSERT_LT(err.code(), 0);
REQUIRE(fail_cb_called == 0);
REQUIRE(err.code() < 0);
fail_cb_called++;
});
// Use a FQDN by ending in a period
GetAddrInfo(loop, [](const addrinfo&) { FAIL(); }, "xyzzy.xyzzy.xyzzy.");
loop->Run();
ASSERT_EQ(fail_cb_called, 1);
REQUIRE(fail_cb_called == 1);
}
TEST(UvGetAddrInfoTest, Basic) {
TEST_CASE("UvGetAddrInfoTest Basic", "[uv][dns][addrinfo]") {
int getaddrinfo_cbs = 0;
auto loop = Loop::Create();
@@ -75,11 +75,11 @@ TEST(UvGetAddrInfoTest, Basic) {
loop->Run();
ASSERT_EQ(getaddrinfo_cbs, 1);
REQUIRE(getaddrinfo_cbs == 1);
}
#ifndef _WIN32
TEST(UvGetAddrInfoTest, Concurrent) {
TEST_CASE("UvGetAddrInfoTest Concurrent", "[uv][dns][addrinfo]") {
int getaddrinfo_cbs = 0;
int callback_counts[CONCURRENT_COUNT];
@@ -100,7 +100,7 @@ TEST(UvGetAddrInfoTest, Concurrent) {
loop->Run();
for (int i = 0; i < CONCURRENT_COUNT; i++) {
ASSERT_EQ(callback_counts[i], 1);
REQUIRE(callback_counts[i] == 1);
}
}
#endif

View File

@@ -27,13 +27,13 @@
#include "wpi/net/uv/GetNameInfo.hpp"
// clang-format on
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Loop.hpp"
namespace wpi::net::uv {
TEST(UvGetNameInfoTest, BasicIp4) {
TEST_CASE("UvGetNameInfoTest BasicIp4", "[uv][dns][nameinfo]") {
int getnameinfo_cbs = 0;
auto loop = Loop::Create();
@@ -42,18 +42,18 @@ TEST(UvGetNameInfoTest, BasicIp4) {
GetNameInfo4(
loop,
[&](const char* hostname, const char* service) {
ASSERT_NE(hostname, nullptr);
ASSERT_NE(service, nullptr);
REQUIRE(hostname != nullptr);
REQUIRE(service != nullptr);
getnameinfo_cbs++;
},
"127.0.0.1", 80);
loop->Run();
ASSERT_EQ(getnameinfo_cbs, 1);
REQUIRE(getnameinfo_cbs == 1);
}
TEST(UvGetNameInfoTest, BasicIp6) {
TEST_CASE("UvGetNameInfoTest BasicIp6", "[uv][dns][nameinfo]") {
int getnameinfo_cbs = 0;
auto loop = Loop::Create();
@@ -62,15 +62,15 @@ TEST(UvGetNameInfoTest, BasicIp6) {
GetNameInfo6(
loop,
[&](const char* hostname, const char* service) {
ASSERT_NE(hostname, nullptr);
ASSERT_NE(service, nullptr);
REQUIRE(hostname != nullptr);
REQUIRE(service != nullptr);
getnameinfo_cbs++;
},
"::1", 80);
loop->Run();
ASSERT_EQ(getnameinfo_cbs, 1);
REQUIRE(getnameinfo_cbs == 1);
}
} // namespace wpi::net::uv

View File

@@ -27,13 +27,13 @@
#include "wpi/net/uv/Loop.hpp"
// clang-format on
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
#include "wpi/net/uv/Timer.hpp"
namespace wpi::net::uv {
TEST(UvLoopTest, Walk) {
TEST_CASE("UvLoopTest Walk", "[uv][loop]") {
int seen_timer_handle = 0;
auto loop = Loop::Create();
@@ -54,9 +54,9 @@ TEST(UvLoopTest, Walk) {
timer->Start(Timer::Time{1});
// Start event loop, expect to see the timer handle
ASSERT_EQ(seen_timer_handle, 0);
REQUIRE(seen_timer_handle == 0);
loop->Run();
ASSERT_EQ(seen_timer_handle, 1);
REQUIRE(seen_timer_handle == 1);
// Loop is finished, should not see our timer handle
seen_timer_handle = 0;
@@ -65,7 +65,7 @@ TEST(UvLoopTest, Walk) {
seen_timer_handle++;
}
});
ASSERT_EQ(seen_timer_handle, 0);
REQUIRE(seen_timer_handle == 0);
}
} // namespace wpi::net::uv

View File

@@ -6,11 +6,11 @@
#include "wpi/net/uv/Timer.hpp"
// clang-format on
#include <gtest/gtest.h>
#include <catch2/catch_test_macros.hpp>
namespace wpi::net::uv {
TEST(UvTimerTest, StartAndStop) {
TEST_CASE("UvTimerTest StartAndStop", "[uv][timer]") {
auto loop = Loop::Create();
auto handleNoRepeat = Timer::Create(loop);
auto handleRepeat = Timer::Create(loop);
@@ -23,11 +23,11 @@ TEST(UvTimerTest, StartAndStop) {
handleNoRepeat->timeout.connect(
[&checkTimerNoRepeatEvent, handle = handleNoRepeat.get()] {
ASSERT_FALSE(checkTimerNoRepeatEvent);
REQUIRE_FALSE(checkTimerNoRepeatEvent);
checkTimerNoRepeatEvent = true;
handle->Stop();
handle->Close();
ASSERT_TRUE(handle->IsClosing());
REQUIRE(handle->IsClosing());
});
handleRepeat->timeout.connect(
@@ -35,34 +35,34 @@ TEST(UvTimerTest, StartAndStop) {
if (checkTimerRepeatEvent) {
handle->Stop();
handle->Close();
ASSERT_TRUE(handle->IsClosing());
REQUIRE(handle->IsClosing());
} else {
checkTimerRepeatEvent = true;
ASSERT_FALSE(handle->IsClosing());
REQUIRE_FALSE(handle->IsClosing());
}
});
handleNoRepeat->Start(Timer::Time{0}, Timer::Time{0});
handleRepeat->Start(Timer::Time{0}, Timer::Time{1});
ASSERT_TRUE(handleNoRepeat->IsActive());
ASSERT_FALSE(handleNoRepeat->IsClosing());
REQUIRE(handleNoRepeat->IsActive());
REQUIRE_FALSE(handleNoRepeat->IsClosing());
ASSERT_TRUE(handleRepeat->IsActive());
ASSERT_FALSE(handleRepeat->IsClosing());
REQUIRE(handleRepeat->IsActive());
REQUIRE_FALSE(handleRepeat->IsClosing());
loop->Run();
ASSERT_TRUE(checkTimerNoRepeatEvent);
ASSERT_TRUE(checkTimerRepeatEvent);
REQUIRE(checkTimerNoRepeatEvent);
REQUIRE(checkTimerRepeatEvent);
}
TEST(UvTimerTest, Repeat) {
TEST_CASE("UvTimerTest Repeat", "[uv][timer]") {
auto loop = Loop::Create();
auto handle = Timer::Create(loop);
handle->SetRepeat(Timer::Time{42});
ASSERT_EQ(handle->GetRepeat(), Timer::Time{42});
REQUIRE(handle->GetRepeat() == Timer::Time{42});
handle->Close();
loop->Run(); // forces close callback to run

View File

@@ -165,16 +165,9 @@ if(WITH_TESTS)
target_sources(wpiutil_test PRIVATE ${wpiutil_nanopb_test_src})
target_link_libraries(wpiutil_test wpiutil googletest wpiutil_testlib)
add_executable(wpiutil_catch2_test ${catch2_test_src})
set_property(TARGET wpiutil_catch2_test PROPERTY FOLDER "tests")
wpilib_target_warnings(wpiutil_catch2_test)
if(MSVC)
target_compile_options(wpiutil_catch2_test PRIVATE /wd4101 /wd4251)
endif()
target_link_libraries(wpiutil_catch2_test wpiutil catch2 wpiutil_testlib)
catch_discover_tests(wpiutil_catch2_test)
wpilib_add_test_catch2(wpiutil_catch2 ${catch2_test_src})
target_link_libraries(wpiutil_catch2_test wpiutil wpiutil_testlib)
if(MSVC)
target_compile_options(wpiutil_test PRIVATE /utf-8)
target_compile_options(wpiutil_catch2_test PRIVATE /utf-8)
endif()
endif()