From 039edcc23ff078bb0e146c4b8b4ff8821e2bc2f4 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Fri, 13 Jan 2023 20:07:24 -0800 Subject: [PATCH] [ntcore] Queue current value on subscriber creation (#4938) This fixes a potential race condition in code that only uses readQueue. --- ntcore/src/main/native/cpp/LocalStorage.cpp | 15 ++++++ .../src/test/native/cpp/LocalStorageTest.cpp | 47 +++++++++++++++++-- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/ntcore/src/main/native/cpp/LocalStorage.cpp b/ntcore/src/main/native/cpp/LocalStorage.cpp index db8e065d47..ceaf011e6e 100644 --- a/ntcore/src/main/native/cpp/LocalStorage.cpp +++ b/ntcore/src/main/native/cpp/LocalStorage.cpp @@ -96,6 +96,7 @@ struct TopicData { NT_Entry entry{0}; // cached entry for GetEntry() bool onNetwork{false}; // true if there are any remote publishers + bool lastValueFromNetwork{false}; wpi::SmallVector datalogs; NT_Type datalogType{NT_UNASSIGNED}; @@ -484,6 +485,7 @@ void LSImpl::CheckReset(TopicData* topic) { } topic->lastValue = {}; topic->lastValueNetwork = {}; + topic->lastValueFromNetwork = false; topic->type = NT_UNASSIGNED; topic->typeStr.clear(); topic->flags = 0; @@ -503,6 +505,7 @@ bool LSImpl::SetValue(TopicData* topic, const Value& value, // TODO: notify option even if older value topic->type = value.type(); topic->lastValue = value; + topic->lastValueFromNetwork = false; NotifyValue(topic, eventFlags, isDuplicate, publisher); } if (!isDuplicate && topic->datalogType == value.type()) { @@ -858,6 +861,17 @@ SubscriberData* LSImpl::AddLocalSubscriber(TopicData* topic, DEBUG4("-> NetworkSubscribe({})", topic->name); m_network->Subscribe(subscriber->handle, {{topic->name}}, config); } + + // queue current value + if (subscriber->active) { + if (!topic->lastValueFromNetwork && !config.disableLocal) { + subscriber->pollStorage.emplace_back(topic->lastValue); + subscriber->handle.Set(); + } else if (topic->lastValueFromNetwork && !config.disableRemote) { + subscriber->pollStorage.emplace_back(topic->lastValueNetwork); + subscriber->handle.Set(); + } + } return subscriber; } @@ -1376,6 +1390,7 @@ void LocalStorage::NetworkSetValue(NT_Topic topicHandle, const Value& value) { if (m_impl->SetValue(topic, value, NT_EVENT_VALUE_REMOTE, value == topic->lastValue, nullptr)) { topic->lastValueNetwork = value; + topic->lastValueFromNetwork = true; } } } diff --git a/ntcore/src/test/native/cpp/LocalStorageTest.cpp b/ntcore/src/test/native/cpp/LocalStorageTest.cpp index 5734284cc3..112f1ea361 100644 --- a/ntcore/src/test/native/cpp/LocalStorageTest.cpp +++ b/ntcore/src/test/native/cpp/LocalStorageTest.cpp @@ -197,9 +197,6 @@ TEST_F(LocalStorageTest, SubscribeNoTypeLocalPubPre) { ASSERT_TRUE(value.IsBoolean()); EXPECT_EQ(value.GetBoolean(), true); EXPECT_EQ(value.time(), 5); - - auto vals = storage.ReadQueueValue(sub); // read queue won't get anything - ASSERT_TRUE(vals.empty()); } TEST_F(LocalStorageTest, EntryNoTypeLocalSet) { @@ -916,4 +913,48 @@ TEST_F(LocalStorageTest, EntryExcludeSelf) { ElementsAre(TSEq(2.0, 60))); } +TEST_F(LocalStorageTest, ReadQueueInitialLocal) { + EXPECT_CALL(network, Publish(_, _, _, _, _, _)); + EXPECT_CALL(network, SetValue(_, _)); + EXPECT_CALL(network, Subscribe(_, _, _)).Times(3); + + auto pub = storage.Publish(fooTopic, NT_DOUBLE, "double", {}, {}); + storage.SetEntryValue(pub, Value::MakeDouble(1.0, 50)); + + auto subBoth = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", kDefaultPubSubOptions); + auto subLocal = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableRemote = true}); + auto subRemote = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true}); + + EXPECT_THAT(storage.ReadQueueDouble(subBoth), + ElementsAre(TSEq(1.0, 50))); + EXPECT_THAT(storage.ReadQueueDouble(subLocal), + ElementsAre(TSEq(1.0, 50))); + EXPECT_THAT(storage.ReadQueueDouble(subRemote), IsEmpty()); +} + +TEST_F(LocalStorageTest, ReadQueueInitialRemote) { + EXPECT_CALL(network, Subscribe(_, _, _)).Times(3); + + auto remoteTopic = + storage.NetworkAnnounce("foo", "double", wpi::json::object(), 0); + storage.NetworkSetValue(remoteTopic, Value::MakeDouble(2.0, 60)); + + auto subBoth = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", kDefaultPubSubOptions); + auto subLocal = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableRemote = true}); + auto subRemote = + storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true}); + + // network set + EXPECT_THAT(storage.ReadQueueDouble(subBoth), + ElementsAre(TSEq(2.0, 60))); + EXPECT_THAT(storage.ReadQueueDouble(subRemote), + ElementsAre(TSEq(2.0, 60))); + EXPECT_THAT(storage.ReadQueueDouble(subLocal), IsEmpty()); +} + } // namespace nt