diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp index 7cd1063665..34475561f2 100644 --- a/src/Dispatcher.cpp +++ b/src/Dispatcher.cpp @@ -18,9 +18,11 @@ using namespace nt; ATOMIC_STATIC_INIT(Dispatcher) -void Dispatcher::StartServer(const char* listen_address, unsigned int port) { - DispatcherBase::StartServer(std::unique_ptr( - new TCPAcceptor(static_cast(port), listen_address))); +void Dispatcher::StartServer(StringRef persist_filename, + const char* listen_address, unsigned int port) { + DispatcherBase::StartServer(persist_filename, + std::unique_ptr(new TCPAcceptor( + static_cast(port), listen_address))); } void Dispatcher::StartClient(const char* server_name, unsigned int port) { @@ -43,15 +45,31 @@ DispatcherBase::~DispatcherBase() { Stop(); } -void DispatcherBase::StartServer(std::unique_ptr acceptor) { +void DispatcherBase::StartServer(StringRef persist_filename, + std::unique_ptr acceptor) { { std::lock_guard lock(m_user_mutex); if (m_active) return; m_active = true; } m_server = true; + m_persist_filename = persist_filename; m_server_acceptor = std::move(acceptor); + // Load persistent file. Ignore errors, but pass along warnings. + if (!persist_filename.empty()) { + bool first = true; + m_storage.LoadPersistent( + persist_filename, [&](std::size_t line, const char* msg) { + if (first) { + first = false; + WARNING("When reading initial persistent values from '" + << persist_filename << "':"); + } + WARNING(persist_filename << ":" << line << ": " << msg); + }); + } + using namespace std::placeholders; m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3), m_server); @@ -144,7 +162,12 @@ std::vector DispatcherBase::GetConnections() const { void DispatcherBase::DispatchThreadMain() { auto timeout_time = std::chrono::steady_clock::now(); + + static const auto save_delta_time = std::chrono::seconds(1); + auto next_save_time = timeout_time + save_delta_time; + int count = 0; + std::unique_lock flush_lock(m_flush_mutex); while (m_active) { // handle loop taking too long @@ -159,6 +182,15 @@ void DispatcherBase::DispatchThreadMain() { m_do_flush = false; if (!m_active) break; // in case we were woken up to terminate + // perform periodic persistent save + if (m_server && !m_persist_filename.empty() && start > next_save_time) { + next_save_time += save_delta_time; + // handle loop taking too long + if (start > next_save_time) next_save_time = start + save_delta_time; + const char* err = m_storage.SavePersistent(m_persist_filename, true); + if (err) WARNING("periodic persistent save: " << err); + } + if (++count > 10) { DEBUG("dispatch running"); count = 0; @@ -397,7 +429,9 @@ bool DispatcherBase::ServerHandshake( if (msg->Is(Message::kClientHelloDone)) break; if (!msg->Is(Message::kEntryAssign)) { // unexpected message - DEBUG("server: received message (" << msg->type() << ") other than entry assignment during initial handshake"); + DEBUG("server: received message (" + << msg->type() + << ") other than entry assignment during initial handshake"); return false; } incoming.push_back(msg); diff --git a/src/Dispatcher.h b/src/Dispatcher.h index 0aed45cf69..36d7d96852 100644 --- a/src/Dispatcher.h +++ b/src/Dispatcher.h @@ -34,7 +34,8 @@ class DispatcherBase { public: virtual ~DispatcherBase(); - void StartServer(std::unique_ptr acceptor); + void StartServer(StringRef persist_filename, + std::unique_ptr acceptor); void StartClient(std::function()> connect); void Stop(); void SetUpdateRate(double interval); @@ -73,6 +74,7 @@ class DispatcherBase { Storage& m_storage; Notifier& m_notifier; bool m_server; + std::string m_persist_filename; std::thread m_dispatch_thread; std::thread m_clientserver_thread; std::thread m_notifier_thread; @@ -107,7 +109,8 @@ class Dispatcher : public DispatcherBase { return instance; } - void StartServer(const char* listen_address, unsigned int port); + void StartServer(StringRef persist_filename, const char* listen_address, + unsigned int port); void StartClient(const char* server_name, unsigned int port); private: diff --git a/src/Storage.cpp b/src/Storage.cpp index fd0d4ae7ee..c2e983b9bf 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -85,6 +85,9 @@ void Storage::ProcessIncoming(std::shared_ptr msg, entry->id = id; m_idmap.push_back(entry); + // update persistent dirty flag if it's persistent + if (entry->IsPersistent()) m_persistent_dirty = true; + // notify m_notifier.NotifyEntry(name, entry->value, true); @@ -160,13 +163,22 @@ void Storage::ProcessIncoming(std::shared_ptr msg, return; } + // don't update flags from a <3.0 remote (not part of message) + if (conn->proto_rev() >= 0x0300) { + // update persistent dirty flag if persistent flag changed + if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT)) + m_persistent_dirty = true; + entry->flags = msg->flags(); + } + + // update persistent dirty flag if the value changed and it's persistent + if (entry->IsPersistent() && *entry->value != *msg->value()) + m_persistent_dirty = true; + // update local entry->value = msg->value(); entry->seq_num = seq_num; - // don't update flags from a <3.0 remote (not part of message) - if (conn->proto_rev() >= 0x0300) entry->flags = msg->flags(); - // notify m_notifier.NotifyEntry(name, entry->value, false); @@ -201,6 +213,9 @@ void Storage::ProcessIncoming(std::shared_ptr msg, entry->value = msg->value(); entry->seq_num = seq_num; + // update persistent dirty flag if it's a persistent value + if (entry->IsPersistent()) m_persistent_dirty = true; + // notify m_notifier.NotifyEntry(entry->name, entry->value, false); @@ -224,6 +239,10 @@ void Storage::ProcessIncoming(std::shared_ptr msg, } Entry* entry = m_idmap[id]; + // update persistent dirty flag if persistent flag changed + if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT)) + m_persistent_dirty = true; + // update local entry->flags = msg->flags(); @@ -247,6 +266,9 @@ void Storage::ProcessIncoming(std::shared_ptr msg, } Entry* entry = m_idmap[id]; + // update persistent dirty flag if it's a persistent value + if (entry->IsPersistent()) m_persistent_dirty = true; + // update local m_entries.erase(entry->name); // erase from map m_idmap[id] = nullptr; // delete it from idmap too @@ -265,6 +287,9 @@ void Storage::ProcessIncoming(std::shared_ptr msg, m_entries.clear(); m_idmap.resize(0); + // set persistent dirty flag + m_persistent_dirty = true; + // broadcast to all other connections (note for client there won't // be any other connections, so don't bother) if (m_server && m_queue_outgoing) { @@ -419,6 +444,9 @@ bool Storage::SetEntryValue(StringRef name, std::shared_ptr value) { m_idmap.push_back(entry); } + // update persistent dirty flag if value changed and it's persistent + if (entry->IsPersistent() && *old_value != *value) m_persistent_dirty = true; + // generate message if (!m_queue_outgoing) return true; auto queue_outgoing = m_queue_outgoing; @@ -458,6 +486,9 @@ void Storage::SetEntryTypeValue(StringRef name, std::shared_ptr value) { m_idmap.push_back(entry); } + // update persistent dirty flag if it's a persistent value + if (entry->IsPersistent()) m_persistent_dirty = true; + // generate message if (!m_queue_outgoing) return; auto queue_outgoing = m_queue_outgoing; @@ -486,6 +517,11 @@ void Storage::SetEntryFlags(StringRef name, unsigned int flags) { if (i == m_entries.end()) return; Entry* entry = i->getValue().get(); if (entry->flags == flags) return; + + // update persistent dirty flag if persistent flag changed + if ((entry->flags & NT_PERSISTENT) != (flags & NT_PERSISTENT)) + m_persistent_dirty = true; + entry->flags = flags; // generate message @@ -512,6 +548,10 @@ void Storage::DeleteEntry(StringRef name) { Entry* entry = i->getValue().get(); unsigned int id = entry->id; bool had_value = entry->value != nullptr; + + // update persistent dirty flag if it's a persistent value + if (entry->IsPersistent()) m_persistent_dirty = true; + m_entries.erase(i); // erase from map if (id < m_idmap.size()) m_idmap[id] = nullptr; @@ -533,6 +573,9 @@ void Storage::DeleteAllEntries() { m_entries.clear(); m_idmap.resize(0); + // set persistent dirty flag + m_persistent_dirty = true; + // generate message if (!m_queue_outgoing) return; auto queue_outgoing = m_queue_outgoing; @@ -600,25 +643,37 @@ static void WriteString(std::ostream& os, llvm::StringRef str) { os << '"'; } -void Storage::SavePersistent(std::ostream& os) const { +bool Storage::GetPersistentEntries( + bool periodic, + std::vector>>* entries) + const { // copy values out of storage as quickly as possible so lock isn't held - typedef std::pair> NewEntry; - std::vector entries; { std::lock_guard lock(m_mutex); - entries.reserve(m_entries.size()); + // for periodic, don't re-save unless something has changed + if (periodic && !m_persistent_dirty) return false; + m_persistent_dirty = false; + entries->reserve(m_entries.size()); for (auto& i : m_entries) { Entry* entry = i.getValue().get(); // only write persistent-flagged values if (!entry->IsPersistent()) continue; - entries.push_back(std::make_pair(i.getKey(), entry->value)); + entries->emplace_back(i.getKey(), entry->value); } } // sort in name order - std::sort(entries.begin(), entries.end(), - [](const NewEntry& a, const NewEntry& b) { return a.first < b.first; }); + std::sort(entries->begin(), entries->end(), + [](const std::pair>& a, + const std::pair>& b) { + return a.first < b.first; + }); + return true; +} +static void SavePersistentImpl( + std::ostream& os, + llvm::ArrayRef>> entries) { std::string base64_encoded; // header @@ -711,6 +766,56 @@ void Storage::SavePersistent(std::ostream& os) const { } } +void Storage::SavePersistent(std::ostream& os, bool periodic) const { + std::vector>> entries; + if (!GetPersistentEntries(periodic, &entries)) return; + SavePersistentImpl(os, entries); +} + +const char* Storage::SavePersistent(StringRef filename, bool periodic) const { + std::string fn = filename; + std::string tmp = filename; + tmp += ".tmp"; + std::string bak = filename; + bak += ".bak"; + + // Get entries before creating file + std::vector>> entries; + if (!GetPersistentEntries(periodic, &entries)) return nullptr; + + const char* err = nullptr; + + // start by writing to temporary file + std::ofstream os(tmp); + if (!os) { + err = "could not open file"; + goto done; + } + DEBUG("saving persistent file '" << filename << "'"); + SavePersistentImpl(os, entries); + os.flush(); + if (!os) { + os.close(); + std::remove(tmp.c_str()); + err = "error saving file"; + goto done; + } + + // Safely move to real file. We ignore any failures related to the backup. + std::remove(bak.c_str()); + std::rename(fn.c_str(), bak.c_str()); + if (std::rename(tmp.c_str(), fn.c_str()) != 0) { + std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup + err = "could not rename temp file to real file"; + goto done; + } + +done: + // try again if there was an error + if (err && periodic) m_persistent_dirty = true; + return err; +} + /* Extracts an escaped string token. Does not unescape the string. * If a string cannot be matched, an empty string is returned. * If the string is unterminated, an empty tail string is returned. @@ -1038,6 +1143,15 @@ next_line: return true; } +const char* Storage::LoadPersistent( + StringRef filename, + std::function warn) { + std::ifstream is(filename); + if (!is) return "could not open file"; + if (!LoadPersistent(is, warn)) return "error reading file"; + return nullptr; +} + void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) { if (name.empty() || def.empty() || !callback) return; std::unique_lock lock(m_mutex); diff --git a/src/Storage.h b/src/Storage.h index fdf84598f4..9cbf8b63db 100644 --- a/src/Storage.h +++ b/src/Storage.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -39,13 +40,19 @@ class Storage { } ~Storage(); - // Accessors required by Dispatcher. + // Accessors required by Dispatcher. A function pointer is used for + // generation of outgoing messages to break a dependency loop between + // Storage and Dispatcher; in operation this is always set to + // Dispatcher::QueueOutgoing. typedef std::function msg, NetworkConnection* only, NetworkConnection* except)> QueueOutgoingFunc; void SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server); void ClearOutgoing(); + // Required for wire protocol 2.0 to get the entry type of an entry when + // receiving entry updates (because the length/type is not provided in the + // message itself). Not used in wire protocol 3.0. NT_Type GetEntryType(unsigned int id) const; void ProcessIncoming(std::shared_ptr msg, NetworkConnection* conn, @@ -57,9 +64,8 @@ class Storage { bool new_server, std::vector>* out_msgs); - std::mutex& mutex() { return m_mutex; } - - // User functions + // User functions. These are the actual implementations of the corresponding + // user API functions in ntcore_cpp. std::shared_ptr GetEntryValue(StringRef name) const; bool SetEntryValue(StringRef name, std::shared_ptr value); void SetEntryTypeValue(StringRef name, std::shared_ptr value); @@ -70,7 +76,16 @@ class Storage { std::vector GetEntryInfo(StringRef prefix, unsigned int types); void NotifyEntries(StringRef prefix); - void SavePersistent(std::ostream& os) const; + // Filename-based save/load functions. Used both by periodic saves and + // accessible directly via the user API. + const char* SavePersistent(StringRef filename, bool periodic) const; + const char* LoadPersistent( + StringRef filename, + std::function warn); + + // Stream-based save/load functions (exposed for testing purposes). These + // implement the guts of the filename-based functions. + void SavePersistent(std::ostream& os, bool periodic) const; bool LoadPersistent( std::istream& is, std::function warn); @@ -89,17 +104,34 @@ class Storage { Storage(const Storage&) = delete; Storage& operator=(const Storage&) = delete; + // Data for each table entry. struct Entry { Entry(llvm::StringRef name_) : name(name_), flags(0), id(0xffff), rpc_call_uid(0) {} bool IsPersistent() const { return (flags & NT_PERSISTENT) != 0; } + // We redundantly store the name so that it's available when accessing the + // raw Entry* via the ID map. std::string name; + + // The current value and flags. std::shared_ptr value; unsigned int flags; + + // Unique ID for this entry as used in network messages. The value is + // assigned by the server, so on the client this is 0xffff until an + // entry assignment is received back from the server. unsigned int id; + + // Sequence number for update resolution. SequenceNumber seq_num; + + // RPC callback function. Null if either not an RPC or if the RPC is + // polled. RpcCallback rpc_callback; + + // Last UID used when calling this RPC (primarily for client use). This + // is incremented for each call. unsigned int rpc_call_uid; }; @@ -112,14 +144,26 @@ class Storage { EntriesMap m_entries; IdMap m_idmap; RpcResultMap m_rpc_results; + // If any persistent values have changed + mutable bool m_persistent_dirty = false; + + // condition variable and termination flag for blocking on a RPC result std::atomic_bool m_terminating; std::condition_variable m_rpc_results_cond; + // configured by dispatcher at startup QueueOutgoingFunc m_queue_outgoing; bool m_server = true; + + // references to singletons (we don't grab them directly for testing purposes) Notifier& m_notifier; RpcServer& m_rpc_server; + bool GetPersistentEntries( + bool periodic, + std::vector>>* entries) + const; + ATOMIC_STATIC_DECL(Storage) }; diff --git a/src/ntcore_cpp.cpp b/src/ntcore_cpp.cpp index a99cf992cc..4bd25472dc 100644 --- a/src/ntcore_cpp.cpp +++ b/src/ntcore_cpp.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include "Dispatcher.h" #include "Log.h" @@ -211,8 +210,7 @@ void SetNetworkIdentity(StringRef name) { void StartServer(StringRef persist_filename, const char *listen_address, unsigned int port) { - Dispatcher& dispatcher = Dispatcher::GetInstance(); - dispatcher.StartServer(listen_address, port); + Dispatcher::GetInstance().StartServer(persist_filename, listen_address, port); } void StopServer() { @@ -240,44 +238,13 @@ std::vector GetConnections() { */ const char* SavePersistent(StringRef filename) { - const Storage& storage = Storage::GetInstance(); - - std::string fn = filename; - std::string tmp = filename; - tmp += ".tmp"; - std::string bak = filename; - bak += ".bak"; - - // start by writing to temporary file - std::ofstream os(tmp); - if (!os) return "could not open file"; - storage.SavePersistent(os); - os.flush(); - if (!os) { - os.close(); - std::remove(tmp.c_str()); - return "error saving file"; - } - - // safely move to real file - std::remove(bak.c_str()); - if (std::rename(fn.c_str(), bak.c_str()) != 0) - return "could not rename real file to backup"; - if (std::rename(tmp.c_str(), fn.c_str()) != 0) { - std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup - return "could not rename temp file to real file"; - } - return nullptr; + return Storage::GetInstance().SavePersistent(filename, false); } const char* LoadPersistent( StringRef filename, std::function warn) { - Storage& storage = Storage::GetInstance(); - std::ifstream is(filename); - if (!is) return "could not open file"; - if (!storage.LoadPersistent(is, warn)) return "error reading file"; - return nullptr; + return Storage::GetInstance().LoadPersistent(filename, warn); } void SetLogger(LogFunc func, unsigned int min_level) {