diff --git a/wpiutil/src/main/native/cpp/UDPClient.cpp b/wpiutil/src/main/native/cpp/UDPClient.cpp index f20ec417dc..aafdf477d2 100644 --- a/wpiutil/src/main/native/cpp/UDPClient.cpp +++ b/wpiutil/src/main/native/cpp/UDPClient.cpp @@ -26,13 +26,15 @@ using namespace wpi; UDPClient::UDPClient(Logger& logger) : UDPClient("", logger) {} UDPClient::UDPClient(const Twine& address, Logger& logger) - : m_lsd(0), m_address(address.str()), m_logger(logger) {} + : m_lsd(0), m_port(0), m_address(address.str()), m_logger(logger) {} UDPClient::UDPClient(UDPClient&& other) : m_lsd(other.m_lsd), + m_port(other.m_port), m_address(std::move(other.m_address)), m_logger(other.m_logger) { other.m_lsd = 0; + other.m_port = 0; } UDPClient::~UDPClient() { @@ -47,11 +49,15 @@ UDPClient& UDPClient::operator=(UDPClient&& other) { m_logger = other.m_logger; m_lsd = other.m_lsd; m_address = std::move(other.m_address); + m_port = other.m_port; other.m_lsd = 0; + other.m_port = 0; return *this; } -int UDPClient::start() { +int UDPClient::start() { return start(0); } + +int UDPClient::start(int port) { if (m_lsd > 0) return 0; #ifdef _WIN32 @@ -85,13 +91,26 @@ int UDPClient::start() { } else { addr.sin_addr.s_addr = INADDR_ANY; } - addr.sin_port = htons(0); + addr.sin_port = htons(port); + + if (port != 0) { +#ifdef _WIN32 + int optval = 1; + setsockopt(m_lsd, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&optval), sizeof optval); +#else + int optval = 1; + setsockopt(m_lsd, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&optval), sizeof optval); +#endif + } int result = bind(m_lsd, reinterpret_cast(&addr), sizeof(addr)); if (result != 0) { WPI_ERROR(m_logger, "bind() failed: " << SocketStrerror()); return result; } + m_port = port; return 0; } @@ -106,6 +125,7 @@ void UDPClient::shutdown() { close(m_lsd); #endif m_lsd = 0; + m_port = 0; } } @@ -167,3 +187,51 @@ int UDPClient::send(StringRef data, const Twine& server, int port) { reinterpret_cast(&addr), sizeof(addr)); return result; } + +int UDPClient::receive(uint8_t* data_received, int receive_len) { + if (m_port == 0) return -1; // return if not receiving + return recv(m_lsd, reinterpret_cast(data_received), receive_len, 0); +} + +int UDPClient::receive(uint8_t* data_received, int receive_len, + SmallVectorImpl* addr_received, + int* port_received) { + if (m_port == 0) return -1; // return if not receiving + + struct sockaddr_in remote; + socklen_t remote_len = sizeof(remote); + std::memset(&remote, 0, sizeof(remote)); + + int result = + recvfrom(m_lsd, reinterpret_cast(data_received), receive_len, 0, + reinterpret_cast(&remote), &remote_len); + + char ip[50]; +#ifdef _WIN32 + InetNtop(PF_INET, &(remote.sin_addr.s_addr), ip, sizeof(ip) - 1); +#else + inet_ntop(PF_INET, reinterpret_cast(&(remote.sin_addr.s_addr)), ip, + sizeof(ip) - 1); +#endif + + ip[49] = '\0'; + int addr_len = std::strlen(ip); + addr_received->clear(); + addr_received->append(&ip[0], &ip[addr_len]); + + *port_received = ntohs(remote.sin_port); + + return result; +} + +int UDPClient::set_timeout(double timeout) { + if (timeout < 0) return -1; + struct timeval tv; + tv.tv_sec = timeout; // truncating will give seconds + timeout -= tv.tv_sec; // remove seconds portion + tv.tv_usec = timeout * 1000000; // fractions of a second to us + int ret = setsockopt(m_lsd, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); + if (ret < 0) WPI_ERROR(m_logger, "set timeout failed"); + return ret; +} diff --git a/wpiutil/src/main/native/include/wpi/UDPClient.h b/wpiutil/src/main/native/include/wpi/UDPClient.h index 1f7d9b7a96..635eca84a6 100644 --- a/wpiutil/src/main/native/include/wpi/UDPClient.h +++ b/wpiutil/src/main/native/include/wpi/UDPClient.h @@ -11,6 +11,7 @@ #include #include "wpi/ArrayRef.h" +#include "wpi/SmallVector.h" #include "wpi/StringRef.h" #include "wpi/Twine.h" #include "wpi/mutex.h" @@ -21,6 +22,7 @@ class Logger; class UDPClient { int m_lsd; + int m_port; std::string m_address; Logger& m_logger; @@ -35,10 +37,15 @@ class UDPClient { UDPClient& operator=(UDPClient&& other); int start(); + int start(int port); void shutdown(); // The passed in address MUST be a resolved IP address. int send(ArrayRef data, const Twine& server, int port); int send(StringRef data, const Twine& server, int port); + int receive(uint8_t* data_received, int receive_len); + int receive(uint8_t* data_received, int receive_len, + SmallVectorImpl* addr_received, int* port_received); + int set_timeout(double timeout); }; } // namespace wpi