Implement automatic persistent saves.

Also loads persistent file on server start.
This commit is contained in:
Peter Johnson
2015-08-19 19:09:25 -07:00
parent a5ccafd924
commit ca9ce0f3a3
5 changed files with 220 additions and 58 deletions

View File

@@ -18,9 +18,11 @@ using namespace nt;
ATOMIC_STATIC_INIT(Dispatcher)
void Dispatcher::StartServer(const char* listen_address, unsigned int port) {
DispatcherBase::StartServer(std::unique_ptr<NetworkAcceptor>(
new TCPAcceptor(static_cast<int>(port), listen_address)));
void Dispatcher::StartServer(StringRef persist_filename,
const char* listen_address, unsigned int port) {
DispatcherBase::StartServer(persist_filename,
std::unique_ptr<NetworkAcceptor>(new TCPAcceptor(
static_cast<int>(port), listen_address)));
}
void Dispatcher::StartClient(const char* server_name, unsigned int port) {
@@ -43,15 +45,31 @@ DispatcherBase::~DispatcherBase() {
Stop();
}
void DispatcherBase::StartServer(std::unique_ptr<NetworkAcceptor> acceptor) {
void DispatcherBase::StartServer(StringRef persist_filename,
std::unique_ptr<NetworkAcceptor> acceptor) {
{
std::lock_guard<std::mutex> lock(m_user_mutex);
if (m_active) return;
m_active = true;
}
m_server = true;
m_persist_filename = persist_filename;
m_server_acceptor = std::move(acceptor);
// Load persistent file. Ignore errors, but pass along warnings.
if (!persist_filename.empty()) {
bool first = true;
m_storage.LoadPersistent(
persist_filename, [&](std::size_t line, const char* msg) {
if (first) {
first = false;
WARNING("When reading initial persistent values from '"
<< persist_filename << "':");
}
WARNING(persist_filename << ":" << line << ": " << msg);
});
}
using namespace std::placeholders;
m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
m_server);
@@ -144,7 +162,12 @@ std::vector<ConnectionInfo> DispatcherBase::GetConnections() const {
void DispatcherBase::DispatchThreadMain() {
auto timeout_time = std::chrono::steady_clock::now();
static const auto save_delta_time = std::chrono::seconds(1);
auto next_save_time = timeout_time + save_delta_time;
int count = 0;
std::unique_lock<std::mutex> flush_lock(m_flush_mutex);
while (m_active) {
// handle loop taking too long
@@ -159,6 +182,15 @@ void DispatcherBase::DispatchThreadMain() {
m_do_flush = false;
if (!m_active) break; // in case we were woken up to terminate
// perform periodic persistent save
if (m_server && !m_persist_filename.empty() && start > next_save_time) {
next_save_time += save_delta_time;
// handle loop taking too long
if (start > next_save_time) next_save_time = start + save_delta_time;
const char* err = m_storage.SavePersistent(m_persist_filename, true);
if (err) WARNING("periodic persistent save: " << err);
}
if (++count > 10) {
DEBUG("dispatch running");
count = 0;
@@ -397,7 +429,9 @@ bool DispatcherBase::ServerHandshake(
if (msg->Is(Message::kClientHelloDone)) break;
if (!msg->Is(Message::kEntryAssign)) {
// unexpected message
DEBUG("server: received message (" << msg->type() << ") other than entry assignment during initial handshake");
DEBUG("server: received message ("
<< msg->type()
<< ") other than entry assignment during initial handshake");
return false;
}
incoming.push_back(msg);

View File

@@ -34,7 +34,8 @@ class DispatcherBase {
public:
virtual ~DispatcherBase();
void StartServer(std::unique_ptr<NetworkAcceptor> acceptor);
void StartServer(StringRef persist_filename,
std::unique_ptr<NetworkAcceptor> acceptor);
void StartClient(std::function<std::unique_ptr<NetworkStream>()> connect);
void Stop();
void SetUpdateRate(double interval);
@@ -73,6 +74,7 @@ class DispatcherBase {
Storage& m_storage;
Notifier& m_notifier;
bool m_server;
std::string m_persist_filename;
std::thread m_dispatch_thread;
std::thread m_clientserver_thread;
std::thread m_notifier_thread;
@@ -107,7 +109,8 @@ class Dispatcher : public DispatcherBase {
return instance;
}
void StartServer(const char* listen_address, unsigned int port);
void StartServer(StringRef persist_filename, const char* listen_address,
unsigned int port);
void StartClient(const char* server_name, unsigned int port);
private:

View File

@@ -85,6 +85,9 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
entry->id = id;
m_idmap.push_back(entry);
// update persistent dirty flag if it's persistent
if (entry->IsPersistent()) m_persistent_dirty = true;
// notify
m_notifier.NotifyEntry(name, entry->value, true);
@@ -160,13 +163,22 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
return;
}
// don't update flags from a <3.0 remote (not part of message)
if (conn->proto_rev() >= 0x0300) {
// update persistent dirty flag if persistent flag changed
if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
m_persistent_dirty = true;
entry->flags = msg->flags();
}
// update persistent dirty flag if the value changed and it's persistent
if (entry->IsPersistent() && *entry->value != *msg->value())
m_persistent_dirty = true;
// update local
entry->value = msg->value();
entry->seq_num = seq_num;
// don't update flags from a <3.0 remote (not part of message)
if (conn->proto_rev() >= 0x0300) entry->flags = msg->flags();
// notify
m_notifier.NotifyEntry(name, entry->value, false);
@@ -201,6 +213,9 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
entry->value = msg->value();
entry->seq_num = seq_num;
// update persistent dirty flag if it's a persistent value
if (entry->IsPersistent()) m_persistent_dirty = true;
// notify
m_notifier.NotifyEntry(entry->name, entry->value, false);
@@ -224,6 +239,10 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
}
Entry* entry = m_idmap[id];
// update persistent dirty flag if persistent flag changed
if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
m_persistent_dirty = true;
// update local
entry->flags = msg->flags();
@@ -247,6 +266,9 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
}
Entry* entry = m_idmap[id];
// update persistent dirty flag if it's a persistent value
if (entry->IsPersistent()) m_persistent_dirty = true;
// update local
m_entries.erase(entry->name); // erase from map
m_idmap[id] = nullptr; // delete it from idmap too
@@ -265,6 +287,9 @@ void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
m_entries.clear();
m_idmap.resize(0);
// set persistent dirty flag
m_persistent_dirty = true;
// broadcast to all other connections (note for client there won't
// be any other connections, so don't bother)
if (m_server && m_queue_outgoing) {
@@ -419,6 +444,9 @@ bool Storage::SetEntryValue(StringRef name, std::shared_ptr<Value> value) {
m_idmap.push_back(entry);
}
// update persistent dirty flag if value changed and it's persistent
if (entry->IsPersistent() && *old_value != *value) m_persistent_dirty = true;
// generate message
if (!m_queue_outgoing) return true;
auto queue_outgoing = m_queue_outgoing;
@@ -458,6 +486,9 @@ void Storage::SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value) {
m_idmap.push_back(entry);
}
// update persistent dirty flag if it's a persistent value
if (entry->IsPersistent()) m_persistent_dirty = true;
// generate message
if (!m_queue_outgoing) return;
auto queue_outgoing = m_queue_outgoing;
@@ -486,6 +517,11 @@ void Storage::SetEntryFlags(StringRef name, unsigned int flags) {
if (i == m_entries.end()) return;
Entry* entry = i->getValue().get();
if (entry->flags == flags) return;
// update persistent dirty flag if persistent flag changed
if ((entry->flags & NT_PERSISTENT) != (flags & NT_PERSISTENT))
m_persistent_dirty = true;
entry->flags = flags;
// generate message
@@ -512,6 +548,10 @@ void Storage::DeleteEntry(StringRef name) {
Entry* entry = i->getValue().get();
unsigned int id = entry->id;
bool had_value = entry->value != nullptr;
// update persistent dirty flag if it's a persistent value
if (entry->IsPersistent()) m_persistent_dirty = true;
m_entries.erase(i); // erase from map
if (id < m_idmap.size()) m_idmap[id] = nullptr;
@@ -533,6 +573,9 @@ void Storage::DeleteAllEntries() {
m_entries.clear();
m_idmap.resize(0);
// set persistent dirty flag
m_persistent_dirty = true;
// generate message
if (!m_queue_outgoing) return;
auto queue_outgoing = m_queue_outgoing;
@@ -600,25 +643,37 @@ static void WriteString(std::ostream& os, llvm::StringRef str) {
os << '"';
}
void Storage::SavePersistent(std::ostream& os) const {
bool Storage::GetPersistentEntries(
bool periodic,
std::vector<std::pair<std::string, std::shared_ptr<Value>>>* entries)
const {
// copy values out of storage as quickly as possible so lock isn't held
typedef std::pair<std::string, std::shared_ptr<Value>> NewEntry;
std::vector<NewEntry> entries;
{
std::lock_guard<std::mutex> lock(m_mutex);
entries.reserve(m_entries.size());
// for periodic, don't re-save unless something has changed
if (periodic && !m_persistent_dirty) return false;
m_persistent_dirty = false;
entries->reserve(m_entries.size());
for (auto& i : m_entries) {
Entry* entry = i.getValue().get();
// only write persistent-flagged values
if (!entry->IsPersistent()) continue;
entries.push_back(std::make_pair(i.getKey(), entry->value));
entries->emplace_back(i.getKey(), entry->value);
}
}
// sort in name order
std::sort(entries.begin(), entries.end(),
[](const NewEntry& a, const NewEntry& b) { return a.first < b.first; });
std::sort(entries->begin(), entries->end(),
[](const std::pair<std::string, std::shared_ptr<Value>>& a,
const std::pair<std::string, std::shared_ptr<Value>>& b) {
return a.first < b.first;
});
return true;
}
static void SavePersistentImpl(
std::ostream& os,
llvm::ArrayRef<std::pair<std::string, std::shared_ptr<Value>>> entries) {
std::string base64_encoded;
// header
@@ -711,6 +766,56 @@ void Storage::SavePersistent(std::ostream& os) const {
}
}
void Storage::SavePersistent(std::ostream& os, bool periodic) const {
std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
if (!GetPersistentEntries(periodic, &entries)) return;
SavePersistentImpl(os, entries);
}
const char* Storage::SavePersistent(StringRef filename, bool periodic) const {
std::string fn = filename;
std::string tmp = filename;
tmp += ".tmp";
std::string bak = filename;
bak += ".bak";
// Get entries before creating file
std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
if (!GetPersistentEntries(periodic, &entries)) return nullptr;
const char* err = nullptr;
// start by writing to temporary file
std::ofstream os(tmp);
if (!os) {
err = "could not open file";
goto done;
}
DEBUG("saving persistent file '" << filename << "'");
SavePersistentImpl(os, entries);
os.flush();
if (!os) {
os.close();
std::remove(tmp.c_str());
err = "error saving file";
goto done;
}
// Safely move to real file. We ignore any failures related to the backup.
std::remove(bak.c_str());
std::rename(fn.c_str(), bak.c_str());
if (std::rename(tmp.c_str(), fn.c_str()) != 0) {
std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup
err = "could not rename temp file to real file";
goto done;
}
done:
// try again if there was an error
if (err && periodic) m_persistent_dirty = true;
return err;
}
/* Extracts an escaped string token. Does not unescape the string.
* If a string cannot be matched, an empty string is returned.
* If the string is unterminated, an empty tail string is returned.
@@ -1038,6 +1143,15 @@ next_line:
return true;
}
const char* Storage::LoadPersistent(
StringRef filename,
std::function<void(std::size_t line, const char* msg)> warn) {
std::ifstream is(filename);
if (!is) return "could not open file";
if (!LoadPersistent(is, warn)) return "error reading file";
return nullptr;
}
void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) {
if (name.empty() || def.empty() || !callback) return;
std::unique_lock<std::mutex> lock(m_mutex);

View File

@@ -10,6 +10,7 @@
#include <atomic>
#include <cstddef>
#include <fstream>
#include <functional>
#include <iosfwd>
#include <memory>
@@ -39,13 +40,19 @@ class Storage {
}
~Storage();
// Accessors required by Dispatcher.
// Accessors required by Dispatcher. A function pointer is used for
// generation of outgoing messages to break a dependency loop between
// Storage and Dispatcher; in operation this is always set to
// Dispatcher::QueueOutgoing.
typedef std::function<void(std::shared_ptr<Message> msg,
NetworkConnection* only,
NetworkConnection* except)> QueueOutgoingFunc;
void SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server);
void ClearOutgoing();
// Required for wire protocol 2.0 to get the entry type of an entry when
// receiving entry updates (because the length/type is not provided in the
// message itself). Not used in wire protocol 3.0.
NT_Type GetEntryType(unsigned int id) const;
void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn,
@@ -57,9 +64,8 @@ class Storage {
bool new_server,
std::vector<std::shared_ptr<Message>>* out_msgs);
std::mutex& mutex() { return m_mutex; }
// User functions
// User functions. These are the actual implementations of the corresponding
// user API functions in ntcore_cpp.
std::shared_ptr<Value> GetEntryValue(StringRef name) const;
bool SetEntryValue(StringRef name, std::shared_ptr<Value> value);
void SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value);
@@ -70,7 +76,16 @@ class Storage {
std::vector<EntryInfo> GetEntryInfo(StringRef prefix, unsigned int types);
void NotifyEntries(StringRef prefix);
void SavePersistent(std::ostream& os) const;
// Filename-based save/load functions. Used both by periodic saves and
// accessible directly via the user API.
const char* SavePersistent(StringRef filename, bool periodic) const;
const char* LoadPersistent(
StringRef filename,
std::function<void(std::size_t line, const char* msg)> warn);
// Stream-based save/load functions (exposed for testing purposes). These
// implement the guts of the filename-based functions.
void SavePersistent(std::ostream& os, bool periodic) const;
bool LoadPersistent(
std::istream& is,
std::function<void(std::size_t line, const char* msg)> warn);
@@ -89,17 +104,34 @@ class Storage {
Storage(const Storage&) = delete;
Storage& operator=(const Storage&) = delete;
// Data for each table entry.
struct Entry {
Entry(llvm::StringRef name_)
: name(name_), flags(0), id(0xffff), rpc_call_uid(0) {}
bool IsPersistent() const { return (flags & NT_PERSISTENT) != 0; }
// We redundantly store the name so that it's available when accessing the
// raw Entry* via the ID map.
std::string name;
// The current value and flags.
std::shared_ptr<Value> value;
unsigned int flags;
// Unique ID for this entry as used in network messages. The value is
// assigned by the server, so on the client this is 0xffff until an
// entry assignment is received back from the server.
unsigned int id;
// Sequence number for update resolution.
SequenceNumber seq_num;
// RPC callback function. Null if either not an RPC or if the RPC is
// polled.
RpcCallback rpc_callback;
// Last UID used when calling this RPC (primarily for client use). This
// is incremented for each call.
unsigned int rpc_call_uid;
};
@@ -112,14 +144,26 @@ class Storage {
EntriesMap m_entries;
IdMap m_idmap;
RpcResultMap m_rpc_results;
// If any persistent values have changed
mutable bool m_persistent_dirty = false;
// condition variable and termination flag for blocking on a RPC result
std::atomic_bool m_terminating;
std::condition_variable m_rpc_results_cond;
// configured by dispatcher at startup
QueueOutgoingFunc m_queue_outgoing;
bool m_server = true;
// references to singletons (we don't grab them directly for testing purposes)
Notifier& m_notifier;
RpcServer& m_rpc_server;
bool GetPersistentEntries(
bool periodic,
std::vector<std::pair<std::string, std::shared_ptr<Value>>>* entries)
const;
ATOMIC_STATIC_DECL(Storage)
};

View File

@@ -10,7 +10,6 @@
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include "Dispatcher.h"
#include "Log.h"
@@ -211,8 +210,7 @@ void SetNetworkIdentity(StringRef name) {
void StartServer(StringRef persist_filename, const char *listen_address,
unsigned int port) {
Dispatcher& dispatcher = Dispatcher::GetInstance();
dispatcher.StartServer(listen_address, port);
Dispatcher::GetInstance().StartServer(persist_filename, listen_address, port);
}
void StopServer() {
@@ -240,44 +238,13 @@ std::vector<ConnectionInfo> GetConnections() {
*/
const char* SavePersistent(StringRef filename) {
const Storage& storage = Storage::GetInstance();
std::string fn = filename;
std::string tmp = filename;
tmp += ".tmp";
std::string bak = filename;
bak += ".bak";
// start by writing to temporary file
std::ofstream os(tmp);
if (!os) return "could not open file";
storage.SavePersistent(os);
os.flush();
if (!os) {
os.close();
std::remove(tmp.c_str());
return "error saving file";
}
// safely move to real file
std::remove(bak.c_str());
if (std::rename(fn.c_str(), bak.c_str()) != 0)
return "could not rename real file to backup";
if (std::rename(tmp.c_str(), fn.c_str()) != 0) {
std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup
return "could not rename temp file to real file";
}
return nullptr;
return Storage::GetInstance().SavePersistent(filename, false);
}
const char* LoadPersistent(
StringRef filename,
std::function<void(size_t line, const char* msg)> warn) {
Storage& storage = Storage::GetInstance();
std::ifstream is(filename);
if (!is) return "could not open file";
if (!storage.LoadPersistent(is, warn)) return "error reading file";
return nullptr;
return Storage::GetInstance().LoadPersistent(filename, warn);
}
void SetLogger(LogFunc func, unsigned int min_level) {