diff --git a/include/ntcore_c.h b/include/ntcore_c.h index d232230b2f..1d2f3d2e9b 100644 --- a/include/ntcore_c.h +++ b/include/ntcore_c.h @@ -152,6 +152,7 @@ struct NT_RpcDefinition { struct NT_RpcCallInfo { unsigned int rpc_id; unsigned int call_uid; + struct NT_ConnectionInfo conn_info; struct NT_String name; struct NT_String params; }; @@ -310,7 +311,8 @@ void NT_SetRpcServerOnStart(void (*on_start)(void *data), void *data); void NT_SetRpcServerOnExit(void (*on_exit)(void *data), void *data); typedef char *(*NT_RpcCallback)(void *data, const char *name, size_t name_len, - const char *params, size_t params_len, + const char *params, size_t params_len, + const struct NT_ConnectionInfo* conn_info, size_t *results_len); void NT_CreateRpc(const char *name, size_t name_len, const char *def, diff --git a/include/ntcore_cpp.h b/include/ntcore_cpp.h index 3376c90290..6779064777 100644 --- a/include/ntcore_cpp.h +++ b/include/ntcore_cpp.h @@ -79,6 +79,7 @@ struct RpcDefinition { struct RpcCallInfo { unsigned int rpc_id; unsigned int call_uid; + ConnectionInfo conn_info; std::string name; std::string params; }; @@ -224,7 +225,8 @@ constexpr double kTimeout_Indefinite = -1; void SetRpcServerOnStart(std::function on_start); void SetRpcServerOnExit(std::function on_exit); -typedef std::function +typedef std::function RpcCallback; void CreateRpc(StringRef name, StringRef def, RpcCallback callback); diff --git a/src/RpcServer.cpp b/src/RpcServer.cpp index 810b35e75c..b46fb82e9f 100644 --- a/src/RpcServer.cpp +++ b/src/RpcServer.cpp @@ -47,15 +47,17 @@ void RpcServer::Stop() { m_owner.Stop(); } void RpcServer::ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, + const ConnectionInfo& conn_info, SendMsgFunc send_response) { if (func) { auto thr = m_owner.GetThread(); if (!thr) return; - thr->m_call_queue.emplace(name, msg, func, conn_id, send_response); + thr->m_call_queue.emplace(name, msg, func, conn_id, conn_info, + send_response); thr->m_cond.notify_one(); } else { std::lock_guard lock(m_mutex); - m_poll_queue.emplace(name, msg, func, conn_id, send_response); + m_poll_queue.emplace(name, msg, func, conn_id, conn_info, send_response); m_poll_cond.notify_one(); } } @@ -96,6 +98,7 @@ bool RpcServer::PollRpc(bool blocking, double time_out, RpcCallInfo* call_info) call_uid = item.msg->seq_num_uid(); call_info->rpc_id = item.msg->id(); call_info->call_uid = call_uid; + call_info->conn_info = item.conn_info; call_info->name = std::move(item.name); call_info->params = item.msg->str(); m_response_map.insert(std::make_pair(std::make_pair(item.msg->id(), call_uid), @@ -138,7 +141,7 @@ void RpcServer::Thread::Main() { // Don't hold mutex during callback execution! lock.unlock(); - auto result = item.func(item.name, item.msg->str()); + auto result = item.func(item.name, item.msg->str(), item.conn_info); item.send_response(Message::RpcResponse(item.msg->id(), item.msg->seq_num_uid(), result)); lock.lock(); diff --git a/src/RpcServer.h b/src/RpcServer.h index d12bb6ad8c..a1971d8473 100644 --- a/src/RpcServer.h +++ b/src/RpcServer.h @@ -41,6 +41,7 @@ class RpcServer { void ProcessRpc(StringRef name, std::shared_ptr msg, RpcCallback func, unsigned int conn_id, + const ConnectionInfo& conn_info, SendMsgFunc send_response); bool PollRpc(bool blocking, RpcCallInfo* call_info); @@ -56,17 +57,20 @@ class RpcServer { struct RpcCall { RpcCall(StringRef name_, std::shared_ptr msg_, RpcCallback func_, - unsigned int conn_id_, SendMsgFunc send_response_) + unsigned int conn_id_, const ConnectionInfo& conn_info_, + SendMsgFunc send_response_) : name(name_), msg(msg_), func(func_), conn_id(conn_id_), + conn_info(conn_info_), send_response(send_response_) {} std::string name; std::shared_ptr msg; RpcCallback func; unsigned int conn_id; + ConnectionInfo conn_info; SendMsgFunc send_response; }; diff --git a/src/Storage.cpp b/src/Storage.cpp index 46f0ff34fb..1402af7ab3 100644 --- a/src/Storage.cpp +++ b/src/Storage.cpp @@ -346,8 +346,12 @@ void Storage::ProcessIncoming(std::shared_ptr msg, DEBUG("received RPC call to non-RPC entry"); return; } + ConnectionInfo conn_info; + auto c = conn_weak.lock(); + if (c) conn_info = c->info(); m_rpc_server.ProcessRpc(entry->name, msg, entry->rpc_callback, - conn->uid(), [=](std::shared_ptr msg) { + conn->uid(), conn_info, + [=](std::shared_ptr msg) { auto c = conn_weak.lock(); if (c) c->QueueOutgoing(msg); }); @@ -1392,8 +1396,12 @@ unsigned int Storage::CallRpc(StringRef name, StringRef params) { // gracefully anyway. auto rpc_callback = entry->rpc_callback; lock.unlock(); + ConnectionInfo conn_info; + conn_info.remote_id = "Server"; + conn_info.remote_ip = "localhost"; m_rpc_server.ProcessRpc( - name, msg, rpc_callback, 0xffffU, [this](std::shared_ptr msg) { + name, msg, rpc_callback, 0xffffU, conn_info, + [this](std::shared_ptr msg) { std::lock_guard lock(m_mutex); m_rpc_results.insert(std::make_pair( std::make_pair(msg->id(), msg->seq_num_uid()), msg->str())); diff --git a/src/ntcore_c.cpp b/src/ntcore_c.cpp index d7d2cb2018..028984c99c 100644 --- a/src/ntcore_c.cpp +++ b/src/ntcore_c.cpp @@ -68,6 +68,7 @@ static void ConvertToC(const RpcDefinition& in, NT_RpcDefinition* out) { static void ConvertToC(const RpcCallInfo& in, NT_RpcCallInfo* out) { out->rpc_id = in.rpc_id; out->call_uid = in.call_uid; + ConvertToC(in.conn_info, &out->conn_info); ConvertToC(in.name, &out->name); ConvertToC(in.params, &out->params); } @@ -234,12 +235,17 @@ void NT_CreateRpc(const char *name, size_t name_len, const char *def, size_t def_len, void *data, NT_RpcCallback callback) { nt::CreateRpc( StringRef(name, name_len), StringRef(def, def_len), - [=](StringRef name, StringRef params) -> std::string { + [=](StringRef name, StringRef params, + const ConnectionInfo& conn_info) -> std::string { + NT_ConnectionInfo conn_c; + ConvertToC(conn_info, &conn_c); size_t results_len; char* results_c = callback(data, name.data(), name.size(), - params.data(), params.size(), &results_len); + params.data(), params.size(), + &conn_c, &results_len); std::string results(results_c, results_len); std::free(results_c); + DisposeConnectionInfo(&conn_c); return results; }); } @@ -521,6 +527,7 @@ void NT_DisposeRpcDefinition(NT_RpcDefinition *def) { } void NT_DisposeRpcCallInfo(NT_RpcCallInfo *call_info) { + DisposeConnectionInfo(&call_info->conn_info); NT_DisposeString(&call_info->name); NT_DisposeString(&call_info->params); } diff --git a/test/rpc_local.cpp b/test/rpc_local.cpp index 20e776ede4..bec06fe9ef 100644 --- a/test/rpc_local.cpp +++ b/test/rpc_local.cpp @@ -4,7 +4,8 @@ #include "ntcore.h" -std::string callback1(nt::StringRef name, nt::StringRef params_str) { +std::string callback1(nt::StringRef name, nt::StringRef params_str, + const nt::ConnectionInfo& conn_info) { auto params = nt::UnpackRpcValues(params_str, NT_DOUBLE); if (params.empty()) { std::fputs("empty params?\n", stderr); diff --git a/test/rpc_speed.cpp b/test/rpc_speed.cpp index 923df3d462..e8513a9dcc 100644 --- a/test/rpc_speed.cpp +++ b/test/rpc_speed.cpp @@ -5,7 +5,8 @@ #include "ntcore.h" -std::string callback1(nt::StringRef name, nt::StringRef params_str) { +std::string callback1(nt::StringRef name, nt::StringRef params_str, + const nt::ConnectionInfo& conn_info) { auto params = nt::UnpackRpcValues(params_str, NT_DOUBLE); if (params.empty()) { std::fputs("empty params?\n", stderr);