[wpinet] Add callback for mDNS service resolver (#7986)

This commit is contained in:
Thad House
2025-05-23 13:22:59 -07:00
committed by GitHub
parent 0cb4df7e05
commit 22d12d2345
4 changed files with 94 additions and 10 deletions

View File

@@ -21,9 +21,9 @@ class MulticastServiceResolver {
explicit MulticastServiceResolver(std::string_view serviceType);
~MulticastServiceResolver() noexcept;
struct ServiceData {
/// IPv4 address.
/// IPv4 address in host order.
unsigned int ipv4Address;
/// Port number.
/// Port number in host order.
int port;
/// Service name.
std::string serviceName;
@@ -32,6 +32,17 @@ class MulticastServiceResolver {
/// Service data payload.
std::vector<std::pair<std::string, std::string>> txt;
};
/**
* Set a copy callback to be called when a service is resolved.
* Takes presidence over the move callback. Return true to
* not send the data to the event queue.
*/
bool SetCopyCallback(std::function<bool(const ServiceData&)> callback);
/**
* Set a move callback to be called when a service is resolved.
* Data is moved into the function and cannot be added to the event queue.
*/
bool SetMoveCallback(std::function<void(ServiceData&&)> callback);
/**
* Starts multicast service resolver.
*/
@@ -72,12 +83,23 @@ class MulticastServiceResolver {
private:
void PushData(ServiceData&& data) {
std::scoped_lock lock{mutex};
queue.emplace_back(std::forward<ServiceData>(data));
event.Set();
if (copyCallback) {
if (!copyCallback(data)) {
queue.emplace_back(std::forward<ServiceData>(data));
event.Set();
}
} else if (moveCallback) {
moveCallback(std::move(data));
} else {
queue.emplace_back(std::forward<ServiceData>(data));
event.Set();
}
}
wpi::Event event{true};
std::vector<ServiceData> queue;
wpi::mutex mutex;
std::function<bool(const ServiceData&)> copyCallback;
std::function<void(ServiceData&&)> moveCallback;
std::unique_ptr<Impl> pImpl;
};
} // namespace wpi

View File

@@ -4,6 +4,8 @@
#include "wpinet/MulticastServiceResolver.h"
#include <arpa/inet.h>
#include <memory>
#include <string>
#include <utility>
@@ -44,6 +46,26 @@ bool MulticastServiceResolver::HasImplementation() const {
return pImpl->table.IsValid();
}
bool MulticastServiceResolver::SetCopyCallback(
std::function<bool(const ServiceData&)> callback) {
std::scoped_lock lock{*pImpl->thread};
if (pImpl->client) {
return false;
}
copyCallback = std::move(callback);
return true;
}
bool MulticastServiceResolver::SetMoveCallback(
std::function<void(ServiceData&&)> callback) {
std::scoped_lock lock{*pImpl->thread};
if (pImpl->client) {
return false;
}
moveCallback = std::move(callback);
return true;
}
static void ResolveCallback(AvahiServiceResolver* r, AvahiIfIndex interface,
AvahiProtocol protocol, AvahiResolverEvent event,
const char* name, const char* type,
@@ -83,8 +105,8 @@ static void ResolveCallback(AvahiServiceResolver* r, AvahiIfIndex interface,
outputHostName.append(".");
} while (true);
data.ipv4Address = address->data.ipv4.address;
data.port = port;
data.ipv4Address = ntohl(address->data.ipv4.address);
data.port = ntohs(port);
data.serviceName = name;
data.hostName = std::string{outputHostName};

View File

@@ -72,8 +72,8 @@ void ServiceGetAddrInfoReply(DNSServiceRef sdRef, DNSServiceFlags flags,
DnsResolveState* resolveState = static_cast<DnsResolveState*>(context);
resolveState->data.hostName = hostname;
resolveState->data.ipv4Address =
reinterpret_cast<const struct sockaddr_in*>(address)->sin_addr.s_addr;
resolveState->data.ipv4Address = ntohl(
reinterpret_cast<const struct sockaddr_in*>(address)->sin_addr.s_addr);
resolveState->pImpl->onFound(std::move(resolveState->data));
@@ -179,6 +179,24 @@ bool MulticastServiceResolver::HasImplementation() const {
return true;
}
bool MulticastServiceResolver::SetCopyCallback(
std::function<bool(const ServiceData&)> callback) {
if (pImpl->serviceRef) {
return false;
}
copyCallback = std::move(callback);
return true;
}
bool MulticastServiceResolver::SetMoveCallback(
std::function<void(ServiceData&&)> callback) {
if (pImpl->serviceRef) {
return false;
}
moveCallback = std::move(callback);
return true;
}
void MulticastServiceResolver::Start() {
if (pImpl->serviceRef) {
return;

View File

@@ -65,9 +65,31 @@ bool MulticastServiceResolver::HasImplementation() const {
return pImpl->dynamicDns.CanDnsResolve;
}
bool MulticastServiceResolver::SetCopyCallback(
std::function<bool(const ServiceData&)> callback) {
if (pImpl->serviceCancel.reserved != nullptr) {
return false;
}
copyCallback = std::move(callback);
return true;
}
bool MulticastServiceResolver::SetMoveCallback(
std::function<void(ServiceData&&)> callback) {
if (pImpl->serviceCancel.reserved != nullptr) {
return false;
}
moveCallback = std::move(callback);
return true;
}
static _Function_class_(DNS_QUERY_COMPLETION_ROUTINE) VOID WINAPI
DnsCompletion(_In_ PVOID pQueryContext,
_Inout_ PDNS_QUERY_RESULT pQueryResults) {
if (pQueryResults->QueryStatus != ERROR_SUCCESS) {
return;
}
MulticastServiceResolver::Impl* impl =
reinterpret_cast<MulticastServiceResolver::Impl*>(pQueryContext);
@@ -167,8 +189,8 @@ static _Function_class_(DNS_QUERY_COMPLETION_ROUTINE) VOID WINAPI
wpi::convertUTF16ToUTF8String(wideServiceName, storage);
data.serviceName = std::string{storage};
data.port = foundSrv->Data.Srv.wPort;
data.ipv4Address = A->Data.A.IpAddress;
data.port = ntohs(foundSrv->Data.Srv.wPort);
data.ipv4Address = ntohl(A->Data.A.IpAddress);
impl->onFound(std::move(data));
}