diff --git a/ntcore/src/generate/cpp/ntcore_cpp_types.cpp.jinja b/ntcore/src/generate/cpp/ntcore_cpp_types.cpp.jinja index ba8bca7c0d..cdaf27fb6e 100644 --- a/ntcore/src/generate/cpp/ntcore_cpp_types.cpp.jinja +++ b/ntcore/src/generate/cpp/ntcore_cpp_types.cpp.jinja @@ -7,69 +7,124 @@ #include "Handle.h" #include "InstanceImpl.h" +namespace { +template +struct ValuesType { + using Vector = + std::vector>::Value>; +}; + +template <> +struct ValuesType { + using Vector = std::vector; +}; +} // namespace + namespace nt { -{% for t in types %} -bool Set{{ t.TypeName }}(NT_Handle pubentry, {{ t.cpp.ParamType }} value, int64_t time) { + +template +static inline bool Set(NT_Handle pubentry, typename TypeInfo::View value, + int64_t time) { if (auto ii = InstanceImpl::Get(Handle{pubentry}.GetInst())) { - return ii->localStorage.SetEntryValue(pubentry, - Value::Make{{ t.TypeName }}(value, time == 0 ? Now() : time)); + return ii->localStorage.SetEntryValue( + pubentry, MakeValue(value, time == 0 ? Now() : time)); } else { return {}; } } -bool SetDefault{{ t.TypeName }}(NT_Handle pubentry, {{ t.cpp.ParamType }} defaultValue) { +template +static inline bool SetDefault(NT_Handle pubentry, + typename TypeInfo::View defaultValue) { if (auto ii = InstanceImpl::Get(Handle{pubentry}.GetInst())) { return ii->localStorage.SetDefaultEntryValue(pubentry, - Value::Make{{ t.TypeName }}(defaultValue, 1)); + MakeValue(defaultValue, 1)); } else { return {}; } } -{{ t.cpp.ValueType }} Get{{ t.TypeName }}(NT_Handle subentry, {{ t.cpp.ParamType }} defaultValue) { - return GetAtomic{{ t.TypeName }}(subentry, defaultValue).value; -} - -Timestamped{{ t.TypeName }} GetAtomic{{ t.TypeName }}(NT_Handle subentry, {{ t.cpp.ParamType }} defaultValue) { +template +static inline Timestamped::Value> GetAtomic( + NT_Handle subentry, typename TypeInfo::View defaultValue) { if (auto ii = InstanceImpl::Get(Handle{subentry}.GetInst())) { - return ii->localStorage.GetAtomic{{ t.TypeName }}(subentry, defaultValue); + return ii->localStorage.GetAtomic(subentry, defaultValue); } else { return {}; } } -std::vector ReadQueue{{ t.TypeName }}(NT_Handle subentry) { +template +inline Timestamped::SmallRet> GetAtomic( + NT_Handle subentry, + wpi::SmallVectorImpl::SmallElem>& buf, + typename TypeInfo::View defaultValue) { if (auto ii = InstanceImpl::Get(Handle{subentry}.GetInst())) { - return ii->localStorage.ReadQueue{{ t.TypeName }}(subentry); + return ii->localStorage.GetAtomic(subentry, buf, defaultValue); } else { return {}; } } -std::vector<{% if t.cpp.ValueType == "bool" %}int{% else %}{{ t.cpp.ValueType }}{% endif %}> ReadQueueValues{{ t.TypeName }}(NT_Handle subentry) { - std::vector<{% if t.cpp.ValueType == "bool" %}int{% else %}{{ t.cpp.ValueType }}{% endif %}> rv; - auto arr = ReadQueue{{ t.TypeName }}(subentry); +template +static inline std::vector::Value>> ReadQueue( + NT_Handle subentry) { + if (auto ii = InstanceImpl::Get(Handle{subentry}.GetInst())) { + return ii->localStorage.ReadQueue(subentry); + } else { + return {}; + } +} + +template +static inline typename ValuesType::Vector ReadQueueValues( + NT_Handle subentry) { + typename ValuesType::Vector rv; + auto arr = ReadQueue(subentry); rv.reserve(arr.size()); for (auto&& elem : arr) { rv.emplace_back(std::move(elem.value)); } return rv; } +{% for t in types %} +bool Set{{ t.TypeName }}(NT_Handle pubentry, {{ t.cpp.ParamType }} value, int64_t time) { + return Set<{{ t.cpp.TemplateType }}>(pubentry, value, time); +} + +bool SetDefault{{ t.TypeName }}(NT_Handle pubentry, {{ t.cpp.ParamType }} defaultValue) { + return SetDefault<{{ t.cpp.TemplateType }}>(pubentry, defaultValue); +} + +{{ t.cpp.ValueType }} Get{{ t.TypeName }}(NT_Handle subentry, {{ t.cpp.ParamType }} defaultValue) { + return GetAtomic<{{ t.cpp.TemplateType }}>(subentry, defaultValue).value; +} + +Timestamped{{ t.TypeName }} GetAtomic{{ t.TypeName }}( + NT_Handle subentry, {{ t.cpp.ParamType }} defaultValue) { + return GetAtomic<{{ t.cpp.TemplateType }}>(subentry, defaultValue); +} + +std::vector ReadQueue{{ t.TypeName }}(NT_Handle subentry) { + return ReadQueue<{{ t.cpp.TemplateType }}>(subentry); +} + +std::vector<{% if t.cpp.ValueType == "bool" %}int{% else %}{{ t.cpp.ValueType }}{% endif %}> ReadQueueValues{{ t.TypeName }}(NT_Handle subentry) { + return ReadQueueValues<{{ t.cpp.TemplateType }}>(subentry); +} {% if t.cpp.SmallRetType and t.cpp.SmallElemType %} -{{ t.cpp.SmallRetType }} Get{{ t.TypeName }}(NT_Handle subentry, wpi::SmallVectorImpl<{{ t.cpp.SmallElemType }}>& buf, {{ t.cpp.ParamType }} defaultValue) { - return GetAtomic{{ t.TypeName }}(subentry, buf, defaultValue).value; +{{ t.cpp.SmallRetType }} Get{{ t.TypeName }}( + NT_Handle subentry, + wpi::SmallVectorImpl<{{ t.cpp.SmallElemType }}>& buf, + {{ t.cpp.ParamType }} defaultValue) { + return GetAtomic<{{ t.cpp.TemplateType }}>(subentry, buf, defaultValue).value; } Timestamped{{ t.TypeName }}View GetAtomic{{ t.TypeName }}( NT_Handle subentry, wpi::SmallVectorImpl<{{ t.cpp.SmallElemType }}>& buf, {{ t.cpp.ParamType }} defaultValue) { - if (auto ii = InstanceImpl::Get(Handle{subentry}.GetInst())) { - return ii->localStorage.GetAtomic{{ t.TypeName }}(subentry, buf, defaultValue); - } else { - return {}; - } + return GetAtomic<{{ t.cpp.TemplateType }}>(subentry, buf, defaultValue); } {% endif %} {% endfor %} diff --git a/ntcore/src/generate/include/ntcore_cpp_types.h.jinja b/ntcore/src/generate/include/ntcore_cpp_types.h.jinja index e987186aa6..df919de05a 100644 --- a/ntcore/src/generate/include/ntcore_cpp_types.h.jinja +++ b/ntcore/src/generate/include/ntcore_cpp_types.h.jinja @@ -20,56 +20,43 @@ class SmallVectorImpl; } // namespace wpi namespace nt { +/** + * Timestamped value. + * @ingroup ntcore_cpp_handle_api + */ +template +struct Timestamped { + Timestamped() = default; + Timestamped(int64_t time, int64_t serverTime, T value) + : time{time}, serverTime{serverTime}, value{std::move(value)} {} + + /** + * Time in local time base. + */ + int64_t time = 0; + + /** + * Time in server time base. May be 0 or 1 for locally set values. + */ + int64_t serverTime = 0; + + /** + * Value. + */ + T value = {}; +}; {% for t in types %} /** * Timestamped {{ t.TypeName }}. * @ingroup ntcore_cpp_handle_api */ -struct Timestamped{{ t.TypeName }} { - Timestamped{{ t.TypeName }}() = default; - Timestamped{{ t.TypeName }}(int64_t time, int64_t serverTime, {{ t.cpp.ValueType }} value) - : time{time}, serverTime{serverTime}, value{std::move(value)} {} - - /** - * Time in local time base. - */ - int64_t time = 0; - - /** - * Time in server time base. May be 0 or 1 for locally set values. - */ - int64_t serverTime = 0; - - /** - * Value. - */ - {{ t.cpp.ValueType }} value = {}; -}; +using Timestamped{{ t.TypeName }} = Timestamped<{{ t.cpp.ValueType }}>; {% if t.cpp.SmallRetType %} /** * Timestamped {{ t.TypeName }} view (for SmallVector-taking functions). * @ingroup ntcore_cpp_handle_api */ -struct Timestamped{{ t.TypeName }}View { - Timestamped{{ t.TypeName }}View() = default; - Timestamped{{ t.TypeName }}View(int64_t time, int64_t serverTime, {{ t.cpp.SmallRetType }} value) - : time{time}, serverTime{serverTime}, value{std::move(value)} {} - - /** - * Time in local time base. - */ - int64_t time = 0; - - /** - * Time in server time base. May be 0 or 1 for locally set values. - */ - int64_t serverTime = 0; - - /** - * Value. - */ - {{ t.cpp.SmallRetType }} value = {}; -}; +using Timestamped{{ t.TypeName }}View = Timestamped<{{ t.cpp.SmallRetType }}>; {% endif %} /** * @defgroup ntcore_{{ t.TypeName }}_func {{ t.TypeName }} Functions diff --git a/ntcore/src/generate/types.json b/ntcore/src/generate/types.json index c9d874bcfa..55f5741119 100644 --- a/ntcore/src/generate/types.json +++ b/ntcore/src/generate/types.json @@ -9,6 +9,7 @@ "cpp": { "ValueType": "bool", "ParamType": "bool", + "TemplateType": "bool", "TYPE_NAME": "BOOLEAN" }, "java": { @@ -40,6 +41,7 @@ "cpp": { "ValueType": "int64_t", "ParamType": "int64_t", + "TemplateType": "int64_t", "TYPE_NAME": "INTEGER" }, "java": { @@ -71,6 +73,7 @@ "cpp": { "ValueType": "float", "ParamType": "float", + "TemplateType": "float", "TYPE_NAME": "FLOAT" }, "java": { @@ -104,6 +107,7 @@ "cpp": { "ValueType": "double", "ParamType": "double", + "TemplateType": "double", "TYPE_NAME": "DOUBLE" }, "java": { @@ -136,6 +140,7 @@ "cpp": { "ValueType": "std::string", "ParamType": "std::string_view", + "TemplateType": "std::string", "TYPE_NAME": "STRING", "INCLUDES": "#include \n#include \n#include ", "SmallRetType": "std::string_view", @@ -168,6 +173,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "uint8_t[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "RAW", "INCLUDES": "#include ", @@ -202,6 +208,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "bool[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "BOOLEAN_ARRAY", "INCLUDES": "#include ", @@ -237,6 +244,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "int64_t[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "INTEGER_ARRAY", "INCLUDES": "#include ", @@ -272,6 +280,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "float[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "FLOAT_ARRAY", "INCLUDES": "#include ", @@ -307,6 +316,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "double[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "DOUBLE_ARRAY", "INCLUDES": "#include ", @@ -342,6 +352,7 @@ "cpp": { "ValueType": "std::vector", "ParamType": "std::span", + "TemplateType": "std::string[]", "DefaultValueCopy": "defaultValue.begin(), defaultValue.end()", "TYPE_NAME": "STRING_ARRAY", "INCLUDES": "#include " diff --git a/ntcore/src/main/native/cpp/ListenerStorage.cpp b/ntcore/src/main/native/cpp/ListenerStorage.cpp index fdbf03bc09..4270d0b919 100644 --- a/ntcore/src/main/native/cpp/ListenerStorage.cpp +++ b/ntcore/src/main/native/cpp/ListenerStorage.cpp @@ -6,25 +6,12 @@ #include -#include #include #include "ntcore_c.h" using namespace nt; -class ListenerStorage::Thread final : public wpi::SafeThreadEvent { - public: - explicit Thread(NT_ListenerPoller poller) : m_poller{poller} {} - - void Main() final; - - NT_ListenerPoller m_poller; - wpi::DenseMap m_callbacks; - wpi::Event m_waitQueueWakeup; - wpi::Event m_waitQueueWaiter; -}; - void ListenerStorage::Thread::Main() { while (m_active) { WPI_Handle signaledBuf[3]; @@ -55,10 +42,6 @@ void ListenerStorage::Thread::Main() { } } -ListenerStorage::ListenerStorage(int inst) : m_inst{inst} {} - -ListenerStorage::~ListenerStorage() = default; - void ListenerStorage::Activate(NT_Listener listenerHandle, unsigned int mask, FinishEventFunc finishEvent) { std::scoped_lock lock{m_mutex}; diff --git a/ntcore/src/main/native/cpp/ListenerStorage.h b/ntcore/src/main/native/cpp/ListenerStorage.h index 44b4ceaa12..49dc8ce082 100644 --- a/ntcore/src/main/native/cpp/ListenerStorage.h +++ b/ntcore/src/main/native/cpp/ListenerStorage.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -19,16 +20,16 @@ #include "Handle.h" #include "HandleMap.h" #include "IListenerStorage.h" +#include "VectorSet.h" #include "ntcore_cpp.h" namespace nt { class ListenerStorage final : public IListenerStorage { public: - explicit ListenerStorage(int inst); + explicit ListenerStorage(int inst) : m_inst{inst} {} ListenerStorage(const ListenerStorage&) = delete; ListenerStorage& operator=(const ListenerStorage&) = delete; - ~ListenerStorage() final; // IListenerStorage interface void Activate(NT_Listener listenerHandle, unsigned int mask, @@ -97,21 +98,23 @@ class ListenerStorage final : public IListenerStorage { }; HandleMap m_listeners; - // Utility wrapper for making a set-like vector - template - class VectorSet : public std::vector { - public: - void Add(T value) { this->push_back(value); } - void Remove(T value) { std::erase(*this, value); } - }; - VectorSet m_connListeners; VectorSet m_topicListeners; VectorSet m_valueListeners; VectorSet m_logListeners; VectorSet m_timeSyncListeners; - class Thread; + class Thread final : public wpi::SafeThreadEvent { + public: + explicit Thread(NT_ListenerPoller poller) : m_poller{poller} {} + + void Main() final; + + NT_ListenerPoller m_poller; + wpi::DenseMap m_callbacks; + wpi::Event m_waitQueueWakeup; + wpi::Event m_waitQueueWaiter; + }; wpi::SafeThreadOwner m_thread; }; diff --git a/ntcore/src/main/native/cpp/LocalStorage.cpp b/ntcore/src/main/native/cpp/LocalStorage.cpp index bf4b643bed..4d3ea71e15 100644 --- a/ntcore/src/main/native/cpp/LocalStorage.cpp +++ b/ntcore/src/main/native/cpp/LocalStorage.cpp @@ -8,21 +8,13 @@ #include #include -#include -#include -#include -#include #include -#include "Handle.h" -#include "HandleMap.h" #include "IListenerStorage.h" #include "Log.h" -#include "PubSubOptions.h" #include "Types_internal.h" #include "Value_internal.h" #include "networktables/NetworkTableValue.h" -#include "ntcore_c.h" using namespace nt; @@ -32,178 +24,18 @@ static constexpr size_t kMaxSubscribers = 512; static constexpr size_t kMaxMultiSubscribers = 512; static constexpr size_t kMaxListeners = 512; -namespace { - -static constexpr bool IsSpecial(std::string_view name) { - return name.empty() ? false : name.front() == '$'; -} - static constexpr bool PrefixMatch(std::string_view name, std::string_view prefix, bool special) { return (!special || !prefix.empty()) && wpi::starts_with(name, prefix); } -// Utility wrapper for making a set-like vector -template -class VectorSet : public std::vector { - public: - void Add(T value) { this->push_back(value); } - void Remove(T value) { std::erase(*this, value); } -}; +std::string LocalStorage::DataLoggerEntry::MakeMetadata( + std::string_view properties) { + return fmt::format("{{\"properties\":{},\"source\":\"NT\"}}", properties); +} -struct EntryData; -struct PublisherData; -struct SubscriberData; -struct MultiSubscriberData; - -struct DataLoggerEntry { - DataLoggerEntry(wpi::log::DataLog& log, int entry, NT_DataLogger logger) - : log{&log}, entry{entry}, logger{logger} {} - - static std::string MakeMetadata(std::string_view properties) { - return fmt::format("{{\"properties\":{},\"source\":\"NT\"}}", properties); - } - - void Append(const Value& v); - - wpi::log::DataLog* log; - int entry; - NT_DataLogger logger; -}; - -struct TopicData { - static constexpr auto kType = Handle::kTopic; - - TopicData(NT_Topic handle, std::string_view name) - : handle{handle}, name{name}, special{IsSpecial(name)} {} - - bool Exists() const { return onNetwork || !localPublishers.empty(); } - - TopicInfo GetTopicInfo() const; - - // invariants - wpi::SignalObject handle; - std::string name; - bool special; - - Value lastValue; // also stores timestamp - Value lastValueNetwork; - NT_Type type{NT_UNASSIGNED}; - std::string typeStr; - unsigned int flags{0}; // for NT3 APIs - std::string propertiesStr{"{}"}; // cached string for GetTopicInfo() et al - wpi::json properties = wpi::json::object(); - NT_Entry entry{0}; // cached entry for GetEntry() - - bool onNetwork{false}; // true if there are any remote publishers - bool lastValueFromNetwork{false}; - - wpi::SmallVector datalogs; - NT_Type datalogType{NT_UNASSIGNED}; - - VectorSet localPublishers; - VectorSet localSubscribers; - VectorSet multiSubscribers; - VectorSet entries; - VectorSet listeners; -}; - -struct PubSubConfig : public PubSubOptionsImpl { - PubSubConfig() = default; - PubSubConfig(NT_Type type, std::string_view typeStr, - const PubSubOptions& options) - : PubSubOptionsImpl{options}, type{type}, typeStr{typeStr} { - prefixMatch = false; - } - - NT_Type type{NT_UNASSIGNED}; - std::string typeStr; -}; - -struct PublisherData { - static constexpr auto kType = Handle::kPublisher; - - PublisherData(NT_Publisher handle, TopicData* topic, PubSubConfig config) - : handle{handle}, topic{topic}, config{std::move(config)} {} - - void UpdateActive(); - - // invariants - wpi::SignalObject handle; - TopicData* topic; - PubSubConfig config; - - // whether or not the publisher should actually publish values - bool active{false}; -}; - -struct SubscriberData { - static constexpr auto kType = Handle::kSubscriber; - - SubscriberData(NT_Subscriber handle, TopicData* topic, PubSubConfig config) - : handle{handle}, - topic{topic}, - config{std::move(config)}, - pollStorage{config.pollStorage} {} - - void UpdateActive(); - - // invariants - wpi::SignalObject handle; - TopicData* topic; - PubSubConfig config; - - // whether or not the subscriber should actually receive values - bool active{false}; - - // polling storage - wpi::circular_buffer pollStorage; - - // value listeners - VectorSet valueListeners; -}; - -struct EntryData { - static constexpr auto kType = Handle::kEntry; - - EntryData(NT_Entry handle, SubscriberData* subscriber) - : handle{handle}, topic{subscriber->topic}, subscriber{subscriber} {} - - // invariants - wpi::SignalObject handle; - TopicData* topic; - SubscriberData* subscriber; - - // the publisher (created on demand) - PublisherData* publisher{nullptr}; -}; - -struct MultiSubscriberData { - static constexpr auto kType = Handle::kMultiSubscriber; - - MultiSubscriberData(NT_MultiSubscriber handle, - std::span prefixes, - const PubSubOptionsImpl& options) - : handle{handle}, options{options} { - this->options.prefixMatch = true; - this->prefixes.reserve(prefixes.size()); - for (auto&& prefix : prefixes) { - this->prefixes.emplace_back(prefix); - } - } - - bool Matches(std::string_view name, bool special); - - // invariants - wpi::SignalObject handle; - std::vector prefixes; - PubSubOptionsImpl options; - - // value listeners - VectorSet valueListeners; -}; - -bool MultiSubscriberData::Matches(std::string_view name, bool special) { +bool LocalStorage::MultiSubscriberData::Matches(std::string_view name, + bool special) { for (auto&& prefix : prefixes) { if (PrefixMatch(name, prefix, special)) { return true; @@ -212,156 +44,14 @@ bool MultiSubscriberData::Matches(std::string_view name, bool special) { return false; } -struct ListenerData { - ListenerData(NT_Listener handle, SubscriberData* subscriber, - unsigned int eventMask, bool subscriberOwned) - : handle{handle}, - eventMask{eventMask}, - subscriber{subscriber}, - subscriberOwned{subscriberOwned} {} - ListenerData(NT_Listener handle, MultiSubscriberData* subscriber, - unsigned int eventMask, bool subscriberOwned) - : handle{handle}, - eventMask{eventMask}, - multiSubscriber{subscriber}, - subscriberOwned{subscriberOwned} {} +int LocalStorage::DataLoggerData::Start(TopicData* topic, int64_t time) { + return log.Start(fmt::format("{}{}", logPrefix, + wpi::drop_front(topic->name, prefix.size())), + topic->typeStr == "int" ? "int64" : topic->typeStr, + DataLoggerEntry::MakeMetadata(topic->propertiesStr), time); +} - NT_Listener handle; - unsigned int eventMask; - SubscriberData* subscriber{nullptr}; - MultiSubscriberData* multiSubscriber{nullptr}; - bool subscriberOwned; -}; - -struct DataLoggerData { - static constexpr auto kType = Handle::kDataLogger; - - DataLoggerData(NT_DataLogger handle, wpi::log::DataLog& log, - std::string_view prefix, std::string_view logPrefix) - : handle{handle}, log{log}, prefix{prefix}, logPrefix{logPrefix} {} - - int Start(TopicData* topic, int64_t time) { - return log.Start(fmt::format("{}{}", logPrefix, - wpi::drop_front(topic->name, prefix.size())), - topic->typeStr == "int" ? "int64" : topic->typeStr, - DataLoggerEntry::MakeMetadata(topic->propertiesStr), time); - } - - NT_DataLogger handle; - wpi::log::DataLog& log; - std::string prefix; - std::string logPrefix; -}; - -struct LSImpl { - LSImpl(int inst, IListenerStorage& listenerStorage, wpi::Logger& logger) - : m_inst{inst}, m_listenerStorage{listenerStorage}, m_logger{logger} {} - - int m_inst; - IListenerStorage& m_listenerStorage; - wpi::Logger& m_logger; - net::NetworkInterface* m_network{nullptr}; - - // handle mappings - HandleMap m_topics; - HandleMap m_publishers; - HandleMap m_subscribers; - HandleMap m_entries; - HandleMap m_multiSubscribers; - HandleMap m_dataloggers; - - // name mappings - wpi::StringMap m_nameTopics; - - // listeners - wpi::DenseMap> m_listeners; - - // string-based listeners - VectorSet m_topicPrefixListeners; - - // topic functions - void NotifyTopic(TopicData* topic, unsigned int eventFlags); - - void CheckReset(TopicData* topic); - - bool SetValue(TopicData* topic, const Value& value, unsigned int eventFlags, - bool suppressIfDuplicate, bool isDuplicate, - const PublisherData* publisher); - void NotifyValue(TopicData* topic, unsigned int eventFlags, bool isDuplicate, - const PublisherData* publisher); - - void SetFlags(TopicData* topic, unsigned int flags); - void SetPersistent(TopicData* topic, bool value); - void SetRetained(TopicData* topic, bool value); - void SetProperties(TopicData* topic, const wpi::json& update, - bool sendNetwork); - void PropertiesUpdated(TopicData* topic, const wpi::json& update, - unsigned int eventFlags, bool sendNetwork, - bool updateFlags = true); - - void RefreshPubSubActive(TopicData* topic, bool warnOnSubMismatch); - - void NetworkAnnounce(TopicData* topic, std::string_view typeStr, - const wpi::json& properties, NT_Publisher pubHandle); - void RemoveNetworkPublisher(TopicData* topic); - void NetworkPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack); - - PublisherData* AddLocalPublisher(TopicData* topic, - const wpi::json& properties, - const PubSubConfig& options); - std::unique_ptr RemoveLocalPublisher(NT_Publisher pubHandle); - - SubscriberData* AddLocalSubscriber(TopicData* topic, - const PubSubConfig& options); - std::unique_ptr RemoveLocalSubscriber( - NT_Subscriber subHandle); - - EntryData* AddEntry(SubscriberData* subscriber); - std::unique_ptr RemoveEntry(NT_Entry entryHandle); - - MultiSubscriberData* AddMultiSubscriber( - std::span prefixes, const PubSubOptions& options); - std::unique_ptr RemoveMultiSubscriber( - NT_MultiSubscriber subHandle); - - void AddListenerImpl(NT_Listener listenerHandle, TopicData* topic, - unsigned int eventMask); - void AddListenerImpl(NT_Listener listenerHandle, SubscriberData* subscriber, - unsigned int eventMask, NT_Handle subentryHandle, - bool subscriberOwned); - void AddListenerImpl(NT_Listener listenerHandle, - MultiSubscriberData* subscriber, unsigned int eventMask, - bool subscriberOwned); - void AddListenerImpl(NT_Listener listenerHandle, - std::span prefixes, - unsigned int eventMask); - - void AddListener(NT_Listener listenerHandle, - std::span prefixes, - unsigned int mask); - void AddListener(NT_Listener listenerHandle, NT_Handle handle, - unsigned int mask); - void RemoveListener(NT_Listener listenerHandle, unsigned int mask); - - TopicData* GetOrCreateTopic(std::string_view name); - TopicData* GetTopic(NT_Handle handle); - SubscriberData* GetSubEntry(NT_Handle subentryHandle); - PublisherData* PublishEntry(EntryData* entry, NT_Type type); - Value* GetSubEntryValue(NT_Handle subentryHandle); - - bool PublishLocalValue(PublisherData* publisher, const Value& value, - bool force = false); - - bool SetEntryValue(NT_Handle pubentryHandle, const Value& value); - bool SetDefaultEntryValue(NT_Handle pubsubentryHandle, const Value& value); - - void RemoveSubEntry(NT_Handle subentryHandle); -}; - -} // namespace - -void DataLoggerEntry::Append(const Value& v) { +void LocalStorage::DataLoggerEntry::Append(const Value& v) { auto time = v.time(); switch (v.type()) { case NT_BOOLEAN: @@ -406,7 +96,7 @@ void DataLoggerEntry::Append(const Value& v) { } } -TopicInfo TopicData::GetTopicInfo() const { +TopicInfo LocalStorage::TopicData::GetTopicInfo() const { TopicInfo info; info.topic = handle; info.name = name; @@ -416,19 +106,8 @@ TopicInfo TopicData::GetTopicInfo() const { return info; } -void PublisherData::UpdateActive() { - active = config.type == topic->type && config.typeStr == topic->typeStr; -} - -void SubscriberData::UpdateActive() { - // for subscribers, unassigned is a wildcard - // also allow numerically compatible subscribers - active = config.type == NT_UNASSIGNED || - (config.type == topic->type && config.typeStr == topic->typeStr) || - IsNumericCompatible(config.type, topic->type); -} - -void LSImpl::NotifyTopic(TopicData* topic, unsigned int eventFlags) { +void LocalStorage::Impl::NotifyTopic(TopicData* topic, + unsigned int eventFlags) { DEBUG4("NotifyTopic({}, {})", topic->name, eventFlags); auto topicInfo = topic->GetTopicInfo(); if (!topic->listeners.empty()) { @@ -480,7 +159,7 @@ void LSImpl::NotifyTopic(TopicData* topic, unsigned int eventFlags) { } } -void LSImpl::CheckReset(TopicData* topic) { +void LocalStorage::Impl::CheckReset(TopicData* topic) { if (topic->Exists()) { return; } @@ -494,10 +173,10 @@ void LSImpl::CheckReset(TopicData* topic) { topic->propertiesStr = "{}"; } -bool LSImpl::SetValue(TopicData* topic, const Value& value, - unsigned int eventFlags, bool isDuplicate, - bool suppressIfDuplicate, - const PublisherData* publisher) { +bool LocalStorage::Impl::SetValue(TopicData* topic, const Value& value, + unsigned int eventFlags, bool isDuplicate, + bool suppressIfDuplicate, + const PublisherData* publisher) { DEBUG4("SetValue({}, {}, {}, {})", topic->name, value.time(), eventFlags, isDuplicate); if (topic->type != NT_UNASSIGNED && topic->type != value.type()) { @@ -522,8 +201,9 @@ bool LSImpl::SetValue(TopicData* topic, const Value& value, return true; } -void LSImpl::NotifyValue(TopicData* topic, unsigned int eventFlags, - bool isDuplicate, const PublisherData* publisher) { +void LocalStorage::Impl::NotifyValue(TopicData* topic, unsigned int eventFlags, + bool isDuplicate, + const PublisherData* publisher) { bool isNetwork = (eventFlags & NT_EVENT_VALUE_REMOTE) != 0; for (auto&& subscriber : topic->localSubscribers) { if (subscriber->active && @@ -552,7 +232,7 @@ void LSImpl::NotifyValue(TopicData* topic, unsigned int eventFlags, } } -void LSImpl::SetFlags(TopicData* topic, unsigned int flags) { +void LocalStorage::Impl::SetFlags(TopicData* topic, unsigned int flags) { wpi::json update = wpi::json::object(); if ((flags & NT_PERSISTENT) != 0) { topic->properties["persistent"] = true; @@ -574,7 +254,7 @@ void LSImpl::SetFlags(TopicData* topic, unsigned int flags) { } } -void LSImpl::SetPersistent(TopicData* topic, bool value) { +void LocalStorage::Impl::SetPersistent(TopicData* topic, bool value) { wpi::json update = wpi::json::object(); if (value) { topic->flags |= NT_PERSISTENT; @@ -588,7 +268,7 @@ void LSImpl::SetPersistent(TopicData* topic, bool value) { PropertiesUpdated(topic, update, NT_EVENT_NONE, true, false); } -void LSImpl::SetRetained(TopicData* topic, bool value) { +void LocalStorage::Impl::SetRetained(TopicData* topic, bool value) { wpi::json update = wpi::json::object(); if (value) { topic->flags |= NT_RETAINED; @@ -602,8 +282,9 @@ void LSImpl::SetRetained(TopicData* topic, bool value) { PropertiesUpdated(topic, update, NT_EVENT_NONE, true, false); } -void LSImpl::SetProperties(TopicData* topic, const wpi::json& update, - bool sendNetwork) { +void LocalStorage::Impl::SetProperties(TopicData* topic, + const wpi::json& update, + bool sendNetwork) { if (!update.is_object()) { return; } @@ -618,9 +299,10 @@ void LSImpl::SetProperties(TopicData* topic, const wpi::json& update, PropertiesUpdated(topic, update, NT_EVENT_NONE, sendNetwork); } -void LSImpl::PropertiesUpdated(TopicData* topic, const wpi::json& update, - unsigned int eventFlags, bool sendNetwork, - bool updateFlags) { +void LocalStorage::Impl::PropertiesUpdated(TopicData* topic, + const wpi::json& update, + unsigned int eventFlags, + bool sendNetwork, bool updateFlags) { DEBUG4("PropertiesUpdated({}, {}, {}, {}, {})", topic->name, update.dump(), eventFlags, sendNetwork, updateFlags); if (updateFlags) { @@ -655,7 +337,8 @@ void LSImpl::PropertiesUpdated(TopicData* topic, const wpi::json& update, } } -void LSImpl::RefreshPubSubActive(TopicData* topic, bool warnOnSubMismatch) { +void LocalStorage::Impl::RefreshPubSubActive(TopicData* topic, + bool warnOnSubMismatch) { for (auto&& publisher : topic->localPublishers) { publisher->UpdateActive(); } @@ -671,9 +354,10 @@ void LSImpl::RefreshPubSubActive(TopicData* topic, bool warnOnSubMismatch) { } } -void LSImpl::NetworkAnnounce(TopicData* topic, std::string_view typeStr, - const wpi::json& properties, - NT_Publisher pubHandle) { +void LocalStorage::Impl::NetworkAnnounce(TopicData* topic, + std::string_view typeStr, + const wpi::json& properties, + NT_Publisher pubHandle) { DEBUG4("LS NetworkAnnounce({}, {}, {}, {})", topic->name, typeStr, properties.dump(), pubHandle); if (pubHandle != 0) { @@ -725,7 +409,7 @@ void LSImpl::NetworkAnnounce(TopicData* topic, std::string_view typeStr, } } -void LSImpl::RemoveNetworkPublisher(TopicData* topic) { +void LocalStorage::Impl::RemoveNetworkPublisher(TopicData* topic) { DEBUG4("LS RemoveNetworkPublisher({}, {})", topic->handle, topic->name); // this acts as an unpublish bool didExist = topic->Exists(); @@ -755,8 +439,9 @@ void LSImpl::RemoveNetworkPublisher(TopicData* topic) { } } -void LSImpl::NetworkPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack) { +void LocalStorage::Impl::NetworkPropertiesUpdate(TopicData* topic, + const wpi::json& update, + bool ack) { DEBUG4("NetworkPropertiesUpdate({},{})", topic->name, ack); if (ack) { return; // ignore acks @@ -764,9 +449,8 @@ void LSImpl::NetworkPropertiesUpdate(TopicData* topic, const wpi::json& update, SetProperties(topic, update, false); } -PublisherData* LSImpl::AddLocalPublisher(TopicData* topic, - const wpi::json& properties, - const PubSubConfig& config) { +LocalStorage::PublisherData* LocalStorage::Impl::AddLocalPublisher( + TopicData* topic, const wpi::json& properties, const PubSubConfig& config) { bool didExist = topic->Exists(); auto publisher = m_publishers.Add(m_inst, topic, config); topic->localPublishers.Add(publisher); @@ -784,8 +468,7 @@ PublisherData* LSImpl::AddLocalPublisher(TopicData* topic, } else if (properties.is_object()) { topic->properties = properties; } else { - WARNING("ignoring non-object properties when publishing '{}'", - topic->name); + WARN("ignoring non-object properties when publishing '{}'", topic->name); topic->properties = wpi::json::object(); } @@ -813,8 +496,8 @@ PublisherData* LSImpl::AddLocalPublisher(TopicData* topic, return publisher; } -std::unique_ptr LSImpl::RemoveLocalPublisher( - NT_Publisher pubHandle) { +std::unique_ptr +LocalStorage::Impl::RemoveLocalPublisher(NT_Publisher pubHandle) { auto publisher = m_publishers.Remove(pubHandle); if (publisher) { auto topic = publisher->topic; @@ -849,8 +532,8 @@ std::unique_ptr LSImpl::RemoveLocalPublisher( return publisher; } -SubscriberData* LSImpl::AddLocalSubscriber(TopicData* topic, - const PubSubConfig& config) { +LocalStorage::SubscriberData* LocalStorage::Impl::AddLocalSubscriber( + TopicData* topic, const PubSubConfig& config) { DEBUG4("AddLocalSubscriber({})", topic->name); auto subscriber = m_subscribers.Add(m_inst, topic, config); topic->localSubscribers.Add(subscriber); @@ -881,8 +564,8 @@ SubscriberData* LSImpl::AddLocalSubscriber(TopicData* topic, return subscriber; } -std::unique_ptr LSImpl::RemoveLocalSubscriber( - NT_Subscriber subHandle) { +std::unique_ptr +LocalStorage::Impl::RemoveLocalSubscriber(NT_Subscriber subHandle) { auto subscriber = m_subscribers.Remove(subHandle); if (subscriber) { auto topic = subscriber->topic; @@ -899,13 +582,15 @@ std::unique_ptr LSImpl::RemoveLocalSubscriber( return subscriber; } -EntryData* LSImpl::AddEntry(SubscriberData* subscriber) { +LocalStorage::EntryData* LocalStorage::Impl::AddEntry( + SubscriberData* subscriber) { auto entry = m_entries.Add(m_inst, subscriber); subscriber->topic->entries.Add(entry); return entry; } -std::unique_ptr LSImpl::RemoveEntry(NT_Entry entryHandle) { +std::unique_ptr LocalStorage::Impl::RemoveEntry( + NT_Entry entryHandle) { auto entry = m_entries.Remove(entryHandle); if (entry) { entry->topic->entries.Remove(entry.get()); @@ -913,7 +598,7 @@ std::unique_ptr LSImpl::RemoveEntry(NT_Entry entryHandle) { return entry; } -MultiSubscriberData* LSImpl::AddMultiSubscriber( +LocalStorage::MultiSubscriberData* LocalStorage::Impl::AddMultiSubscriber( std::span prefixes, const PubSubOptions& options) { DEBUG4("AddMultiSubscriber({})", fmt::join(prefixes, ",")); auto subscriber = m_multiSubscribers.Add(m_inst, prefixes, options); @@ -934,8 +619,8 @@ MultiSubscriberData* LSImpl::AddMultiSubscriber( return subscriber; } -std::unique_ptr LSImpl::RemoveMultiSubscriber( - NT_MultiSubscriber subHandle) { +std::unique_ptr +LocalStorage::Impl::RemoveMultiSubscriber(NT_MultiSubscriber subHandle) { auto subscriber = m_multiSubscribers.Remove(subHandle); if (subscriber) { for (auto&& topic : m_topics) { @@ -953,11 +638,11 @@ std::unique_ptr LSImpl::RemoveMultiSubscriber( return subscriber; } -void LSImpl::AddListenerImpl(NT_Listener listenerHandle, TopicData* topic, - unsigned int eventMask) { +void LocalStorage::Impl::AddListenerImpl(NT_Listener listenerHandle, + TopicData* topic, + unsigned int eventMask) { if (topic->localSubscribers.size() >= kMaxSubscribers) { - ERROR( - "reached maximum number of subscribers to '{}', ignoring listener add", + ERR("reached maximum number of subscribers to '{}', ignoring listener add", topic->name); return; } @@ -968,9 +653,11 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, TopicData* topic, AddListenerImpl(listenerHandle, sub, eventMask, sub->handle, true); } -void LSImpl::AddListenerImpl(NT_Listener listenerHandle, - SubscriberData* subscriber, unsigned int eventMask, - NT_Handle subentryHandle, bool subscriberOwned) { +void LocalStorage::Impl::AddListenerImpl(NT_Listener listenerHandle, + SubscriberData* subscriber, + unsigned int eventMask, + NT_Handle subentryHandle, + bool subscriberOwned) { m_listeners.try_emplace(listenerHandle, std::make_unique( listenerHandle, subscriber, eventMask, subscriberOwned)); @@ -979,8 +666,8 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, if ((eventMask & NT_EVENT_TOPIC) != 0) { if (topic->listeners.size() >= kMaxListeners) { - ERROR("reached maximum number of listeners to '{}', not adding listener", - topic->name); + ERR("reached maximum number of listeners to '{}', not adding listener", + topic->name); return; } @@ -1001,8 +688,8 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, if ((eventMask & NT_EVENT_VALUE_ALL) != 0) { if (subscriber->valueListeners.size() >= kMaxListeners) { - ERROR("reached maximum number of listeners to '{}', not adding listener", - topic->name); + ERR("reached maximum number of listeners to '{}', not adding listener", + topic->name); return; } m_listenerStorage.Activate( @@ -1026,9 +713,10 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, } } -void LSImpl::AddListenerImpl(NT_Listener listenerHandle, - MultiSubscriberData* subscriber, - unsigned int eventMask, bool subscriberOwned) { +void LocalStorage::Impl::AddListenerImpl(NT_Listener listenerHandle, + MultiSubscriberData* subscriber, + unsigned int eventMask, + bool subscriberOwned) { auto listener = m_listeners .try_emplace(listenerHandle, std::make_unique( @@ -1050,7 +738,7 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, if ((eventMask & NT_EVENT_TOPIC) != 0) { if (m_topicPrefixListeners.size() >= kMaxListeners) { - ERROR("reached maximum number of listeners, not adding listener"); + ERR("reached maximum number of listeners, not adding listener"); return; } @@ -1076,7 +764,7 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, if ((eventMask & NT_EVENT_VALUE_ALL) != 0) { if (subscriber->valueListeners.size() >= kMaxListeners) { - ERROR("reached maximum number of listeners, not adding listener"); + ERR("reached maximum number of listeners, not adding listener"); return; } @@ -1106,61 +794,8 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, } } -void LSImpl::AddListener(NT_Listener listenerHandle, - std::span prefixes, - unsigned int eventMask) { - if (m_multiSubscribers.size() >= kMaxMultiSubscribers) { - ERROR("reached maximum number of multi-subscribers, not adding listener"); - return; - } - // subscribe to make sure topic updates are received - auto sub = AddMultiSubscriber( - prefixes, {.topicsOnly = (eventMask & NT_EVENT_VALUE_ALL) == 0}); - AddListenerImpl(listenerHandle, sub, eventMask, true); -} - -void LSImpl::AddListener(NT_Listener listenerHandle, NT_Handle handle, - unsigned int mask) { - if (auto topic = m_topics.Get(handle)) { - AddListenerImpl(listenerHandle, topic, mask); - } else if (auto sub = m_multiSubscribers.Get(handle)) { - AddListenerImpl(listenerHandle, sub, mask, false); - } else if (auto sub = m_subscribers.Get(handle)) { - AddListenerImpl(listenerHandle, sub, mask, sub->handle, false); - } else if (auto entry = m_entries.Get(handle)) { - AddListenerImpl(listenerHandle, entry->subscriber, mask, entry->handle, - false); - } -} - -void LSImpl::RemoveListener(NT_Listener listenerHandle, unsigned int mask) { - auto listenerIt = m_listeners.find(listenerHandle); - if (listenerIt == m_listeners.end()) { - return; - } - auto listener = std::move(listenerIt->getSecond()); - m_listeners.erase(listenerIt); - if (!listener) { - return; - } - - m_topicPrefixListeners.Remove(listener.get()); - if (listener->subscriber) { - listener->subscriber->valueListeners.Remove(listenerHandle); - listener->subscriber->topic->listeners.Remove(listenerHandle); - if (listener->subscriberOwned) { - RemoveLocalSubscriber(listener->subscriber->handle); - } - } - if (listener->multiSubscriber) { - listener->multiSubscriber->valueListeners.Remove(listenerHandle); - if (listener->subscriberOwned) { - RemoveMultiSubscriber(listener->multiSubscriber->handle); - } - } -} - -TopicData* LSImpl::GetOrCreateTopic(std::string_view name) { +LocalStorage::TopicData* LocalStorage::Impl::GetOrCreateTopic( + std::string_view name) { auto& topic = m_nameTopics[name]; // create if it does not already exist if (!topic) { @@ -1175,7 +810,7 @@ TopicData* LSImpl::GetOrCreateTopic(std::string_view name) { return topic; } -TopicData* LSImpl::GetTopic(NT_Handle handle) { +LocalStorage::TopicData* LocalStorage::Impl::GetTopic(NT_Handle handle) { switch (Handle{handle}.GetType()) { case Handle::kEntry: { if (auto entry = m_entries.Get(handle)) { @@ -1203,7 +838,8 @@ TopicData* LSImpl::GetTopic(NT_Handle handle) { return {}; } -SubscriberData* LSImpl::GetSubEntry(NT_Handle subentryHandle) { +LocalStorage::SubscriberData* LocalStorage::Impl::GetSubEntry( + NT_Handle subentryHandle) { Handle h{subentryHandle}; if (h.IsType(Handle::kSubscriber)) { return m_subscribers.Get(subentryHandle); @@ -1215,7 +851,8 @@ SubscriberData* LSImpl::GetSubEntry(NT_Handle subentryHandle) { } } -PublisherData* LSImpl::PublishEntry(EntryData* entry, NT_Type type) { +LocalStorage::PublisherData* LocalStorage::Impl::PublishEntry(EntryData* entry, + NT_Type type) { if (entry->publisher) { return entry->publisher; } @@ -1227,8 +864,8 @@ PublisherData* LSImpl::PublishEntry(EntryData* entry, NT_Type type) { entry->subscriber->config.typeStr != typeStr) { if (!IsNumericCompatible(type, entry->subscriber->config.type)) { // don't allow dynamically changing the type of an entry - ERROR("cannot publish entry {} as type {}, previously subscribed as {}", - entry->topic->name, typeStr, entry->subscriber->config.typeStr); + ERR("cannot publish entry {} as type {}, previously subscribed as {}", + entry->topic->name, typeStr, entry->subscriber->config.typeStr); return nullptr; } } @@ -1242,16 +879,8 @@ PublisherData* LSImpl::PublishEntry(EntryData* entry, NT_Type type) { return entry->publisher; } -Value* LSImpl::GetSubEntryValue(NT_Handle subentryHandle) { - if (auto subscriber = GetSubEntry(subentryHandle)) { - return &subscriber->topic->lastValue; - } else { - return nullptr; - } -} - -bool LSImpl::PublishLocalValue(PublisherData* publisher, const Value& value, - bool force) { +bool LocalStorage::Impl::PublishLocalValue(PublisherData* publisher, + const Value& value, bool force) { if (!value) { return false; } @@ -1284,7 +913,8 @@ bool LSImpl::PublishLocalValue(PublisherData* publisher, const Value& value, } } -bool LSImpl::SetEntryValue(NT_Handle pubentryHandle, const Value& value) { +bool LocalStorage::Impl::SetEntryValue(NT_Handle pubentryHandle, + const Value& value) { if (!value) { return false; } @@ -1300,8 +930,8 @@ bool LSImpl::SetEntryValue(NT_Handle pubentryHandle, const Value& value) { return PublishLocalValue(publisher, value); } -bool LSImpl::SetDefaultEntryValue(NT_Handle pubsubentryHandle, - const Value& value) { +bool LocalStorage::Impl::SetDefaultEntryValue(NT_Handle pubsubentryHandle, + const Value& value) { DEBUG4("SetDefaultEntryValue({}, {})", pubsubentryHandle, static_cast(value.type())); if (!value) { @@ -1341,7 +971,7 @@ bool LSImpl::SetDefaultEntryValue(NT_Handle pubsubentryHandle, return false; } -void LSImpl::RemoveSubEntry(NT_Handle subentryHandle) { +void LocalStorage::Impl::RemoveSubEntry(NT_Handle subentryHandle) { Handle h{subentryHandle}; if (h.IsType(Handle::kSubscriber)) { RemoveLocalSubscriber(subentryHandle); @@ -1357,15 +987,9 @@ void LSImpl::RemoveSubEntry(NT_Handle subentryHandle) { } } -class LocalStorage::Impl : public LSImpl { - public: - Impl(int inst, IListenerStorage& listenerStorage, wpi::Logger& logger) - : LSImpl{inst, listenerStorage, logger} {} -}; - -LocalStorage::LocalStorage(int inst, IListenerStorage& listenerStorage, - wpi::Logger& logger) - : m_impl{std::make_unique(inst, listenerStorage, logger)} {} +LocalStorage::Impl::Impl(int inst, IListenerStorage& listenerStorage, + wpi::Logger& logger) + : m_inst{inst}, m_listenerStorage{listenerStorage}, m_logger{logger} {} LocalStorage::~LocalStorage() = default; @@ -1374,31 +998,31 @@ NT_Topic LocalStorage::NetworkAnnounce(std::string_view name, const wpi::json& properties, NT_Publisher pubHandle) { std::scoped_lock lock{m_mutex}; - auto topic = m_impl->GetOrCreateTopic(name); - m_impl->NetworkAnnounce(topic, typeStr, properties, pubHandle); + auto topic = m_impl.GetOrCreateTopic(name); + m_impl.NetworkAnnounce(topic, typeStr, properties, pubHandle); return topic->handle; } void LocalStorage::NetworkUnannounce(std::string_view name) { std::scoped_lock lock{m_mutex}; - auto topic = m_impl->GetOrCreateTopic(name); - m_impl->RemoveNetworkPublisher(topic); + auto topic = m_impl.GetOrCreateTopic(name); + m_impl.RemoveNetworkPublisher(topic); } void LocalStorage::NetworkPropertiesUpdate(std::string_view name, const wpi::json& update, bool ack) { std::scoped_lock lock{m_mutex}; - auto it = m_impl->m_nameTopics.find(name); - if (it != m_impl->m_nameTopics.end()) { - m_impl->NetworkPropertiesUpdate(it->second, update, ack); + auto it = m_impl.m_nameTopics.find(name); + if (it != m_impl.m_nameTopics.end()) { + m_impl.NetworkPropertiesUpdate(it->second, update, ack); } } void LocalStorage::NetworkSetValue(NT_Topic topicHandle, const Value& value) { std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - if (m_impl->SetValue(topic, value, NT_EVENT_VALUE_REMOTE, - value == topic->lastValue, false, nullptr)) { + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + if (m_impl.SetValue(topic, value, NT_EVENT_VALUE_REMOTE, + value == topic->lastValue, false, nullptr)) { topic->lastValueNetwork = value; topic->lastValueFromNetwork = true; } @@ -1406,12 +1030,16 @@ void LocalStorage::NetworkSetValue(NT_Topic topicHandle, const Value& value) { } void LocalStorage::StartNetwork(net::NetworkInterface* network) { - WPI_DEBUG4(m_impl->m_logger, "StartNetwork()"); std::scoped_lock lock{m_mutex}; - m_impl->m_network = network; + m_impl.StartNetwork(network); +} + +void LocalStorage::Impl::StartNetwork(net::NetworkInterface* network) { + DEBUG4("StartNetwork()"); + m_network = network; // publish all active publishers to the network and send last values // only send value once per topic - for (auto&& topic : m_impl->m_topics) { + for (auto&& topic : m_topics) { PublisherData* anyPublisher = nullptr; for (auto&& publisher : topic->localPublishers) { if (publisher->active) { @@ -1424,23 +1052,23 @@ void LocalStorage::StartNetwork(net::NetworkInterface* network) { network->SetValue(anyPublisher->handle, topic->lastValue); } } - for (auto&& subscriber : m_impl->m_subscribers) { + for (auto&& subscriber : m_subscribers) { network->Subscribe(subscriber->handle, {{subscriber->topic->name}}, subscriber->config); } - for (auto&& subscriber : m_impl->m_multiSubscribers) { + for (auto&& subscriber : m_multiSubscribers) { network->Subscribe(subscriber->handle, subscriber->prefixes, subscriber->options); } } void LocalStorage::ClearNetwork() { - WPI_DEBUG4(m_impl->m_logger, "ClearNetwork()"); + WPI_DEBUG4(m_impl.m_logger, "ClearNetwork()"); std::scoped_lock lock{m_mutex}; - m_impl->m_network = nullptr; + m_impl.m_network = nullptr; // treat as an unannounce all from the network side - for (auto&& topic : m_impl->m_topics) { - m_impl->RemoveNetworkPublisher(topic.get()); + for (auto&& topic : m_impl.m_topics) { + m_impl.RemoveNetworkPublisher(topic.get()); } } @@ -1491,7 +1119,7 @@ std::vector LocalStorage::GetTopics(std::string_view prefix, unsigned int types) { std::scoped_lock lock(m_mutex); std::vector rv; - ForEachTopic(m_impl->m_topics, prefix, types, + ForEachTopic(m_impl.m_topics, prefix, types, [&](TopicData& topic) { rv.push_back(topic.handle); }); return rv; } @@ -1500,7 +1128,7 @@ std::vector LocalStorage::GetTopics( std::string_view prefix, std::span types) { std::scoped_lock lock(m_mutex); std::vector rv; - ForEachTopic(m_impl->m_topics, prefix, types, + ForEachTopic(m_impl.m_topics, prefix, types, [&](TopicData& topic) { rv.push_back(topic.handle); }); return rv; } @@ -1509,7 +1137,7 @@ std::vector LocalStorage::GetTopicInfo(std::string_view prefix, unsigned int types) { std::scoped_lock lock(m_mutex); std::vector rv; - ForEachTopic(m_impl->m_topics, prefix, types, [&](TopicData& topic) { + ForEachTopic(m_impl.m_topics, prefix, types, [&](TopicData& topic) { rv.emplace_back(topic.GetTopicInfo()); }); return rv; @@ -1519,99 +1147,16 @@ std::vector LocalStorage::GetTopicInfo( std::string_view prefix, std::span types) { std::scoped_lock lock(m_mutex); std::vector rv; - ForEachTopic(m_impl->m_topics, prefix, types, [&](TopicData& topic) { + ForEachTopic(m_impl.m_topics, prefix, types, [&](TopicData& topic) { rv.emplace_back(topic.GetTopicInfo()); }); return rv; } -NT_Topic LocalStorage::GetTopic(std::string_view name) { - if (name.empty()) { - return {}; - } - std::scoped_lock lock{m_mutex}; - return m_impl->GetOrCreateTopic(name)->handle; -} - -std::string LocalStorage::GetTopicName(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->name; - } else { - return {}; - } -} - -NT_Type LocalStorage::GetTopicType(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->type; - } else { - return {}; - } -} - -std::string LocalStorage::GetTopicTypeString(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->typeStr; - } else { - return {}; - } -} - -void LocalStorage::SetTopicPersistent(NT_Topic topicHandle, bool value) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - m_impl->SetPersistent(topic, value); - } -} - -bool LocalStorage::GetTopicPersistent(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return (topic->flags & NT_PERSISTENT) != 0; - } else { - return false; - } -} - -void LocalStorage::SetTopicRetained(NT_Topic topicHandle, bool value) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - m_impl->SetRetained(topic, value); - } -} - -bool LocalStorage::GetTopicRetained(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return (topic->flags & NT_RETAINED) != 0; - } else { - return false; - } -} - -bool LocalStorage::GetTopicExists(NT_Handle handle) { - std::scoped_lock lock{m_mutex}; - TopicData* topic = m_impl->GetTopic(handle); - return topic && topic->Exists(); -} - -wpi::json LocalStorage::GetTopicProperty(NT_Topic topicHandle, - std::string_view name) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->properties.value(name, wpi::json{}); - } else { - return {}; - } -} - void LocalStorage::SetTopicProperty(NT_Topic topicHandle, std::string_view name, const wpi::json& value) { std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { + if (auto topic = m_impl.m_topics.Get(topicHandle)) { if (value.is_null()) { topic->properties.erase(name); } else { @@ -1619,27 +1164,18 @@ void LocalStorage::SetTopicProperty(NT_Topic topicHandle, std::string_view name, } wpi::json update = wpi::json::object(); update[name] = value; - m_impl->PropertiesUpdated(topic, update, NT_EVENT_NONE, true); + m_impl.PropertiesUpdated(topic, update, NT_EVENT_NONE, true); } } void LocalStorage::DeleteTopicProperty(NT_Topic topicHandle, std::string_view name) { std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { + if (auto topic = m_impl.m_topics.Get(topicHandle)) { topic->properties.erase(name); wpi::json update = wpi::json::object(); update[name] = wpi::json(); - m_impl->PropertiesUpdated(topic, update, NT_EVENT_NONE, true); - } -} - -wpi::json LocalStorage::GetTopicProperties(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->properties; - } else { - return wpi::json::object(); + m_impl.PropertiesUpdated(topic, update, NT_EVENT_NONE, true); } } @@ -1649,67 +1185,48 @@ bool LocalStorage::SetTopicProperties(NT_Topic topicHandle, return false; } std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - m_impl->SetProperties(topic, update, true); + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + m_impl.SetProperties(topic, update, true); return true; } else { return {}; } } -TopicInfo LocalStorage::GetTopicInfo(NT_Topic topicHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->m_topics.Get(topicHandle)) { - return topic->GetTopicInfo(); - } else { - return {}; - } -} - NT_Subscriber LocalStorage::Subscribe(NT_Topic topicHandle, NT_Type type, std::string_view typeStr, const PubSubOptions& options) { std::scoped_lock lock{m_mutex}; // Get the topic - auto* topic = m_impl->m_topics.Get(topicHandle); + auto* topic = m_impl.m_topics.Get(topicHandle); if (!topic) { return 0; } if (topic->localSubscribers.size() >= kMaxSubscribers) { - WPI_ERROR(m_impl->m_logger, + WPI_ERROR(m_impl.m_logger, "reached maximum number of subscribers to '{}', not subscribing", topic->name); return 0; } // Create subscriber - return m_impl->AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}) + return m_impl.AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}) ->handle; } -void LocalStorage::Unsubscribe(NT_Subscriber subHandle) { - std::scoped_lock lock{m_mutex}; - m_impl->RemoveSubEntry(subHandle); -} - NT_MultiSubscriber LocalStorage::SubscribeMultiple( std::span prefixes, const PubSubOptions& options) { std::scoped_lock lock{m_mutex}; - if (m_impl->m_multiSubscribers.size() >= kMaxMultiSubscribers) { - WPI_ERROR(m_impl->m_logger, + if (m_impl.m_multiSubscribers.size() >= kMaxMultiSubscribers) { + WPI_ERROR(m_impl.m_logger, "reached maximum number of multi-subscribers, not subscribing"); return 0; } - return m_impl->AddMultiSubscriber(prefixes, options)->handle; -} - -void LocalStorage::UnsubscribeMultiple(NT_MultiSubscriber subHandle) { - std::scoped_lock lock{m_mutex}; - m_impl->RemoveMultiSubscriber(subHandle); + return m_impl.AddMultiSubscriber(prefixes, options)->handle; } NT_Publisher LocalStorage::Publish(NT_Topic topicHandle, NT_Type type, @@ -1719,31 +1236,31 @@ NT_Publisher LocalStorage::Publish(NT_Topic topicHandle, NT_Type type, std::scoped_lock lock{m_mutex}; // Get the topic - auto* topic = m_impl->m_topics.Get(topicHandle); + auto* topic = m_impl.m_topics.Get(topicHandle); if (!topic) { - WPI_ERROR(m_impl->m_logger, "trying to publish invalid topic handle ({})", + WPI_ERROR(m_impl.m_logger, "trying to publish invalid topic handle ({})", topicHandle); return 0; } if (type == NT_UNASSIGNED || typeStr.empty()) { WPI_ERROR( - m_impl->m_logger, + m_impl.m_logger, "cannot publish '{}' with an unassigned type or empty type string", topic->name); return 0; } if (topic->localPublishers.size() >= kMaxPublishers) { - WPI_ERROR(m_impl->m_logger, + WPI_ERROR(m_impl.m_logger, "reached maximum number of publishers to '{}', not publishing", topic->name); return 0; } return m_impl - ->AddLocalPublisher(topic, properties, - PubSubConfig{type, typeStr, options}) + .AddLocalPublisher(topic, properties, + PubSubConfig{type, typeStr, options}) ->handle; } @@ -1751,10 +1268,10 @@ void LocalStorage::Unpublish(NT_Handle pubentryHandle) { std::scoped_lock lock{m_mutex}; if (Handle{pubentryHandle}.IsType(Handle::kPublisher)) { - m_impl->RemoveLocalPublisher(pubentryHandle); - } else if (auto entry = m_impl->m_entries.Get(pubentryHandle)) { + m_impl.RemoveLocalPublisher(pubentryHandle); + } else if (auto entry = m_impl.m_entries.Get(pubentryHandle)) { if (entry->publisher) { - m_impl->RemoveLocalPublisher(entry->publisher->handle); + m_impl.RemoveLocalPublisher(entry->publisher->handle); entry->publisher = nullptr; } } else { @@ -1769,14 +1286,14 @@ NT_Entry LocalStorage::GetEntry(NT_Topic topicHandle, NT_Type type, std::scoped_lock lock{m_mutex}; // Get the topic - auto* topic = m_impl->m_topics.Get(topicHandle); + auto* topic = m_impl.m_topics.Get(topicHandle); if (!topic) { return 0; } if (topic->localSubscribers.size() >= kMaxSubscribers) { WPI_ERROR( - m_impl->m_logger, + m_impl.m_logger, "reached maximum number of subscribers to '{}', not creating entry", topic->name); return 0; @@ -1784,15 +1301,10 @@ NT_Entry LocalStorage::GetEntry(NT_Topic topicHandle, NT_Type type, // Create subscriber auto subscriber = - m_impl->AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}); + m_impl.AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}); // Create entry - return m_impl->AddEntry(subscriber)->handle; -} - -void LocalStorage::ReleaseEntry(NT_Entry entryHandle) { - std::scoped_lock lock{m_mutex}; - m_impl->RemoveSubEntry(entryHandle); + return m_impl.AddEntry(subscriber)->handle; } void LocalStorage::Release(NT_Handle pubsubentryHandle) { @@ -1814,324 +1326,9 @@ void LocalStorage::Release(NT_Handle pubsubentryHandle) { } } -NT_Topic LocalStorage::GetTopicFromHandle(NT_Handle pubsubentryHandle) { - std::scoped_lock lock{m_mutex}; - if (auto topic = m_impl->GetTopic(pubsubentryHandle)) { - return topic->handle; - } else { - return {}; - } -} - -bool LocalStorage::SetEntryValue(NT_Handle pubentryHandle, const Value& value) { - std::scoped_lock lock{m_mutex}; - return m_impl->SetEntryValue(pubentryHandle, value); -} - -bool LocalStorage::SetDefaultEntryValue(NT_Handle pubsubentryHandle, - const Value& value) { - std::scoped_lock lock{m_mutex}; - return m_impl->SetDefaultEntryValue(pubsubentryHandle, value); -} - -TimestampedBoolean LocalStorage::GetAtomicBoolean(NT_Handle subentryHandle, - bool defaultValue) { - std::scoped_lock lock{m_mutex}; - Value* value = m_impl->GetSubEntryValue(subentryHandle); - if (value && value->type() == NT_BOOLEAN) { - return {value->time(), value->server_time(), value->GetBoolean()}; - } else { - return {0, 0, defaultValue}; - } -} - -TimestampedString LocalStorage::GetAtomicString(NT_Handle subentryHandle, - std::string_view defaultValue) { - std::scoped_lock lock{m_mutex}; - Value* value = m_impl->GetSubEntryValue(subentryHandle); - if (value && value->type() == NT_STRING) { - return {value->time(), value->server_time(), - std::string{value->GetString()}}; - } else { - return {0, 0, std::string{defaultValue}}; - } -} - -TimestampedStringView LocalStorage::GetAtomicString( - NT_Handle subentryHandle, wpi::SmallVectorImpl& buf, - std::string_view defaultValue) { - std::scoped_lock lock{m_mutex}; - Value* value = m_impl->GetSubEntryValue(subentryHandle); - if (value && value->type() == NT_STRING) { - auto str = value->GetString(); - buf.assign(str.begin(), str.end()); - return {value->time(), value->server_time(), {buf.data(), buf.size()}}; - } else { - return {0, 0, defaultValue}; - } -} - -template -static T GetAtomicNumber(Value* value, U defaultValue) { - if (value && value->type() == NT_INTEGER) { - return {value->time(), value->server_time(), - static_cast(value->GetInteger())}; - } else if (value && value->type() == NT_FLOAT) { - return {value->time(), value->server_time(), - static_cast(value->GetFloat())}; - } else if (value && value->type() == NT_DOUBLE) { - return {value->time(), value->server_time(), - static_cast(value->GetDouble())}; - } else { - return {0, 0, defaultValue}; - } -} - -template -static T GetAtomicNumberArray(Value* value, std::span defaultValue) { - if (value && value->type() == NT_INTEGER_ARRAY) { - auto arr = value->GetIntegerArray(); - return {value->time(), value->server_time(), {arr.begin(), arr.end()}}; - } else if (value && value->type() == NT_FLOAT_ARRAY) { - auto arr = value->GetFloatArray(); - return {value->time(), value->server_time(), {arr.begin(), arr.end()}}; - } else if (value && value->type() == NT_DOUBLE_ARRAY) { - auto arr = value->GetDoubleArray(); - return {value->time(), value->server_time(), {arr.begin(), arr.end()}}; - } else { - return {0, 0, {defaultValue.begin(), defaultValue.end()}}; - } -} - -template -static T GetAtomicNumberArray(Value* value, wpi::SmallVectorImpl& buf, - std::span defaultValue) { - if (value && value->type() == NT_INTEGER_ARRAY) { - auto str = value->GetIntegerArray(); - buf.assign(str.begin(), str.end()); - return {value->time(), value->server_time(), {buf.data(), buf.size()}}; - } else if (value && value->type() == NT_FLOAT_ARRAY) { - auto str = value->GetFloatArray(); - buf.assign(str.begin(), str.end()); - return {value->time(), value->server_time(), {buf.data(), buf.size()}}; - } else if (value && value->type() == NT_DOUBLE_ARRAY) { - auto str = value->GetDoubleArray(); - buf.assign(str.begin(), str.end()); - return {value->time(), value->server_time(), {buf.data(), buf.size()}}; - } else { - buf.assign(defaultValue.begin(), defaultValue.end()); - return {0, 0, {buf.data(), buf.size()}}; - } -} - -#define GET_ATOMIC_NUMBER(Name, dtype) \ - Timestamped##Name LocalStorage::GetAtomic##Name(NT_Handle subentry, \ - dtype defaultValue) { \ - std::scoped_lock lock{m_mutex}; \ - return GetAtomicNumber( \ - m_impl->GetSubEntryValue(subentry), defaultValue); \ - } \ - \ - Timestamped##Name##Array LocalStorage::GetAtomic##Name##Array( \ - NT_Handle subentry, std::span defaultValue) { \ - std::scoped_lock lock{m_mutex}; \ - return GetAtomicNumberArray( \ - m_impl->GetSubEntryValue(subentry), defaultValue); \ - } \ - \ - Timestamped##Name##ArrayView LocalStorage::GetAtomic##Name##Array( \ - NT_Handle subentry, wpi::SmallVectorImpl& buf, \ - std::span defaultValue) { \ - std::scoped_lock lock{m_mutex}; \ - return GetAtomicNumberArray( \ - m_impl->GetSubEntryValue(subentry), buf, defaultValue); \ - } - -GET_ATOMIC_NUMBER(Integer, int64_t) -GET_ATOMIC_NUMBER(Float, float) -GET_ATOMIC_NUMBER(Double, double) - -#define GET_ATOMIC_ARRAY(Name, dtype) \ - Timestamped##Name LocalStorage::GetAtomic##Name( \ - NT_Handle subentry, std::span defaultValue) { \ - std::scoped_lock lock{m_mutex}; \ - Value* value = m_impl->GetSubEntryValue(subentry); \ - if (value && value->Is##Name()) { \ - auto arr = value->Get##Name(); \ - return {value->time(), value->server_time(), {arr.begin(), arr.end()}}; \ - } else { \ - return {0, 0, {defaultValue.begin(), defaultValue.end()}}; \ - } \ - } - -GET_ATOMIC_ARRAY(Raw, uint8_t) -GET_ATOMIC_ARRAY(BooleanArray, int) -GET_ATOMIC_ARRAY(StringArray, std::string) - -#define GET_ATOMIC_SMALL_ARRAY(Name, dtype) \ - Timestamped##Name##View LocalStorage::GetAtomic##Name( \ - NT_Handle subentry, wpi::SmallVectorImpl& buf, \ - std::span defaultValue) { \ - std::scoped_lock lock{m_mutex}; \ - Value* value = m_impl->GetSubEntryValue(subentry); \ - if (value && value->Is##Name()) { \ - auto str = value->Get##Name(); \ - buf.assign(str.begin(), str.end()); \ - return {value->time(), value->server_time(), {buf.data(), buf.size()}}; \ - } else { \ - buf.assign(defaultValue.begin(), defaultValue.end()); \ - return {0, 0, {buf.data(), buf.size()}}; \ - } \ - } - -GET_ATOMIC_SMALL_ARRAY(Raw, uint8_t) -GET_ATOMIC_SMALL_ARRAY(BooleanArray, int) - -std::vector LocalStorage::ReadQueueValue(NT_Handle subentry) { - std::scoped_lock lock{m_mutex}; - auto subscriber = m_impl->GetSubEntry(subentry); - if (!subscriber) { - return {}; - } - std::vector rv; - rv.reserve(subscriber->pollStorage.size()); - for (auto&& val : subscriber->pollStorage) { - rv.emplace_back(std::move(val)); - } - subscriber->pollStorage.reset(); - return rv; -} - -std::vector LocalStorage::ReadQueueBoolean( - NT_Handle subentry) { - std::scoped_lock lock{m_mutex}; - auto subscriber = m_impl->GetSubEntry(subentry); - if (!subscriber) { - return {}; - } - std::vector rv; - rv.reserve(subscriber->pollStorage.size()); - for (auto&& val : subscriber->pollStorage) { - if (val.IsBoolean()) { - rv.emplace_back(val.time(), val.server_time(), val.GetBoolean()); - } - } - subscriber->pollStorage.reset(); - return rv; -} - -std::vector LocalStorage::ReadQueueString( - NT_Handle subentry) { - std::scoped_lock lock{m_mutex}; - auto subscriber = m_impl->GetSubEntry(subentry); - if (!subscriber) { - return {}; - } - std::vector rv; - rv.reserve(subscriber->pollStorage.size()); - for (auto&& val : subscriber->pollStorage) { - if (val.IsString()) { - rv.emplace_back(val.time(), val.server_time(), - std::string{val.GetString()}); - } - } - subscriber->pollStorage.reset(); - return rv; -} - -#define READ_QUEUE_ARRAY(Name) \ - std::vector LocalStorage::ReadQueue##Name( \ - NT_Handle subentry) { \ - std::scoped_lock lock{m_mutex}; \ - auto subscriber = m_impl->GetSubEntry(subentry); \ - if (!subscriber) { \ - return {}; \ - } \ - std::vector rv; \ - rv.reserve(subscriber->pollStorage.size()); \ - for (auto&& val : subscriber->pollStorage) { \ - if (val.Is##Name()) { \ - auto arr = val.Get##Name(); \ - rv.emplace_back(Timestamped##Name{ \ - val.time(), val.server_time(), {arr.begin(), arr.end()}}); \ - } \ - } \ - subscriber->pollStorage.reset(); \ - return rv; \ - } - -READ_QUEUE_ARRAY(Raw) -READ_QUEUE_ARRAY(BooleanArray) -READ_QUEUE_ARRAY(StringArray) - -template -static std::vector ReadQueueNumber(SubscriberData* subscriber) { - if (!subscriber) { - return {}; - } - std::vector rv; - rv.reserve(subscriber->pollStorage.size()); - for (auto&& val : subscriber->pollStorage) { - auto ts = val.time(); - auto sts = val.server_time(); - if (val.IsInteger()) { - rv.emplace_back(T(ts, sts, val.GetInteger())); - } else if (val.IsFloat()) { - rv.emplace_back(T(ts, sts, val.GetFloat())); - } else if (val.IsDouble()) { - rv.emplace_back(T(ts, sts, val.GetDouble())); - } - } - subscriber->pollStorage.reset(); - return rv; -} - -template -static std::vector ReadQueueNumberArray(SubscriberData* subscriber) { - if (!subscriber) { - return {}; - } - std::vector rv; - rv.reserve(subscriber->pollStorage.size()); - for (auto&& val : subscriber->pollStorage) { - auto ts = val.time(); - auto sts = val.server_time(); - if (val.IsIntegerArray()) { - auto arr = val.GetIntegerArray(); - rv.emplace_back(T{ts, sts, {arr.begin(), arr.end()}}); - } else if (val.IsFloatArray()) { - auto arr = val.GetFloatArray(); - rv.emplace_back(T{ts, sts, {arr.begin(), arr.end()}}); - } else if (val.IsDoubleArray()) { - auto arr = val.GetDoubleArray(); - rv.emplace_back(T{ts, sts, {arr.begin(), arr.end()}}); - } - } - subscriber->pollStorage.reset(); - return rv; -} - -#define READ_QUEUE_NUMBER(Name) \ - std::vector LocalStorage::ReadQueue##Name( \ - NT_Handle subentry) { \ - std::scoped_lock lock{m_mutex}; \ - return ReadQueueNumber(m_impl->GetSubEntry(subentry)); \ - } \ - \ - std::vector LocalStorage::ReadQueue##Name##Array( \ - NT_Handle subentry) { \ - std::scoped_lock lock{m_mutex}; \ - return ReadQueueNumberArray( \ - m_impl->GetSubEntry(subentry)); \ - } - -READ_QUEUE_NUMBER(Integer) -READ_QUEUE_NUMBER(Float) -READ_QUEUE_NUMBER(Double) - Value LocalStorage::GetEntryValue(NT_Handle subentryHandle) { std::scoped_lock lock{m_mutex}; - if (auto subscriber = m_impl->GetSubEntry(subentryHandle)) { + if (auto subscriber = m_impl.GetSubEntry(subentryHandle)) { if (subscriber->config.type == NT_UNASSIGNED || !subscriber->topic->lastValue || subscriber->config.type == subscriber->topic->lastValue.type()) { @@ -2145,22 +1342,6 @@ Value LocalStorage::GetEntryValue(NT_Handle subentryHandle) { return {}; } -void LocalStorage::SetEntryFlags(NT_Entry entryHandle, unsigned int flags) { - std::scoped_lock lock{m_mutex}; - if (auto entry = m_impl->m_entries.Get(entryHandle)) { - m_impl->SetFlags(entry->subscriber->topic, flags); - } -} - -unsigned int LocalStorage::GetEntryFlags(NT_Entry entryHandle) { - std::scoped_lock lock{m_mutex}; - if (auto entry = m_impl->m_entries.Get(entryHandle)) { - return entry->subscriber->topic->flags; - } else { - return 0; - } -} - NT_Entry LocalStorage::GetEntry(std::string_view name) { if (name.empty()) { return {}; @@ -2169,72 +1350,87 @@ NT_Entry LocalStorage::GetEntry(std::string_view name) { std::scoped_lock lock{m_mutex}; // Get the topic data - auto* topic = m_impl->GetOrCreateTopic(name); + auto* topic = m_impl.GetOrCreateTopic(name); if (topic->entry == 0) { if (topic->localSubscribers.size() >= kMaxSubscribers) { WPI_ERROR( - m_impl->m_logger, + m_impl.m_logger, "reached maximum number of subscribers to '{}', not creating entry", topic->name); return 0; } // Create subscriber - auto* subscriber = m_impl->AddLocalSubscriber(topic, {}); + auto* subscriber = m_impl.AddLocalSubscriber(topic, {}); // Create entry - topic->entry = m_impl->AddEntry(subscriber)->handle; + topic->entry = m_impl.AddEntry(subscriber)->handle; } return topic->entry; } -std::string LocalStorage::GetEntryName(NT_Handle subentryHandle) { - std::scoped_lock lock{m_mutex}; - if (auto subscriber = m_impl->GetSubEntry(subentryHandle)) { - return subscriber->topic->name; - } else { - return {}; - } -} - -NT_Type LocalStorage::GetEntryType(NT_Handle subentryHandle) { - std::scoped_lock lock{m_mutex}; - if (auto subscriber = m_impl->GetSubEntry(subentryHandle)) { - return subscriber->topic->type; - } else { - return {}; - } -} - -int64_t LocalStorage::GetEntryLastChange(NT_Handle subentryHandle) { - std::scoped_lock lock{m_mutex}; - if (auto subscriber = m_impl->GetSubEntry(subentryHandle)) { - return subscriber->topic->lastValue.time(); - } else { - return 0; - } -} - -void LocalStorage::AddListener(NT_Listener listener, +void LocalStorage::AddListener(NT_Listener listenerHandle, std::span prefixes, unsigned int mask) { mask &= (NT_EVENT_TOPIC | NT_EVENT_VALUE_ALL | NT_EVENT_IMMEDIATE); std::scoped_lock lock{m_mutex}; - m_impl->AddListener(listener, prefixes, mask); + if (m_impl.m_multiSubscribers.size() >= kMaxMultiSubscribers) { + WPI_ERROR( + m_impl.m_logger, + "reached maximum number of multi-subscribers, not adding listener"); + return; + } + // subscribe to make sure topic updates are received + auto sub = m_impl.AddMultiSubscriber( + prefixes, {.topicsOnly = (mask & NT_EVENT_VALUE_ALL) == 0}); + m_impl.AddListenerImpl(listenerHandle, sub, mask, true); } -void LocalStorage::AddListener(NT_Listener listener, NT_Handle handle, +void LocalStorage::AddListener(NT_Listener listenerHandle, NT_Handle handle, unsigned int mask) { mask &= (NT_EVENT_TOPIC | NT_EVENT_VALUE_ALL | NT_EVENT_IMMEDIATE); std::scoped_lock lock{m_mutex}; - m_impl->AddListener(listener, handle, mask); + if (auto topic = m_impl.m_topics.Get(handle)) { + m_impl.AddListenerImpl(listenerHandle, topic, mask); + } else if (auto sub = m_impl.m_multiSubscribers.Get(handle)) { + m_impl.AddListenerImpl(listenerHandle, sub, mask, false); + } else if (auto sub = m_impl.m_subscribers.Get(handle)) { + m_impl.AddListenerImpl(listenerHandle, sub, mask, sub->handle, false); + } else if (auto entry = m_impl.m_entries.Get(handle)) { + m_impl.AddListenerImpl(listenerHandle, entry->subscriber, mask, + entry->handle, false); + } } -void LocalStorage::RemoveListener(NT_Listener listener, unsigned int mask) { +void LocalStorage::RemoveListener(NT_Listener listenerHandle, + unsigned int mask) { std::scoped_lock lock{m_mutex}; - m_impl->RemoveListener(listener, mask); + auto listenerIt = m_impl.m_listeners.find(listenerHandle); + if (listenerIt == m_impl.m_listeners.end()) { + return; + } + auto listener = std::move(listenerIt->getSecond()); + m_impl.m_listeners.erase(listenerIt); + if (!listener) { + return; + } + + m_impl.m_topicPrefixListeners.Remove(listener.get()); + if (listener->subscriber) { + listener->subscriber->valueListeners.Remove(listenerHandle); + listener->subscriber->topic->listeners.Remove(listenerHandle); + if (listener->subscriberOwned) { + m_impl.RemoveLocalSubscriber(listener->subscriber->handle); + } + } + if (listener->multiSubscriber) { + listener->multiSubscriber->valueListeners.Remove(listenerHandle); + if (listener->subscriberOwned) { + m_impl.RemoveMultiSubscriber(listener->multiSubscriber->handle); + } + } } NT_DataLogger LocalStorage::StartDataLog(wpi::log::DataLog& log, @@ -2242,11 +1438,11 @@ NT_DataLogger LocalStorage::StartDataLog(wpi::log::DataLog& log, std::string_view logPrefix) { std::scoped_lock lock{m_mutex}; auto datalogger = - m_impl->m_dataloggers.Add(m_impl->m_inst, log, prefix, logPrefix); + m_impl.m_dataloggers.Add(m_impl.m_inst, log, prefix, logPrefix); // start logging any matching topics auto now = nt::Now(); - for (auto&& topic : m_impl->m_topics) { + for (auto&& topic : m_impl.m_topics) { if (!wpi::starts_with(topic->name, prefix) || topic->type == NT_UNASSIGNED || topic->typeStr.empty()) { continue; @@ -2267,10 +1463,10 @@ NT_DataLogger LocalStorage::StartDataLog(wpi::log::DataLog& log, void LocalStorage::StopDataLog(NT_DataLogger logger) { std::scoped_lock lock{m_mutex}; - if (auto datalogger = m_impl->m_dataloggers.Remove(logger)) { + if (auto datalogger = m_impl.m_dataloggers.Remove(logger)) { // finish any active entries auto now = Now(); - for (auto&& topic : m_impl->m_topics) { + for (auto&& topic : m_impl.m_topics) { auto it = std::find_if(topic->datalogs.begin(), topic->datalogs.end(), [&](const auto& elem) { return elem.logger == logger; }); @@ -2284,6 +1480,14 @@ void LocalStorage::StopDataLog(NT_DataLogger logger) { void LocalStorage::Reset() { std::scoped_lock lock{m_mutex}; - m_impl = std::make_unique(m_impl->m_inst, m_impl->m_listenerStorage, - m_impl->m_logger); + m_impl.m_network = nullptr; + m_impl.m_topics.clear(); + m_impl.m_publishers.clear(); + m_impl.m_subscribers.clear(); + m_impl.m_entries.clear(); + m_impl.m_multiSubscribers.clear(); + m_impl.m_dataloggers.clear(); + m_impl.m_nameTopics.clear(); + m_impl.m_listeners.clear(); + m_impl.m_topicPrefixListeners.clear(); } diff --git a/ntcore/src/main/native/cpp/LocalStorage.h b/ntcore/src/main/native/cpp/LocalStorage.h index a93adb0c5e..086c0574eb 100644 --- a/ntcore/src/main/native/cpp/LocalStorage.h +++ b/ntcore/src/main/native/cpp/LocalStorage.h @@ -14,8 +14,18 @@ #include #include +#include +#include +#include +#include #include +#include "Handle.h" +#include "HandleMap.h" +#include "PubSubOptions.h" +#include "Types_internal.h" +#include "ValueCircularBuffer.h" +#include "VectorSet.h" #include "net/NetworkInterface.h" #include "ntcore_cpp.h" @@ -29,8 +39,8 @@ class IListenerStorage; class LocalStorage final : public net::ILocalStorage { public: - LocalStorage(int inst, IListenerStorage& listenerStorage, - wpi::Logger& logger); + LocalStorage(int inst, IListenerStorage& listenerStorage, wpi::Logger& logger) + : m_impl{inst, listenerStorage, logger} {} LocalStorage(const LocalStorage&) = delete; LocalStorage& operator=(const LocalStorage&) = delete; ~LocalStorage() final; @@ -59,47 +69,129 @@ class LocalStorage final : public net::ILocalStorage { std::vector GetTopicInfo(std::string_view prefix, std::span types); - NT_Topic GetTopic(std::string_view name); + NT_Topic GetTopic(std::string_view name) { + if (name.empty()) { + return {}; + } + std::scoped_lock lock{m_mutex}; + return m_impl.GetOrCreateTopic(name)->handle; + } - std::string GetTopicName(NT_Topic topic); + std::string GetTopicName(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->name; + } else { + return {}; + } + } - NT_Type GetTopicType(NT_Topic topic); + NT_Type GetTopicType(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->type; + } else { + return {}; + } + } - std::string GetTopicTypeString(NT_Topic topic); + std::string GetTopicTypeString(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->typeStr; + } else { + return {}; + } + } - void SetTopicPersistent(NT_Topic topic, bool value); + void SetTopicPersistent(NT_Topic topicHandle, bool value) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + m_impl.SetPersistent(topic, value); + } + } - bool GetTopicPersistent(NT_Topic topic); + bool GetTopicPersistent(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return (topic->flags & NT_PERSISTENT) != 0; + } else { + return false; + } + } - void SetTopicRetained(NT_Topic topic, bool value); + void SetTopicRetained(NT_Topic topicHandle, bool value) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + m_impl.SetRetained(topic, value); + } + } - bool GetTopicRetained(NT_Topic topic); + bool GetTopicRetained(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return (topic->flags & NT_RETAINED) != 0; + } else { + return false; + } + } - bool GetTopicExists(NT_Handle handle); + bool GetTopicExists(NT_Handle handle) { + std::scoped_lock lock{m_mutex}; + TopicData* topic = m_impl.GetTopic(handle); + return topic && topic->Exists(); + } - wpi::json GetTopicProperty(NT_Topic topic, std::string_view name); + wpi::json GetTopicProperty(NT_Topic topicHandle, std::string_view name) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->properties.value(name, wpi::json{}); + } else { + return {}; + } + } void SetTopicProperty(NT_Topic topic, std::string_view name, const wpi::json& value); void DeleteTopicProperty(NT_Topic topic, std::string_view name); - wpi::json GetTopicProperties(NT_Topic topic); + wpi::json GetTopicProperties(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->properties; + } else { + return wpi::json::object(); + } + } bool SetTopicProperties(NT_Topic topic, const wpi::json& update); - TopicInfo GetTopicInfo(NT_Topic topic); + TopicInfo GetTopicInfo(NT_Topic topicHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.m_topics.Get(topicHandle)) { + return topic->GetTopicInfo(); + } else { + return {}; + } + } NT_Subscriber Subscribe(NT_Topic topic, NT_Type type, std::string_view typeStr, const PubSubOptions& options); - void Unsubscribe(NT_Subscriber sub); + void Unsubscribe(NT_Subscriber subHandle) { + std::scoped_lock lock{m_mutex}; + m_impl.RemoveSubEntry(subHandle); + } NT_MultiSubscriber SubscribeMultiple( std::span prefixes, const PubSubOptions& options); - void UnsubscribeMultiple(NT_MultiSubscriber subHandle); + void UnsubscribeMultiple(NT_MultiSubscriber subHandle) { + std::scoped_lock lock{m_mutex}; + m_impl.RemoveMultiSubscriber(subHandle); + } NT_Publisher Publish(NT_Topic topic, NT_Type type, std::string_view typeStr, const wpi::json& properties, @@ -110,84 +202,106 @@ class LocalStorage final : public net::ILocalStorage { NT_Entry GetEntry(NT_Topic topic, NT_Type type, std::string_view typeStr, const PubSubOptions& options); - void ReleaseEntry(NT_Entry entry); + void ReleaseEntry(NT_Entry entryHandle) { + std::scoped_lock lock{m_mutex}; + m_impl.RemoveSubEntry(entryHandle); + } void Release(NT_Handle pubsubentry); - NT_Topic GetTopicFromHandle(NT_Handle pubsubentry); + NT_Topic GetTopicFromHandle(NT_Handle pubsubentryHandle) { + std::scoped_lock lock{m_mutex}; + if (auto topic = m_impl.GetTopic(pubsubentryHandle)) { + return topic->handle; + } else { + return {}; + } + } - bool SetEntryValue(NT_Handle pubentry, const Value& value); + bool SetEntryValue(NT_Handle pubentryHandle, const Value& value) { + std::scoped_lock lock{m_mutex}; + return m_impl.SetEntryValue(pubentryHandle, value); + } - bool SetDefaultEntryValue(NT_Handle pubsubentry, const Value& value); + bool SetDefaultEntryValue(NT_Handle pubsubentryHandle, const Value& value) { + std::scoped_lock lock{m_mutex}; + return m_impl.SetDefaultEntryValue(pubsubentryHandle, value); + } - TimestampedBoolean GetAtomicBoolean(NT_Handle subentry, bool defaultValue); - TimestampedInteger GetAtomicInteger(NT_Handle subentry, int64_t defaultValue); - TimestampedFloat GetAtomicFloat(NT_Handle subentry, float defaultValue); - TimestampedDouble GetAtomicDouble(NT_Handle subentry, double defaultValue); - TimestampedString GetAtomicString(NT_Handle subentry, - std::string_view defaultValue); - TimestampedRaw GetAtomicRaw(NT_Handle subentry, - std::span defaultValue); - TimestampedBooleanArray GetAtomicBooleanArray( - NT_Handle subentry, std::span defaultValue); - TimestampedIntegerArray GetAtomicIntegerArray( - NT_Handle subentry, std::span defaultValue); - TimestampedFloatArray GetAtomicFloatArray( - NT_Handle subentry, std::span defaultValue); - TimestampedDoubleArray GetAtomicDoubleArray( - NT_Handle subentry, std::span defaultValue); - TimestampedStringArray GetAtomicStringArray( - NT_Handle subentry, std::span defaultValue); + template + Timestamped::Value> GetAtomic( + NT_Handle subentry, typename TypeInfo::View defaultValue); - TimestampedStringView GetAtomicString(NT_Handle subentry, - wpi::SmallVectorImpl& buf, - std::string_view defaultValue); - TimestampedRawView GetAtomicRaw(NT_Handle subentry, - wpi::SmallVectorImpl& buf, - std::span defaultValue); - TimestampedBooleanArrayView GetAtomicBooleanArray( - NT_Handle subentry, wpi::SmallVectorImpl& buf, - std::span defaultValue); - TimestampedIntegerArrayView GetAtomicIntegerArray( - NT_Handle subentry, wpi::SmallVectorImpl& buf, - std::span defaultValue); - TimestampedFloatArrayView GetAtomicFloatArray( - NT_Handle subentry, wpi::SmallVectorImpl& buf, - std::span defaultValue); - TimestampedDoubleArrayView GetAtomicDoubleArray( - NT_Handle subentry, wpi::SmallVectorImpl& buf, - std::span defaultValue); + template + Timestamped::SmallRet> GetAtomic( + NT_Handle subentry, + wpi::SmallVectorImpl::SmallElem>& buf, + typename TypeInfo::View defaultValue); - std::vector ReadQueueValue(NT_Handle subentry); + std::vector ReadQueueValue(NT_Handle subentry) { + std::scoped_lock lock{m_mutex}; + auto subscriber = m_impl.GetSubEntry(subentry); + if (!subscriber) { + return {}; + } + return subscriber->pollStorage.ReadValue(); + } - std::vector ReadQueueBoolean(NT_Handle subentry); - std::vector ReadQueueInteger(NT_Handle subentry); - std::vector ReadQueueFloat(NT_Handle subentry); - std::vector ReadQueueDouble(NT_Handle subentry); - std::vector ReadQueueString(NT_Handle subentry); - std::vector ReadQueueRaw(NT_Handle subentry); - std::vector ReadQueueBooleanArray( + template + std::vector::Value>> ReadQueue( NT_Handle subentry); - std::vector ReadQueueIntegerArray( - NT_Handle subentry); - std::vector ReadQueueFloatArray(NT_Handle subentry); - std::vector ReadQueueDoubleArray(NT_Handle subentry); - std::vector ReadQueueStringArray(NT_Handle subentry); // // Backwards compatible user functions // Value GetEntryValue(NT_Handle subentry); - void SetEntryFlags(NT_Entry entry, unsigned int flags); - unsigned int GetEntryFlags(NT_Entry entry); + + void SetEntryFlags(NT_Entry entryHandle, unsigned int flags) { + std::scoped_lock lock{m_mutex}; + if (auto entry = m_impl.m_entries.Get(entryHandle)) { + m_impl.SetFlags(entry->subscriber->topic, flags); + } + } + + unsigned int GetEntryFlags(NT_Entry entryHandle) { + std::scoped_lock lock{m_mutex}; + if (auto entry = m_impl.m_entries.Get(entryHandle)) { + return entry->subscriber->topic->flags; + } else { + return 0; + } + } // Index-only NT_Entry GetEntry(std::string_view name); - std::string GetEntryName(NT_Entry entry); - NT_Type GetEntryType(NT_Entry entry); - int64_t GetEntryLastChange(NT_Entry entry); + std::string GetEntryName(NT_Entry subentryHandle) { + std::scoped_lock lock{m_mutex}; + if (auto subscriber = m_impl.GetSubEntry(subentryHandle)) { + return subscriber->topic->name; + } else { + return {}; + } + } + + NT_Type GetEntryType(NT_Entry subentryHandle) { + std::scoped_lock lock{m_mutex}; + if (auto subscriber = m_impl.GetSubEntry(subentryHandle)) { + return subscriber->topic->type; + } else { + return {}; + } + } + + int64_t GetEntryLastChange(NT_Entry subentryHandle) { + std::scoped_lock lock{m_mutex}; + if (auto subscriber = m_impl.GetSubEntry(subentryHandle)) { + return subscriber->topic->lastValue.time(); + } else { + return 0; + } + } // // Listener functions @@ -210,10 +324,352 @@ class LocalStorage final : public net::ILocalStorage { void Reset(); private: - class Impl; - std::unique_ptr m_impl; + static constexpr bool IsSpecial(std::string_view name) { + return name.empty() ? false : name.front() == '$'; + } + + struct EntryData; + struct PublisherData; + struct SubscriberData; + struct MultiSubscriberData; + + struct DataLoggerEntry { + DataLoggerEntry(wpi::log::DataLog& log, int entry, NT_DataLogger logger) + : log{&log}, entry{entry}, logger{logger} {} + + static std::string MakeMetadata(std::string_view properties); + + void Append(const Value& v); + + wpi::log::DataLog* log; + int entry; + NT_DataLogger logger; + }; + + struct TopicData { + static constexpr auto kType = Handle::kTopic; + + TopicData(NT_Topic handle, std::string_view name) + : handle{handle}, name{name}, special{IsSpecial(name)} {} + + bool Exists() const { return onNetwork || !localPublishers.empty(); } + + TopicInfo GetTopicInfo() const; + + // invariants + wpi::SignalObject handle; + std::string name; + bool special; + + Value lastValue; // also stores timestamp + Value lastValueNetwork; + NT_Type type{NT_UNASSIGNED}; + std::string typeStr; + unsigned int flags{0}; // for NT3 APIs + std::string propertiesStr{"{}"}; // cached string for GetTopicInfo() et al + wpi::json properties = wpi::json::object(); + NT_Entry entry{0}; // cached entry for GetEntry() + + bool onNetwork{false}; // true if there are any remote publishers + bool lastValueFromNetwork{false}; + + wpi::SmallVector datalogs; + NT_Type datalogType{NT_UNASSIGNED}; + + VectorSet localPublishers; + VectorSet localSubscribers; + VectorSet multiSubscribers; + VectorSet entries; + VectorSet listeners; + }; + + struct PubSubConfig : public PubSubOptionsImpl { + PubSubConfig() = default; + PubSubConfig(NT_Type type, std::string_view typeStr, + const PubSubOptions& options) + : PubSubOptionsImpl{options}, type{type}, typeStr{typeStr} { + prefixMatch = false; + } + + NT_Type type{NT_UNASSIGNED}; + std::string typeStr; + }; + + struct PublisherData { + static constexpr auto kType = Handle::kPublisher; + + PublisherData(NT_Publisher handle, TopicData* topic, PubSubConfig config) + : handle{handle}, topic{topic}, config{std::move(config)} {} + + void UpdateActive() { + active = config.type == topic->type && config.typeStr == topic->typeStr; + } + + // invariants + wpi::SignalObject handle; + TopicData* topic; + PubSubConfig config; + + // whether or not the publisher should actually publish values + bool active{false}; + }; + + struct SubscriberData { + static constexpr auto kType = Handle::kSubscriber; + + SubscriberData(NT_Subscriber handle, TopicData* topic, PubSubConfig config) + : handle{handle}, + topic{topic}, + config{std::move(config)}, + pollStorage{config.pollStorage} {} + + void UpdateActive() { + // for subscribers, unassigned is a wildcard + // also allow numerically compatible subscribers + active = + config.type == NT_UNASSIGNED || + (config.type == topic->type && config.typeStr == topic->typeStr) || + IsNumericCompatible(config.type, topic->type); + } + + // invariants + wpi::SignalObject handle; + TopicData* topic; + PubSubConfig config; + + // whether or not the subscriber should actually receive values + bool active{false}; + + // polling storage + ValueCircularBuffer pollStorage; + + // value listeners + VectorSet valueListeners; + }; + + struct EntryData { + static constexpr auto kType = Handle::kEntry; + + EntryData(NT_Entry handle, SubscriberData* subscriber) + : handle{handle}, topic{subscriber->topic}, subscriber{subscriber} {} + + // invariants + wpi::SignalObject handle; + TopicData* topic; + SubscriberData* subscriber; + + // the publisher (created on demand) + PublisherData* publisher{nullptr}; + }; + + struct MultiSubscriberData { + static constexpr auto kType = Handle::kMultiSubscriber; + + MultiSubscriberData(NT_MultiSubscriber handle, + std::span prefixes, + const PubSubOptionsImpl& options) + : handle{handle}, options{options} { + this->options.prefixMatch = true; + this->prefixes.reserve(prefixes.size()); + for (auto&& prefix : prefixes) { + this->prefixes.emplace_back(prefix); + } + } + + bool Matches(std::string_view name, bool special); + + // invariants + wpi::SignalObject handle; + std::vector prefixes; + PubSubOptionsImpl options; + + // value listeners + VectorSet valueListeners; + }; + + struct ListenerData { + ListenerData(NT_Listener handle, SubscriberData* subscriber, + unsigned int eventMask, bool subscriberOwned) + : handle{handle}, + eventMask{eventMask}, + subscriber{subscriber}, + subscriberOwned{subscriberOwned} {} + ListenerData(NT_Listener handle, MultiSubscriberData* subscriber, + unsigned int eventMask, bool subscriberOwned) + : handle{handle}, + eventMask{eventMask}, + multiSubscriber{subscriber}, + subscriberOwned{subscriberOwned} {} + + NT_Listener handle; + unsigned int eventMask; + SubscriberData* subscriber{nullptr}; + MultiSubscriberData* multiSubscriber{nullptr}; + bool subscriberOwned; + }; + + struct DataLoggerData { + static constexpr auto kType = Handle::kDataLogger; + + DataLoggerData(NT_DataLogger handle, wpi::log::DataLog& log, + std::string_view prefix, std::string_view logPrefix) + : handle{handle}, log{log}, prefix{prefix}, logPrefix{logPrefix} {} + + int Start(TopicData* topic, int64_t time); + + NT_DataLogger handle; + wpi::log::DataLog& log; + std::string prefix; + std::string logPrefix; + }; + + // inner struct to protect against accidentally deadlocking on the mutex + struct Impl { + Impl(int inst, IListenerStorage& listenerStorage, wpi::Logger& logger); + + int m_inst; + IListenerStorage& m_listenerStorage; + wpi::Logger& m_logger; + net::NetworkInterface* m_network{nullptr}; + + // handle mappings + HandleMap m_topics; + HandleMap m_publishers; + HandleMap m_subscribers; + HandleMap m_entries; + HandleMap m_multiSubscribers; + HandleMap m_dataloggers; + + // name mappings + wpi::StringMap m_nameTopics; + + // listeners + wpi::DenseMap> m_listeners; + + // string-based listeners + VectorSet m_topicPrefixListeners; + + // topic functions + void NotifyTopic(TopicData* topic, unsigned int eventFlags); + + void CheckReset(TopicData* topic); + + bool SetValue(TopicData* topic, const Value& value, unsigned int eventFlags, + bool isDuplicate, bool suppressIfDuplicate, + const PublisherData* publisher); + void NotifyValue(TopicData* topic, unsigned int eventFlags, + bool isDuplicate, const PublisherData* publisher); + + void SetFlags(TopicData* topic, unsigned int flags); + void SetPersistent(TopicData* topic, bool value); + void SetRetained(TopicData* topic, bool value); + void SetProperties(TopicData* topic, const wpi::json& update, + bool sendNetwork); + void PropertiesUpdated(TopicData* topic, const wpi::json& update, + unsigned int eventFlags, bool sendNetwork, + bool updateFlags = true); + + void RefreshPubSubActive(TopicData* topic, bool warnOnSubMismatch); + + void NetworkAnnounce(TopicData* topic, std::string_view typeStr, + const wpi::json& properties, NT_Publisher pubHandle); + void RemoveNetworkPublisher(TopicData* topic); + void NetworkPropertiesUpdate(TopicData* topic, const wpi::json& update, + bool ack); + void StartNetwork(net::NetworkInterface* network); + + PublisherData* AddLocalPublisher(TopicData* topic, + const wpi::json& properties, + const PubSubConfig& options); + std::unique_ptr RemoveLocalPublisher(NT_Publisher pubHandle); + + SubscriberData* AddLocalSubscriber(TopicData* topic, + const PubSubConfig& options); + std::unique_ptr RemoveLocalSubscriber( + NT_Subscriber subHandle); + + EntryData* AddEntry(SubscriberData* subscriber); + std::unique_ptr RemoveEntry(NT_Entry entryHandle); + + MultiSubscriberData* AddMultiSubscriber( + std::span prefixes, + const PubSubOptions& options); + std::unique_ptr RemoveMultiSubscriber( + NT_MultiSubscriber subHandle); + + void AddListenerImpl(NT_Listener listenerHandle, TopicData* topic, + unsigned int eventMask); + void AddListenerImpl(NT_Listener listenerHandle, SubscriberData* subscriber, + unsigned int eventMask, NT_Handle subentryHandle, + bool subscriberOwned); + void AddListenerImpl(NT_Listener listenerHandle, + MultiSubscriberData* subscriber, + unsigned int eventMask, bool subscriberOwned); + void AddListenerImpl(NT_Listener listenerHandle, + std::span prefixes, + unsigned int eventMask); + + TopicData* GetOrCreateTopic(std::string_view name); + TopicData* GetTopic(NT_Handle handle); + SubscriberData* GetSubEntry(NT_Handle subentryHandle); + PublisherData* PublishEntry(EntryData* entry, NT_Type type); + + Value* GetSubEntryValue(NT_Handle subentryHandle) { + if (auto subscriber = GetSubEntry(subentryHandle)) { + return &subscriber->topic->lastValue; + } else { + return nullptr; + } + } + + bool PublishLocalValue(PublisherData* publisher, const Value& value, + bool force = false); + + bool SetEntryValue(NT_Handle pubentryHandle, const Value& value); + bool SetDefaultEntryValue(NT_Handle pubsubentryHandle, const Value& value); + + void RemoveSubEntry(NT_Handle subentryHandle); + }; wpi::mutex m_mutex; + Impl m_impl; }; +template +Timestamped::Value> LocalStorage::GetAtomic( + NT_Handle subentry, typename TypeInfo::View defaultValue) { + std::scoped_lock lock{m_mutex}; + Value* value = m_impl.GetSubEntryValue(subentry); + if (value && (IsNumericConvertibleTo(*value) || IsType(*value))) { + return GetTimestamped(*value); + } else { + return {0, 0, CopyValue(defaultValue)}; + } +} + +template +Timestamped::SmallRet> LocalStorage::GetAtomic( + NT_Handle subentry, + wpi::SmallVectorImpl::SmallElem>& buf, + typename TypeInfo::View defaultValue) { + std::scoped_lock lock{m_mutex}; + Value* value = m_impl.GetSubEntryValue(subentry); + if (value && (IsNumericConvertibleTo(*value) || IsType(*value))) { + return GetTimestamped(*value, buf); + } else { + return {0, 0, CopyValue(defaultValue, buf)}; + } +} + +template +std::vector::Value>> LocalStorage::ReadQueue( + NT_Handle subentry) { + std::scoped_lock lock{m_mutex}; + auto subscriber = m_impl.GetSubEntry(subentry); + if (!subscriber) { + return {}; + } + return subscriber->pollStorage.Read(); +} + } // namespace nt diff --git a/ntcore/src/main/native/cpp/Log.h b/ntcore/src/main/native/cpp/Log.h index 7e052f92a2..ef9743bc92 100644 --- a/ntcore/src/main/native/cpp/Log.h +++ b/ntcore/src/main/native/cpp/Log.h @@ -9,10 +9,8 @@ #define LOG(level, format, ...) \ WPI_LOG(m_logger, level, format __VA_OPT__(, ) __VA_ARGS__) -#undef ERROR -#define ERROR(format, ...) \ - WPI_ERROR(m_logger, format __VA_OPT__(, ) __VA_ARGS__) -#define WARNING(format, ...) \ +#define ERR(format, ...) WPI_ERROR(m_logger, format __VA_OPT__(, ) __VA_ARGS__) +#define WARN(format, ...) \ WPI_WARNING(m_logger, format __VA_OPT__(, ) __VA_ARGS__) #define INFO(format, ...) WPI_INFO(m_logger, format __VA_OPT__(, ) __VA_ARGS__) diff --git a/ntcore/src/main/native/cpp/NetworkClient.cpp b/ntcore/src/main/native/cpp/NetworkClient.cpp index 4391d733ed..1affb0082f 100644 --- a/ntcore/src/main/native/cpp/NetworkClient.cpp +++ b/ntcore/src/main/native/cpp/NetworkClient.cpp @@ -13,25 +13,13 @@ #include #include #include -#include -#include #include -#include -#include -#include #include #include -#include #include #include "IConnectionList.h" #include "Log.h" -#include "net/ClientImpl.h" -#include "net/Message.h" -#include "net/NetworkLoopQueue.h" -#include "net/WebSocketConnection.h" -#include "net3/ClientImpl3.h" -#include "net3/UvStreamConnection3.h" using namespace nt; namespace uv = wpi::uv; @@ -41,97 +29,10 @@ static constexpr uv::Timer::Time kWebsocketHandshakeTimeout{500}; // use a larger max message size for websockets static constexpr size_t kMaxMessageSize = 2 * 1024 * 1024; -namespace { - -class NCImpl { - public: - NCImpl(int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger); - virtual ~NCImpl() = default; - - // user-facing functions - void SetServers(std::span> servers, - unsigned int defaultPort); - void StartDSClient(unsigned int port); - void StopDSClient(); - - virtual void TcpConnected(uv::Tcp& tcp) = 0; - virtual void ForceDisconnect(std::string_view reason) = 0; - virtual void Disconnect(std::string_view reason); - - // invariants - int m_inst; - net::ILocalStorage& m_localStorage; - IConnectionList& m_connList; - wpi::Logger& m_logger; - std::string m_id; - - // used only from loop - std::shared_ptr m_parallelConnect; - std::shared_ptr m_readLocalTimer; - std::shared_ptr m_sendValuesTimer; - std::shared_ptr> m_flushLocal; - std::shared_ptr> m_flush; - - std::vector m_localMsgs; - - std::vector> m_servers; - - std::pair m_dsClientServer{"", 0}; - std::shared_ptr m_dsClient; - - // shared with user - std::atomic*> m_flushLocalAtomic{nullptr}; - std::atomic*> m_flushAtomic{nullptr}; - - net::NetworkLoopQueue m_localQueue; - - int m_connHandle = 0; - - wpi::EventLoopRunner m_loopRunner; - uv::Loop& m_loop; -}; - -class NCImpl3 : public NCImpl { - public: - NCImpl3(int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger); - ~NCImpl3() override; - - void HandleLocal(); - void TcpConnected(uv::Tcp& tcp) final; - void ForceDisconnect(std::string_view reason) override; - void Disconnect(std::string_view reason) override; - - std::shared_ptr m_wire; - std::shared_ptr m_clientImpl; -}; - -class NCImpl4 : public NCImpl { - public: - NCImpl4( - int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function - timeSyncUpdated); - ~NCImpl4() override; - - void HandleLocal(); - void TcpConnected(uv::Tcp& tcp) final; - void WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp); - void ForceDisconnect(std::string_view reason) override; - void Disconnect(std::string_view reason) override; - - std::function - m_timeSyncUpdated; - std::shared_ptr m_wire; - std::unique_ptr m_clientImpl; -}; - -} // namespace - -NCImpl::NCImpl(int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger) +NetworkClientBase::NetworkClientBase(int inst, std::string_view id, + net::ILocalStorage& localStorage, + IConnectionList& connList, + wpi::Logger& logger) : m_inst{inst}, m_localStorage{localStorage}, m_connList{connList}, @@ -144,28 +45,17 @@ NCImpl::NCImpl(int inst, std::string_view id, net::ILocalStorage& localStorage, INFO("starting network client"); } -void NCImpl::SetServers( - std::span> servers, - unsigned int defaultPort) { - std::vector> serversCopy; - serversCopy.reserve(servers.size()); - for (auto&& server : servers) { - serversCopy.emplace_back(wpi::trim(server.first), - server.second == 0 ? defaultPort : server.second); - } - - m_loopRunner.ExecAsync( - [this, servers = std::move(serversCopy)](uv::Loop&) mutable { - m_servers = std::move(servers); - if (m_dsClientServer.first.empty()) { - if (m_parallelConnect) { - m_parallelConnect->SetServers(m_servers); - } - } - }); +NetworkClientBase::~NetworkClientBase() { + m_localStorage.ClearNetwork(); + m_connList.ClearConnections(); } -void NCImpl::StartDSClient(unsigned int port) { +void NetworkClientBase::Disconnect() { + m_loopRunner.ExecAsync( + [this](auto&) { ForceDisconnect("requested by application"); }); +} + +void NetworkClientBase::StartDSClient(unsigned int port) { m_loopRunner.ExecAsync([=, this](uv::Loop& loop) { if (m_dsClient) { return; @@ -189,7 +79,7 @@ void NCImpl::StartDSClient(unsigned int port) { }); } -void NCImpl::StopDSClient() { +void NetworkClientBase::StopDSClient() { m_loopRunner.ExecAsync([this](uv::Loop& loop) { if (m_dsClient) { m_dsClient->Close(); @@ -198,7 +88,40 @@ void NCImpl::StopDSClient() { }); } -void NCImpl::Disconnect(std::string_view reason) { +void NetworkClientBase::FlushLocal() { + if (auto async = m_flushLocalAtomic.load(std::memory_order_relaxed)) { + async->UnsafeSend(); + } +} + +void NetworkClientBase::Flush() { + if (auto async = m_flushAtomic.load(std::memory_order_relaxed)) { + async->UnsafeSend(); + } +} + +void NetworkClientBase::DoSetServers( + std::span> servers, + unsigned int defaultPort) { + std::vector> serversCopy; + serversCopy.reserve(servers.size()); + for (auto&& server : servers) { + serversCopy.emplace_back(wpi::trim(server.first), + server.second == 0 ? defaultPort : server.second); + } + + m_loopRunner.ExecAsync( + [this, servers = std::move(serversCopy)](uv::Loop&) mutable { + m_servers = std::move(servers); + if (m_dsClientServer.first.empty()) { + if (m_parallelConnect) { + m_parallelConnect->SetServers(m_servers); + } + } + }); +} + +void NetworkClientBase::DoDisconnect(std::string_view reason) { if (m_readLocalTimer) { m_readLocalTimer->Stop(); } @@ -218,10 +141,10 @@ void NCImpl::Disconnect(std::string_view reason) { }); } -NCImpl3::NCImpl3(int inst, std::string_view id, - net::ILocalStorage& localStorage, IConnectionList& connList, - wpi::Logger& logger) - : NCImpl{inst, id, localStorage, connList, logger} { +NetworkClient3::NetworkClient3(int inst, std::string_view id, + net::ILocalStorage& localStorage, + IConnectionList& connList, wpi::Logger& logger) + : NetworkClientBase{inst, id, localStorage, connList, logger} { m_loopRunner.ExecAsync([this](uv::Loop& loop) { m_parallelConnect = wpi::ParallelTcpConnector::Create( loop, kReconnectRate, m_logger, @@ -257,7 +180,7 @@ NCImpl3::NCImpl3(int inst, std::string_view id, }); } -NCImpl3::~NCImpl3() { +NetworkClient3::~NetworkClient3() { // must explicitly destroy these on loop m_loopRunner.ExecSync([&](auto&) { m_clientImpl.reset(); @@ -267,14 +190,14 @@ NCImpl3::~NCImpl3() { m_loopRunner.Stop(); } -void NCImpl3::HandleLocal() { +void NetworkClient3::HandleLocal() { m_localQueue.ReadQueue(&m_localMsgs); if (m_clientImpl) { m_clientImpl->HandleLocal(m_localMsgs); } } -void NCImpl3::TcpConnected(uv::Tcp& tcp) { +void NetworkClient3::TcpConnected(uv::Tcp& tcp) { tcp.SetNoDelay(true); // create as shared_ptr and capture in lambda because there may be multiple @@ -319,19 +242,19 @@ void NCImpl3::TcpConnected(uv::Tcp& tcp) { tcp.error.connect([this, &tcp](uv::Error err) { DEBUG3("NT3 TCP error {}", err.str()); if (!tcp.IsLoopClosing()) { - Disconnect(err.str()); + DoDisconnect(err.str()); } }); tcp.end.connect([this, &tcp] { DEBUG3("NT3 TCP read ended"); if (!tcp.IsLoopClosing()) { - Disconnect("remote end closed connection"); + DoDisconnect("remote end closed connection"); } }); tcp.closed.connect([this, &tcp] { DEBUG3("NT3 TCP connection closed"); if (!tcp.IsLoopClosing()) { - Disconnect(m_wire ? m_wire->GetDisconnectReason() : "unknown"); + DoDisconnect(m_wire ? m_wire->GetDisconnectReason() : "unknown"); } }); @@ -349,25 +272,25 @@ void NCImpl3::TcpConnected(uv::Tcp& tcp) { tcp.StartRead(); } -void NCImpl3::ForceDisconnect(std::string_view reason) { +void NetworkClient3::ForceDisconnect(std::string_view reason) { if (m_wire) { m_wire->Disconnect(reason); } } -void NCImpl3::Disconnect(std::string_view reason) { +void NetworkClient3::DoDisconnect(std::string_view reason) { INFO("DISCONNECTED NT3 connection: {}", reason); m_clientImpl.reset(); m_wire.reset(); - NCImpl::Disconnect(reason); + NetworkClientBase::DoDisconnect(reason); } -NCImpl4::NCImpl4( +NetworkClient::NetworkClient( int inst, std::string_view id, net::ILocalStorage& localStorage, IConnectionList& connList, wpi::Logger& logger, std::function timeSyncUpdated) - : NCImpl{inst, id, localStorage, connList, logger}, + : NetworkClientBase{inst, id, localStorage, connList, logger}, m_timeSyncUpdated{std::move(timeSyncUpdated)} { m_loopRunner.ExecAsync([this](uv::Loop& loop) { m_parallelConnect = wpi::ParallelTcpConnector::Create( @@ -415,7 +338,7 @@ NCImpl4::NCImpl4( }); } -NCImpl4::~NCImpl4() { +NetworkClient::~NetworkClient() { // must explicitly destroy these on loop m_loopRunner.ExecSync([&](auto&) { m_clientImpl.reset(); @@ -425,14 +348,14 @@ NCImpl4::~NCImpl4() { m_loopRunner.Stop(); } -void NCImpl4::HandleLocal() { +void NetworkClient::HandleLocal() { m_localQueue.ReadQueue(&m_localMsgs); if (m_clientImpl) { m_clientImpl->HandleLocal(std::move(m_localMsgs)); } } -void NCImpl4::TcpConnected(uv::Tcp& tcp) { +void NetworkClient::TcpConnected(uv::Tcp& tcp) { tcp.SetNoDelay(true); // Start the WS client if (m_logger.min_level() >= wpi::WPI_LOG_DEBUG4) { @@ -457,7 +380,7 @@ void NCImpl4::TcpConnected(uv::Tcp& tcp) { }); } -void NCImpl4::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp) { +void NetworkClient::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp) { if (m_parallelConnect) { m_parallelConnect->Succeeded(tcp); } @@ -485,7 +408,7 @@ void NCImpl4::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp) { m_clientImpl->SendInitial(); ws.closed.connect([this, &ws](uint16_t, std::string_view reason) { if (!ws.GetStream().IsLoopClosing()) { - Disconnect(reason); + DoDisconnect(reason); } }); ws.text.connect([this](std::string_view data, bool) { @@ -500,13 +423,13 @@ void NCImpl4::WsConnected(wpi::WebSocket& ws, uv::Tcp& tcp) { }); } -void NCImpl4::ForceDisconnect(std::string_view reason) { +void NetworkClient::ForceDisconnect(std::string_view reason) { if (m_wire) { m_wire->Disconnect(reason); } } -void NCImpl4::Disconnect(std::string_view reason) { +void NetworkClient::DoDisconnect(std::string_view reason) { std::string realReason; if (m_wire) { realReason = m_wire->GetDisconnectReason(); @@ -515,107 +438,6 @@ void NCImpl4::Disconnect(std::string_view reason) { realReason.empty() ? reason : realReason); m_clientImpl.reset(); m_wire.reset(); - NCImpl::Disconnect(reason); + NetworkClientBase::DoDisconnect(reason); m_timeSyncUpdated(0, 0, false); } - -class NetworkClient::Impl final : public NCImpl4 { - public: - Impl(int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function - timeSyncUpdated) - : NCImpl4{inst, id, localStorage, - connList, logger, std::move(timeSyncUpdated)} {} -}; - -NetworkClient::NetworkClient( - int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function - timeSyncUpdated) - : m_impl{std::make_unique(inst, id, localStorage, connList, logger, - std::move(timeSyncUpdated))} {} - -NetworkClient::~NetworkClient() { - m_impl->m_localStorage.ClearNetwork(); - m_impl->m_connList.ClearConnections(); -} - -void NetworkClient::SetServers( - std::span> servers) { - m_impl->SetServers(servers, NT_DEFAULT_PORT4); -} - -void NetworkClient::Disconnect() { - m_impl->m_loopRunner.ExecAsync( - [this](auto&) { m_impl->ForceDisconnect("requested by application"); }); -} - -void NetworkClient::StartDSClient(unsigned int port) { - m_impl->StartDSClient(port); -} - -void NetworkClient::StopDSClient() { - m_impl->StopDSClient(); -} - -void NetworkClient::FlushLocal() { - if (auto async = m_impl->m_flushLocalAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} - -void NetworkClient::Flush() { - if (auto async = m_impl->m_flushAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} - -class NetworkClient3::Impl final : public NCImpl3 { - public: - Impl(int inst, std::string_view id, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger) - : NCImpl3{inst, id, localStorage, connList, logger} {} -}; - -NetworkClient3::NetworkClient3(int inst, std::string_view id, - net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger) - : m_impl{std::make_unique(inst, id, localStorage, connList, logger)} { -} - -NetworkClient3::~NetworkClient3() { - m_impl->m_localStorage.ClearNetwork(); - m_impl->m_connList.ClearConnections(); -} - -void NetworkClient3::SetServers( - std::span> servers) { - m_impl->SetServers(servers, NT_DEFAULT_PORT3); -} - -void NetworkClient3::Disconnect() { - m_impl->m_loopRunner.ExecAsync( - [this](auto&) { m_impl->ForceDisconnect("requested by application"); }); -} - -void NetworkClient3::StartDSClient(unsigned int port) { - m_impl->StartDSClient(port); -} - -void NetworkClient3::StopDSClient() { - m_impl->StopDSClient(); -} - -void NetworkClient3::FlushLocal() { - if (auto async = m_impl->m_flushLocalAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} - -void NetworkClient3::Flush() { - if (auto async = m_impl->m_flushAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} diff --git a/ntcore/src/main/native/cpp/NetworkClient.h b/ntcore/src/main/native/cpp/NetworkClient.h index f2a7e652d4..7839db7393 100644 --- a/ntcore/src/main/native/cpp/NetworkClient.h +++ b/ntcore/src/main/native/cpp/NetworkClient.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -11,8 +12,22 @@ #include #include #include +#include + +#include +#include +#include +#include +#include +#include #include "INetworkClient.h" +#include "net/ClientImpl.h" +#include "net/Message.h" +#include "net/NetworkLoopQueue.h" +#include "net/WebSocketConnection.h" +#include "net3/ClientImpl3.h" +#include "net3/UvStreamConnection3.h" #include "ntcore_cpp.h" namespace wpi { @@ -27,7 +42,86 @@ namespace nt { class IConnectionList; -class NetworkClient final : public INetworkClient { +class NetworkClientBase : public INetworkClient { + public: + NetworkClientBase(int inst, std::string_view id, + net::ILocalStorage& localStorage, IConnectionList& connList, + wpi::Logger& logger); + ~NetworkClientBase() override; + + void Disconnect() override; + + void StartDSClient(unsigned int port) override; + void StopDSClient() override; + + void FlushLocal() override; + void Flush() override; + + protected: + void DoSetServers( + std::span> servers, + unsigned int defaultPort); + + virtual void TcpConnected(wpi::uv::Tcp& tcp) = 0; + virtual void ForceDisconnect(std::string_view reason) = 0; + virtual void DoDisconnect(std::string_view reason); + + // invariants + int m_inst; + net::ILocalStorage& m_localStorage; + IConnectionList& m_connList; + wpi::Logger& m_logger; + std::string m_id; + + // used only from loop + std::shared_ptr m_parallelConnect; + std::shared_ptr m_readLocalTimer; + std::shared_ptr m_sendValuesTimer; + std::shared_ptr> m_flushLocal; + std::shared_ptr> m_flush; + + std::vector m_localMsgs; + + std::vector> m_servers; + + std::pair m_dsClientServer{"", 0}; + std::shared_ptr m_dsClient; + + // shared with user + std::atomic*> m_flushLocalAtomic{nullptr}; + std::atomic*> m_flushAtomic{nullptr}; + + net::NetworkLoopQueue m_localQueue; + + int m_connHandle = 0; + + wpi::EventLoopRunner m_loopRunner; + wpi::uv::Loop& m_loop; +}; + +class NetworkClient3 final : public NetworkClientBase { + public: + NetworkClient3(int inst, std::string_view id, + net::ILocalStorage& localStorage, IConnectionList& connList, + wpi::Logger& logger); + ~NetworkClient3() final; + + void SetServers( + std::span> servers) final { + DoSetServers(servers, NT_DEFAULT_PORT3); + } + + private: + void HandleLocal(); + void TcpConnected(wpi::uv::Tcp& tcp) final; + void ForceDisconnect(std::string_view reason) override; + void DoDisconnect(std::string_view reason) override; + + std::shared_ptr m_wire; + std::shared_ptr m_clientImpl; +}; + +class NetworkClient final : public NetworkClientBase { public: NetworkClient( int inst, std::string_view id, net::ILocalStorage& localStorage, @@ -37,40 +131,21 @@ class NetworkClient final : public INetworkClient { ~NetworkClient() final; void SetServers( - std::span> servers) final; - void Disconnect() final; - - void StartDSClient(unsigned int port) final; - void StopDSClient() final; - - void FlushLocal() final; - void Flush() final; + std::span> servers) final { + DoSetServers(servers, NT_DEFAULT_PORT4); + } private: - class Impl; - std::unique_ptr m_impl; -}; + void HandleLocal(); + void TcpConnected(wpi::uv::Tcp& tcp) final; + void WsConnected(wpi::WebSocket& ws, wpi::uv::Tcp& tcp); + void ForceDisconnect(std::string_view reason) override; + void DoDisconnect(std::string_view reason) override; -class NetworkClient3 final : public INetworkClient { - public: - NetworkClient3(int inst, std::string_view id, - net::ILocalStorage& localStorage, IConnectionList& connList, - wpi::Logger& logger); - ~NetworkClient3() final; - - void SetServers( - std::span> servers) final; - void Disconnect() final; - - void StartDSClient(unsigned int port) final; - void StopDSClient() final; - - void FlushLocal() final; - void Flush() final; - - private: - class Impl; - std::unique_ptr m_impl; + std::function + m_timeSyncUpdated; + std::shared_ptr m_wire; + std::unique_ptr m_clientImpl; }; } // namespace nt diff --git a/ntcore/src/main/native/cpp/NetworkServer.cpp b/ntcore/src/main/native/cpp/NetworkServer.cpp index 484bf9f76e..f3e26535df 100644 --- a/ntcore/src/main/native/cpp/NetworkServer.cpp +++ b/ntcore/src/main/native/cpp/NetworkServer.cpp @@ -17,11 +17,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include @@ -29,9 +27,6 @@ #include "IConnectionList.h" #include "InstanceImpl.h" #include "Log.h" -#include "net/Message.h" -#include "net/NetworkLoopQueue.h" -#include "net/ServerImpl.h" #include "net/WebSocketConnection.h" #include "net3/UvStreamConnection3.h" @@ -41,14 +36,10 @@ namespace uv = wpi::uv; // use a larger max message size for websockets static constexpr size_t kMaxMessageSize = 2 * 1024 * 1024; -namespace { - -class NSImpl; - -class ServerConnection { +class NetworkServer::ServerConnection { public: - ServerConnection(NSImpl& server, std::string_view addr, unsigned int port, - wpi::Logger& logger) + ServerConnection(NetworkServer& server, std::string_view addr, + unsigned int port, wpi::Logger& logger) : m_server{server}, m_connInfo{fmt::format("{}:{}", addr, port)}, m_logger{logger} { @@ -63,7 +54,7 @@ class ServerConnection { void UpdatePeriodicTimer(uint32_t repeatMs); void ConnectionClosed(); - NSImpl& m_server; + NetworkServer& m_server; ConnectionInfo m_info; std::string m_connInfo; wpi::Logger& m_logger; @@ -73,11 +64,21 @@ class ServerConnection { std::shared_ptr m_sendValuesTimer; }; -class ServerConnection4 final +class NetworkServer::ServerConnection3 : public ServerConnection { + public: + ServerConnection3(std::shared_ptr stream, NetworkServer& server, + std::string_view addr, unsigned int port, + wpi::Logger& logger); + + private: + std::shared_ptr m_wire; +}; + +class NetworkServer::ServerConnection4 final : public ServerConnection, public wpi::HttpWebSocketServerConnection { public: - ServerConnection4(std::shared_ptr stream, NSImpl& server, + ServerConnection4(std::shared_ptr stream, NetworkServer& server, std::string_view addr, unsigned int port, wpi::Logger& logger) : ServerConnection{server, addr, port, logger}, @@ -92,71 +93,7 @@ class ServerConnection4 final std::shared_ptr m_wire; }; -class ServerConnection3 : public ServerConnection { - public: - ServerConnection3(std::shared_ptr stream, NSImpl& server, - std::string_view addr, unsigned int port, - wpi::Logger& logger); - - private: - std::shared_ptr m_wire; -}; - -class NSImpl { - public: - NSImpl(std::string_view persistFilename, std::string_view listenAddress, - unsigned int port3, unsigned int port4, - net::ILocalStorage& localStorage, IConnectionList& connList, - wpi::Logger& logger, std::function initDone); - ~NSImpl(); - - void HandleLocal(); - void LoadPersistent(); - void SavePersistent(std::string_view filename, std::string_view data); - void Init(); - void AddConnection(ServerConnection* conn, const ConnectionInfo& info); - void RemoveConnection(ServerConnection* conn); - - net::ILocalStorage& m_localStorage; - IConnectionList& m_connList; - wpi::Logger& m_logger; - std::function m_initDone; - std::string m_persistentData; - std::string m_persistentFilename; - std::string m_listenAddress; - unsigned int m_port3; - unsigned int m_port4; - - // used only from loop - std::shared_ptr m_readLocalTimer; - std::shared_ptr m_savePersistentTimer; - std::shared_ptr> m_flushLocal; - std::shared_ptr> m_flush; - bool m_shutdown = false; - - std::vector m_localMsgs; - - net::ServerImpl m_serverImpl; - - // shared with user (must be atomic or mutex-protected) - std::atomic*> m_flushLocalAtomic{nullptr}; - std::atomic*> m_flushAtomic{nullptr}; - mutable wpi::mutex m_mutex; - struct Connection { - ServerConnection* conn; - int connHandle; - }; - std::vector m_connections; - - net::NetworkLoopQueue m_localQueue; - - wpi::EventLoopRunner m_loopRunner; - wpi::uv::Loop& m_loop; -}; - -} // namespace - -void ServerConnection::SetupPeriodicTimer() { +void NetworkServer::ServerConnection::SetupPeriodicTimer() { m_sendValuesTimer = uv::Timer::Create(m_server.m_loop); m_sendValuesTimer->timeout.connect([this] { m_server.HandleLocal(); @@ -164,7 +101,7 @@ void ServerConnection::SetupPeriodicTimer() { }); } -void ServerConnection::UpdatePeriodicTimer(uint32_t repeatMs) { +void NetworkServer::ServerConnection::UpdatePeriodicTimer(uint32_t repeatMs) { if (repeatMs == UINT32_MAX) { m_sendValuesTimer->Stop(); } else { @@ -173,7 +110,7 @@ void ServerConnection::UpdatePeriodicTimer(uint32_t repeatMs) { } } -void ServerConnection::ConnectionClosed() { +void NetworkServer::ServerConnection::ConnectionClosed() { // don't call back into m_server if it's being destroyed if (!m_sendValuesTimer->IsLoopClosing()) { m_server.m_serverImpl.RemoveClient(m_clientId); @@ -182,7 +119,54 @@ void ServerConnection::ConnectionClosed() { m_sendValuesTimer->Close(); } -void ServerConnection4::ProcessRequest() { +NetworkServer::ServerConnection3::ServerConnection3( + std::shared_ptr stream, NetworkServer& server, + std::string_view addr, unsigned int port, wpi::Logger& logger) + : ServerConnection{server, addr, port, logger}, + m_wire{std::make_shared(*stream)} { + m_info.remote_ip = addr; + m_info.remote_port = port; + + // TODO: set local flag appropriately + m_clientId = m_server.m_serverImpl.AddClient3( + m_connInfo, false, *m_wire, + [this](std::string_view name, uint16_t proto) { + m_info.remote_id = name; + m_info.protocol_version = proto; + m_server.AddConnection(this, m_info); + INFO("CONNECTED NT3 client '{}' (from {})", name, m_connInfo); + }, + [this](uint32_t repeatMs) { UpdatePeriodicTimer(repeatMs); }); + + stream->error.connect([this](uv::Error err) { + if (!m_wire->GetDisconnectReason().empty()) { + return; + } + m_wire->Disconnect(fmt::format("stream error: {}", err.name())); + m_wire->GetStream().Shutdown([this] { m_wire->GetStream().Close(); }); + }); + stream->end.connect([this] { + if (!m_wire->GetDisconnectReason().empty()) { + return; + } + m_wire->Disconnect("remote end closed connection"); + m_wire->GetStream().Shutdown([this] { m_wire->GetStream().Close(); }); + }); + stream->closed.connect([this] { + INFO("DISCONNECTED NT3 client '{}' (from {}): {}", m_info.remote_id, + m_connInfo, m_wire->GetDisconnectReason()); + ConnectionClosed(); + }); + stream->data.connect([this](uv::Buffer& buf, size_t size) { + m_server.m_serverImpl.ProcessIncomingBinary( + m_clientId, {reinterpret_cast(buf.base), size}); + }); + stream->StartRead(); + + SetupPeriodicTimer(); +} + +void NetworkServer::ServerConnection4::ProcessRequest() { DEBUG1("HTTP request: '{}'", m_request.GetUrl()); wpi::UrlParser url{m_request.GetUrl(), m_request.GetMethod() == wpi::HTTP_CONNECT}; @@ -219,7 +203,7 @@ void ServerConnection4::ProcessRequest() { } } -void ServerConnection4::ProcessWsUpgrade() { +void NetworkServer::ServerConnection4::ProcessWsUpgrade() { // get name from URL wpi::UrlParser url{m_request.GetUrl(), false}; std::string_view path; @@ -271,58 +255,12 @@ void ServerConnection4::ProcessWsUpgrade() { }); } -ServerConnection3::ServerConnection3(std::shared_ptr stream, - NSImpl& server, std::string_view addr, - unsigned int port, wpi::Logger& logger) - : ServerConnection{server, addr, port, logger}, - m_wire{std::make_shared(*stream)} { - m_info.remote_ip = addr; - m_info.remote_port = port; - - // TODO: set local flag appropriately - m_clientId = m_server.m_serverImpl.AddClient3( - m_connInfo, false, *m_wire, - [this](std::string_view name, uint16_t proto) { - m_info.remote_id = name; - m_info.protocol_version = proto; - m_server.AddConnection(this, m_info); - INFO("CONNECTED NT3 client '{}' (from {})", name, m_connInfo); - }, - [this](uint32_t repeatMs) { UpdatePeriodicTimer(repeatMs); }); - - stream->error.connect([this](uv::Error err) { - if (!m_wire->GetDisconnectReason().empty()) { - return; - } - m_wire->Disconnect(fmt::format("stream error: {}", err.name())); - m_wire->GetStream().Shutdown([this] { m_wire->GetStream().Close(); }); - }); - stream->end.connect([this] { - if (!m_wire->GetDisconnectReason().empty()) { - return; - } - m_wire->Disconnect("remote end closed connection"); - m_wire->GetStream().Shutdown([this] { m_wire->GetStream().Close(); }); - }); - stream->closed.connect([this] { - INFO("DISCONNECTED NT3 client '{}' (from {}): {}", m_info.remote_id, - m_connInfo, m_wire->GetDisconnectReason()); - ConnectionClosed(); - }); - stream->data.connect([this](uv::Buffer& buf, size_t size) { - m_server.m_serverImpl.ProcessIncomingBinary( - m_clientId, {reinterpret_cast(buf.base), size}); - }); - stream->StartRead(); - - SetupPeriodicTimer(); -} - -NSImpl::NSImpl(std::string_view persistentFilename, - std::string_view listenAddress, unsigned int port3, - unsigned int port4, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function initDone) +NetworkServer::NetworkServer(std::string_view persistentFilename, + std::string_view listenAddress, unsigned int port3, + unsigned int port4, + net::ILocalStorage& localStorage, + IConnectionList& connList, wpi::Logger& logger, + std::function initDone) : m_localStorage{localStorage}, m_connList{connList}, m_logger{logger}, @@ -347,16 +285,30 @@ NSImpl::NSImpl(std::string_view persistentFilename, }); } -NSImpl::~NSImpl() { +NetworkServer::~NetworkServer() { m_loopRunner.ExecAsync([this](uv::Loop&) { m_shutdown = true; }); + m_localStorage.ClearNetwork(); + m_connList.ClearConnections(); } -void NSImpl::HandleLocal() { +void NetworkServer::FlushLocal() { + if (auto async = m_flushLocalAtomic.load(std::memory_order_relaxed)) { + async->UnsafeSend(); + } +} + +void NetworkServer::Flush() { + if (auto async = m_flushAtomic.load(std::memory_order_relaxed)) { + async->UnsafeSend(); + } +} + +void NetworkServer::HandleLocal() { m_localQueue.ReadQueue(&m_localMsgs); m_serverImpl.HandleLocal(m_localMsgs); } -void NSImpl::LoadPersistent() { +void NetworkServer::LoadPersistent() { std::error_code ec; auto size = fs::file_size(m_persistentFilename, ec); wpi::raw_fd_istream is{m_persistentFilename, ec}; @@ -376,12 +328,13 @@ void NSImpl::LoadPersistent() { is.readinto(m_persistentData, size); DEBUG4("read data: {}", m_persistentData); if (is.has_error()) { - WARNING("error reading persistent file"); + WARN("error reading persistent file"); return; } } -void NSImpl::SavePersistent(std::string_view filename, std::string_view data) { +void NetworkServer::SavePersistent(std::string_view filename, + std::string_view data) { // write to temporary file auto tmp = fmt::format("{}.tmp", filename); std::error_code ec; @@ -409,13 +362,13 @@ void NSImpl::SavePersistent(std::string_view filename, std::string_view data) { } } -void NSImpl::Init() { +void NetworkServer::Init() { if (m_shutdown) { return; } auto errs = m_serverImpl.LoadPersistent(m_persistentData); if (!errs.empty()) { - WARNING("error reading persistent file: {}", errs); + WARN("error reading persistent file: {}", errs); } // set up timers @@ -535,13 +488,14 @@ void NSImpl::Init() { } } -void NSImpl::AddConnection(ServerConnection* conn, const ConnectionInfo& info) { +void NetworkServer::AddConnection(ServerConnection* conn, + const ConnectionInfo& info) { std::scoped_lock lock{m_mutex}; m_connections.emplace_back(Connection{conn, m_connList.AddConnection(info)}); m_serverImpl.ConnectionsChanged(m_connList.GetConnections()); } -void NSImpl::RemoveConnection(ServerConnection* conn) { +void NetworkServer::RemoveConnection(ServerConnection* conn) { std::scoped_lock lock{m_mutex}; auto it = std::find_if(m_connections.begin(), m_connections.end(), [=](auto&& c) { return c.conn == conn; }); @@ -551,40 +505,3 @@ void NSImpl::RemoveConnection(ServerConnection* conn) { m_serverImpl.ConnectionsChanged(m_connList.GetConnections()); } } - -class NetworkServer::Impl final : public NSImpl { - public: - Impl(std::string_view persistFilename, std::string_view listenAddress, - unsigned int port3, unsigned int port4, net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function initDone) - : NSImpl{persistFilename, listenAddress, port3, port4, - localStorage, connList, logger, std::move(initDone)} {} -}; - -NetworkServer::NetworkServer(std::string_view persistFilename, - std::string_view listenAddress, unsigned int port3, - unsigned int port4, - net::ILocalStorage& localStorage, - IConnectionList& connList, wpi::Logger& logger, - std::function initDone) - : m_impl{std::make_unique(persistFilename, listenAddress, port3, - port4, localStorage, connList, logger, - std::move(initDone))} {} - -NetworkServer::~NetworkServer() { - m_impl->m_localStorage.ClearNetwork(); - m_impl->m_connList.ClearConnections(); -} - -void NetworkServer::FlushLocal() { - if (auto async = m_impl->m_flushLocalAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} - -void NetworkServer::Flush() { - if (auto async = m_impl->m_flushAtomic.load(std::memory_order_relaxed)) { - async->UnsafeSend(); - } -} diff --git a/ntcore/src/main/native/cpp/NetworkServer.h b/ntcore/src/main/native/cpp/NetworkServer.h index b70c968d61..3f5a0947c7 100644 --- a/ntcore/src/main/native/cpp/NetworkServer.h +++ b/ntcore/src/main/native/cpp/NetworkServer.h @@ -4,10 +4,20 @@ #pragma once +#include #include #include +#include #include +#include +#include +#include +#include + +#include "net/Message.h" +#include "net/NetworkLoopQueue.h" +#include "net/ServerImpl.h" #include "ntcore_cpp.h" namespace wpi { @@ -35,8 +45,52 @@ class NetworkServer { void Flush(); private: - class Impl; - std::unique_ptr m_impl; + class ServerConnection; + class ServerConnection3; + class ServerConnection4; + + void HandleLocal(); + void LoadPersistent(); + void SavePersistent(std::string_view filename, std::string_view data); + void Init(); + void AddConnection(ServerConnection* conn, const ConnectionInfo& info); + void RemoveConnection(ServerConnection* conn); + + net::ILocalStorage& m_localStorage; + IConnectionList& m_connList; + wpi::Logger& m_logger; + std::function m_initDone; + std::string m_persistentData; + std::string m_persistentFilename; + std::string m_listenAddress; + unsigned int m_port3; + unsigned int m_port4; + + // used only from loop + std::shared_ptr m_readLocalTimer; + std::shared_ptr m_savePersistentTimer; + std::shared_ptr> m_flushLocal; + std::shared_ptr> m_flush; + bool m_shutdown = false; + + std::vector m_localMsgs; + + net::ServerImpl m_serverImpl; + + // shared with user (must be atomic or mutex-protected) + std::atomic*> m_flushLocalAtomic{nullptr}; + std::atomic*> m_flushAtomic{nullptr}; + mutable wpi::mutex m_mutex; + struct Connection { + ServerConnection* conn; + int connHandle; + }; + std::vector m_connections; + + net::NetworkLoopQueue m_localQueue; + + wpi::EventLoopRunner m_loopRunner; + wpi::uv::Loop& m_loop; }; } // namespace nt diff --git a/ntcore/src/main/native/cpp/ValueCircularBuffer.cpp b/ntcore/src/main/native/cpp/ValueCircularBuffer.cpp new file mode 100644 index 0000000000..f611a335d9 --- /dev/null +++ b/ntcore/src/main/native/cpp/ValueCircularBuffer.cpp @@ -0,0 +1,17 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include "ValueCircularBuffer.h" + +using namespace nt; + +std::vector ValueCircularBuffer::ReadValue() { + std::vector rv; + rv.reserve(m_storage.size()); + for (auto&& val : m_storage) { + rv.emplace_back(std::move(val)); + } + m_storage.reset(); + return rv; +} diff --git a/ntcore/src/main/native/cpp/ValueCircularBuffer.h b/ntcore/src/main/native/cpp/ValueCircularBuffer.h new file mode 100644 index 0000000000..b80b5db6bb --- /dev/null +++ b/ntcore/src/main/native/cpp/ValueCircularBuffer.h @@ -0,0 +1,49 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include +#include + +#include + +#include "Value_internal.h" +#include "networktables/NetworkTableValue.h" +#include "ntcore_cpp_types.h" + +namespace nt { + +class ValueCircularBuffer { + public: + explicit ValueCircularBuffer(size_t size) : m_storage{size} {} + + template + void emplace_back(Args&&... args) { + m_storage.emplace_back(std::forward(args...)); + } + + std::vector ReadValue(); + template + std::vector::Value>> Read(); + + private: + wpi::circular_buffer m_storage; +}; + +template +std::vector::Value>> +ValueCircularBuffer::Read() { + std::vector::Value>> rv; + rv.reserve(m_storage.size()); + for (auto&& val : m_storage) { + if (IsNumericConvertibleTo(val) || IsType(val)) { + rv.emplace_back(GetTimestamped(val)); + } + } + m_storage.reset(); + return rv; +} + +} // namespace nt diff --git a/ntcore/src/main/native/cpp/Value_internal.cpp b/ntcore/src/main/native/cpp/Value_internal.cpp index 2003d31f57..c947bdcbe9 100644 --- a/ntcore/src/main/native/cpp/Value_internal.cpp +++ b/ntcore/src/main/native/cpp/Value_internal.cpp @@ -26,19 +26,19 @@ Value nt::ConvertNumericValue(const Value& value, NT_Type type) { return newval; } case NT_INTEGER_ARRAY: { - Value newval = Value::MakeIntegerArray(GetNumericArrayAs(value), - value.time()); + Value newval = Value::MakeIntegerArray( + GetNumericArrayAs(value), value.time()); newval.SetServerTime(value.server_time()); return newval; } case NT_FLOAT_ARRAY: { - Value newval = - Value::MakeFloatArray(GetNumericArrayAs(value), value.time()); + Value newval = Value::MakeFloatArray(GetNumericArrayAs(value), + value.time()); newval.SetServerTime(value.server_time()); return newval; } case NT_DOUBLE_ARRAY: { - Value newval = Value::MakeDoubleArray(GetNumericArrayAs(value), + Value newval = Value::MakeDoubleArray(GetNumericArrayAs(value), value.time()); newval.SetServerTime(value.server_time()); return newval; diff --git a/ntcore/src/main/native/cpp/Value_internal.h b/ntcore/src/main/native/cpp/Value_internal.h index 03532ac61f..8f2c8fb138 100644 --- a/ntcore/src/main/native/cpp/Value_internal.h +++ b/ntcore/src/main/native/cpp/Value_internal.h @@ -4,20 +4,413 @@ #pragma once +#include #include #include +#include #include #include +#include #include #include #include "networktables/NetworkTableValue.h" #include "ntcore_c.h" +#include "ntcore_cpp_types.h" namespace nt { -class Value; +template +struct TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_BOOLEAN; + + using Value = bool; + using View = bool; +}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_INTEGER; + + using Value = int64_t; + using View = int64_t; +}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_FLOAT; + + using Value = float; + using View = float; +}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_DOUBLE; + + using Value = double; + using View = double; +}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_STRING; + + using Value = std::string; + using View = std::string_view; + + using SmallRet = std::string_view; + using SmallElem = char; +}; + +template <> +struct TypeInfo : public TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_RAW; + + using Value = std::vector; + using View = std::span; + + using SmallRet = std::span; + using SmallElem = uint8_t; +}; + +template <> +struct TypeInfo> : public TypeInfo {}; +template <> +struct TypeInfo> : public TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_BOOLEAN_ARRAY; + using ElementType = bool; + + using Value = std::vector; + using View = std::span; + + using SmallRet = std::span; + using SmallElem = int; +}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_INTEGER_ARRAY; + using ElementType = int64_t; + + using Value = std::vector; + using View = std::span; + + using SmallRet = std::span; + using SmallElem = int64_t; +}; + +template <> +struct TypeInfo> : public TypeInfo {}; +template <> +struct TypeInfo> : public TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_FLOAT_ARRAY; + using ElementType = float; + + using Value = std::vector; + using View = std::span; + + using SmallRet = std::span; + using SmallElem = float; +}; + +template <> +struct TypeInfo> : public TypeInfo {}; +template <> +struct TypeInfo> : public TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_DOUBLE_ARRAY; + using ElementType = double; + + using Value = std::vector; + using View = std::span; + + using SmallRet = std::span; + using SmallElem = double; +}; + +template <> +struct TypeInfo> : public TypeInfo {}; +template <> +struct TypeInfo> : public TypeInfo {}; + +template <> +struct TypeInfo { + static constexpr NT_Type kType = NT_STRING_ARRAY; + using ElementType = std::string; + + using Value = std::vector; + using View = std::span; +}; + +template <> +struct TypeInfo> : public TypeInfo {}; +template <> +struct TypeInfo> : public TypeInfo { +}; + +template +concept ValidType = requires { + { TypeInfo>::kType } -> std::convertible_to; + typename TypeInfo>::Value; + typename TypeInfo>::View; +}; + +static_assert(ValidType); +static_assert(!ValidType); +static_assert(ValidType); +static_assert(ValidType>); + +template +constexpr bool IsNTType = TypeInfo>::kType == type; + +static_assert(IsNTType); +static_assert(!IsNTType); + +template +concept ArrayType = + requires { typename TypeInfo>::ElementType; }; + +static_assert(ArrayType); +static_assert(!ArrayType); + +template +concept SmallArrayType = requires { + typename TypeInfo>::SmallRet; + typename TypeInfo>::SmallElem; +}; + +static_assert(SmallArrayType); +static_assert(!SmallArrayType); + +template +concept NumericType = + IsNTType || IsNTType || IsNTType; + +static_assert(NumericType); +static_assert(NumericType); +static_assert(NumericType); +static_assert(!NumericType); +static_assert(!NumericType); +static_assert(!NumericType); +static_assert(!NumericType); +static_assert(!NumericType); +static_assert(!NumericType); + +template +concept NumericArrayType = + ArrayType && + NumericType>::ElementType>; + +static_assert(NumericArrayType); +static_assert(NumericArrayType); +static_assert(NumericArrayType); +static_assert(!NumericArrayType); + +template +inline typename TypeInfo::Value GetNumericAs(const Value& value) { + if (value.IsInteger()) { + return static_cast::Value>(value.GetInteger()); + } else if (value.IsFloat()) { + return static_cast::Value>(value.GetFloat()); + } else if (value.IsDouble()) { + return static_cast::Value>(value.GetDouble()); + } else { + return {}; + } +} + +template +typename TypeInfo::Value GetNumericArrayAs(const Value& value) { + if (value.IsIntegerArray()) { + auto arr = value.GetIntegerArray(); + return {arr.begin(), arr.end()}; + } else if (value.IsFloatArray()) { + auto arr = value.GetFloatArray(); + return {arr.begin(), arr.end()}; + } else if (value.IsDoubleArray()) { + auto arr = value.GetDoubleArray(); + return {arr.begin(), arr.end()}; + } else { + return {}; + } +} + +template +inline bool IsType(const Value& value) { + return value.type() == TypeInfo::kType; +} + +template +inline bool IsNumericConvertibleTo(const Value& value) { + if constexpr (NumericType) { + return value.IsInteger() || value.IsFloat() || value.IsDouble(); + } else if constexpr (NumericArrayType) { + return value.IsIntegerArray() || value.IsFloatArray() || + value.IsDoubleArray(); + } else { + return false; + } +} + +template +inline typename TypeInfo::View GetValueView(const Value& value) { + if constexpr (IsNTType) { + return value.GetBoolean(); + } else if constexpr (IsNTType) { + return value.GetInteger(); + } else if constexpr (IsNTType) { + return value.GetFloat(); + } else if constexpr (IsNTType) { + return value.GetDouble(); + } else if constexpr (IsNTType) { + return value.GetString(); + } else if constexpr (IsNTType) { + return value.GetRaw(); + } else if constexpr (IsNTType) { + return value.GetBooleanArray(); + } else if constexpr (IsNTType) { + return value.GetIntegerArray(); + } else if constexpr (IsNTType) { + return value.GetFloatArray(); + } else if constexpr (IsNTType) { + return value.GetDoubleArray(); + } else if constexpr (IsNTType) { + return value.GetStringArray(); + } +} + +template +inline Value MakeValue(typename TypeInfo::View value, int64_t time) { + if constexpr (IsNTType) { + return Value::MakeBoolean(value, time); + } else if constexpr (IsNTType) { + return Value::MakeInteger(value, time); + } else if constexpr (IsNTType) { + return Value::MakeFloat(value, time); + } else if constexpr (IsNTType) { + return Value::MakeDouble(value, time); + } else if constexpr (IsNTType) { + return Value::MakeString(value, time); + } else if constexpr (IsNTType) { + return Value::MakeRaw(value, time); + } else if constexpr (IsNTType) { + return Value::MakeBooleanArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeIntegerArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeFloatArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeDoubleArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeStringArray(value, time); + } +} + +template + requires ArrayType || IsNTType || IsNTType +inline Value MakeValue(typename TypeInfo::Value&& value, int64_t time) { + if constexpr (IsNTType) { + return Value::MakeString(value, time); + } else if constexpr (IsNTType) { + return Value::MakeRaw(value, time); + } else if constexpr (IsNTType) { + return Value::MakeBooleanArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeIntegerArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeFloatArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeDoubleArray(value, time); + } else if constexpr (IsNTType) { + return Value::MakeStringArray(value, time); + } +} + +template +inline typename TypeInfo::Value CopyValue(typename TypeInfo::View value) { + if constexpr (ArrayType || IsNTType) { + return {value.begin(), value.end()}; + } else if constexpr (IsNTType) { + return std::string{value}; + } else { + return value; + } +} + +template +inline typename TypeInfo::SmallRet CopyValue( + typename TypeInfo::View arr, + wpi::SmallVectorImpl::SmallElem>& buf) { + buf.assign(arr.begin(), arr.end()); + return {buf.data(), buf.size()}; +} + +template +inline typename TypeInfo::Value GetValueCopy(const Value& value) { + if constexpr (ConvertNumeric && NumericType) { + return GetNumericAs(value); + } else if constexpr (ConvertNumeric && NumericArrayType) { + return GetNumericArrayAs(value); + } else { + return CopyValue(GetValueView(value)); + } +} + +template +inline typename TypeInfo::SmallRet GetValueCopy( + const Value& value, + wpi::SmallVectorImpl::SmallElem>& buf) { + if constexpr (ConvertNumeric && NumericArrayType) { + if (value.IsIntegerArray()) { + auto arr = value.GetIntegerArray(); + buf.assign(arr.begin(), arr.end()); + return {buf.data(), buf.size()}; + } else if (value.IsFloatArray()) { + auto arr = value.GetFloatArray(); + buf.assign(arr.begin(), arr.end()); + return {buf.data(), buf.size()}; + } else if (value.IsDoubleArray()) { + auto arr = value.GetDoubleArray(); + buf.assign(arr.begin(), arr.end()); + return {buf.data(), buf.size()}; + } else { + return {}; + } + } else { + return CopyValue(GetValueView(value), buf); + } +} + +template +inline Timestamped::Value> GetTimestamped( + const Value& value) { + return {value.time(), value.server_time(), + GetValueCopy(value)}; +} + +template +inline Timestamped::SmallRet> GetTimestamped( + const Value& value, + wpi::SmallVectorImpl::SmallElem>& buf) { + return {value.time(), value.server_time(), + GetValueCopy(value, buf)}; +} template inline void ConvertToC(const T& in, T* out) { @@ -57,35 +450,6 @@ O* ConvertToC(const std::basic_string& in, size_t* out_len) { return out; } -template -T GetNumericAs(const Value& value) { - if (value.IsInteger()) { - return static_cast(value.GetInteger()); - } else if (value.IsFloat()) { - return static_cast(value.GetFloat()); - } else if (value.IsDouble()) { - return static_cast(value.GetDouble()); - } else { - return {}; - } -} - -template -std::vector GetNumericArrayAs(const Value& value) { - if (value.IsIntegerArray()) { - auto arr = value.GetIntegerArray(); - return {arr.begin(), arr.end()}; - } else if (value.IsFloatArray()) { - auto arr = value.GetFloatArray(); - return {arr.begin(), arr.end()}; - } else if (value.IsDoubleArray()) { - auto arr = value.GetDoubleArray(); - return {arr.begin(), arr.end()}; - } else { - return {}; - } -} - Value ConvertNumericValue(const Value& value, NT_Type type); } // namespace nt diff --git a/ntcore/src/main/native/cpp/VectorSet.h b/ntcore/src/main/native/cpp/VectorSet.h new file mode 100644 index 0000000000..9e13490736 --- /dev/null +++ b/ntcore/src/main/native/cpp/VectorSet.h @@ -0,0 +1,21 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +namespace nt { + +// Utility wrapper for making a set-like vector +template +class VectorSet : public std::vector { + public: + using iterator = typename std::vector::iterator; + void Add(T value) { this->push_back(value); } + // returns true if element was present + bool Remove(T value) { return std::erase(*this, value) != 0; } +}; + +} // namespace nt diff --git a/ntcore/src/main/native/cpp/net/ClientImpl.cpp b/ntcore/src/main/native/cpp/net/ClientImpl.cpp index 17ec807edb..71849ad5c9 100644 --- a/ntcore/src/main/native/cpp/net/ClientImpl.cpp +++ b/ntcore/src/main/native/cpp/net/ClientImpl.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -19,9 +18,7 @@ #include "Log.h" #include "Message.h" #include "NetworkInterface.h" -#include "PubSubOptions.h" #include "WireConnection.h" -#include "WireDecoder.h" #include "WireEncoder.h" #include "networktables/NetworkTableValue.h" @@ -34,79 +31,7 @@ static constexpr uint32_t kMinPeriodMs = 5; // transmission before we close the connection static constexpr uint32_t kWireMaxNotReadyUs = 1000000; -namespace { - -struct PublisherData { - NT_Publisher handle; - PubSubOptionsImpl options; - // in options as double, but copy here as integer; rounded to the nearest - // 10 ms - uint32_t periodMs; - uint64_t nextSendMs{0}; - std::vector outValues; // outgoing values -}; - -class CImpl : public ServerMessageHandler { - public: - CImpl(uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger, - std::function - timeSyncUpdated, - std::function setPeriodic); - - void ProcessIncomingBinary(uint64_t curTimeMs, std::span data); - void HandleLocal(std::vector&& msgs); - bool SendControl(uint64_t curTimeMs); - void SendValues(uint64_t curTimeMs, bool flush); - void SendInitialValues(); - bool CheckNetworkReady(uint64_t curTimeMs); - - // ServerMessageHandler interface - void ServerAnnounce(std::string_view name, int64_t id, - std::string_view typeStr, const wpi::json& properties, - std::optional pubuid) final; - void ServerUnannounce(std::string_view name, int64_t id) final; - void ServerPropertiesUpdate(std::string_view name, const wpi::json& update, - bool ack) final; - - void Publish(NT_Publisher pubHandle, NT_Topic topicHandle, - std::string_view name, std::string_view typeStr, - const wpi::json& properties, const PubSubOptionsImpl& options); - bool Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle); - void SetValue(NT_Publisher pubHandle, const Value& value); - - int m_inst; - WireConnection& m_wire; - wpi::Logger& m_logger; - LocalInterface* m_local{nullptr}; - std::function - m_timeSyncUpdated; - std::function m_setPeriodic; - - // indexed by publisher index - std::vector> m_publishers; - - // indexed by server-provided topic id - wpi::DenseMap m_topicMap; - - // timestamp handling - static constexpr uint32_t kPingIntervalMs = 3000; - uint64_t m_nextPingTimeMs{0}; - uint64_t m_pongTimeMs{0}; - uint32_t m_rtt2Us{UINT32_MAX}; - bool m_haveTimeOffset{false}; - int64_t m_serverTimeOffsetUs{0}; - - // periodic sweep handling - uint32_t m_periodMs{kPingIntervalMs + 10}; - uint64_t m_lastSendMs{0}; - - // outgoing queue - std::vector m_outgoing; -}; - -} // namespace - -CImpl::CImpl( +ClientImpl::ClientImpl( uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger, std::function timeSyncUpdated, @@ -126,8 +51,8 @@ CImpl::CImpl( m_setPeriodic(m_periodMs); } -void CImpl::ProcessIncomingBinary(uint64_t curTimeMs, - std::span data) { +void ClientImpl::ProcessIncomingBinary(uint64_t curTimeMs, + std::span data) { for (;;) { if (data.empty()) { break; @@ -138,7 +63,7 @@ void CImpl::ProcessIncomingBinary(uint64_t curTimeMs, Value value; std::string error; if (!WireDecodeBinary(&data, &id, &value, &error, -m_serverTimeOffsetUs)) { - ERROR("binary decode error: {}", error); + ERR("binary decode error: {}", error); break; // FIXME } DEBUG4("BinaryMessage({})", id); @@ -146,8 +71,8 @@ void CImpl::ProcessIncomingBinary(uint64_t curTimeMs, // handle RTT ping response if (id == -1) { if (!value.IsInteger()) { - WARNING("RTT ping response with non-integer type {}", - static_cast(value.type())); + WARN("RTT ping response with non-integer type {}", + static_cast(value.type())); continue; } DEBUG4("RTT ping response time {} value {}", value.time(), @@ -168,7 +93,7 @@ void CImpl::ProcessIncomingBinary(uint64_t curTimeMs, // otherwise it's a value message, get the local topic handle for it auto topicIt = m_topicMap.find(id); if (topicIt == m_topicMap.end()) { - WARNING("received unknown id {}", id); + WARN("received unknown id {}", id); continue; } @@ -179,7 +104,7 @@ void CImpl::ProcessIncomingBinary(uint64_t curTimeMs, } } -void CImpl::HandleLocal(std::vector&& msgs) { +void ClientImpl::HandleLocal(std::vector&& msgs) { DEBUG4("HandleLocal()"); for (auto&& elem : msgs) { // common case is value @@ -200,7 +125,7 @@ void CImpl::HandleLocal(std::vector&& msgs) { } } -bool CImpl::SendControl(uint64_t curTimeMs) { +bool ClientImpl::DoSendControl(uint64_t curTimeMs) { DEBUG4("SendControl({})", curTimeMs); // rate limit sends @@ -246,7 +171,7 @@ bool CImpl::SendControl(uint64_t curTimeMs) { return true; } -void CImpl::SendValues(uint64_t curTimeMs, bool flush) { +void ClientImpl::DoSendValues(uint64_t curTimeMs, bool flush) { DEBUG4("SendValues({})", curTimeMs); // can't send value updates until we have a RTT @@ -255,7 +180,7 @@ void CImpl::SendValues(uint64_t curTimeMs, bool flush) { } // ensure all control messages are sent ahead of value updates - if (!SendControl(curTimeMs)) { + if (!DoSendControl(curTimeMs)) { return; } @@ -291,11 +216,11 @@ void CImpl::SendValues(uint64_t curTimeMs, bool flush) { } } -void CImpl::SendInitialValues() { +void ClientImpl::SendInitialValues() { DEBUG4("SendInitialValues()"); // ensure all control messages are sent ahead of value updates - if (!SendControl(0)) { + if (!DoSendControl(0)) { return; } @@ -321,7 +246,7 @@ void CImpl::SendInitialValues() { } } -bool CImpl::CheckNetworkReady(uint64_t curTimeMs) { +bool ClientImpl::CheckNetworkReady(uint64_t curTimeMs) { if (!m_wire.Ready()) { uint64_t lastFlushTime = m_wire.GetLastFlushTime(); uint64_t now = wpi::Now(); @@ -333,10 +258,10 @@ bool CImpl::CheckNetworkReady(uint64_t curTimeMs) { return true; } -void CImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, - std::string_view name, std::string_view typeStr, - const wpi::json& properties, - const PubSubOptionsImpl& options) { +void ClientImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, + std::string_view name, std::string_view typeStr, + const wpi::json& properties, + const PubSubOptionsImpl& options) { unsigned int index = Handle{pubHandle}.GetIndex(); if (index >= m_publishers.size()) { m_publishers.resize(index + 1); @@ -360,7 +285,7 @@ void CImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, m_setPeriodic(m_periodMs); } -bool CImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { +bool ClientImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { unsigned int index = Handle{pubHandle}.GetIndex(); if (index >= m_publishers.size()) { return false; @@ -400,7 +325,7 @@ bool CImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { return doSend; } -void CImpl::SetValue(NT_Publisher pubHandle, const Value& value) { +void ClientImpl::SetValue(NT_Publisher pubHandle, const Value& value) { DEBUG4("SetValue({}, time={}, server_time={}, st_off={})", pubHandle, value.time(), value.server_time(), m_serverTimeOffsetUs); unsigned int index = Handle{pubHandle}.GetIndex(); @@ -415,10 +340,10 @@ void CImpl::SetValue(NT_Publisher pubHandle, const Value& value) { } } -void CImpl::ServerAnnounce(std::string_view name, int64_t id, - std::string_view typeStr, - const wpi::json& properties, - std::optional pubuid) { +void ClientImpl::ServerAnnounce(std::string_view name, int64_t id, + std::string_view typeStr, + const wpi::json& properties, + std::optional pubuid) { DEBUG4("ServerAnnounce({}, {}, {})", name, id, typeStr); assert(m_local); NT_Publisher pubHandle{0}; @@ -429,76 +354,38 @@ void CImpl::ServerAnnounce(std::string_view name, int64_t id, m_local->NetworkAnnounce(name, typeStr, properties, pubHandle); } -void CImpl::ServerUnannounce(std::string_view name, int64_t id) { +void ClientImpl::ServerUnannounce(std::string_view name, int64_t id) { DEBUG4("ServerUnannounce({}, {})", name, id); assert(m_local); m_local->NetworkUnannounce(name); m_topicMap.erase(id); } -void CImpl::ServerPropertiesUpdate(std::string_view name, - const wpi::json& update, bool ack) { +void ClientImpl::ServerPropertiesUpdate(std::string_view name, + const wpi::json& update, bool ack) { DEBUG4("ServerProperties({}, {}, {})", name, update.dump(), ack); assert(m_local); m_local->NetworkPropertiesUpdate(name, update, ack); } -class ClientImpl::Impl final : public CImpl { - public: - Impl(uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger, - std::function - timeSyncUpdated, - std::function setPeriodic) - : CImpl{curTimeMs, - inst, - wire, - logger, - std::move(timeSyncUpdated), - std::move(setPeriodic)} {} -}; - -ClientImpl::ClientImpl( - uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger, - std::function - timeSyncUpdated, - std::function setPeriodic) - : m_impl{std::make_unique(curTimeMs, inst, wire, logger, - std::move(timeSyncUpdated), - std::move(setPeriodic))} {} - -ClientImpl::~ClientImpl() = default; - void ClientImpl::ProcessIncomingText(std::string_view data) { - if (!m_impl->m_local) { + if (!m_local) { return; } - WireDecodeText(data, *m_impl, m_impl->m_logger); -} - -void ClientImpl::ProcessIncomingBinary(uint64_t curTimeMs, - std::span data) { - m_impl->ProcessIncomingBinary(curTimeMs, data); -} - -void ClientImpl::HandleLocal(std::vector&& msgs) { - m_impl->HandleLocal(std::move(msgs)); + WireDecodeText(data, *this, m_logger); } void ClientImpl::SendControl(uint64_t curTimeMs) { - m_impl->SendControl(curTimeMs); - m_impl->m_wire.Flush(); + DoSendControl(curTimeMs); + m_wire.Flush(); } void ClientImpl::SendValues(uint64_t curTimeMs, bool flush) { - m_impl->SendValues(curTimeMs, flush); - m_impl->m_wire.Flush(); -} - -void ClientImpl::SetLocal(LocalInterface* local) { - m_impl->m_local = local; + DoSendValues(curTimeMs, flush); + m_wire.Flush(); } void ClientImpl::SendInitial() { - m_impl->SendInitialValues(); - m_impl->m_wire.Flush(); + SendInitialValues(); + m_wire.Flush(); } diff --git a/ntcore/src/main/native/cpp/net/ClientImpl.h b/ntcore/src/main/native/cpp/net/ClientImpl.h index d6eb1d939c..6e97e8dd39 100644 --- a/ntcore/src/main/native/cpp/net/ClientImpl.h +++ b/ntcore/src/main/native/cpp/net/ClientImpl.h @@ -13,8 +13,12 @@ #include #include +#include + #include "NetworkInterface.h" +#include "PubSubOptions.h" #include "WireConnection.h" +#include "WireDecoder.h" namespace wpi { class Logger; @@ -30,14 +34,13 @@ namespace nt::net { struct ClientMessage; class WireConnection; -class ClientImpl { +class ClientImpl final : private ServerMessageHandler { public: ClientImpl( uint64_t curTimeMs, int inst, WireConnection& wire, wpi::Logger& logger, std::function timeSyncUpdated, std::function setPeriodic); - ~ClientImpl(); void ProcessIncomingText(std::string_view data); void ProcessIncomingBinary(uint64_t curTimeMs, std::span data); @@ -46,12 +49,67 @@ class ClientImpl { void SendControl(uint64_t curTimeMs); void SendValues(uint64_t curTimeMs, bool flush); - void SetLocal(LocalInterface* local); + void SetLocal(LocalInterface* local) { m_local = local; } void SendInitial(); private: - class Impl; - std::unique_ptr m_impl; + struct PublisherData { + NT_Publisher handle; + PubSubOptionsImpl options; + // in options as double, but copy here as integer; rounded to the nearest + // 10 ms + uint32_t periodMs; + uint64_t nextSendMs{0}; + std::vector outValues; // outgoing values + }; + + bool DoSendControl(uint64_t curTimeMs); + void DoSendValues(uint64_t curTimeMs, bool flush); + void SendInitialValues(); + bool CheckNetworkReady(uint64_t curTimeMs); + + // ServerMessageHandler interface + void ServerAnnounce(std::string_view name, int64_t id, + std::string_view typeStr, const wpi::json& properties, + std::optional pubuid) final; + void ServerUnannounce(std::string_view name, int64_t id) final; + void ServerPropertiesUpdate(std::string_view name, const wpi::json& update, + bool ack) final; + + void Publish(NT_Publisher pubHandle, NT_Topic topicHandle, + std::string_view name, std::string_view typeStr, + const wpi::json& properties, const PubSubOptionsImpl& options); + bool Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle); + void SetValue(NT_Publisher pubHandle, const Value& value); + + int m_inst; + WireConnection& m_wire; + wpi::Logger& m_logger; + LocalInterface* m_local{nullptr}; + std::function + m_timeSyncUpdated; + std::function m_setPeriodic; + + // indexed by publisher index + std::vector> m_publishers; + + // indexed by server-provided topic id + wpi::DenseMap m_topicMap; + + // timestamp handling + static constexpr uint32_t kPingIntervalMs = 3000; + uint64_t m_nextPingTimeMs{0}; + uint64_t m_pongTimeMs{0}; + uint32_t m_rtt2Us{UINT32_MAX}; + bool m_haveTimeOffset{false}; + int64_t m_serverTimeOffsetUs{0}; + + // periodic sweep handling + uint32_t m_periodMs{kPingIntervalMs + 10}; + uint64_t m_lastSendMs{0}; + + // outgoing queue + std::vector m_outgoing; }; } // namespace nt::net diff --git a/ntcore/src/main/native/cpp/net/ServerImpl.cpp b/ntcore/src/main/native/cpp/net/ServerImpl.cpp index 3bc9165171..14dcfed2c4 100644 --- a/ntcore/src/main/native/cpp/net/ServerImpl.cpp +++ b/ntcore/src/main/native/cpp/net/ServerImpl.cpp @@ -14,30 +14,18 @@ #include #include -#include #include #include #include -#include -#include -#include #include #include #include #include "IConnectionList.h" #include "Log.h" -#include "Message.h" #include "NetworkInterface.h" -#include "PubSubOptions.h" #include "Types_internal.h" -#include "WireConnection.h" -#include "WireDecoder.h" -#include "WireEncoder.h" -#include "net3/Message3.h" -#include "net3/SequenceNumber.h" #include "net3/WireConnection3.h" -#include "net3/WireDecoder3.h" #include "net3/WireEncoder3.h" #include "networktables/NetworkTableValue.h" #include "ntcore_c.h" @@ -46,396 +34,11 @@ using namespace nt; using namespace nt::net; using namespace mpack; -static constexpr uint32_t kMinPeriodMs = 5; - // maximum amount of time the wire can be not ready to send another // transmission before we close the connection static constexpr uint32_t kWireMaxNotReadyUs = 1000000; namespace { - -// Utility wrapper for making a set-like vector -template -class VectorSet : public std::vector { - public: - using iterator = typename std::vector::iterator; - void Add(T value) { this->push_back(value); } - // returns true if element was present - bool Remove(T value) { - auto removeIt = std::remove(this->begin(), this->end(), value); - if (removeIt == this->end()) { - return false; - } - this->erase(removeIt, this->end()); - return true; - } -}; - -struct PublisherData; -struct SubscriberData; -struct TopicData; -class SImpl; - -class ClientData { - public: - ClientData(std::string_view name, std::string_view connInfo, bool local, - ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, int id, - wpi::Logger& logger) - : m_name{name}, - m_connInfo{connInfo}, - m_local{local}, - m_setPeriodic{std::move(setPeriodic)}, - m_server{server}, - m_id{id}, - m_logger{logger} {} - virtual ~ClientData() = default; - - virtual void ProcessIncomingText(std::string_view data) = 0; - virtual void ProcessIncomingBinary(std::span data) = 0; - - enum SendMode { kSendDisabled = 0, kSendAll, kSendNormal, kSendImmNoFlush }; - - virtual void SendValue(TopicData* topic, const Value& value, - SendMode mode) = 0; - virtual void SendAnnounce(TopicData* topic, - std::optional pubuid) = 0; - virtual void SendUnannounce(TopicData* topic) = 0; - virtual void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack) = 0; - virtual void SendOutgoing(uint64_t curTimeMs) = 0; - virtual void Flush() = 0; - - void UpdateMetaClientPub(); - void UpdateMetaClientSub(); - - std::span GetSubscribers( - std::string_view name, bool special, - wpi::SmallVectorImpl& buf); - - std::string_view GetName() const { return m_name; } - int GetId() const { return m_id; } - - protected: - std::string m_name; - std::string m_connInfo; - bool m_local; // local to machine - ServerImpl::SetPeriodicFunc m_setPeriodic; - // TODO: make this per-topic? - uint32_t m_periodMs{UINT32_MAX}; - uint64_t m_lastSendMs{0}; - SImpl& m_server; - int m_id; - - wpi::Logger& m_logger; - - wpi::DenseMap> m_publishers; - wpi::DenseMap> m_subscribers; - - public: - // meta topics - TopicData* m_metaPub = nullptr; - TopicData* m_metaSub = nullptr; -}; - -class ClientData4Base : public ClientData, protected ClientMessageHandler { - public: - ClientData4Base(std::string_view name, std::string_view connInfo, bool local, - ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, - int id, wpi::Logger& logger) - : ClientData{name, connInfo, local, setPeriodic, server, id, logger} {} - - protected: - // ClientMessageHandler interface - void ClientPublish(int64_t pubuid, std::string_view name, - std::string_view typeStr, - const wpi::json& properties) final; - void ClientUnpublish(int64_t pubuid) final; - void ClientSetProperties(std::string_view name, - const wpi::json& update) final; - void ClientSubscribe(int64_t subuid, std::span topicNames, - const PubSubOptionsImpl& options) final; - void ClientUnsubscribe(int64_t subuid) final; - - void ClientSetValue(int64_t pubuid, const Value& value); - - wpi::DenseMap m_announceSent; -}; - -class ClientDataLocal final : public ClientData4Base { - public: - ClientDataLocal(SImpl& server, int id, wpi::Logger& logger) - : ClientData4Base{"", "", true, [](uint32_t) {}, server, id, logger} {} - - void ProcessIncomingText(std::string_view data) final {} - void ProcessIncomingBinary(std::span data) final {} - - void SendValue(TopicData* topic, const Value& value, SendMode mode) final; - void SendAnnounce(TopicData* topic, std::optional pubuid) final; - void SendUnannounce(TopicData* topic) final; - void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack) final; - void SendOutgoing(uint64_t curTimeMs) final {} - void Flush() final {} - - void HandleLocal(std::span msgs); -}; - -class ClientData4 final : public ClientData4Base { - public: - ClientData4(std::string_view name, std::string_view connInfo, bool local, - WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic, - SImpl& server, int id, wpi::Logger& logger) - : ClientData4Base{name, connInfo, local, setPeriodic, server, id, logger}, - m_wire{wire} {} - - void ProcessIncomingText(std::string_view data) final; - void ProcessIncomingBinary(std::span data) final; - - void SendValue(TopicData* topic, const Value& value, SendMode mode) final; - void SendAnnounce(TopicData* topic, std::optional pubuid) final; - void SendUnannounce(TopicData* topic) final; - void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack) final; - void SendOutgoing(uint64_t curTimeMs) final; - - void Flush() final; - - public: - WireConnection& m_wire; - - private: - std::vector m_outgoing; - wpi::DenseMap m_outgoingValueMap; - - bool WriteBinary(int64_t id, int64_t time, const Value& value) { - return WireEncodeBinary(SendBinary().Add(), id, time, value); - } - - TextWriter& SendText() { - m_outBinary.reset(); // ensure proper interleaving of text and binary - if (!m_outText) { - m_outText = m_wire.SendText(); - } - return *m_outText; - } - - BinaryWriter& SendBinary() { - m_outText.reset(); // ensure proper interleaving of text and binary - if (!m_outBinary) { - m_outBinary = m_wire.SendBinary(); - } - return *m_outBinary; - } - - // valid when we are actively writing to this client - std::optional m_outText; - std::optional m_outBinary; -}; - -class ClientData3 final : public ClientData, private net3::MessageHandler3 { - public: - ClientData3(std::string_view connInfo, bool local, - net3::WireConnection3& wire, ServerImpl::Connected3Func connected, - ServerImpl::SetPeriodicFunc setPeriodic, SImpl& server, int id, - wpi::Logger& logger) - : ClientData{"", connInfo, local, setPeriodic, server, id, logger}, - m_connected{std::move(connected)}, - m_wire{wire}, - m_decoder{*this} {} - - void ProcessIncomingText(std::string_view data) final {} - void ProcessIncomingBinary(std::span data) final; - - void SendValue(TopicData* topic, const Value& value, SendMode mode) final; - void SendAnnounce(TopicData* topic, std::optional pubuid) final; - void SendUnannounce(TopicData* topic) final; - void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, - bool ack) final; - void SendOutgoing(uint64_t curTimeMs) final; - - void Flush() final { m_wire.Flush(); } - - private: - // MessageHandler3 interface - void KeepAlive() final; - void ServerHelloDone() final; - void ClientHelloDone() final; - void ClearEntries() final; - void ProtoUnsup(unsigned int proto_rev) final; - void ClientHello(std::string_view self_id, unsigned int proto_rev) final; - void ServerHello(unsigned int flags, std::string_view self_id) final; - void EntryAssign(std::string_view name, unsigned int id, unsigned int seq_num, - const Value& value, unsigned int flags) final; - void EntryUpdate(unsigned int id, unsigned int seq_num, - const Value& value) final; - void FlagsUpdate(unsigned int id, unsigned int flags) final; - void EntryDelete(unsigned int id) final; - void ExecuteRpc(unsigned int id, unsigned int uid, - std::span params) final {} - void RpcResponse(unsigned int id, unsigned int uid, - std::span result) final {} - - ServerImpl::Connected3Func m_connected; - net3::WireConnection3& m_wire; - - enum State { kStateInitial, kStateServerHelloComplete, kStateRunning }; - State m_state{kStateInitial}; - net3::WireDecoder3 m_decoder; - - std::vector m_outgoing; - wpi::DenseMap m_outgoingValueMap; - int64_t m_nextPubUid{1}; - - struct TopicData3 { - explicit TopicData3(TopicData* topic) { UpdateFlags(topic); } - - unsigned int flags{0}; - net3::SequenceNumber seqNum; - bool sentAssign{false}; - bool published{false}; - int64_t pubuid{0}; - - bool UpdateFlags(TopicData* topic); - }; - wpi::DenseMap m_topics3; - TopicData3* GetTopic3(TopicData* topic) { - return &m_topics3.try_emplace(topic, topic).first->second; - } -}; - -struct TopicData { - TopicData(std::string_view name, std::string_view typeStr) - : name{name}, typeStr{typeStr} {} - TopicData(std::string_view name, std::string_view typeStr, - wpi::json properties) - : name{name}, typeStr{typeStr}, properties(std::move(properties)) { - RefreshProperties(); - } - - bool IsPublished() const { - return persistent || retained || !publishers.empty(); - } - - // returns true if properties changed - bool SetProperties(const wpi::json& update); - void RefreshProperties(); - bool SetFlags(unsigned int flags_); - - std::string name; - unsigned int id; - Value lastValue; - ClientData* lastValueClient = nullptr; - std::string typeStr; - wpi::json properties = wpi::json::object(); - bool persistent{false}; - bool retained{false}; - bool special{false}; - NT_Topic localHandle{0}; - - VectorSet publishers; - VectorSet subscribers; - - // meta topics - TopicData* metaPub = nullptr; - TopicData* metaSub = nullptr; -}; - -struct PublisherData { - PublisherData(ClientData* client, TopicData* topic, int64_t pubuid) - : client{client}, topic{topic}, pubuid{pubuid} {} - - ClientData* client; - TopicData* topic; - int64_t pubuid; -}; - -struct SubscriberData { - SubscriberData(ClientData* client, std::span topicNames, - int64_t subuid, const PubSubOptionsImpl& options) - : client{client}, - topicNames{topicNames.begin(), topicNames.end()}, - subuid{subuid}, - options{options}, - periodMs(std::lround(options.periodicMs / 10.0) * 10) { - if (periodMs < kMinPeriodMs) { - periodMs = kMinPeriodMs; - } - } - - void Update(std::span topicNames_, - const PubSubOptionsImpl& options_) { - topicNames = {topicNames_.begin(), topicNames_.end()}; - options = options_; - periodMs = std::lround(options_.periodicMs / 10.0) * 10; - if (periodMs < kMinPeriodMs) { - periodMs = kMinPeriodMs; - } - } - - bool Matches(std::string_view name, bool special); - - ClientData* client; - std::vector topicNames; - int64_t subuid; - PubSubOptionsImpl options; - // in options as double, but copy here as integer; rounded to the nearest - // 10 ms - uint32_t periodMs; -}; - -class SImpl { - public: - explicit SImpl(wpi::Logger& logger); - - wpi::Logger& m_logger; - LocalInterface* m_local{nullptr}; - bool m_controlReady{false}; - - ClientDataLocal* m_localClient; - std::vector> m_clients; - wpi::UidVector, 16> m_topics; - wpi::StringMap m_nameTopics; - bool m_persistentChanged{false}; - - // global meta topics (other meta topics are linked to from the specific - // client or topic) - TopicData* m_metaClients; - - // ServerImpl interface - std::pair AddClient( - std::string_view name, std::string_view connInfo, bool local, - WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic); - int AddClient3(std::string_view connInfo, bool local, - net3::WireConnection3& wire, - ServerImpl::Connected3Func connected, - ServerImpl::SetPeriodicFunc setPeriodic); - void RemoveClient(int clientId); - - bool PersistentChanged(); - void DumpPersistent(wpi::raw_ostream& os); - std::string LoadPersistent(std::string_view in); - - // helper functions - TopicData* CreateTopic(ClientData* client, std::string_view name, - std::string_view typeStr, const wpi::json& properties, - bool special = false); - TopicData* CreateMetaTopic(std::string_view name); - void DeleteTopic(TopicData* topic); - void SetProperties(ClientData* client, TopicData* topic, - const wpi::json& update); - void SetFlags(ClientData* client, TopicData* topic, unsigned int flags); - void SetValue(ClientData* client, TopicData* topic, const Value& value); - - // update meta topic values from data structures - void UpdateMetaClients(const std::vector& conns); - void UpdateMetaTopicPub(TopicData* topic); - void UpdateMetaTopicSub(TopicData* topic); - - private: - void PropertiesChanged(ClientData* client, TopicData* topic, - const wpi::json& update); -}; - struct Writer : public mpack_writer_t { Writer() { mpack_writer_init(this, buf, sizeof(buf)); @@ -477,7 +80,7 @@ static void WriteOptions(mpack_writer_t& w, const PubSubOptionsImpl& options) { mpack_finish_map(&w); } -void ClientData::UpdateMetaClientPub() { +void ServerImpl::ClientData::UpdateMetaClientPub() { if (!m_metaPub) { return; } @@ -497,7 +100,7 @@ void ClientData::UpdateMetaClientPub() { } } -void ClientData::UpdateMetaClientSub() { +void ServerImpl::ClientData::UpdateMetaClientSub() { if (!m_metaSub) { return; } @@ -523,7 +126,7 @@ void ClientData::UpdateMetaClientSub() { } } -std::span ClientData::GetSubscribers( +std::span ServerImpl::ClientData::GetSubscribers( std::string_view name, bool special, wpi::SmallVectorImpl& buf) { buf.resize(0); @@ -536,9 +139,10 @@ std::span ClientData::GetSubscribers( return {buf.data(), buf.size()}; } -void ClientData4Base::ClientPublish(int64_t pubuid, std::string_view name, - std::string_view typeStr, - const wpi::json& properties) { +void ServerImpl::ClientData4Base::ClientPublish(int64_t pubuid, + std::string_view name, + std::string_view typeStr, + const wpi::json& properties) { DEBUG3("ClientPublish({}, {}, {}, {})", m_id, name, pubuid, typeStr); auto topic = m_server.CreateTopic(this, name, typeStr, properties); @@ -546,7 +150,7 @@ void ClientData4Base::ClientPublish(int64_t pubuid, std::string_view name, auto [publisherIt, isNew] = m_publishers.try_emplace( pubuid, std::make_unique(this, topic, pubuid)); if (!isNew) { - WARNING("client {} duplicate publish of pubuid {}", m_id, pubuid); + WARN("client {} duplicate publish of pubuid {}", m_id, pubuid); } // add publisher to topic @@ -561,7 +165,7 @@ void ClientData4Base::ClientPublish(int64_t pubuid, std::string_view name, SendAnnounce(topic, pubuid); } -void ClientData4Base::ClientUnpublish(int64_t pubuid) { +void ServerImpl::ClientData4Base::ClientUnpublish(int64_t pubuid) { DEBUG3("ClientUnpublish({}, {})", m_id, pubuid); auto publisherIt = m_publishers.find(pubuid); if (publisherIt == m_publishers.end()) { @@ -586,13 +190,13 @@ void ClientData4Base::ClientUnpublish(int64_t pubuid) { } } -void ClientData4Base::ClientSetProperties(std::string_view name, - const wpi::json& update) { +void ServerImpl::ClientData4Base::ClientSetProperties(std::string_view name, + const wpi::json& update) { DEBUG4("ClientSetProperties({}, {}, {})", m_id, name, update.dump()); auto topicIt = m_server.m_nameTopics.find(name); if (topicIt == m_server.m_nameTopics.end() || !topicIt->second->IsPublished()) { - WARNING( + WARN( "server ignoring SetProperties({}) from client {} on unpublished topic " "'{}'; publish or set a value first", update.dump(), m_id, name); @@ -600,17 +204,16 @@ void ClientData4Base::ClientSetProperties(std::string_view name, } auto topic = topicIt->second; if (topic->special) { - WARNING( - "server ignoring SetProperties({}) from client {} on meta topic '{}'", - update.dump(), m_id, name); + WARN("server ignoring SetProperties({}) from client {} on meta topic '{}'", + update.dump(), m_id, name); return; // nothing to do } m_server.SetProperties(nullptr, topic, update); } -void ClientData4Base::ClientSubscribe(int64_t subuid, - std::span topicNames, - const PubSubOptionsImpl& options) { +void ServerImpl::ClientData4Base::ClientSubscribe( + int64_t subuid, std::span topicNames, + const PubSubOptionsImpl& options) { DEBUG4("ClientSubscribe({}, ({}), {})", m_id, fmt::join(topicNames, ","), subuid); auto& sub = m_subscribers[subuid]; @@ -700,7 +303,7 @@ void ClientData4Base::ClientSubscribe(int64_t subuid, Flush(); } -void ClientData4Base::ClientUnsubscribe(int64_t subuid) { +void ServerImpl::ClientData4Base::ClientUnsubscribe(int64_t subuid) { DEBUG3("ClientUnsubscribe({}, {})", m_id, subuid); auto subIt = m_subscribers.find(subuid); if (subIt == m_subscribers.end() || !subIt->getSecond()) { @@ -734,26 +337,27 @@ void ClientData4Base::ClientUnsubscribe(int64_t subuid) { m_setPeriodic(m_periodMs); } -void ClientData4Base::ClientSetValue(int64_t pubuid, const Value& value) { +void ServerImpl::ClientData4Base::ClientSetValue(int64_t pubuid, + const Value& value) { DEBUG4("ClientSetValue({}, {})", m_id, pubuid); auto publisherIt = m_publishers.find(pubuid); if (publisherIt == m_publishers.end()) { - WARNING("unrecognized client {} pubuid {}, ignoring set", m_id, pubuid); + WARN("unrecognized client {} pubuid {}, ignoring set", m_id, pubuid); return; // ignore unrecognized pubuids } auto topic = publisherIt->getSecond().get()->topic; m_server.SetValue(this, topic, value); } -void ClientDataLocal::SendValue(TopicData* topic, const Value& value, - SendMode mode) { +void ServerImpl::ClientDataLocal::SendValue(TopicData* topic, + const Value& value, SendMode mode) { if (m_server.m_local) { m_server.m_local->NetworkSetValue(topic->localHandle, value); } } -void ClientDataLocal::SendAnnounce(TopicData* topic, - std::optional pubuid) { +void ServerImpl::ClientDataLocal::SendAnnounce(TopicData* topic, + std::optional pubuid) { if (m_server.m_local) { auto& sent = m_announceSent[topic]; if (sent) { @@ -766,7 +370,7 @@ void ClientDataLocal::SendAnnounce(TopicData* topic, } } -void ClientDataLocal::SendUnannounce(TopicData* topic) { +void ServerImpl::ClientDataLocal::SendUnannounce(TopicData* topic) { if (m_server.m_local) { auto& sent = m_announceSent[topic]; if (!sent) { @@ -777,8 +381,9 @@ void ClientDataLocal::SendUnannounce(TopicData* topic) { } } -void ClientDataLocal::SendPropertiesUpdate(TopicData* topic, - const wpi::json& update, bool ack) { +void ServerImpl::ClientDataLocal::SendPropertiesUpdate(TopicData* topic, + const wpi::json& update, + bool ack) { if (m_server.m_local) { if (!m_announceSent.lookup(topic)) { return; @@ -787,7 +392,8 @@ void ClientDataLocal::SendPropertiesUpdate(TopicData* topic, } } -void ClientDataLocal::HandleLocal(std::span msgs) { +void ServerImpl::ClientDataLocal::HandleLocal( + std::span msgs) { DEBUG4("HandleLocal()"); // just map as a normal client into client=0 calls for (const auto& elem : msgs) { // NOLINT @@ -808,11 +414,12 @@ void ClientDataLocal::HandleLocal(std::span msgs) { } } -void ClientData4::ProcessIncomingText(std::string_view data) { +void ServerImpl::ClientData4::ProcessIncomingText(std::string_view data) { WireDecodeText(data, *this, m_logger); } -void ClientData4::ProcessIncomingBinary(std::span data) { +void ServerImpl::ClientData4::ProcessIncomingBinary( + std::span data) { for (;;) { if (data.empty()) { break; @@ -844,8 +451,8 @@ void ClientData4::ProcessIncomingBinary(std::span data) { } } -void ClientData4::SendValue(TopicData* topic, const Value& value, - SendMode mode) { +void ServerImpl::ClientData4::SendValue(TopicData* topic, const Value& value, + SendMode mode) { if (m_local) { mode = ClientData::kSendImmNoFlush; // always send local immediately } @@ -881,8 +488,8 @@ void ClientData4::SendValue(TopicData* topic, const Value& value, } } -void ClientData4::SendAnnounce(TopicData* topic, - std::optional pubuid) { +void ServerImpl::ClientData4::SendAnnounce(TopicData* topic, + std::optional pubuid) { auto& sent = m_announceSent[topic]; if (sent) { return; @@ -900,7 +507,7 @@ void ClientData4::SendAnnounce(TopicData* topic, } } -void ClientData4::SendUnannounce(TopicData* topic) { +void ServerImpl::ClientData4::SendUnannounce(TopicData* topic) { auto& sent = m_announceSent[topic]; if (!sent) { return; @@ -917,8 +524,9 @@ void ClientData4::SendUnannounce(TopicData* topic) { } } -void ClientData4::SendPropertiesUpdate(TopicData* topic, - const wpi::json& update, bool ack) { +void ServerImpl::ClientData4::SendPropertiesUpdate(TopicData* topic, + const wpi::json& update, + bool ack) { if (!m_announceSent.lookup(topic)) { return; } @@ -933,7 +541,7 @@ void ClientData4::SendPropertiesUpdate(TopicData* topic, } } -void ClientData4::SendOutgoing(uint64_t curTimeMs) { +void ServerImpl::ClientData4::SendOutgoing(uint64_t curTimeMs) { if (m_outgoing.empty()) { return; // nothing to do } @@ -964,27 +572,28 @@ void ClientData4::SendOutgoing(uint64_t curTimeMs) { m_lastSendMs = curTimeMs; } -void ClientData4::Flush() { +void ServerImpl::ClientData4::Flush() { m_outText.reset(); m_outBinary.reset(); m_wire.Flush(); } -bool ClientData3::TopicData3::UpdateFlags(TopicData* topic) { +bool ServerImpl::ClientData3::TopicData3::UpdateFlags(TopicData* topic) { unsigned int newFlags = topic->persistent ? NT_PERSISTENT : 0; bool updated = flags != newFlags; flags = newFlags; return updated; } -void ClientData3::ProcessIncomingBinary(std::span data) { +void ServerImpl::ClientData3::ProcessIncomingBinary( + std::span data) { if (!m_decoder.Execute(&data)) { m_wire.Disconnect(m_decoder.GetError()); } } -void ClientData3::SendValue(TopicData* topic, const Value& value, - SendMode mode) { +void ServerImpl::ClientData3::SendValue(TopicData* topic, const Value& value, + SendMode mode) { if (m_state != kStateRunning) { if (mode == kSendImmNoFlush) { mode = kSendAll; @@ -1048,8 +657,8 @@ void ClientData3::SendValue(TopicData* topic, const Value& value, } } -void ClientData3::SendAnnounce(TopicData* topic, - std::optional pubuid) { +void ServerImpl::ClientData3::SendAnnounce(TopicData* topic, + std::optional pubuid) { // ignore if we've not yet built the subscriber if (m_subscribers.empty()) { return; @@ -1065,7 +674,7 @@ void ClientData3::SendAnnounce(TopicData* topic, // will get sent when the first value is sent (by SendValue). } -void ClientData3::SendUnannounce(TopicData* topic) { +void ServerImpl::ClientData3::SendUnannounce(TopicData* topic) { auto it = m_topics3.find(topic); if (it == m_topics3.end()) { return; // never sent to client @@ -1085,8 +694,9 @@ void ClientData3::SendUnannounce(TopicData* topic) { } } -void ClientData3::SendPropertiesUpdate(TopicData* topic, - const wpi::json& update, bool ack) { +void ServerImpl::ClientData3::SendPropertiesUpdate(TopicData* topic, + const wpi::json& update, + bool ack) { if (ack) { return; // we don't ack in NT3 } @@ -1110,7 +720,7 @@ void ClientData3::SendPropertiesUpdate(TopicData* topic, } } -void ClientData3::SendOutgoing(uint64_t curTimeMs) { +void ServerImpl::ClientData3::SendOutgoing(uint64_t curTimeMs) { if (m_outgoing.empty() || m_state != kStateRunning) { return; // nothing to do } @@ -1138,7 +748,7 @@ void ClientData3::SendOutgoing(uint64_t curTimeMs) { m_lastSendMs = curTimeMs; } -void ClientData3::KeepAlive() { +void ServerImpl::ClientData3::KeepAlive() { DEBUG4("KeepAlive({})", m_id); if (m_state != kStateRunning) { m_decoder.SetError("received unexpected KeepAlive message"); @@ -1147,12 +757,12 @@ void ClientData3::KeepAlive() { // ignore } -void ClientData3::ServerHelloDone() { +void ServerImpl::ClientData3::ServerHelloDone() { DEBUG4("ServerHelloDone({})", m_id); m_decoder.SetError("received unexpected ServerHelloDone message"); } -void ClientData3::ClientHelloDone() { +void ServerImpl::ClientData3::ClientHelloDone() { DEBUG4("ClientHelloDone({})", m_id); if (m_state != kStateServerHelloComplete) { m_decoder.SetError("received unexpected ClientHelloDone message"); @@ -1161,7 +771,7 @@ void ClientData3::ClientHelloDone() { m_state = kStateRunning; } -void ClientData3::ClearEntries() { +void ServerImpl::ClientData3::ClearEntries() { DEBUG4("ClearEntries({})", m_id); if (m_state != kStateRunning) { m_decoder.SetError("received unexpected ClearEntries message"); @@ -1196,13 +806,13 @@ void ClientData3::ClearEntries() { } } -void ClientData3::ProtoUnsup(unsigned int proto_rev) { +void ServerImpl::ClientData3::ProtoUnsup(unsigned int proto_rev) { DEBUG4("ProtoUnsup({})", m_id); m_decoder.SetError("received unexpected ProtoUnsup message"); } -void ClientData3::ClientHello(std::string_view self_id, - unsigned int proto_rev) { +void ServerImpl::ClientData3::ClientHello(std::string_view self_id, + unsigned int proto_rev) { DEBUG4("ClientHello({}, '{}', {:04x})", m_id, self_id, proto_rev); if (m_state != kStateInitial) { m_decoder.SetError("received unexpected ClientHello message"); @@ -1266,14 +876,16 @@ void ClientData3::ClientHello(std::string_view self_id, UpdateMetaClientSub(); } -void ClientData3::ServerHello(unsigned int flags, std::string_view self_id) { +void ServerImpl::ClientData3::ServerHello(unsigned int flags, + std::string_view self_id) { DEBUG4("ServerHello({}, {}, {})", m_id, flags, self_id); m_decoder.SetError("received unexpected ServerHello message"); } -void ClientData3::EntryAssign(std::string_view name, unsigned int id, - unsigned int seq_num, const Value& value, - unsigned int flags) { +void ServerImpl::ClientData3::EntryAssign(std::string_view name, + unsigned int id, unsigned int seq_num, + const Value& value, + unsigned int flags) { DEBUG4("EntryAssign({}, {}, {}, {}, {})", m_id, id, seq_num, static_cast(value.type()), flags); if (id != 0xffff) { @@ -1293,7 +905,7 @@ void ClientData3::EntryAssign(std::string_view name, unsigned int id, auto topic = m_server.CreateTopic(this, name, typeStr, properties); TopicData3* topic3 = GetTopic3(topic); if (topic3->published || topic3->sentAssign) { - WARNING("ignoring client {} duplicate publish of '{}'", m_id, name); + WARN("ignoring client {} duplicate publish of '{}'", m_id, name); return; } ++topic3->seqNum; @@ -1330,8 +942,8 @@ void ClientData3::EntryAssign(std::string_view name, unsigned int id, } } -void ClientData3::EntryUpdate(unsigned int id, unsigned int seq_num, - const Value& value) { +void ServerImpl::ClientData3::EntryUpdate(unsigned int id, unsigned int seq_num, + const Value& value) { DEBUG4("EntryUpdate({}, {}, {}, {})", m_id, id, seq_num, static_cast(value.type())); if (m_state != kStateRunning) { @@ -1372,7 +984,7 @@ void ClientData3::EntryUpdate(unsigned int id, unsigned int seq_num, m_server.SetValue(this, topic, value); } -void ClientData3::FlagsUpdate(unsigned int id, unsigned int flags) { +void ServerImpl::ClientData3::FlagsUpdate(unsigned int id, unsigned int flags) { DEBUG4("FlagsUpdate({}, {}, {})", m_id, id, flags); if (m_state != kStateRunning) { m_decoder.SetError("received unexpected FlagsUpdate message"); @@ -1394,7 +1006,7 @@ void ClientData3::FlagsUpdate(unsigned int id, unsigned int flags) { m_server.SetFlags(this, topic, flags); } -void ClientData3::EntryDelete(unsigned int id) { +void ServerImpl::ClientData3::EntryDelete(unsigned int id) { DEBUG4("EntryDelete({}, {})", m_id, id); if (m_state != kStateRunning) { m_decoder.SetError("received unexpected EntryDelete message"); @@ -1441,7 +1053,7 @@ void ClientData3::EntryDelete(unsigned int id) { m_server.SetProperties(this, topic, {{"retained", false}}); } -bool TopicData::SetProperties(const wpi::json& update) { +bool ServerImpl::TopicData::SetProperties(const wpi::json& update) { if (!update.is_object()) { return false; } @@ -1460,7 +1072,7 @@ bool TopicData::SetProperties(const wpi::json& update) { return updated; } -void TopicData::RefreshProperties() { +void ServerImpl::TopicData::RefreshProperties() { persistent = false; retained = false; @@ -1479,7 +1091,7 @@ void TopicData::RefreshProperties() { } } -bool TopicData::SetFlags(unsigned int flags_) { +bool ServerImpl::TopicData::SetFlags(unsigned int flags_) { bool updated; if ((flags_ & NT_PERSISTENT) != 0) { updated = !persistent; @@ -1493,7 +1105,7 @@ bool TopicData::SetFlags(unsigned int flags_) { return updated; } -bool SubscriberData::Matches(std::string_view name, bool special) { +bool ServerImpl::SubscriberData::Matches(std::string_view name, bool special) { for (auto&& topicName : topicNames) { if ((!options.prefixMatch && name == topicName) || (options.prefixMatch && (!special || !topicName.empty()) && @@ -1504,13 +1116,13 @@ bool SubscriberData::Matches(std::string_view name, bool special) { return false; } -SImpl::SImpl(wpi::Logger& logger) : m_logger{logger} { +ServerImpl::ServerImpl(wpi::Logger& logger) : m_logger{logger} { // local is client 0 m_clients.emplace_back(std::make_unique(*this, 0, logger)); m_localClient = static_cast(m_clients.back().get()); } -std::pair SImpl::AddClient( +std::pair ServerImpl::AddClient( std::string_view name, std::string_view connInfo, bool local, WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic) { if (name.empty()) { @@ -1553,10 +1165,10 @@ std::pair SImpl::AddClient( return {std::move(dedupName), index}; } -int SImpl::AddClient3(std::string_view connInfo, bool local, - net3::WireConnection3& wire, - ServerImpl::Connected3Func connected, - ServerImpl::SetPeriodicFunc setPeriodic) { +int ServerImpl::AddClient3(std::string_view connInfo, bool local, + net3::WireConnection3& wire, + ServerImpl::Connected3Func connected, + ServerImpl::SetPeriodicFunc setPeriodic) { size_t index = m_clients.size(); // find an empty slot; we can't check for duplicates until we get a hello. // just do a linear search as number of clients is typically small (<10) @@ -1578,7 +1190,7 @@ int SImpl::AddClient3(std::string_view connInfo, bool local, return index; } -void SImpl::RemoveClient(int clientId) { +void ServerImpl::RemoveClient(int clientId) { DEBUG3("RemoveClient({})", clientId); auto& client = m_clients[clientId]; @@ -1620,7 +1232,7 @@ void SImpl::RemoveClient(int clientId) { client.reset(); } -bool SImpl::PersistentChanged() { +bool ServerImpl::PersistentChanged() { bool rv = m_persistentChanged; m_persistentChanged = false; return rv; @@ -1738,7 +1350,7 @@ static void DumpValue(wpi::raw_ostream& os, const Value& value, } } -void SImpl::DumpPersistent(wpi::raw_ostream& os) { +void ServerImpl::DumpPersistent(wpi::raw_ostream& os) { wpi::json::serializer s{os, ' ', 16}; os << "[\n"; bool first = true; @@ -1778,7 +1390,7 @@ static std::string* ObjGetString(wpi::json::object_t& obj, std::string_view key, return val; } -std::string SImpl::LoadPersistent(std::string_view in) { +std::string ServerImpl::LoadPersistent(std::string_view in) { if (in.empty()) { return {}; } @@ -1999,15 +1611,17 @@ std::string SImpl::LoadPersistent(std::string_view in) { return allerrors; } -TopicData* SImpl::CreateTopic(ClientData* client, std::string_view name, - std::string_view typeStr, - const wpi::json& properties, bool special) { +ServerImpl::TopicData* ServerImpl::CreateTopic(ClientData* client, + std::string_view name, + std::string_view typeStr, + const wpi::json& properties, + bool special) { auto& topic = m_nameTopics[name]; if (topic) { if (typeStr != topic->typeStr) { if (client) { - WARNING("client {} publish '{}' conflicting type '{}' (currently '{}')", - client->GetName(), name, typeStr, topic->typeStr); + WARN("client {} publish '{}' conflicting type '{}' (currently '{}')", + client->GetName(), name, typeStr, topic->typeStr); } } } else { @@ -2056,11 +1670,11 @@ TopicData* SImpl::CreateTopic(ClientData* client, std::string_view name, return topic; } -TopicData* SImpl::CreateMetaTopic(std::string_view name) { +ServerImpl::TopicData* ServerImpl::CreateMetaTopic(std::string_view name) { return CreateTopic(nullptr, name, "msgpack", {{"retained", true}}, true); } -void SImpl::DeleteTopic(TopicData* topic) { +void ServerImpl::DeleteTopic(TopicData* topic) { if (!topic) { return; } @@ -2093,8 +1707,8 @@ void SImpl::DeleteTopic(TopicData* topic) { m_topics.erase(topic->id); } -void SImpl::SetProperties(ClientData* client, TopicData* topic, - const wpi::json& update) { +void ServerImpl::SetProperties(ClientData* client, TopicData* topic, + const wpi::json& update) { DEBUG4("SetProperties({}, {}, {})", client ? client->GetId() : -1, topic->name, update.dump()); bool wasPersistent = topic->persistent; @@ -2107,7 +1721,8 @@ void SImpl::SetProperties(ClientData* client, TopicData* topic, } } -void SImpl::SetFlags(ClientData* client, TopicData* topic, unsigned int flags) { +void ServerImpl::SetFlags(ClientData* client, TopicData* topic, + unsigned int flags) { bool wasPersistent = topic->persistent; if (topic->SetFlags(flags)) { // update persistentChanged flag @@ -2124,7 +1739,8 @@ void SImpl::SetFlags(ClientData* client, TopicData* topic, unsigned int flags) { } } -void SImpl::SetValue(ClientData* client, TopicData* topic, const Value& value) { +void ServerImpl::SetValue(ClientData* client, TopicData* topic, + const Value& value) { // update retained value if from same client or timestamp newer if (!topic->lastValue || topic->lastValueClient == client || topic->lastValue.time() == 0 || value.time() >= topic->lastValue.time()) { @@ -2169,7 +1785,7 @@ void SImpl::SetValue(ClientData* client, TopicData* topic, const Value& value) { } } -void SImpl::UpdateMetaClients(const std::vector& conns) { +void ServerImpl::UpdateMetaClients(const std::vector& conns) { Writer w; mpack_start_array(&w, conns.size()); for (auto&& conn : conns) { @@ -2190,7 +1806,7 @@ void SImpl::UpdateMetaClients(const std::vector& conns) { } } -void SImpl::UpdateMetaTopicPub(TopicData* topic) { +void ServerImpl::UpdateMetaTopicPub(TopicData* topic) { if (!topic->metaPub) { return; } @@ -2214,7 +1830,7 @@ void SImpl::UpdateMetaTopicPub(TopicData* topic) { } } -void SImpl::UpdateMetaTopicSub(TopicData* topic) { +void ServerImpl::UpdateMetaTopicSub(TopicData* topic) { if (!topic->metaSub) { return; } @@ -2240,8 +1856,8 @@ void SImpl::UpdateMetaTopicSub(TopicData* topic) { } } -void SImpl::PropertiesChanged(ClientData* client, TopicData* topic, - const wpi::json& update) { +void ServerImpl::PropertiesChanged(ClientData* client, TopicData* topic, + const wpi::json& update) { // removing some properties can result in the topic being unpublished if (!topic->IsPublished()) { DeleteTopic(topic); @@ -2263,23 +1879,13 @@ void SImpl::PropertiesChanged(ClientData* client, TopicData* topic, } } -class ServerImpl::Impl final : public SImpl { - public: - explicit Impl(wpi::Logger& logger) : SImpl{logger} {} -}; - -ServerImpl::ServerImpl(wpi::Logger& logger) - : m_impl{std::make_unique(logger)} {} - -ServerImpl::~ServerImpl() = default; - void ServerImpl::SendControl(uint64_t curTimeMs) { - if (!m_impl->m_controlReady) { + if (!m_controlReady) { return; } - m_impl->m_controlReady = false; + m_controlReady = false; - for (auto&& client : m_impl->m_clients) { + for (auto&& client : m_clients) { if (client) { // to ensure ordering, just send everything client->SendOutgoing(curTimeMs); @@ -2289,7 +1895,7 @@ void ServerImpl::SendControl(uint64_t curTimeMs) { } void ServerImpl::SendValues(int clientId, uint64_t curTimeMs) { - if (auto client = m_impl->m_clients[clientId].get()) { + if (auto client = m_clients[clientId].get()) { client->SendOutgoing(curTimeMs); client->Flush(); } @@ -2297,70 +1903,42 @@ void ServerImpl::SendValues(int clientId, uint64_t curTimeMs) { void ServerImpl::HandleLocal(std::span msgs) { // just map as a normal client into client=0 calls - m_impl->m_localClient->HandleLocal(msgs); + m_localClient->HandleLocal(msgs); } void ServerImpl::SetLocal(LocalInterface* local) { - WPI_DEBUG4(m_impl->m_logger, "SetLocal()"); - m_impl->m_local = local; + DEBUG4("SetLocal()"); + m_local = local; // create server meta topics - m_impl->m_metaClients = m_impl->CreateMetaTopic("$clients"); + m_metaClients = CreateMetaTopic("$clients"); // create local client meta topics - m_impl->m_localClient->m_metaPub = m_impl->CreateMetaTopic("$serverpub"); - m_impl->m_localClient->m_metaSub = m_impl->CreateMetaTopic("$serversub"); + m_localClient->m_metaPub = CreateMetaTopic("$serverpub"); + m_localClient->m_metaSub = CreateMetaTopic("$serversub"); // update meta topics - m_impl->m_localClient->UpdateMetaClientPub(); - m_impl->m_localClient->UpdateMetaClientSub(); + m_localClient->UpdateMetaClientPub(); + m_localClient->UpdateMetaClientSub(); } void ServerImpl::ProcessIncomingText(int clientId, std::string_view data) { - m_impl->m_clients[clientId]->ProcessIncomingText(data); + m_clients[clientId]->ProcessIncomingText(data); } void ServerImpl::ProcessIncomingBinary(int clientId, std::span data) { - m_impl->m_clients[clientId]->ProcessIncomingBinary(data); -} - -std::pair ServerImpl::AddClient(std::string_view name, - std::string_view connInfo, - bool local, - WireConnection& wire, - SetPeriodicFunc setPeriodic) { - return m_impl->AddClient(name, connInfo, local, wire, std::move(setPeriodic)); -} - -int ServerImpl::AddClient3(std::string_view connInfo, bool local, - net3::WireConnection3& wire, - Connected3Func connected, - SetPeriodicFunc setPeriodic) { - return m_impl->AddClient3(connInfo, local, wire, std::move(connected), - std::move(setPeriodic)); -} - -void ServerImpl::RemoveClient(int clientId) { - m_impl->RemoveClient(clientId); + m_clients[clientId]->ProcessIncomingBinary(data); } void ServerImpl::ConnectionsChanged(const std::vector& conns) { - m_impl->UpdateMetaClients(conns); -} - -bool ServerImpl::PersistentChanged() { - return m_impl->PersistentChanged(); + UpdateMetaClients(conns); } std::string ServerImpl::DumpPersistent() { std::string rv; wpi::raw_string_ostream os{rv}; - m_impl->DumpPersistent(os); + DumpPersistent(os); os.flush(); return rv; } - -std::string ServerImpl::LoadPersistent(std::string_view in) { - return m_impl->LoadPersistent(in); -} diff --git a/ntcore/src/main/native/cpp/net/ServerImpl.h b/ntcore/src/main/native/cpp/net/ServerImpl.h index 86607e9bd9..b1f367e33b 100644 --- a/ntcore/src/main/native/cpp/net/ServerImpl.h +++ b/ntcore/src/main/native/cpp/net/ServerImpl.h @@ -6,6 +6,7 @@ #include +#include #include #include #include @@ -14,11 +15,28 @@ #include #include +#include +#include +#include +#include + +#include "Message.h" #include "NetworkInterface.h" +#include "PubSubOptions.h" +#include "VectorSet.h" +#include "WireConnection.h" +#include "WireDecoder.h" +#include "WireEncoder.h" +#include "net3/Message3.h" +#include "net3/SequenceNumber.h" #include "net3/WireConnection3.h" +#include "net3/WireDecoder3.h" namespace wpi { class Logger; +template +class SmallVectorImpl; +class raw_ostream; } // namespace wpi namespace nt::net3 { @@ -38,7 +56,6 @@ class ServerImpl final { std::function; explicit ServerImpl(wpi::Logger& logger); - ~ServerImpl(); void SendControl(uint64_t curTimeMs); void SendValues(int clientId, uint64_t curTimeMs); @@ -69,8 +86,357 @@ class ServerImpl final { std::string LoadPersistent(std::string_view in); private: - class Impl; - std::unique_ptr m_impl; + static constexpr uint32_t kMinPeriodMs = 5; + + struct PublisherData; + struct SubscriberData; + struct TopicData; + + class ClientData { + public: + ClientData(std::string_view name, std::string_view connInfo, bool local, + ServerImpl::SetPeriodicFunc setPeriodic, ServerImpl& server, + int id, wpi::Logger& logger) + : m_name{name}, + m_connInfo{connInfo}, + m_local{local}, + m_setPeriodic{std::move(setPeriodic)}, + m_server{server}, + m_id{id}, + m_logger{logger} {} + virtual ~ClientData() = default; + + virtual void ProcessIncomingText(std::string_view data) = 0; + virtual void ProcessIncomingBinary(std::span data) = 0; + + enum SendMode { kSendDisabled = 0, kSendAll, kSendNormal, kSendImmNoFlush }; + + virtual void SendValue(TopicData* topic, const Value& value, + SendMode mode) = 0; + virtual void SendAnnounce(TopicData* topic, + std::optional pubuid) = 0; + virtual void SendUnannounce(TopicData* topic) = 0; + virtual void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, + bool ack) = 0; + virtual void SendOutgoing(uint64_t curTimeMs) = 0; + virtual void Flush() = 0; + + void UpdateMetaClientPub(); + void UpdateMetaClientSub(); + + std::span GetSubscribers( + std::string_view name, bool special, + wpi::SmallVectorImpl& buf); + + std::string_view GetName() const { return m_name; } + int GetId() const { return m_id; } + + protected: + std::string m_name; + std::string m_connInfo; + bool m_local; // local to machine + ServerImpl::SetPeriodicFunc m_setPeriodic; + // TODO: make this per-topic? + uint32_t m_periodMs{UINT32_MAX}; + uint64_t m_lastSendMs{0}; + ServerImpl& m_server; + int m_id; + + wpi::Logger& m_logger; + + wpi::DenseMap> m_publishers; + wpi::DenseMap> m_subscribers; + + public: + // meta topics + TopicData* m_metaPub = nullptr; + TopicData* m_metaSub = nullptr; + }; + + class ClientData4Base : public ClientData, protected ClientMessageHandler { + public: + ClientData4Base(std::string_view name, std::string_view connInfo, + bool local, ServerImpl::SetPeriodicFunc setPeriodic, + ServerImpl& server, int id, wpi::Logger& logger) + : ClientData{name, connInfo, local, setPeriodic, server, id, logger} {} + + protected: + // ClientMessageHandler interface + void ClientPublish(int64_t pubuid, std::string_view name, + std::string_view typeStr, + const wpi::json& properties) final; + void ClientUnpublish(int64_t pubuid) final; + void ClientSetProperties(std::string_view name, + const wpi::json& update) final; + void ClientSubscribe(int64_t subuid, + std::span topicNames, + const PubSubOptionsImpl& options) final; + void ClientUnsubscribe(int64_t subuid) final; + + void ClientSetValue(int64_t pubuid, const Value& value); + + wpi::DenseMap m_announceSent; + }; + + class ClientDataLocal final : public ClientData4Base { + public: + ClientDataLocal(ServerImpl& server, int id, wpi::Logger& logger) + : ClientData4Base{"", "", true, [](uint32_t) {}, server, id, logger} {} + + void ProcessIncomingText(std::string_view data) final {} + void ProcessIncomingBinary(std::span data) final {} + + void SendValue(TopicData* topic, const Value& value, SendMode mode) final; + void SendAnnounce(TopicData* topic, std::optional pubuid) final; + void SendUnannounce(TopicData* topic) final; + void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, + bool ack) final; + void SendOutgoing(uint64_t curTimeMs) final {} + void Flush() final {} + + void HandleLocal(std::span msgs); + }; + + class ClientData4 final : public ClientData4Base { + public: + ClientData4(std::string_view name, std::string_view connInfo, bool local, + WireConnection& wire, ServerImpl::SetPeriodicFunc setPeriodic, + ServerImpl& server, int id, wpi::Logger& logger) + : ClientData4Base{name, connInfo, local, setPeriodic, + server, id, logger}, + m_wire{wire} {} + + void ProcessIncomingText(std::string_view data) final; + void ProcessIncomingBinary(std::span data) final; + + void SendValue(TopicData* topic, const Value& value, SendMode mode) final; + void SendAnnounce(TopicData* topic, std::optional pubuid) final; + void SendUnannounce(TopicData* topic) final; + void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, + bool ack) final; + void SendOutgoing(uint64_t curTimeMs) final; + + void Flush() final; + + public: + WireConnection& m_wire; + + private: + std::vector m_outgoing; + wpi::DenseMap m_outgoingValueMap; + + bool WriteBinary(int64_t id, int64_t time, const Value& value) { + return WireEncodeBinary(SendBinary().Add(), id, time, value); + } + + TextWriter& SendText() { + m_outBinary.reset(); // ensure proper interleaving of text and binary + if (!m_outText) { + m_outText = m_wire.SendText(); + } + return *m_outText; + } + + BinaryWriter& SendBinary() { + m_outText.reset(); // ensure proper interleaving of text and binary + if (!m_outBinary) { + m_outBinary = m_wire.SendBinary(); + } + return *m_outBinary; + } + + // valid when we are actively writing to this client + std::optional m_outText; + std::optional m_outBinary; + }; + + class ClientData3 final : public ClientData, private net3::MessageHandler3 { + public: + ClientData3(std::string_view connInfo, bool local, + net3::WireConnection3& wire, + ServerImpl::Connected3Func connected, + ServerImpl::SetPeriodicFunc setPeriodic, ServerImpl& server, + int id, wpi::Logger& logger) + : ClientData{"", connInfo, local, setPeriodic, server, id, logger}, + m_connected{std::move(connected)}, + m_wire{wire}, + m_decoder{*this} {} + + void ProcessIncomingText(std::string_view data) final {} + void ProcessIncomingBinary(std::span data) final; + + void SendValue(TopicData* topic, const Value& value, SendMode mode) final; + void SendAnnounce(TopicData* topic, std::optional pubuid) final; + void SendUnannounce(TopicData* topic) final; + void SendPropertiesUpdate(TopicData* topic, const wpi::json& update, + bool ack) final; + void SendOutgoing(uint64_t curTimeMs) final; + + void Flush() final { m_wire.Flush(); } + + private: + // MessageHandler3 interface + void KeepAlive() final; + void ServerHelloDone() final; + void ClientHelloDone() final; + void ClearEntries() final; + void ProtoUnsup(unsigned int proto_rev) final; + void ClientHello(std::string_view self_id, unsigned int proto_rev) final; + void ServerHello(unsigned int flags, std::string_view self_id) final; + void EntryAssign(std::string_view name, unsigned int id, + unsigned int seq_num, const Value& value, + unsigned int flags) final; + void EntryUpdate(unsigned int id, unsigned int seq_num, + const Value& value) final; + void FlagsUpdate(unsigned int id, unsigned int flags) final; + void EntryDelete(unsigned int id) final; + void ExecuteRpc(unsigned int id, unsigned int uid, + std::span params) final {} + void RpcResponse(unsigned int id, unsigned int uid, + std::span result) final {} + + ServerImpl::Connected3Func m_connected; + net3::WireConnection3& m_wire; + + enum State { kStateInitial, kStateServerHelloComplete, kStateRunning }; + State m_state{kStateInitial}; + net3::WireDecoder3 m_decoder; + + std::vector m_outgoing; + wpi::DenseMap m_outgoingValueMap; + int64_t m_nextPubUid{1}; + + struct TopicData3 { + explicit TopicData3(TopicData* topic) { UpdateFlags(topic); } + + unsigned int flags{0}; + net3::SequenceNumber seqNum; + bool sentAssign{false}; + bool published{false}; + int64_t pubuid{0}; + + bool UpdateFlags(TopicData* topic); + }; + wpi::DenseMap m_topics3; + TopicData3* GetTopic3(TopicData* topic) { + return &m_topics3.try_emplace(topic, topic).first->second; + } + }; + + struct TopicData { + TopicData(std::string_view name, std::string_view typeStr) + : name{name}, typeStr{typeStr} {} + TopicData(std::string_view name, std::string_view typeStr, + wpi::json properties) + : name{name}, typeStr{typeStr}, properties(std::move(properties)) { + RefreshProperties(); + } + + bool IsPublished() const { + return persistent || retained || !publishers.empty(); + } + + // returns true if properties changed + bool SetProperties(const wpi::json& update); + void RefreshProperties(); + bool SetFlags(unsigned int flags_); + + std::string name; + unsigned int id; + Value lastValue; + ClientData* lastValueClient = nullptr; + std::string typeStr; + wpi::json properties = wpi::json::object(); + bool persistent{false}; + bool retained{false}; + bool special{false}; + NT_Topic localHandle{0}; + + VectorSet publishers; + VectorSet subscribers; + + // meta topics + TopicData* metaPub = nullptr; + TopicData* metaSub = nullptr; + }; + + struct PublisherData { + PublisherData(ClientData* client, TopicData* topic, int64_t pubuid) + : client{client}, topic{topic}, pubuid{pubuid} {} + + ClientData* client; + TopicData* topic; + int64_t pubuid; + }; + + struct SubscriberData { + SubscriberData(ClientData* client, std::span topicNames, + int64_t subuid, const PubSubOptionsImpl& options) + : client{client}, + topicNames{topicNames.begin(), topicNames.end()}, + subuid{subuid}, + options{options}, + periodMs(std::lround(options.periodicMs / 10.0) * 10) { + if (periodMs < kMinPeriodMs) { + periodMs = kMinPeriodMs; + } + } + + void Update(std::span topicNames_, + const PubSubOptionsImpl& options_) { + topicNames = {topicNames_.begin(), topicNames_.end()}; + options = options_; + periodMs = std::lround(options_.periodicMs / 10.0) * 10; + if (periodMs < kMinPeriodMs) { + periodMs = kMinPeriodMs; + } + } + + bool Matches(std::string_view name, bool special); + + ClientData* client; + std::vector topicNames; + int64_t subuid; + PubSubOptionsImpl options; + // in options as double, but copy here as integer; rounded to the nearest + // 10 ms + uint32_t periodMs; + }; + + wpi::Logger& m_logger; + LocalInterface* m_local{nullptr}; + bool m_controlReady{false}; + + ClientDataLocal* m_localClient; + std::vector> m_clients; + wpi::UidVector, 16> m_topics; + wpi::StringMap m_nameTopics; + bool m_persistentChanged{false}; + + // global meta topics (other meta topics are linked to from the specific + // client or topic) + TopicData* m_metaClients; + + void DumpPersistent(wpi::raw_ostream& os); + + // helper functions + TopicData* CreateTopic(ClientData* client, std::string_view name, + std::string_view typeStr, const wpi::json& properties, + bool special = false); + TopicData* CreateMetaTopic(std::string_view name); + void DeleteTopic(TopicData* topic); + void SetProperties(ClientData* client, TopicData* topic, + const wpi::json& update); + void SetFlags(ClientData* client, TopicData* topic, unsigned int flags); + void SetValue(ClientData* client, TopicData* topic, const Value& value); + + // update meta topic values from data structures + void UpdateMetaClients(const std::vector& conns); + void UpdateMetaTopicPub(TopicData* topic); + void UpdateMetaTopicSub(TopicData* topic); + + void PropertiesChanged(ClientData* client, TopicData* topic, + const wpi::json& update); }; } // namespace nt::net diff --git a/ntcore/src/main/native/cpp/net3/ClientImpl3.cpp b/ntcore/src/main/native/cpp/net3/ClientImpl3.cpp index 179b577857..a65dd89c87 100644 --- a/ntcore/src/main/native/cpp/net3/ClientImpl3.cpp +++ b/ntcore/src/main/native/cpp/net3/ClientImpl3.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include "Handle.h" @@ -20,10 +19,6 @@ #include "Types_internal.h" #include "net/Message.h" #include "net/NetworkInterface.h" -#include "net3/Message3.h" -#include "net3/SequenceNumber.h" -#include "net3/WireConnection3.h" -#include "net3/WireDecoder3.h" #include "net3/WireEncoder3.h" #include "networktables/NetworkTableValue.h" @@ -36,145 +31,7 @@ static constexpr uint32_t kMinPeriodMs = 5; // transmission before we close the connection static constexpr uint32_t kWireMaxNotReadyUs = 1000000; -namespace { - -struct Entry; - -struct PublisherData { - explicit PublisherData(Entry* entry) : entry{entry} {} - - Entry* entry; - NT_Publisher handle; - PubSubOptionsImpl options; - // in options as double, but copy here as integer; rounded to the nearest - // 10 ms - uint32_t periodMs; - uint64_t nextSendMs{0}; - std::vector outValues; // outgoing values -}; - -// data for each entry -struct Entry { - explicit Entry(std::string_view name_) : name(name_) {} - bool IsPersistent() const { return (flags & NT_PERSISTENT) != 0; } - wpi::json SetFlags(unsigned int flags_); - - std::string name; - - std::string typeStr; - NT_Type type{NT_UNASSIGNED}; - - wpi::json properties = wpi::json::object(); - - // The current value and flags - Value value; - unsigned int flags{0}; - - // Unique ID used in network messages; this is 0xffff until assigned - // by the server. - unsigned int id{0xffff}; - - // Sequence number for update resolution - SequenceNumber seqNum; - - // Local topic handle - NT_Topic topic{0}; - - // Local publishers - std::vector publishers; -}; - -class CImpl : public MessageHandler3 { - public: - CImpl(uint64_t curTimeMs, int inst, WireConnection3& wire, - wpi::Logger& logger, - std::function setPeriodic); - - void ProcessIncoming(std::span data); - void HandleLocal(std::span msgs); - void SendPeriodic(uint64_t curTimeMs, bool initial, bool flush); - void SendValue(Writer& out, Entry* entry, const Value& value); - bool CheckNetworkReady(uint64_t curTimeMs); - - // Outgoing handlers - void Publish(NT_Publisher pubHandle, NT_Topic topicHandle, - std::string_view name, std::string_view typeStr, - const wpi::json& properties, const PubSubOptionsImpl& options); - void Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle); - void SetProperties(NT_Topic topicHandle, std::string_view name, - const wpi::json& update); - void SetValue(NT_Publisher pubHandle, const Value& value); - - // MessageHandler interface - void KeepAlive() final; - void ServerHelloDone() final; - void ClientHelloDone() final; - void ClearEntries() final; - void ProtoUnsup(unsigned int proto_rev) final; - void ClientHello(std::string_view self_id, unsigned int proto_rev) final; - void ServerHello(unsigned int flags, std::string_view self_id) final; - void EntryAssign(std::string_view name, unsigned int id, unsigned int seq_num, - const Value& value, unsigned int flags) final; - void EntryUpdate(unsigned int id, unsigned int seq_num, - const Value& value) final; - void FlagsUpdate(unsigned int id, unsigned int flags) final; - void EntryDelete(unsigned int id) final; - void ExecuteRpc(unsigned int id, unsigned int uid, - std::span params) final {} - void RpcResponse(unsigned int id, unsigned int uid, - std::span result) final {} - - enum State { - kStateInitial, - kStateHelloSent, - kStateInitialAssignments, - kStateRunning - }; - - int m_inst; - WireConnection3& m_wire; - wpi::Logger& m_logger; - net::LocalInterface* m_local{nullptr}; - std::function m_setPeriodic; - uint64_t m_initTimeMs; - - // periodic sweep handling - static constexpr uint32_t kKeepAliveIntervalMs = 1000; - uint32_t m_periodMs{kKeepAliveIntervalMs + 10}; - uint64_t m_lastSendMs{0}; - uint64_t m_nextKeepAliveTimeMs; - - // indexed by publisher index - std::vector> m_publishers; - - State m_state{kStateInitial}; - WireDecoder3 m_decoder; - std::string m_remoteId; - std::function m_handshakeSucceeded; - - std::vector> m_outgoingFlags; - - using NameMap = wpi::StringMap>; - using IdMap = std::vector; - - NameMap m_nameMap; - IdMap m_idMap; - - Entry* GetOrNewEntry(std::string_view name) { - auto& entry = m_nameMap[name]; - if (!entry) { - entry = std::make_unique(name); - } - return entry.get(); - } - Entry* LookupId(unsigned int id) { - return id < m_idMap.size() ? m_idMap[id] : nullptr; - } -}; - -} // namespace - -wpi::json Entry::SetFlags(unsigned int flags_) { +wpi::json ClientImpl3::Entry::SetFlags(unsigned int flags_) { bool wasPersistent = IsPersistent(); flags = flags_; bool isPersistent = IsPersistent(); @@ -189,25 +46,28 @@ wpi::json Entry::SetFlags(unsigned int flags_) { } } -CImpl::CImpl(uint64_t curTimeMs, int inst, WireConnection3& wire, - wpi::Logger& logger, - std::function setPeriodic) - : m_inst{inst}, - m_wire{wire}, +ClientImpl3::ClientImpl3(uint64_t curTimeMs, int inst, WireConnection3& wire, + wpi::Logger& logger, + std::function setPeriodic) + : m_wire{wire}, m_logger{logger}, m_setPeriodic{std::move(setPeriodic)}, m_initTimeMs{curTimeMs}, m_nextKeepAliveTimeMs{curTimeMs + kKeepAliveIntervalMs}, m_decoder{*this} {} -void CImpl::ProcessIncoming(std::span data) { +ClientImpl3::~ClientImpl3() { + DEBUG4("NT3 ClientImpl destroyed"); +} + +void ClientImpl3::ProcessIncoming(std::span data) { DEBUG4("received {} bytes", data.size()); if (!m_decoder.Execute(&data)) { m_wire.Disconnect(m_decoder.GetError()); } } -void CImpl::HandleLocal(std::span msgs) { +void ClientImpl3::HandleLocal(std::span msgs) { for (const auto& elem : msgs) { // NOLINT // common case is value if (auto msg = std::get_if(&elem.contents)) { @@ -223,7 +83,7 @@ void CImpl::HandleLocal(std::span msgs) { } } -void CImpl::SendPeriodic(uint64_t curTimeMs, bool initial, bool flush) { +void ClientImpl3::DoSendPeriodic(uint64_t curTimeMs, bool initial, bool flush) { DEBUG4("SendPeriodic({})", curTimeMs); // rate limit sends @@ -283,7 +143,7 @@ void CImpl::SendPeriodic(uint64_t curTimeMs, bool initial, bool flush) { m_lastSendMs = curTimeMs; } -void CImpl::SendValue(Writer& out, Entry* entry, const Value& value) { +void ClientImpl3::SendValue(Writer& out, Entry* entry, const Value& value) { DEBUG4("sending value for '{}', seqnum {}", entry->name, entry->seqNum.value()); @@ -302,7 +162,7 @@ void CImpl::SendValue(Writer& out, Entry* entry, const Value& value) { } } -bool CImpl::CheckNetworkReady(uint64_t curTimeMs) { +bool ClientImpl3::CheckNetworkReady(uint64_t curTimeMs) { if (!m_wire.Ready()) { uint64_t lastFlushTime = m_wire.GetLastFlushTime(); uint64_t now = wpi::Now(); @@ -314,10 +174,10 @@ bool CImpl::CheckNetworkReady(uint64_t curTimeMs) { return true; } -void CImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, - std::string_view name, std::string_view typeStr, - const wpi::json& properties, - const PubSubOptionsImpl& options) { +void ClientImpl3::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, + std::string_view name, std::string_view typeStr, + const wpi::json& properties, + const PubSubOptionsImpl& options) { DEBUG4("Publish('{}', '{}')", name, typeStr); unsigned int index = Handle{pubHandle}.GetIndex(); if (index >= m_publishers.size()) { @@ -342,7 +202,7 @@ void CImpl::Publish(NT_Publisher pubHandle, NT_Topic topicHandle, m_setPeriodic(m_periodMs); } -void CImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { +void ClientImpl3::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { DEBUG4("Unpublish({}, {})", pubHandle, topicHandle); unsigned int index = Handle{pubHandle}.GetIndex(); if (index >= m_publishers.size()) { @@ -365,8 +225,8 @@ void CImpl::Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle) { m_setPeriodic(m_periodMs); } -void CImpl::SetProperties(NT_Topic topicHandle, std::string_view name, - const wpi::json& update) { +void ClientImpl3::SetProperties(NT_Topic topicHandle, std::string_view name, + const wpi::json& update) { DEBUG4("SetProperties({}, {}, {})", topicHandle, name, update.dump()); auto entry = GetOrNewEntry(name); bool updated = false; @@ -388,7 +248,7 @@ void CImpl::SetProperties(NT_Topic topicHandle, std::string_view name, } } -void CImpl::SetValue(NT_Publisher pubHandle, const Value& value) { +void ClientImpl3::SetValue(NT_Publisher pubHandle, const Value& value) { DEBUG4("SetValue({})", pubHandle); unsigned int index = Handle{pubHandle}.GetIndex(); assert(index < m_publishers.size() && m_publishers[index]); @@ -404,7 +264,7 @@ void CImpl::SetValue(NT_Publisher pubHandle, const Value& value) { } } -void CImpl::KeepAlive() { +void ClientImpl3::KeepAlive() { DEBUG4("KeepAlive()"); if (m_state != kStateRunning && m_state != kStateInitialAssignments) { m_decoder.SetError("received unexpected KeepAlive message"); @@ -413,7 +273,7 @@ void CImpl::KeepAlive() { // ignore } -void CImpl::ServerHelloDone() { +void ClientImpl3::ServerHelloDone() { DEBUG4("ServerHelloDone()"); if (m_state != kStateInitialAssignments) { m_decoder.SetError("received unexpected ServerHelloDone message"); @@ -421,28 +281,29 @@ void CImpl::ServerHelloDone() { } // send initial assignments - SendPeriodic(m_initTimeMs, true, true); + DoSendPeriodic(m_initTimeMs, true, true); m_state = kStateRunning; m_setPeriodic(m_periodMs); } -void CImpl::ClientHelloDone() { +void ClientImpl3::ClientHelloDone() { DEBUG4("ClientHelloDone()"); m_decoder.SetError("received unexpected ClientHelloDone message"); } -void CImpl::ProtoUnsup(unsigned int proto_rev) { +void ClientImpl3::ProtoUnsup(unsigned int proto_rev) { DEBUG4("ProtoUnsup({})", proto_rev); m_decoder.SetError(fmt::format("received ProtoUnsup(version={})", proto_rev)); } -void CImpl::ClientHello(std::string_view self_id, unsigned int proto_rev) { +void ClientImpl3::ClientHello(std::string_view self_id, + unsigned int proto_rev) { DEBUG4("ClientHello({}, {})", self_id, proto_rev); m_decoder.SetError("received unexpected ClientHello message"); } -void CImpl::ServerHello(unsigned int flags, std::string_view self_id) { +void ClientImpl3::ServerHello(unsigned int flags, std::string_view self_id) { DEBUG4("ServerHello({}, {})", flags, self_id); if (m_state != kStateHelloSent) { m_decoder.SetError("received unexpected ServerHello message"); @@ -454,9 +315,9 @@ void CImpl::ServerHello(unsigned int flags, std::string_view self_id) { m_handshakeSucceeded = nullptr; // no longer required } -void CImpl::EntryAssign(std::string_view name, unsigned int id, - unsigned int seq_num, const Value& value, - unsigned int flags) { +void ClientImpl3::EntryAssign(std::string_view name, unsigned int id, + unsigned int seq_num, const Value& value, + unsigned int flags) { DEBUG4("EntryAssign({}, {}, {}, value, {})", name, id, seq_num, flags); if (m_state != kStateInitialAssignments && m_state != kStateRunning) { m_decoder.SetError("received unexpected EntryAssign message"); @@ -513,8 +374,8 @@ void CImpl::EntryAssign(std::string_view name, unsigned int id, } } -void CImpl::EntryUpdate(unsigned int id, unsigned int seq_num, - const Value& value) { +void ClientImpl3::EntryUpdate(unsigned int id, unsigned int seq_num, + const Value& value) { DEBUG4("EntryUpdate({}, {}, value)", id, seq_num); if (m_state != kStateRunning) { m_decoder.SetError("received EntryUpdate message before ServerHelloDone"); @@ -528,7 +389,7 @@ void CImpl::EntryUpdate(unsigned int id, unsigned int seq_num, } } -void CImpl::FlagsUpdate(unsigned int id, unsigned int flags) { +void ClientImpl3::FlagsUpdate(unsigned int id, unsigned int flags) { DEBUG4("FlagsUpdate({}, {})", id, flags); if (m_state != kStateRunning) { m_decoder.SetError("received FlagsUpdate message before ServerHelloDone"); @@ -548,7 +409,7 @@ void CImpl::FlagsUpdate(unsigned int id, unsigned int flags) { m_outgoingFlags.end()); } -void CImpl::EntryDelete(unsigned int id) { +void ClientImpl3::EntryDelete(unsigned int id) { DEBUG4("EntryDelete({})", id); if (m_state != kStateRunning) { m_decoder.SetError("received EntryDelete message before ServerHelloDone"); @@ -573,7 +434,7 @@ void CImpl::EntryDelete(unsigned int id) { m_outgoingFlags.end()); } -void CImpl::ClearEntries() { +void ClientImpl3::ClearEntries() { DEBUG4("ClearEntries()"); if (m_state != kStateRunning) { m_decoder.SetError("received ClearEntries message before ServerHelloDone"); @@ -597,47 +458,14 @@ void CImpl::ClearEntries() { m_outgoingFlags.resize(0); } -class ClientImpl3::Impl final : public CImpl { - public: - Impl(uint64_t curTimeMs, int inst, WireConnection3& wire, wpi::Logger& logger, - std::function setPeriodic) - : CImpl{curTimeMs, inst, wire, logger, std::move(setPeriodic)} {} -}; - -ClientImpl3::ClientImpl3(uint64_t curTimeMs, int inst, WireConnection3& wire, - wpi::Logger& logger, - std::function setPeriodic) - : m_impl{std::make_unique(curTimeMs, inst, wire, logger, - std::move(setPeriodic))} {} - -ClientImpl3::~ClientImpl3() { - WPI_DEBUG4(m_impl->m_logger, "NT3 ClientImpl destroyed"); -} - void ClientImpl3::Start(std::string_view selfId, std::function succeeded) { - if (m_impl->m_state != CImpl::kStateInitial) { + if (m_state != kStateInitial) { return; } - m_impl->m_handshakeSucceeded = std::move(succeeded); - auto writer = m_impl->m_wire.Send(); + m_handshakeSucceeded = std::move(succeeded); + auto writer = m_wire.Send(); WireEncodeClientHello(writer.stream(), selfId, 0x0300); - m_impl->m_wire.Flush(); - m_impl->m_state = CImpl::kStateHelloSent; -} - -void ClientImpl3::ProcessIncoming(std::span data) { - m_impl->ProcessIncoming(data); -} - -void ClientImpl3::HandleLocal(std::span msgs) { - m_impl->HandleLocal(msgs); -} - -void ClientImpl3::SendPeriodic(uint64_t curTimeMs, bool flush) { - m_impl->SendPeriodic(curTimeMs, false, flush); -} - -void ClientImpl3::SetLocal(net::LocalInterface* local) { - m_impl->m_local = local; + m_wire.Flush(); + m_state = kStateHelloSent; } diff --git a/ntcore/src/main/native/cpp/net3/ClientImpl3.h b/ntcore/src/main/native/cpp/net3/ClientImpl3.h index 3ea5ac7f36..7b3e676127 100644 --- a/ntcore/src/main/native/cpp/net3/ClientImpl3.h +++ b/ntcore/src/main/native/cpp/net3/ClientImpl3.h @@ -11,8 +11,17 @@ #include #include #include +#include +#include +#include + +#include "PubSubOptions.h" #include "net/NetworkInterface.h" +#include "net3/Message3.h" +#include "net3/SequenceNumber.h" +#include "net3/WireConnection3.h" +#include "net3/WireDecoder3.h" namespace wpi { class Logger; @@ -27,24 +36,147 @@ namespace nt::net3 { class WireConnection3; -class ClientImpl3 { +class ClientImpl3 final : private MessageHandler3 { public: explicit ClientImpl3(uint64_t curTimeMs, int inst, WireConnection3& wire, wpi::Logger& logger, std::function setPeriodic); - ~ClientImpl3(); + ~ClientImpl3() final; void Start(std::string_view selfId, std::function succeeded); void ProcessIncoming(std::span data); void HandleLocal(std::span msgs); - void SendPeriodic(uint64_t curTimeMs, bool flush); + void SendPeriodic(uint64_t curTimeMs, bool flush) { + DoSendPeriodic(curTimeMs, false, flush); + } - void SetLocal(net::LocalInterface* local); + void SetLocal(net::LocalInterface* local) { m_local = local; } private: - class Impl; - std::unique_ptr m_impl; + struct Entry; + + struct PublisherData { + explicit PublisherData(Entry* entry) : entry{entry} {} + + Entry* entry; + NT_Publisher handle; + PubSubOptionsImpl options; + // in options as double, but copy here as integer; rounded to the nearest + // 10 ms + uint32_t periodMs; + uint64_t nextSendMs{0}; + std::vector outValues; // outgoing values + }; + + // data for each entry + struct Entry { + explicit Entry(std::string_view name_) : name(name_) {} + bool IsPersistent() const { return (flags & NT_PERSISTENT) != 0; } + wpi::json SetFlags(unsigned int flags_); + + std::string name; + + std::string typeStr; + NT_Type type{NT_UNASSIGNED}; + + wpi::json properties = wpi::json::object(); + + // The current value and flags + Value value; + unsigned int flags{0}; + + // Unique ID used in network messages; this is 0xffff until assigned + // by the server. + unsigned int id{0xffff}; + + // Sequence number for update resolution + SequenceNumber seqNum; + + // Local topic handle + NT_Topic topic{0}; + + // Local publishers + std::vector publishers; + }; + + void DoSendPeriodic(uint64_t curTimeMs, bool initial, bool flush); + void SendValue(Writer& out, Entry* entry, const Value& value); + bool CheckNetworkReady(uint64_t curTimeMs); + + // Outgoing handlers + void Publish(NT_Publisher pubHandle, NT_Topic topicHandle, + std::string_view name, std::string_view typeStr, + const wpi::json& properties, const PubSubOptionsImpl& options); + void Unpublish(NT_Publisher pubHandle, NT_Topic topicHandle); + void SetProperties(NT_Topic topicHandle, std::string_view name, + const wpi::json& update); + void SetValue(NT_Publisher pubHandle, const Value& value); + + // MessageHandler interface + void KeepAlive() final; + void ServerHelloDone() final; + void ClientHelloDone() final; + void ClearEntries() final; + void ProtoUnsup(unsigned int proto_rev) final; + void ClientHello(std::string_view self_id, unsigned int proto_rev) final; + void ServerHello(unsigned int flags, std::string_view self_id) final; + void EntryAssign(std::string_view name, unsigned int id, unsigned int seq_num, + const Value& value, unsigned int flags) final; + void EntryUpdate(unsigned int id, unsigned int seq_num, + const Value& value) final; + void FlagsUpdate(unsigned int id, unsigned int flags) final; + void EntryDelete(unsigned int id) final; + void ExecuteRpc(unsigned int id, unsigned int uid, + std::span params) final {} + void RpcResponse(unsigned int id, unsigned int uid, + std::span result) final {} + + enum State { + kStateInitial, + kStateHelloSent, + kStateInitialAssignments, + kStateRunning + }; + + WireConnection3& m_wire; + wpi::Logger& m_logger; + net::LocalInterface* m_local{nullptr}; + std::function m_setPeriodic; + uint64_t m_initTimeMs; + + // periodic sweep handling + static constexpr uint32_t kKeepAliveIntervalMs = 1000; + uint32_t m_periodMs{kKeepAliveIntervalMs + 10}; + uint64_t m_lastSendMs{0}; + uint64_t m_nextKeepAliveTimeMs; + + // indexed by publisher index + std::vector> m_publishers; + + State m_state{kStateInitial}; + WireDecoder3 m_decoder; + std::string m_remoteId; + std::function m_handshakeSucceeded; + + std::vector> m_outgoingFlags; + + using NameMap = wpi::StringMap>; + using IdMap = std::vector; + + NameMap m_nameMap; + IdMap m_idMap; + + Entry* GetOrNewEntry(std::string_view name) { + auto& entry = m_nameMap[name]; + if (!entry) { + entry = std::make_unique(name); + } + return entry.get(); + } + Entry* LookupId(unsigned int id) { + return id < m_idMap.size() ? m_idMap[id] : nullptr; + } }; } // namespace nt::net3 diff --git a/ntcore/src/main/native/cpp/net3/WireDecoder3.cpp b/ntcore/src/main/native/cpp/net3/WireDecoder3.cpp index ea08358ec1..04d79cd43b 100644 --- a/ntcore/src/main/native/cpp/net3/WireDecoder3.cpp +++ b/ntcore/src/main/native/cpp/net3/WireDecoder3.cpp @@ -5,147 +5,24 @@ #include "WireDecoder3.h" #include -#include #include -#include #include #include #include -#include #include "Message3.h" using namespace nt; using namespace nt::net3; -namespace { - -class SimpleValueReader { - public: - std::optional Read16(std::span* in); - std::optional Read32(std::span* in); - std::optional Read64(std::span* in); - std::optional ReadDouble(std::span* in); - - private: - uint64_t m_value = 0; - int m_count = 0; -}; - -struct StringReader { - void SetLen(uint64_t len_) { - len = len_; - buf.clear(); - } - - std::optional len; - std::string buf; -}; - -struct RawReader { - void SetLen(uint64_t len_) { - len = len_; - buf.clear(); - } - - std::optional len; - std::vector buf; -}; - -struct ValueReader { - ValueReader() = default; - explicit ValueReader(NT_Type type_) : type{type_} {} - - void SetSize(uint32_t size_) { - haveSize = true; - size = size_; - ints.clear(); - doubles.clear(); - strings.clear(); - } - - NT_Type type = NT_UNASSIGNED; - bool haveSize = false; - uint32_t size = 0; - std::vector ints; - std::vector doubles; - std::vector strings; -}; - -struct WDImpl { - explicit WDImpl(MessageHandler3& out) : m_out{out} {} - - MessageHandler3& m_out; - - // primary (message) decode state - enum { - kStart, - kClientHello_1ProtoRev, - kClientHello_2Id, - kProtoUnsup_1ProtoRev, - kServerHello_1Flags, - kServerHello_2Id, - kEntryAssign_1Name, - kEntryAssign_2Type, - kEntryAssign_3Id, - kEntryAssign_4SeqNum, - kEntryAssign_5Flags, - kEntryAssign_6Value, - kEntryUpdate_1Id, - kEntryUpdate_2SeqNum, - kEntryUpdate_3Type, - kEntryUpdate_4Value, - kFlagsUpdate_1Id, - kFlagsUpdate_2Flags, - kEntryDelete_1Id, - kClearEntries_1Magic, - kExecuteRpc_1Id, - kExecuteRpc_2Uid, - kExecuteRpc_3Params, - kRpcResponse_1Id, - kRpcResponse_2Uid, - kRpcResponse_3Result, - kError - } m_state = kStart; - - // detail decoders - wpi::Uleb128Reader m_ulebReader; - SimpleValueReader m_simpleReader; - StringReader m_stringReader; - RawReader m_rawReader; - ValueReader m_valueReader; - - std::string m_error; - - std::string m_str; - unsigned int m_id{0}; // also used for proto_rev - unsigned int m_flags{0}; - unsigned int m_seq_num_uid{0}; - - void Execute(std::span* in); - - std::nullopt_t EmitError(std::string_view msg) { - m_state = kError; - m_error = msg; - return std::nullopt; - } - - std::optional ReadString(std::span* in); - std::optional> ReadRaw(std::span* in); - std::optional ReadType(std::span* in); - std::optional ReadValue(std::span* in); -}; - -} // namespace - static uint8_t Read8(std::span* in) { uint8_t val = in->front(); *in = wpi::drop_front(*in); return val; } -std::optional SimpleValueReader::Read16( +std::optional WireDecoder3::SimpleValueReader::Read16( std::span* in) { while (!in->empty()) { m_value <<= 8; @@ -161,7 +38,7 @@ std::optional SimpleValueReader::Read16( return std::nullopt; } -std::optional SimpleValueReader::Read32( +std::optional WireDecoder3::SimpleValueReader::Read32( std::span* in) { while (!in->empty()) { m_value <<= 8; @@ -177,7 +54,7 @@ std::optional SimpleValueReader::Read32( return std::nullopt; } -std::optional SimpleValueReader::Read64( +std::optional WireDecoder3::SimpleValueReader::Read64( std::span* in) { while (!in->empty()) { m_value <<= 8; @@ -193,7 +70,7 @@ std::optional SimpleValueReader::Read64( return std::nullopt; } -std::optional SimpleValueReader::ReadDouble( +std::optional WireDecoder3::SimpleValueReader::ReadDouble( std::span* in) { if (auto val = Read64(in)) { return wpi::BitsToDouble(val.value()); @@ -202,7 +79,7 @@ std::optional SimpleValueReader::ReadDouble( } } -void WDImpl::Execute(std::span* in) { +void WireDecoder3::DoExecute(std::span* in) { while (!in->empty()) { switch (m_state) { case kStart: { @@ -417,7 +294,8 @@ void WDImpl::Execute(std::span* in) { } } -std::optional WDImpl::ReadString(std::span* in) { +std::optional WireDecoder3::ReadString( + std::span* in) { // string length if (!m_stringReader.len) { if (auto val = m_ulebReader.ReadOne(in)) { @@ -442,7 +320,7 @@ std::optional WDImpl::ReadString(std::span* in) { return std::nullopt; } -std::optional> WDImpl::ReadRaw( +std::optional> WireDecoder3::ReadRaw( std::span* in) { // string length if (!m_rawReader.len) { @@ -468,7 +346,7 @@ std::optional> WDImpl::ReadRaw( return std::nullopt; } -std::optional WDImpl::ReadType(std::span* in) { +std::optional WireDecoder3::ReadType(std::span* in) { // Convert from byte value to enum switch (Read8(in)) { case Message3::kBoolean: @@ -492,7 +370,7 @@ std::optional WDImpl::ReadType(std::span* in) { } } -std::optional WDImpl::ReadValue(std::span* in) { +std::optional WireDecoder3::ReadValue(std::span* in) { while (!in->empty()) { switch (m_valueReader.type) { case NT_BOOLEAN: @@ -577,24 +455,3 @@ std::optional WDImpl::ReadValue(std::span* in) { } return std::nullopt; } - -struct WireDecoder3::Impl : public WDImpl { - explicit Impl(MessageHandler3& out) : WDImpl{out} {} -}; - -WireDecoder3::WireDecoder3(MessageHandler3& out) : m_impl{new Impl{out}} {} - -WireDecoder3::~WireDecoder3() = default; - -bool WireDecoder3::Execute(std::span* in) { - m_impl->Execute(in); - return m_impl->m_state != Impl::kError; -} - -void WireDecoder3::SetError(std::string_view message) { - m_impl->EmitError(message); -} - -std::string WireDecoder3::GetError() const { - return m_impl->m_error; -} diff --git a/ntcore/src/main/native/cpp/net3/WireDecoder3.h b/ntcore/src/main/native/cpp/net3/WireDecoder3.h index e877833fca..48064f7998 100644 --- a/ntcore/src/main/native/cpp/net3/WireDecoder3.h +++ b/ntcore/src/main/native/cpp/net3/WireDecoder3.h @@ -7,8 +7,14 @@ #include #include +#include #include #include +#include + +#include + +#include "ntcore_c.h" namespace nt { class Value; @@ -18,6 +24,8 @@ namespace nt::net3 { class MessageHandler3 { public: + virtual ~MessageHandler3() = default; + virtual void KeepAlive() = 0; virtual void ServerHelloDone() = 0; virtual void ClientHelloDone() = 0; @@ -42,8 +50,7 @@ class MessageHandler3 { /* Decodes NT3 protocol into native representation. */ class WireDecoder3 { public: - explicit WireDecoder3(MessageHandler3& out); - ~WireDecoder3(); + explicit WireDecoder3(MessageHandler3& out) : m_out{out} {} /** * Executes the decoder. All input data will be consumed unless an error @@ -51,14 +58,126 @@ class WireDecoder3 { * @param in input data (updated during parse) * @return false if error occurred */ - bool Execute(std::span* in); + bool Execute(std::span* in) { + DoExecute(in); + return m_state != kError; + } - void SetError(std::string_view message); - std::string GetError() const; + void SetError(std::string_view message) { EmitError(message); } + std::string GetError() const { return m_error; } private: - struct Impl; - std::unique_ptr m_impl; + class SimpleValueReader { + public: + std::optional Read16(std::span* in); + std::optional Read32(std::span* in); + std::optional Read64(std::span* in); + std::optional ReadDouble(std::span* in); + + private: + uint64_t m_value = 0; + int m_count = 0; + }; + + struct StringReader { + void SetLen(uint64_t len_) { + len = len_; + buf.clear(); + } + + std::optional len; + std::string buf; + }; + + struct RawReader { + void SetLen(uint64_t len_) { + len = len_; + buf.clear(); + } + + std::optional len; + std::vector buf; + }; + + struct ValueReader { + ValueReader() = default; + explicit ValueReader(NT_Type type_) : type{type_} {} + + void SetSize(uint32_t size_) { + haveSize = true; + size = size_; + ints.clear(); + doubles.clear(); + strings.clear(); + } + + NT_Type type = NT_UNASSIGNED; + bool haveSize = false; + uint32_t size = 0; + std::vector ints; + std::vector doubles; + std::vector strings; + }; + + MessageHandler3& m_out; + + // primary (message) decode state + enum { + kStart, + kClientHello_1ProtoRev, + kClientHello_2Id, + kProtoUnsup_1ProtoRev, + kServerHello_1Flags, + kServerHello_2Id, + kEntryAssign_1Name, + kEntryAssign_2Type, + kEntryAssign_3Id, + kEntryAssign_4SeqNum, + kEntryAssign_5Flags, + kEntryAssign_6Value, + kEntryUpdate_1Id, + kEntryUpdate_2SeqNum, + kEntryUpdate_3Type, + kEntryUpdate_4Value, + kFlagsUpdate_1Id, + kFlagsUpdate_2Flags, + kEntryDelete_1Id, + kClearEntries_1Magic, + kExecuteRpc_1Id, + kExecuteRpc_2Uid, + kExecuteRpc_3Params, + kRpcResponse_1Id, + kRpcResponse_2Uid, + kRpcResponse_3Result, + kError + } m_state = kStart; + + // detail decoders + wpi::Uleb128Reader m_ulebReader; + SimpleValueReader m_simpleReader; + StringReader m_stringReader; + RawReader m_rawReader; + ValueReader m_valueReader; + + std::string m_error; + + std::string m_str; + unsigned int m_id{0}; // also used for proto_rev + unsigned int m_flags{0}; + unsigned int m_seq_num_uid{0}; + + void DoExecute(std::span* in); + + std::nullopt_t EmitError(std::string_view msg) { + m_state = kError; + m_error = msg; + return std::nullopt; + } + + std::optional ReadString(std::span* in); + std::optional> ReadRaw(std::span* in); + std::optional ReadType(std::span* in); + std::optional ReadValue(std::span* in); }; } // namespace nt::net3 diff --git a/ntcore/src/test/native/cpp/LocalStorageTest.cpp b/ntcore/src/test/native/cpp/LocalStorageTest.cpp index 4d92d105ac..ccdb26ed39 100644 --- a/ntcore/src/test/native/cpp/LocalStorageTest.cpp +++ b/ntcore/src/test/native/cpp/LocalStorageTest.cpp @@ -162,7 +162,7 @@ TEST_F(LocalStorageTest, SubscribeNoTypeLocalPubPost) { EXPECT_EQ(value.GetBoolean(), true); EXPECT_EQ(value.time(), 5); - auto vals = storage.ReadQueueBoolean(sub); + auto vals = storage.ReadQueue(sub); ASSERT_EQ(vals.size(), 1u); EXPECT_EQ(vals[0].value, true); EXPECT_EQ(vals[0].time, 5); @@ -171,7 +171,7 @@ TEST_F(LocalStorageTest, SubscribeNoTypeLocalPubPost) { EXPECT_CALL(network, SetValue(pub, val)); storage.SetEntryValue(pub, val); - auto vals2 = storage.ReadQueueInteger(sub); // mismatched type + auto vals2 = storage.ReadQueue(sub); // mismatched type ASSERT_TRUE(vals2.empty()); } @@ -221,7 +221,7 @@ TEST_F(LocalStorageTest, EntryNoTypeLocalSet) { EXPECT_EQ(value.GetBoolean(), true); EXPECT_EQ(value.time(), 5); - auto vals = storage.ReadQueueBoolean(entry); + auto vals = storage.ReadQueue(entry); ASSERT_EQ(vals.size(), 1u); EXPECT_EQ(vals[0].value, true); EXPECT_EQ(vals[0].time, 5); @@ -231,7 +231,7 @@ TEST_F(LocalStorageTest, EntryNoTypeLocalSet) { EXPECT_CALL(network, SetValue(_, val)); EXPECT_TRUE(storage.SetEntryValue(entry, val)); - auto vals2 = storage.ReadQueueInteger(entry); // mismatched type + auto vals2 = storage.ReadQueue(entry); // mismatched type ASSERT_TRUE(vals2.empty()); // cannot change type; won't generate network message @@ -241,7 +241,7 @@ TEST_F(LocalStorageTest, EntryNoTypeLocalSet) { EXPECT_EQ(storage.GetTopicType(fooTopic), NT_BOOLEAN); EXPECT_EQ(storage.GetTopicTypeString(fooTopic), "boolean"); - auto vals3 = storage.ReadQueueInteger(entry); // mismatched type + auto vals3 = storage.ReadQueue(entry); // mismatched type ASSERT_TRUE(vals3.empty()); } @@ -266,7 +266,7 @@ TEST_F(LocalStorageTest, PubUnpubPub) { EXPECT_EQ(storage.GetTopicTypeString(fooTopic), "boolean"); EXPECT_TRUE(storage.GetTopicExists(fooTopic)); - EXPECT_TRUE(storage.ReadQueueInteger(sub).empty()); + EXPECT_TRUE(storage.ReadQueue(sub).empty()); EXPECT_CALL(network, Unpublish(pub, fooTopic)); storage.Unpublish(pub); @@ -288,7 +288,7 @@ TEST_F(LocalStorageTest, PubUnpubPub) { EXPECT_EQ(storage.GetTopicTypeString(fooTopic), "int"); EXPECT_TRUE(storage.GetTopicExists(fooTopic)); - EXPECT_EQ(storage.ReadQueueInteger(sub).size(), 1u); + EXPECT_EQ(storage.ReadQueue(sub).size(), 1u); } TEST_F(LocalStorageTest, LocalPubConflict) { @@ -535,7 +535,7 @@ TEST_F(LocalStorageDuplicatesTest, Defaults) { SetValues(false); // verify 2nd update was dropped locally - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 2u); ASSERT_EQ(values[0].value, val1.GetDouble()); ASSERT_EQ(values[0].time, val1.time()); @@ -552,7 +552,7 @@ TEST_F(LocalStorageDuplicatesTest, KeepPub) { SetValues(true); // verify only 2 updates were received locally - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 2u); } @@ -565,7 +565,7 @@ TEST_F(LocalStorageDuplicatesTest, KeepSub) { SetValues(false); // verify 2 updates were received locally - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 2u); } @@ -579,7 +579,7 @@ TEST_F(LocalStorageDuplicatesTest, KeepPubSub) { SetValues(true); // verify all 3 updates were received locally - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 3u); } @@ -595,7 +595,7 @@ TEST_F(LocalStorageDuplicatesTest, FromNetworkDefault) { storage.NetworkSetValue(topic, val3); // verify 2nd update was dropped for local subscriber - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 2u); ASSERT_EQ(values[0].value, val1.GetDouble()); ASSERT_EQ(values[0].time, val1.time()); @@ -615,7 +615,7 @@ TEST_F(LocalStorageDuplicatesTest, FromNetworkKeepPub) { storage.NetworkSetValue(topic, val3); // verify 2nd update was dropped for local subscriber - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 2u); ASSERT_EQ(values[0].value, val1.GetDouble()); ASSERT_EQ(values[0].time, val1.time()); @@ -634,7 +634,7 @@ TEST_F(LocalStorageDuplicatesTest, FromNetworkKeepSub) { storage.NetworkSetValue(topic, val3); // verify 2nd update was received by local subscriber - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 3u); ASSERT_EQ(values[0].value, val1.GetDouble()); ASSERT_EQ(values[0].time, val1.time()); @@ -656,7 +656,7 @@ TEST_F(LocalStorageDuplicatesTest, FromNetworkKeepPubSub) { storage.NetworkSetValue(topic, val3); // verify 2nd update was received by local subscriber - auto values = storage.ReadQueueDouble(sub); + auto values = storage.ReadQueue(sub); ASSERT_EQ(values.size(), 3u); ASSERT_EQ(values[0].value, val1.GetDouble()); ASSERT_EQ(values[0].time, val1.time()); @@ -778,13 +778,13 @@ TEST_F(LocalStorageNumberVariantsTest, GetAtomic) { for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); - EXPECT_THAT(storage.GetAtomicDouble(subentry.subentry, 0), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, 0), TSEq(1.0, 50)); - EXPECT_THAT(storage.GetAtomicInteger(subentry.subentry, 0), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, 0), TSEq(1, 50)); - EXPECT_THAT(storage.GetAtomicFloat(subentry.subentry, 0), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, 0), TSEq(1.0, 50)); - EXPECT_THAT(storage.GetAtomicBoolean(subentry.subentry, false), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, false), TSEq(false, 0)); } } @@ -807,15 +807,15 @@ TEST_F(LocalStorageNumberVariantsTest, GetAtomicArray) { for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); double doubleVal = 1.0; - EXPECT_THAT(storage.GetAtomicDoubleArray(subentry.subentry, {}), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, {}), TSSpanEq(std::span{&doubleVal, 1}, 50)); int64_t intVal = 1; - EXPECT_THAT(storage.GetAtomicIntegerArray(subentry.subentry, {}), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, {}), TSSpanEq(std::span{&intVal, 1}, 50)); float floatVal = 1.0; - EXPECT_THAT(storage.GetAtomicFloatArray(subentry.subentry, {}), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, {}), TSSpanEq(std::span{&floatVal, 1}, 50)); - EXPECT_THAT(storage.GetAtomicBooleanArray(subentry.subentry, {}), + EXPECT_THAT(storage.GetAtomic(subentry.subentry, {}), TSSpanEq(std::span{}, 0)); } } @@ -831,9 +831,9 @@ TEST_F(LocalStorageNumberVariantsTest, ReadQueue) { for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); if (subentry.type == NT_BOOLEAN) { - EXPECT_THAT(storage.ReadQueueDouble(subentry.subentry), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subentry.subentry), IsEmpty()); } else { - EXPECT_THAT(storage.ReadQueueDouble(subentry.subentry), + EXPECT_THAT(storage.ReadQueue(subentry.subentry), ElementsAre(TSEq(1.0, 50))); } } @@ -842,9 +842,9 @@ TEST_F(LocalStorageNumberVariantsTest, ReadQueue) { for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); if (subentry.type == NT_BOOLEAN) { - EXPECT_THAT(storage.ReadQueueInteger(subentry.subentry), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subentry.subentry), IsEmpty()); } else { - EXPECT_THAT(storage.ReadQueueInteger(subentry.subentry), + EXPECT_THAT(storage.ReadQueue(subentry.subentry), ElementsAre(TSEq(2, 50))); } } @@ -853,9 +853,9 @@ TEST_F(LocalStorageNumberVariantsTest, ReadQueue) { for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); if (subentry.type == NT_BOOLEAN) { - EXPECT_THAT(storage.ReadQueueFloat(subentry.subentry), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subentry.subentry), IsEmpty()); } else { - EXPECT_THAT(storage.ReadQueueFloat(subentry.subentry), + EXPECT_THAT(storage.ReadQueue(subentry.subentry), ElementsAre(TSEq(3.0, 50))); } } @@ -863,7 +863,7 @@ TEST_F(LocalStorageNumberVariantsTest, ReadQueue) { storage.SetEntryValue(pub, Value::MakeDouble(4.0, 50)); for (auto&& subentry : subentries) { SCOPED_TRACE(subentry.name); - EXPECT_THAT(storage.ReadQueueBoolean(subentry.subentry), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subentry.subentry), IsEmpty()); } } @@ -930,19 +930,19 @@ TEST_F(LocalStorageTest, ReadQueueLocalRemote) { // local set EXPECT_CALL(network, SetValue(_, _)); storage.SetEntryValue(pub, Value::MakeDouble(1.0, 50)); - EXPECT_THAT(storage.ReadQueueDouble(subBoth), + EXPECT_THAT(storage.ReadQueue(subBoth), ElementsAre(TSEq(1.0, 50))); - EXPECT_THAT(storage.ReadQueueDouble(subLocal), + EXPECT_THAT(storage.ReadQueue(subLocal), ElementsAre(TSEq(1.0, 50))); - EXPECT_THAT(storage.ReadQueueDouble(subRemote), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subRemote), IsEmpty()); // network set storage.NetworkSetValue(remoteTopic, Value::MakeDouble(2.0, 60)); - EXPECT_THAT(storage.ReadQueueDouble(subBoth), + EXPECT_THAT(storage.ReadQueue(subBoth), ElementsAre(TSEq(2.0, 60))); - EXPECT_THAT(storage.ReadQueueDouble(subRemote), + EXPECT_THAT(storage.ReadQueue(subRemote), ElementsAre(TSEq(2.0, 60))); - EXPECT_THAT(storage.ReadQueueDouble(subLocal), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subLocal), IsEmpty()); } TEST_F(LocalStorageTest, SubExcludePub) { @@ -959,15 +959,15 @@ TEST_F(LocalStorageTest, SubExcludePub) { // local set EXPECT_CALL(network, SetValue(_, _)); storage.SetEntryValue(pub, Value::MakeDouble(1.0, 50)); - EXPECT_THAT(storage.ReadQueueDouble(subActive), + EXPECT_THAT(storage.ReadQueue(subActive), ElementsAre(TSEq(1.0, 50))); - EXPECT_THAT(storage.ReadQueueDouble(subExclude), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subExclude), IsEmpty()); // network set storage.NetworkSetValue(remoteTopic, Value::MakeDouble(2.0, 60)); - EXPECT_THAT(storage.ReadQueueDouble(subActive), + EXPECT_THAT(storage.ReadQueue(subActive), ElementsAre(TSEq(2.0, 60))); - EXPECT_THAT(storage.ReadQueueDouble(subExclude), + EXPECT_THAT(storage.ReadQueue(subExclude), ElementsAre(TSEq(2.0, 60))); } @@ -983,11 +983,11 @@ TEST_F(LocalStorageTest, EntryExcludeSelf) { EXPECT_CALL(network, Publish(_, _, _, _, _, _)); EXPECT_CALL(network, SetValue(_, _)); storage.SetEntryValue(entry, Value::MakeDouble(1.0, 50)); - EXPECT_THAT(storage.ReadQueueDouble(entry), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(entry), IsEmpty()); // network set storage.NetworkSetValue(remoteTopic, Value::MakeDouble(2.0, 60)); - EXPECT_THAT(storage.ReadQueueDouble(entry), + EXPECT_THAT(storage.ReadQueue(entry), ElementsAre(TSEq(2.0, 60))); } @@ -1006,11 +1006,11 @@ TEST_F(LocalStorageTest, ReadQueueInitialLocal) { auto subRemote = storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true}); - EXPECT_THAT(storage.ReadQueueDouble(subBoth), + EXPECT_THAT(storage.ReadQueue(subBoth), ElementsAre(TSEq(1.0, 50))); - EXPECT_THAT(storage.ReadQueueDouble(subLocal), + EXPECT_THAT(storage.ReadQueue(subLocal), ElementsAre(TSEq(1.0, 50))); - EXPECT_THAT(storage.ReadQueueDouble(subRemote), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subRemote), IsEmpty()); } TEST_F(LocalStorageTest, ReadQueueInitialRemote) { @@ -1028,11 +1028,11 @@ TEST_F(LocalStorageTest, ReadQueueInitialRemote) { storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true}); // network set - EXPECT_THAT(storage.ReadQueueDouble(subBoth), + EXPECT_THAT(storage.ReadQueue(subBoth), ElementsAre(TSEq(2.0, 60))); - EXPECT_THAT(storage.ReadQueueDouble(subRemote), + EXPECT_THAT(storage.ReadQueue(subRemote), ElementsAre(TSEq(2.0, 60))); - EXPECT_THAT(storage.ReadQueueDouble(subLocal), IsEmpty()); + EXPECT_THAT(storage.ReadQueue(subLocal), IsEmpty()); } } // namespace nt