Fix connection notification races. (#130)

Use a mutex on the connection state and one-shot all change notifications.

Fixes #127.
This commit is contained in:
Peter Johnson
2016-10-21 19:40:56 -07:00
committed by GitHub
parent 5c1b7ecd17
commit 86c43df8d1
3 changed files with 38 additions and 19 deletions

View File

@@ -27,10 +27,10 @@ NetworkConnection::NetworkConnection(std::unique_ptr<wpi::NetworkStream> stream,
m_stream(std::move(stream)),
m_notifier(notifier),
m_handshake(handshake),
m_get_entry_type(get_entry_type) {
m_get_entry_type(get_entry_type),
m_state(kCreated) {
m_active = false;
m_proto_rev = 0x0300;
m_state = static_cast<int>(kCreated);
m_last_update = 0;
// turn off Nagle algorithm; we bundle packets for transmission
@@ -42,7 +42,7 @@ NetworkConnection::~NetworkConnection() { Stop(); }
void NetworkConnection::Start() {
if (m_active) return;
m_active = true;
m_state = static_cast<int>(kInit);
set_state(kInit);
// clear queue
while (!m_outgoing.empty()) m_outgoing.pop();
// reset shutdown flags
@@ -58,7 +58,7 @@ void NetworkConnection::Start() {
void NetworkConnection::Stop() {
DEBUG2("NetworkConnection stopping (" << this << ")");
m_state = static_cast<int>(kDead);
set_state(kDead);
m_active = false;
// closing the stream so the read thread terminates
if (m_stream) m_stream->close();
@@ -95,6 +95,29 @@ ConnectionInfo NetworkConnection::info() const {
m_last_update, m_proto_rev};
}
NetworkConnection::State NetworkConnection::state() const {
std::lock_guard<std::mutex> lock(m_state_mutex);
return m_state;
}
void NetworkConnection::set_state(State state) {
std::lock_guard<std::mutex> lock(m_state_mutex);
// Don't update state any more once we've died
if (m_state == kDead) return;
// One-shot notify state changes
if (m_state != kActive && state == kActive)
m_notifier.NotifyConnection(true, info());
if (m_state != kDead && state == kDead)
m_notifier.NotifyConnection(false, info());
m_state = state;
}
void NetworkConnection::NotifyIfActive(
ConnectionListenerCallback callback) const {
std::lock_guard<std::mutex> lock(m_state_mutex);
if (m_state == kActive) m_notifier.NotifyConnection(true, info(), callback);
}
std::string NetworkConnection::remote_id() const {
std::lock_guard<std::mutex> lock(m_remote_id_mutex);
return m_remote_id;
@@ -109,7 +132,7 @@ void NetworkConnection::ReadThreadMain() {
wpi::raw_socket_istream is(*m_stream);
WireDecoder decoder(is, m_proto_rev);
m_state = static_cast<int>(kHandshake);
set_state(kHandshake);
if (!m_handshake(*this,
[&] {
decoder.set_proto_rev(m_proto_rev);
@@ -121,13 +144,12 @@ void NetworkConnection::ReadThreadMain() {
[&](llvm::ArrayRef<std::shared_ptr<Message>> msgs) {
m_outgoing.emplace(msgs);
})) {
m_state = static_cast<int>(kDead);
set_state(kDead);
m_active = false;
goto done;
}
m_state = static_cast<int>(kActive);
m_notifier.NotifyConnection(true, info());
set_state(kActive);
while (m_active) {
if (!m_stream)
break;
@@ -147,8 +169,7 @@ void NetworkConnection::ReadThreadMain() {
m_process_incoming(std::move(msg), this);
}
DEBUG2("read thread died (" << this << ")");
if (m_state != kDead) m_notifier.NotifyConnection(false, info());
m_state = static_cast<int>(kDead);
set_state(kDead);
m_active = false;
m_outgoing.push(Outgoing()); // also kill write thread
@@ -186,8 +207,7 @@ void NetworkConnection::WriteThreadMain() {
DEBUG4("sent " << encoder.size() << " bytes");
}
DEBUG2("write thread died (" << this << ")");
if (m_state != kDead) m_notifier.NotifyConnection(false, info());
m_state = static_cast<int>(kDead);
set_state(kDead);
m_active = false;
if (m_stream) m_stream->close(); // also kill read thread