diff --git a/ntcore/src/main/native/cpp/LocalStorage.cpp b/ntcore/src/main/native/cpp/LocalStorage.cpp index 75806bfbaa..9ec17d6240 100644 --- a/ntcore/src/main/native/cpp/LocalStorage.cpp +++ b/ntcore/src/main/native/cpp/LocalStorage.cpp @@ -34,6 +34,15 @@ static constexpr size_t kMaxListeners = 512; namespace { +static constexpr bool IsSpecial(std::string_view name) { + return name.empty() ? false : name.front() == '$'; +} + +static constexpr bool PrefixMatch(std::string_view name, + std::string_view prefix, bool special) { + return (!special || !prefix.empty()) && wpi::starts_with(name, prefix); +} + // Utility wrapper for making a set-like vector template class VectorSet : public std::vector { @@ -66,7 +75,7 @@ struct TopicData { static constexpr auto kType = Handle::kTopic; TopicData(NT_Topic handle, std::string_view name) - : handle{handle}, name{name} {} + : handle{handle}, name{name}, special{IsSpecial(name)} {} bool Exists() const { return onNetwork || !localPublishers.empty(); } @@ -75,6 +84,7 @@ struct TopicData { // invariants wpi::SignalObject handle; std::string name; + bool special; Value lastValue; // also stores timestamp Value lastValueNetwork; @@ -179,6 +189,8 @@ struct MultiSubscriberData { } } + bool Matches(std::string_view name, bool special); + // invariants wpi::SignalObject handle; std::vector prefixes; @@ -188,6 +200,15 @@ struct MultiSubscriberData { VectorSet valueListeners; }; +bool MultiSubscriberData::Matches(std::string_view name, bool special) { + for (auto&& prefix : prefixes) { + if (PrefixMatch(name, prefix, special)) { + return true; + } + } + return false; +} + struct ListenerData { ListenerData(NT_Listener handle, SubscriberData* subscriber, unsigned int eventMask, bool subscriberOwned) @@ -403,6 +424,7 @@ void SubscriberData::UpdateActive() { } void LSImpl::NotifyTopic(TopicData* topic, unsigned int eventFlags) { + DEBUG4("NotifyTopic({}, {})\n", topic->name, eventFlags); auto topicInfo = topic->GetTopicInfo(); if (!topic->listeners.empty()) { m_listenerStorage.Notify(topic->listeners, eventFlags, topicInfo); @@ -410,12 +432,9 @@ void LSImpl::NotifyTopic(TopicData* topic, unsigned int eventFlags) { wpi::SmallVector listeners; for (auto listener : m_topicPrefixListeners) { - if (listener->multiSubscriber) { - for (auto&& prefix : listener->multiSubscriber->prefixes) { - if (wpi::starts_with(topic->name, prefix)) { - listeners.emplace_back(listener->handle); - } - } + if (listener->multiSubscriber && + listener->multiSubscriber->Matches(topic->name, topic->special)) { + listeners.emplace_back(listener->handle); } } if (!listeners.empty()) { @@ -874,7 +893,7 @@ MultiSubscriberData* LSImpl::AddMultiSubscriber( // subscribe to any already existing topics for (auto&& topic : m_topics) { for (auto&& prefix : prefixes) { - if (wpi::starts_with(topic->name, prefix)) { + if (PrefixMatch(topic->name, prefix, topic->special)) { topic->multiSubscribers.Add(subscriber); break; } @@ -995,10 +1014,8 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, if ((eventMask & NT_EVENT_IMMEDIATE) != 0 && (eventMask & (NT_EVENT_PUBLISH | NT_EVENT_VALUE_ALL)) != 0) { for (auto&& topic : m_topics) { - for (auto&& prefix : subscriber->prefixes) { - if (wpi::starts_with(topic->name, prefix) && topic->Exists()) { - topics.emplace_back(topic.get()); - } + if (topic->Exists() && subscriber->Matches(topic->name, topic->special)) { + topics.emplace_back(topic.get()); } } } @@ -1124,11 +1141,8 @@ TopicData* LSImpl::GetOrCreateTopic(std::string_view name) { topic = m_topics.Add(m_inst, name); // attach multi-subscribers for (auto&& sub : m_multiSubscribers) { - for (auto&& prefix : sub->prefixes) { - if (wpi::starts_with(name, prefix)) { - topic->multiSubscribers.Add(sub.get()); - break; - } + if (sub->Matches(name, topic->special)) { + topic->multiSubscribers.Add(sub.get()); } } } diff --git a/ntcore/src/test/native/cpp/LocalStorageTest.cpp b/ntcore/src/test/native/cpp/LocalStorageTest.cpp index cb4d22b325..ebe098fe66 100644 --- a/ntcore/src/test/native/cpp/LocalStorageTest.cpp +++ b/ntcore/src/test/native/cpp/LocalStorageTest.cpp @@ -789,6 +789,34 @@ TEST_F(LocalStorageNumberVariantsTest, ReadQueue) { } } +TEST_F(LocalStorageTest, MultiSubSpecial) { + EXPECT_CALL(network, Subscribe(_, _, _)).Times(2); + EXPECT_CALL(network, Publish(_, _, _, _, _, _)).Times(2); + EXPECT_CALL(network, SetValue(_, _)).Times(2); + EXPECT_CALL(listenerStorage, Activate(_, _, _)).Times(2); + + auto subnormal = storage.SubscribeMultiple({{""}}, {}); + auto subspecial = storage.SubscribeMultiple({{"", "$"}}, {}); + auto pubnormal = storage.Publish(fooTopic, NT_DOUBLE, "double", {}, {}); + auto specialTopic = storage.GetTopic("$topic"); + auto pubspecial = storage.Publish(specialTopic, NT_DOUBLE, "double", {}, {}); + storage.AddListener(1, subnormal, NT_EVENT_VALUE_ALL); + storage.AddListener(2, subspecial, NT_EVENT_VALUE_ALL); + + EXPECT_CALL( + listenerStorage, + Notify(wpi::SpanEq(std::span{{2}}), _, _, _, _)); + storage.SetEntryValue(pubspecial, Value::MakeDouble(1.0, 30)); + + EXPECT_CALL( + listenerStorage, + Notify(wpi::SpanEq(std::span{{1}}), _, _, _, _)); + EXPECT_CALL( + listenerStorage, + Notify(wpi::SpanEq(std::span{{2}}), _, _, _, _)); + storage.SetEntryValue(pubnormal, Value::MakeDouble(2.0, 40)); +} + TEST_F(LocalStorageTest, NetworkDuplicateDetect) { EXPECT_CALL(network, Publish(_, _, _, _, _, _)); auto pub = storage.Publish(fooTopic, NT_DOUBLE, "double", {}, {});