diff --git a/src/main/native/cpp/tcpsockets/TCPConnector_parallel.cpp b/src/main/native/cpp/tcpsockets/TCPConnector_parallel.cpp new file mode 100644 index 0000000000..c9b92924bc --- /dev/null +++ b/src/main/native/cpp/tcpsockets/TCPConnector_parallel.cpp @@ -0,0 +1,128 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) FIRST 2017. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +#include "tcpsockets/TCPConnector.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/SmallSet.h" + +using namespace wpi; + +// MSVC < 1900 doesn't have support for thread_local +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +// clang check for availability of thread_local +#if !defined(__has_feature) || __has_feature(cxx_thread_local) +#define HAVE_THREAD_LOCAL +#endif +#endif + +std::unique_ptr TCPConnector::connect_parallel( + llvm::ArrayRef> servers, Logger& logger, + int timeout) { + if (servers.empty()) return nullptr; + + // structure to make sure we don't start duplicate workers + struct GlobalState { + std::mutex mutex; +#ifdef HAVE_THREAD_LOCAL + llvm::SmallSet, 16> active; +#else + llvm::SmallSet, 16> active; +#endif + }; +#ifdef HAVE_THREAD_LOCAL + thread_local auto global = std::make_shared(); +#else + static auto global = std::make_shared(); + auto this_id = std::this_thread::get_id(); +#endif + auto local = global; // copy to an automatic variable for lambda capture + + // structure shared between threads and this function + struct Result { + std::mutex mutex; + std::condition_variable cv; + std::unique_ptr stream; + std::atomic count{0}; + std::atomic done{false}; + }; + auto result = std::make_shared(); + + // start worker threads; this is I/O bound so we don't limit to # of procs + Logger* plogger = &logger; + unsigned int num_workers = 0; + for (const auto& server : servers) { + std::pair server_copy{std::string{server.first}, + server.second}; +#ifdef HAVE_THREAD_LOCAL + const auto& active_tracker = server_copy; +#else + std::tuple active_tracker{ + this_id, server_copy.first, server_copy.second}; +#endif + + // don't start a new worker if we had a previously still-active connection + // attempt to the same server + { + std::lock_guard lock(local->mutex); + if (local->active.count(active_tracker) > 0) continue; // already in set + } + + ++num_workers; + + // start the worker + std::thread([=]() { + if (!result->done) { + // add to global state + { + std::lock_guard lock(local->mutex); + local->active.insert(active_tracker); + } + + // try to connect + auto stream = connect(server_copy.first.c_str(), server_copy.second, + *plogger, timeout); + + // remove from global state + { + std::lock_guard lock(local->mutex); + local->active.erase(active_tracker); + } + + // successful connection + if (stream) { + std::lock_guard lock(result->mutex); + if (!result->done.exchange(true)) result->stream = std::move(stream); + } + } + ++result->count; + result->cv.notify_all(); + }).detach(); + } + + // wait for a result, timeout, or all finished + std::unique_lock lock(result->mutex); + if (timeout == 0) { + result->cv.wait( + lock, [&] { return result->stream || result->count >= num_workers; }); + } else { + auto timeout_time = + std::chrono::steady_clock::now() + std::chrono::seconds(timeout); + result->cv.wait_until(lock, timeout_time, [&] { + return result->stream || result->count >= num_workers; + }); + } + + // no need to wait for remaining worker threads; shared_ptr will clean up + return std::move(result->stream); +} diff --git a/src/main/native/include/tcpsockets/TCPConnector.h b/src/main/native/include/tcpsockets/TCPConnector.h index 7afb0446af..6b056eca3f 100644 --- a/src/main/native/include/tcpsockets/TCPConnector.h +++ b/src/main/native/include/tcpsockets/TCPConnector.h @@ -26,6 +26,7 @@ #include +#include "llvm/ArrayRef.h" #include "tcpsockets/NetworkStream.h" namespace wpi { @@ -37,6 +38,9 @@ class TCPConnector { static std::unique_ptr connect(const char* server, int port, Logger& logger, int timeout = 0); + static std::unique_ptr connect_parallel( + llvm::ArrayRef> servers, Logger& logger, + int timeout = 0); }; } // namespace wpi