diff --git a/src/MJPEGServerImpl.cpp b/src/MJPEGServerImpl.cpp index f40379a9cf..97c127ac1c 100644 --- a/src/MJPEGServerImpl.cpp +++ b/src/MJPEGServerImpl.cpp @@ -27,16 +27,48 @@ using namespace cs; // It separates the multipart stream of pictures #define BOUNDARY "boundarydonotcross" +class MJPEGServerImpl::ConnThread : public wpi::SafeThread { + public: + void Main(); + + bool ProcessCommand(llvm::raw_ostream& os, SourceImpl& source, + llvm::StringRef parameters, bool respond); + void SendJSON(llvm::raw_ostream& os, SourceImpl& source, bool header); + void SendStream(wpi::raw_socket_ostream& os); + void ProcessRequest(); + + std::unique_ptr m_stream; + std::shared_ptr m_source; + bool m_streaming = false; + + private: + std::shared_ptr GetSource() { + std::lock_guard lock(m_mutex); + return m_source; + } + + void StartStream() { + std::lock_guard lock(m_mutex); + m_source->EnableSink(); + m_streaming = true; + } + + void StopStream() { + std::lock_guard lock(m_mutex); + m_source->DisableSink(); + m_streaming = false; + } +}; + // Standard header to send along with other header information like mimetype. // // The parameters should ensure the browser does not cache our answer. // A browser should connect for each file and not serve files from its cache. // Using cached pictures would lead to showing old/outdated pictures. // Many browsers seem to ignore, or at least not always obey, those headers. -void MJPEGServerImpl::SendHeader(llvm::raw_ostream& os, int code, - llvm::StringRef codeText, - llvm::StringRef contentType, - llvm::StringRef extra) { +static void SendHeader(llvm::raw_ostream& os, int code, + llvm::StringRef codeText, llvm::StringRef contentType, + llvm::StringRef extra = llvm::StringRef{}) { os << "HTTP/1.0 " << code << ' ' << codeText << "\r\n"; os << "Connection: close\r\n" "Server: CameraServer/1.0\r\n" @@ -52,8 +84,8 @@ void MJPEGServerImpl::SendHeader(llvm::raw_ostream& os, int code, // Send error header and message // @param code HTTP error code (e.g. 404) // @param message Additional message text -void MJPEGServerImpl::SendError(llvm::raw_ostream& os, int code, - llvm::StringRef message) { +static void SendError(llvm::raw_ostream& os, int code, + llvm::StringRef message) { llvm::StringRef codeText, extra, baseMessage; switch (code) { case 401: @@ -90,14 +122,13 @@ void MJPEGServerImpl::SendError(llvm::raw_ostream& os, int code, // Read a line from an input stream (up to a maximum length). // The returned buffer will contain the trailing \n (unless the maximum length // was reached). -bool MJPEGServerImpl::ReadLine(wpi::raw_istream& istream, - llvm::SmallVectorImpl& buffer, - int maxLen) { +static bool ReadLine(wpi::raw_istream& is, llvm::SmallVectorImpl& buffer, + int maxLen) { buffer.clear(); for (int i = 0; i < maxLen; ++i) { char c; - istream.read(c); - if (istream.has_error()) return false; + is.read(c); + if (is.has_error()) return false; buffer.push_back(c); if (c == '\n') break; } @@ -105,8 +136,7 @@ bool MJPEGServerImpl::ReadLine(wpi::raw_istream& istream, } // Unescape a %xx-encoded URI. Returns false on error. -bool MJPEGServerImpl::UnescapeURI(llvm::StringRef str, - llvm::SmallVectorImpl& out) { +static bool UnescapeURI(llvm::StringRef str, llvm::SmallVectorImpl& out) { for (auto i = str.begin(), end = str.end(); i != end; ++i) { // pass non-escaped characters to output if (*i != '%') { @@ -133,8 +163,10 @@ bool MJPEGServerImpl::UnescapeURI(llvm::StringRef str, } // Perform a command specified by HTTP GET parameters. -bool MJPEGServerImpl::ProcessCommand(llvm::raw_ostream& os, SourceImpl& source, - llvm::StringRef parameters, bool respond) { +bool MJPEGServerImpl::ConnThread::ProcessCommand(llvm::raw_ostream& os, + SourceImpl& source, + llvm::StringRef parameters, + bool respond) { llvm::SmallString<256> responseBuf; llvm::raw_svector_ostream response{responseBuf}; // command format: param1=value1¶m2=value2... @@ -263,8 +295,8 @@ bool MJPEGServerImpl::ProcessCommand(llvm::raw_ostream& os, SourceImpl& source, } // Send a JSON file which is contains information about the source parameters. -void MJPEGServerImpl::SendJSON(llvm::raw_ostream& os, SourceImpl& source, - bool header) { +void MJPEGServerImpl::ConnThread::SendJSON(llvm::raw_ostream& os, + SourceImpl& source, bool header) { if (header) SendHeader(os, 200, "OK", "application/x-javascript"); os << "{\n\"controls\": [\n"; @@ -350,18 +382,20 @@ void MJPEGServerImpl::Stop() { if (m_serverThread.joinable()) m_serverThread.join(); // close streams - for (auto& stream : m_connStreams) stream->close(); + for (auto& connThread : m_connThreads) { + if (auto thr = connThread.GetThread()) { + if (thr->m_stream) thr->m_stream->close(); + } + connThread.Stop(); + } // wake up connection threads by forcing an empty frame to be sent if (auto source = GetSource()) source->Wakeup(); - - // join connection threads - for (auto& connThread : m_connThreads) connThread.join(); } // Send HTTP response and a stream of JPG-frames -void MJPEGServerImpl::SendStream(wpi::raw_socket_ostream& os) { +void MJPEGServerImpl::ConnThread::SendStream(wpi::raw_socket_ostream& os) { os.SetUnbuffered(); llvm::SmallString<256> header; @@ -373,7 +407,7 @@ void MJPEGServerImpl::SendStream(wpi::raw_socket_ostream& os) { DEBUG("HTTP: Headers send, sending stream now"); - Enable(); + StartStream(); while (m_active && !os.has_error()) { auto source = GetSource(); if (!source) { @@ -419,13 +453,12 @@ void MJPEGServerImpl::SendStream(wpi::raw_socket_ostream& os) { os << llvm::StringRef(data, size); // os.flush(); } - Disable(); + StopStream(); } -// thread for clients that connected to this server -void MJPEGServerImpl::ConnThreadMain(wpi::NetworkStream* stream) { - wpi::raw_socket_istream is{*stream}; - wpi::raw_socket_ostream os{*stream, true}; +void MJPEGServerImpl::ConnThread::ProcessRequest() { + wpi::raw_socket_istream is{*m_stream}; + wpi::raw_socket_ostream os{*m_stream, true}; // Read the request string from the stream llvm::SmallString<128> buf; @@ -506,6 +539,21 @@ void MJPEGServerImpl::ConnThreadMain(wpi::NetworkStream* stream) { DEBUG("leaving HTTP client thread"); } +// worker thread for clients that connected to this server +void MJPEGServerImpl::ConnThread::Main() { + std::unique_lock lock(m_mutex); + while (m_active) { + while (!m_stream) { + m_cond.wait(lock); + if (!m_active) return; + } + lock.unlock(); + ProcessRequest(); + lock.lock(); + m_stream = nullptr; + } +} + // Main server thread void MJPEGServerImpl::ServerThreadMain() { if (m_acceptor->start() != 0) { @@ -524,14 +572,50 @@ void MJPEGServerImpl::ServerThreadMain() { DEBUG("server: client connection from " << stream->getPeerIP()); - m_connThreads.emplace_back(&MJPEGServerImpl::ConnThreadMain, this, - stream.get()); - m_connStreams.emplace_back(std::move(stream)); + auto source = GetSource(); + + std::lock_guard lock(m_mutex); + // Find unoccupied worker thread, or create one if necessary + auto it = std::find_if(m_connThreads.begin(), m_connThreads.end(), + [](const wpi::SafeThreadOwner& owner) { + auto thr = owner.GetThread(); + return !thr || !thr->m_stream; + }); + if (it == m_connThreads.end()) { + m_connThreads.emplace_back(); + it = std::prev(m_connThreads.end()); + } + + // Start it if not already started + { + auto thr = it->GetThread(); + if (!thr) it->Start(); + } + + // Hand off connection to it + auto thr = it->GetThread(); + thr->m_stream = std::move(stream); + thr->m_source = source; + thr->m_cond.notify_one(); } DEBUG("leaving server thread"); } +void MJPEGServerImpl::SetSourceImpl(std::shared_ptr source) { + std::lock_guard lock(m_mutex); + for (auto& connThread : m_connThreads) { + if (auto thr = connThread.GetThread()) { + if (thr->m_source != source) { + bool streaming = thr->m_streaming; + if (streaming) thr->m_source->DisableSink(); + thr->m_source = source; + if (streaming) thr->m_source->EnableSink(); + } + } + } +} + namespace cs { CS_Sink CreateMJPEGServer(llvm::StringRef name, llvm::StringRef listenAddress, diff --git a/src/MJPEGServerImpl.h b/src/MJPEGServerImpl.h index 42193a9ad6..cf288b8899 100644 --- a/src/MJPEGServerImpl.h +++ b/src/MJPEGServerImpl.h @@ -16,6 +16,7 @@ #include "llvm/raw_ostream.h" #include "llvm/SmallVector.h" #include "llvm/StringRef.h" +#include "support/SafeThread.h" #include "support/raw_istream.h" #include "support/raw_socket_ostream.h" #include "tcpsockets/NetworkAcceptor.h" @@ -35,31 +36,18 @@ class MJPEGServerImpl : public SinkImpl { void Stop(); - static void SendHeader(llvm::raw_ostream& os, int code, - llvm::StringRef codeText, llvm::StringRef contentType, - llvm::StringRef extra = llvm::StringRef{}); - static void SendError(llvm::raw_ostream& os, int code, - llvm::StringRef message); - static bool ReadLine(wpi::raw_istream& istream, - llvm::SmallVectorImpl& buffer, int maxLen); - static bool UnescapeURI(llvm::StringRef str, - llvm::SmallVectorImpl& out); - static bool ProcessCommand(llvm::raw_ostream& os, SourceImpl& source, - llvm::StringRef parameters, bool respond); - static void SendJSON(llvm::raw_ostream& os, SourceImpl& source, bool header); - - void SendStream(wpi::raw_socket_ostream& os); - private: + void SetSourceImpl(std::shared_ptr source) override; + void ServerThreadMain(); - void ConnThreadMain(wpi::NetworkStream* stream); + + class ConnThread; std::unique_ptr m_acceptor; std::atomic_bool m_active; // set to false to terminate threads std::thread m_serverThread; - std::vector m_connThreads; - std::vector> m_connStreams; + std::vector> m_connThreads; }; } // namespace cs diff --git a/src/SinkImpl.cpp b/src/SinkImpl.cpp index fa76750247..b0c381b4fc 100644 --- a/src/SinkImpl.cpp +++ b/src/SinkImpl.cpp @@ -33,14 +33,17 @@ llvm::StringRef SinkImpl::GetDescription( } void SinkImpl::SetSource(std::shared_ptr source) { - std::lock_guard lock(m_mutex); - if (m_source) { - if (m_enabledCount > 0) m_source->DisableSink(); - m_source->RemoveSink(); + { + std::lock_guard lock(m_mutex); + if (m_source) { + if (m_enabledCount > 0) m_source->DisableSink(); + m_source->RemoveSink(); + } + m_source = source; + m_source->AddSink(); + if (m_enabledCount > 0) m_source->EnableSink(); } - m_source = source; - m_source->AddSink(); - if (m_enabledCount > 0) m_source->EnableSink(); + SetSourceImpl(source); } std::string SinkImpl::GetError() const { @@ -59,3 +62,5 @@ llvm::StringRef SinkImpl::GetError(llvm::SmallVectorImpl& buf) const { buf.append(frame.data(), frame.data() + frame.size()); return llvm::StringRef{buf.data(), buf.size()}; } + +void SinkImpl::SetSourceImpl(std::shared_ptr source) {} diff --git a/src/SinkImpl.h b/src/SinkImpl.h index 2351077e89..af155a09d4 100644 --- a/src/SinkImpl.h +++ b/src/SinkImpl.h @@ -66,6 +66,8 @@ class SinkImpl { llvm::StringRef GetError(llvm::SmallVectorImpl& buf) const; protected: + virtual void SetSourceImpl(std::shared_ptr source); + mutable std::mutex m_mutex; private: