diff --git a/src/agent/CMakeLists.txt b/src/agent/CMakeLists.txt index 6255ffa..06fb269 100644 --- a/src/agent/CMakeLists.txt +++ b/src/agent/CMakeLists.txt @@ -6,6 +6,8 @@ ADD_EXECUTABLE( agent SmartReader.h HttpServer.h HttpServer.cpp + MdnsNetwork.h + MdnsNetwork.cpp MdnsPublisher.h MdnsPublisher.cpp OutputParsersUtils.h diff --git a/src/agent/MdnsNetwork.cpp b/src/agent/MdnsNetwork.cpp new file mode 100644 index 0000000..a406013 --- /dev/null +++ b/src/agent/MdnsNetwork.cpp @@ -0,0 +1,218 @@ +#include "MdnsNetwork.h" + +#ifdef _WIN32 +#define _CRT_SECURE_NO_WARNINGS 1 +#include +#else +#include +#include +#include +#include +#endif + +#include + +#include +#include + + +namespace +{ + +constexpr auto HostnameFallback = "rdhm-agent"; + +#ifdef _WIN32 +void collectLocalAddresses(MdnsNetworkAddresses& result) +{ + IP_ADAPTER_ADDRESSES* adapterAddress = nullptr; + ULONG addressSize = 8000; + ULONG ret = NO_ERROR; + unsigned int retries = 4; + + do { + adapterAddress = static_cast(malloc(addressSize)); + ret = GetAdaptersAddresses(AF_UNSPEC, GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_ANYCAST, + nullptr, adapterAddress, &addressSize); + + if (ret == ERROR_BUFFER_OVERFLOW) { + free(adapterAddress); + adapterAddress = nullptr; + addressSize *= 2; + } else { + break; + } + } while (retries-- > 0); + + if (!adapterAddress || ret != NO_ERROR) { + free(adapterAddress); + return; + } + + for (PIP_ADAPTER_ADDRESSES adapter = adapterAddress; adapter; adapter = adapter->Next) { + if (adapter->TunnelType == TUNNEL_TYPE_TEREDO || adapter->OperStatus != IfOperStatusUp) + continue; + + for (IP_ADAPTER_UNICAST_ADDRESS* unicast = adapter->FirstUnicastAddress; unicast; + unicast = unicast->Next) { + const auto* socketAddress = unicast->Address.lpSockaddr; + + if (socketAddress->sa_family == AF_INET && !result.hasIpv4) { + result.ipv4 = *reinterpret_cast(socketAddress); + result.hasIpv4 = true; + } else if (socketAddress->sa_family == AF_INET6 && !result.hasIpv6) { + const auto* ipv6 = reinterpret_cast(socketAddress); + if (ipv6->sin6_scope_id == 0) { + result.ipv6 = *ipv6; + result.hasIpv6 = true; + } + } + } + } + + free(adapterAddress); +} +#else +bool isUsableInterface(const ifaddrs& interface) +{ + return interface.ifa_addr && + (interface.ifa_flags & IFF_UP) && + (interface.ifa_flags & IFF_MULTICAST) && + !(interface.ifa_flags & IFF_LOOPBACK) && + !(interface.ifa_flags & IFF_POINTOPOINT); +} + +bool isLoopbackAddress(const sockaddr_in& address) +{ + return address.sin_addr.s_addr == htonl(INADDR_LOOPBACK); +} + +bool isLoopbackAddress(const sockaddr_in6& address) +{ + static const unsigned char localhost[] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1}; + return memcmp(address.sin6_addr.s6_addr, localhost, sizeof(localhost)) == 0; +} + +void collectLocalAddresses(MdnsNetworkAddresses& result) +{ + ifaddrs* interfaces = nullptr; + if (getifaddrs(&interfaces) < 0) + return; + + for (ifaddrs* interface = interfaces; interface; interface = interface->ifa_next) { + if (!isUsableInterface(*interface)) + continue; + + if (interface->ifa_addr->sa_family == AF_INET && !result.hasIpv4) { + const auto* ipv4 = reinterpret_cast(interface->ifa_addr); + if (!isLoopbackAddress(*ipv4)) { + result.ipv4 = *ipv4; + result.hasIpv4 = true; + } + } else if (interface->ifa_addr->sa_family == AF_INET6 && !result.hasIpv6) { + const auto* ipv6 = reinterpret_cast(interface->ifa_addr); + if (ipv6->sin6_scope_id == 0 && !isLoopbackAddress(*ipv6)) { + result.ipv6 = *ipv6; + result.hasIpv6 = true; + } + } + } + + freeifaddrs(interfaces); +} +#endif + +void addSocketIfOpen(std::vector& sockets, int socket) +{ + if (socket >= 0) + sockets.push_back(socket); +} + +timeval toTimeval(std::chrono::microseconds timeout) +{ + const auto seconds = std::chrono::duration_cast(timeout); + const auto microseconds = timeout - seconds; + + timeval result{}; + result.tv_sec = static_cast(seconds.count()); + result.tv_usec = static_cast(microseconds.count()); + return result; +} + +} // anonymous namespace + + +std::string getLocalHostname() +{ + char hostname[256] = {}; + if (gethostname(hostname, sizeof(hostname)) != 0 || hostname[0] == '\0') + return HostnameFallback; + + hostname[sizeof(hostname) - 1] = '\0'; + return hostname; +} + + +MdnsNetworkAddresses getLocalAddresses() +{ + MdnsNetworkAddresses result; + collectLocalAddresses(result); + return result; +} + + +std::vector openMdnsServiceSockets() +{ + std::vector sockets; + sockets.reserve(2); + + sockaddr_in ipv4{}; + ipv4.sin_family = AF_INET; +#ifdef _WIN32 + ipv4.sin_addr = in4addr_any; +#else + ipv4.sin_addr.s_addr = INADDR_ANY; +#endif + ipv4.sin_port = htons(MDNS_PORT); + addSocketIfOpen(sockets, mdns_socket_open_ipv4(&ipv4)); + + sockaddr_in6 ipv6{}; + ipv6.sin6_family = AF_INET6; + ipv6.sin6_addr = in6addr_any; + ipv6.sin6_port = htons(MDNS_PORT); + addSocketIfOpen(sockets, mdns_socket_open_ipv6(&ipv6)); + + return sockets; +} + + +std::vector waitForReadableSockets(const std::vector& sockets, + std::chrono::microseconds timeout) +{ + if (sockets.empty()) + return {}; + + fd_set readable; + FD_ZERO(&readable); + + int nfds = 0; + for (int socket : sockets) { + if (socket >= nfds) + nfds = socket + 1; + FD_SET(socket, &readable); + } + + timeval tv = toTimeval(timeout); + if (select(nfds, &readable, nullptr, nullptr, &tv) <= 0) + return {}; + + std::vector ready; + ready.reserve(sockets.size()); + + for (int socket : sockets) { + if (FD_ISSET(socket, &readable)) + ready.push_back(socket); + } + + return ready; +} diff --git a/src/agent/MdnsNetwork.h b/src/agent/MdnsNetwork.h new file mode 100644 index 0000000..3db813b --- /dev/null +++ b/src/agent/MdnsNetwork.h @@ -0,0 +1,27 @@ +#pragma once + +#ifdef _WIN32 +#include +#include +#else +#include +#endif + +#include +#include +#include + + +struct MdnsNetworkAddresses +{ + sockaddr_in ipv4{}; + sockaddr_in6 ipv6{}; + bool hasIpv4 = false; + bool hasIpv6 = false; +}; + +std::string getLocalHostname(); +MdnsNetworkAddresses getLocalAddresses(); +std::vector openMdnsServiceSockets(); +std::vector waitForReadableSockets(const std::vector& sockets, + std::chrono::microseconds timeout); diff --git a/src/agent/MdnsPublisher.cpp b/src/agent/MdnsPublisher.cpp index 426186f..372a1f7 100644 --- a/src/agent/MdnsPublisher.cpp +++ b/src/agent/MdnsPublisher.cpp @@ -1,18 +1,10 @@ #include "MdnsPublisher.h" -#ifdef _WIN32 -#define _CRT_SECURE_NO_WARNINGS 1 -#include -#include -#else -#include -#include -#include -#include -#endif +#include "MdnsNetwork.h" #include +#include #include #include #include @@ -23,278 +15,240 @@ namespace { +constexpr auto DnsSdServiceType = "_services._dns-sd._udp.local."; +constexpr auto ResponsePollTimeout = std::chrono::milliseconds(200); +constexpr size_t ResponseBufferSize = 2048; + struct ServiceData { mdns_string_t service; - mdns_string_t hostname; - mdns_string_t service_instance; - mdns_string_t hostname_qualified; - struct sockaddr_in address_ipv4; - struct sockaddr_in6 address_ipv6; + mdns_string_t serviceInstance; + mdns_string_t hostnameQualified; + sockaddr_in addressIpv4; + sockaddr_in6 addressIpv6; int port; - mdns_record_t record_ptr; - mdns_record_t record_srv; - mdns_record_t record_a; - mdns_record_t record_aaaa; - mdns_record_t txt_record; + mdns_record_t ptrRecord; + mdns_record_t srvRecord; + mdns_record_t aRecord; + mdns_record_t aaaaRecord; + mdns_record_t txtRecord; }; -// Get local network addresses -void getLocalAddresses(struct sockaddr_in& addr_ipv4, struct sockaddr_in6& addr_ipv6, - bool& has_ipv4, bool& has_ipv6) +struct RecordList { - memset(&addr_ipv4, 0, sizeof(addr_ipv4)); - memset(&addr_ipv6, 0, sizeof(addr_ipv6)); - has_ipv4 = false; - has_ipv6 = false; - -#ifdef _WIN32 - IP_ADAPTER_ADDRESSES* adapter_address = nullptr; - ULONG address_size = 8000; - unsigned int ret; - unsigned int num_retries = 4; - do { - adapter_address = (IP_ADAPTER_ADDRESSES*)malloc(address_size); - ret = GetAdaptersAddresses(AF_UNSPEC, GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_ANYCAST, 0, - adapter_address, &address_size); - if (ret == ERROR_BUFFER_OVERFLOW) { - free(adapter_address); - adapter_address = nullptr; - address_size *= 2; - } else { - break; - } - } while (num_retries-- > 0); + std::array records{}; + size_t count = 0; - if (!adapter_address || (ret != NO_ERROR)) { - free(adapter_address); - return; + void add(const mdns_record_t& record) + { + records[count++] = record; } +}; - for (PIP_ADAPTER_ADDRESSES adapter = adapter_address; adapter; adapter = adapter->Next) { - if (adapter->TunnelType == TUNNEL_TYPE_TEREDO) - continue; - if (adapter->OperStatus != IfOperStatusUp) - continue; - - for (IP_ADAPTER_UNICAST_ADDRESS* unicast = adapter->FirstUnicastAddress; unicast; - unicast = unicast->Next) { - if (unicast->Address.lpSockaddr->sa_family == AF_INET && !has_ipv4) { - addr_ipv4 = *(struct sockaddr_in*)unicast->Address.lpSockaddr; - has_ipv4 = true; - } else if (unicast->Address.lpSockaddr->sa_family == AF_INET6 && !has_ipv6) { - auto* saddr = (struct sockaddr_in6*)unicast->Address.lpSockaddr; - if (saddr->sin6_scope_id == 0) { - addr_ipv6 = *saddr; - has_ipv6 = true; - } - } - } - } - free(adapter_address); -#else - struct ifaddrs* ifaddr = nullptr; - if (getifaddrs(&ifaddr) < 0) - return; +mdns_string_t mdnsString(const std::string& value) +{ + return {value.c_str(), value.size()}; +} - for (struct ifaddrs* ifa = ifaddr; ifa; ifa = ifa->ifa_next) { - if (!ifa->ifa_addr) - continue; - if (!(ifa->ifa_flags & IFF_UP) || !(ifa->ifa_flags & IFF_MULTICAST)) - continue; - if ((ifa->ifa_flags & IFF_LOOPBACK) || (ifa->ifa_flags & IFF_POINTOPOINT)) - continue; - - if (ifa->ifa_addr->sa_family == AF_INET && !has_ipv4) { - auto* saddr = (struct sockaddr_in*)ifa->ifa_addr; - if (saddr->sin_addr.s_addr != htonl(INADDR_LOOPBACK)) { - addr_ipv4 = *saddr; - has_ipv4 = true; - } - } else if (ifa->ifa_addr->sa_family == AF_INET6 && !has_ipv6) { - auto* saddr = (struct sockaddr_in6*)ifa->ifa_addr; - if (saddr->sin6_scope_id) - continue; - static const unsigned char localhost[] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1}; - if (memcmp(saddr->sin6_addr.s6_addr, localhost, 16) != 0) { - addr_ipv6 = *saddr; - has_ipv6 = true; - } - } - } - freeifaddrs(ifaddr); -#endif +bool hasIpv4Address(const ServiceData& service) +{ + return service.addressIpv4.sin_family == AF_INET; } -int openServiceSockets(int* sockets, int maxSockets, - struct sockaddr_in& addr_ipv4, struct sockaddr_in6& addr_ipv6, - bool& has_ipv4, bool& has_ipv6) +bool hasIpv6Address(const ServiceData& service) { - int num_sockets = 0; - - getLocalAddresses(addr_ipv4, addr_ipv6, has_ipv4, has_ipv6); - - if (num_sockets < maxSockets) { - struct sockaddr_in sock_addr; - memset(&sock_addr, 0, sizeof(sock_addr)); - sock_addr.sin_family = AF_INET; -#ifdef _WIN32 - sock_addr.sin_addr = in4addr_any; -#else - sock_addr.sin_addr.s_addr = INADDR_ANY; -#endif - sock_addr.sin_port = htons(MDNS_PORT); - int sock = mdns_socket_open_ipv4(&sock_addr); - if (sock >= 0) - sockets[num_sockets++] = sock; - } + return service.addressIpv6.sin6_family == AF_INET6; +} - if (num_sockets < maxSockets) { - struct sockaddr_in6 sock_addr; - memset(&sock_addr, 0, sizeof(sock_addr)); - sock_addr.sin6_family = AF_INET6; - sock_addr.sin6_addr = in6addr_any; - sock_addr.sin6_port = htons(MDNS_PORT); - int sock = mdns_socket_open_ipv6(&sock_addr); - if (sock >= 0) - sockets[num_sockets++] = sock; - } +bool sameName(mdns_string_t lhs, mdns_string_t rhs) +{ + return lhs.length == rhs.length && strncmp(lhs.str, rhs.str, lhs.length) == 0; +} - return num_sockets; +bool sameName(mdns_string_t lhs, const char* rhs) +{ + const size_t rhsLength = strlen(rhs); + return lhs.length == rhsLength && strncmp(lhs.str, rhs, rhsLength) == 0; } +bool acceptsRecordType(uint16_t requestedType, mdns_record_type_t availableType) +{ + return requestedType == availableType || requestedType == MDNS_RECORDTYPE_ANY; +} -int serviceCallback(int sock, const struct sockaddr* from, size_t addrlen, - mdns_entry_type_t entry, uint16_t query_id, uint16_t rtype, - uint16_t rclass, uint32_t /*ttl*/, const void* data, - size_t size, size_t name_offset, size_t /*name_length*/, - size_t /*record_offset*/, size_t /*record_length*/, void* user_data) +RecordList serviceRecords(const ServiceData& service) { - if (entry != MDNS_ENTRYTYPE_QUESTION) - return 0; + RecordList records; + records.add(service.srvRecord); - const char dns_sd[] = "_services._dns-sd._udp.local."; - const auto* service = static_cast(user_data); + if (hasIpv4Address(service)) + records.add(service.aRecord); + if (hasIpv6Address(service)) + records.add(service.aaaaRecord); - char namebuffer[256]; - size_t offset = name_offset; - mdns_string_t name = mdns_string_extract(data, size, &offset, namebuffer, sizeof(namebuffer)); + records.add(service.txtRecord); + return records; +} - char sendbuffer[1024]; - const uint16_t unicast = (rclass & MDNS_UNICAST_RESPONSE); +RecordList instanceRecords(const ServiceData& service) +{ + RecordList records; - if ((name.length == (sizeof(dns_sd) - 1)) && - (strncmp(name.str, dns_sd, sizeof(dns_sd) - 1) == 0)) - { - if ((rtype == MDNS_RECORDTYPE_PTR) || (rtype == MDNS_RECORDTYPE_ANY)) - { - mdns_record_t answer = {}; - answer.name = name; - answer.type = MDNS_RECORDTYPE_PTR; - answer.data.ptr.name = service->service; - - if (unicast) { - mdns_query_answer_unicast(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), - query_id, static_cast(rtype), name.str, name.length, answer, 0, 0, 0, 0); - } else { - mdns_query_answer_multicast(sock, sendbuffer, sizeof(sendbuffer), answer, 0, 0, 0, 0); - } - } - } - else if ((name.length == service->service.length) && - (strncmp(name.str, service->service.str, name.length) == 0)) - { - if ((rtype == MDNS_RECORDTYPE_PTR) || (rtype == MDNS_RECORDTYPE_ANY)) - { - mdns_record_t answer = service->record_ptr; - - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - additional[additional_count++] = service->record_srv; - if (service->address_ipv4.sin_family == AF_INET) - additional[additional_count++] = service->record_a; - if (service->address_ipv6.sin6_family == AF_INET6) - additional[additional_count++] = service->record_aaaa; - additional[additional_count++] = service->txt_record; - - if (unicast) { - mdns_query_answer_unicast(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), - query_id, static_cast(rtype), name.str, name.length, answer, 0, 0, - additional, additional_count); - } else { - mdns_query_answer_multicast(sock, sendbuffer, sizeof(sendbuffer), answer, 0, 0, - additional, additional_count); - } - } - } - else if ((name.length == service->service_instance.length) && - (strncmp(name.str, service->service_instance.str, name.length) == 0)) - { - if ((rtype == MDNS_RECORDTYPE_SRV) || (rtype == MDNS_RECORDTYPE_ANY)) - { - mdns_record_t answer = service->record_srv; - - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - if (service->address_ipv4.sin_family == AF_INET) - additional[additional_count++] = service->record_a; - if (service->address_ipv6.sin6_family == AF_INET6) - additional[additional_count++] = service->record_aaaa; - additional[additional_count++] = service->txt_record; - - if (unicast) { - mdns_query_answer_unicast(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), - query_id, static_cast(rtype), name.str, name.length, answer, 0, 0, - additional, additional_count); - } else { - mdns_query_answer_multicast(sock, sendbuffer, sizeof(sendbuffer), answer, 0, 0, - additional, additional_count); - } - } + if (hasIpv4Address(service)) + records.add(service.aRecord); + if (hasIpv6Address(service)) + records.add(service.aaaaRecord); + + records.add(service.txtRecord); + return records; +} + +RecordList addressRecords(const ServiceData& service, mdns_record_type_t answeredType) +{ + RecordList records; + + if (answeredType == MDNS_RECORDTYPE_A && hasIpv6Address(service)) + records.add(service.aaaaRecord); + if (answeredType == MDNS_RECORDTYPE_AAAA && hasIpv4Address(service)) + records.add(service.aRecord); + + records.add(service.txtRecord); + return records; +} + +void answerQuestion(int socket, const sockaddr* from, size_t addressLength, + uint16_t queryId, uint16_t queryType, bool unicast, + mdns_string_t questionName, const mdns_record_t& answer, + const RecordList& additional) +{ + char buffer[1024] = {}; + const auto recordType = static_cast(queryType); + + if (unicast) { + mdns_query_answer_unicast(socket, from, addressLength, buffer, sizeof(buffer), + queryId, recordType, questionName.str, questionName.length, + answer, nullptr, 0, additional.records.data(), additional.count); + } else { + mdns_query_answer_multicast(socket, buffer, sizeof(buffer), answer, nullptr, 0, + additional.records.data(), additional.count); } - else if ((name.length == service->hostname_qualified.length) && - (strncmp(name.str, service->hostname_qualified.str, name.length) == 0)) - { - if (((rtype == MDNS_RECORDTYPE_A) || (rtype == MDNS_RECORDTYPE_ANY)) && - (service->address_ipv4.sin_family == AF_INET)) - { - mdns_record_t answer = service->record_a; - - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - if (service->address_ipv6.sin6_family == AF_INET6) - additional[additional_count++] = service->record_aaaa; - additional[additional_count++] = service->txt_record; - - if (unicast) { - mdns_query_answer_unicast(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), - query_id, static_cast(rtype), name.str, name.length, answer, 0, 0, - additional, additional_count); - } else { - mdns_query_answer_multicast(sock, sendbuffer, sizeof(sendbuffer), answer, 0, 0, - additional, additional_count); - } - } - else if (((rtype == MDNS_RECORDTYPE_AAAA) || (rtype == MDNS_RECORDTYPE_ANY)) && - (service->address_ipv6.sin6_family == AF_INET6)) - { - mdns_record_t answer = service->record_aaaa; - - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - if (service->address_ipv4.sin_family == AF_INET) - additional[additional_count++] = service->record_a; - additional[additional_count++] = service->txt_record; - - if (unicast) { - mdns_query_answer_unicast(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), - query_id, static_cast(rtype), name.str, name.length, answer, 0, 0, - additional, additional_count); - } else { - mdns_query_answer_multicast(sock, sendbuffer, sizeof(sendbuffer), answer, 0, 0, - additional, additional_count); - } - } +} + +ServiceData buildServiceData(const std::string& serviceTypeLocal, + const std::string& serviceInstanceLocal, + const std::string& hostnameLocal, + const MdnsNetworkAddresses& addresses, + unsigned int port) +{ + ServiceData service{}; + + service.service = mdnsString(serviceTypeLocal); + service.serviceInstance = mdnsString(serviceInstanceLocal); + service.hostnameQualified = mdnsString(hostnameLocal); + service.addressIpv4 = addresses.hasIpv4 ? addresses.ipv4 : sockaddr_in{}; + service.addressIpv6 = addresses.hasIpv6 ? addresses.ipv6 : sockaddr_in6{}; + service.port = static_cast(port); + + service.ptrRecord.name = service.service; + service.ptrRecord.type = MDNS_RECORDTYPE_PTR; + service.ptrRecord.data.ptr.name = service.serviceInstance; + + service.srvRecord.name = service.serviceInstance; + service.srvRecord.type = MDNS_RECORDTYPE_SRV; + service.srvRecord.data.srv.name = service.hostnameQualified; + service.srvRecord.data.srv.port = static_cast(service.port); + service.srvRecord.data.srv.priority = 0; + service.srvRecord.data.srv.weight = 0; + + service.aRecord.name = service.hostnameQualified; + service.aRecord.type = MDNS_RECORDTYPE_A; + service.aRecord.data.a.addr = service.addressIpv4; + + service.aaaaRecord.name = service.hostnameQualified; + service.aaaaRecord.type = MDNS_RECORDTYPE_AAAA; + service.aaaaRecord.data.aaaa.addr = service.addressIpv6; + + service.txtRecord.name = service.serviceInstance; + service.txtRecord.type = MDNS_RECORDTYPE_TXT; + service.txtRecord.data.txt.key = {MDNS_STRING_CONST("RDHAgent")}; + service.txtRecord.data.txt.value = {MDNS_STRING_CONST("1")}; + + return service; +} + +void announceService(const std::vector& sockets, const ServiceData& service) +{ + char buffer[ResponseBufferSize] = {}; + const RecordList additional = serviceRecords(service); + + for (int socket : sockets) + mdns_announce_multicast(socket, buffer, sizeof(buffer), service.ptrRecord, nullptr, 0, + additional.records.data(), additional.count); +} + +void sendGoodbye(const std::vector& sockets, const ServiceData& service) +{ + char buffer[ResponseBufferSize] = {}; + const RecordList additional = serviceRecords(service); + + for (int socket : sockets) + mdns_goodbye_multicast(socket, buffer, sizeof(buffer), service.ptrRecord, nullptr, 0, + additional.records.data(), additional.count); +} + +void closeSockets(std::vector& sockets) +{ + for (int socket : sockets) + mdns_socket_close(socket); + sockets.clear(); +} + +int serviceCallback(int socket, const sockaddr* from, size_t addressLength, + mdns_entry_type_t entry, uint16_t queryId, uint16_t queryType, + uint16_t queryClass, uint32_t /*ttl*/, const void* data, + size_t size, size_t nameOffset, size_t /*nameLength*/, + size_t /*recordOffset*/, size_t /*recordLength*/, void* userData) +{ + if (entry != MDNS_ENTRYTYPE_QUESTION) + return 0; + + const auto* service = static_cast(userData); + const bool unicast = (queryClass & MDNS_UNICAST_RESPONSE) != 0; + + char nameBuffer[256] = {}; + size_t offset = nameOffset; + const mdns_string_t questionName = + mdns_string_extract(data, size, &offset, nameBuffer, sizeof(nameBuffer)); + + if (sameName(questionName, DnsSdServiceType) && + acceptsRecordType(queryType, MDNS_RECORDTYPE_PTR)) { + mdns_record_t answer{}; + answer.name = questionName; + answer.type = MDNS_RECORDTYPE_PTR; + answer.data.ptr.name = service->service; + + answerQuestion(socket, from, addressLength, queryId, queryType, unicast, + questionName, answer, {}); + } else if (sameName(questionName, service->service) && + acceptsRecordType(queryType, MDNS_RECORDTYPE_PTR)) { + answerQuestion(socket, from, addressLength, queryId, queryType, unicast, + questionName, service->ptrRecord, serviceRecords(*service)); + } else if (sameName(questionName, service->serviceInstance) && + acceptsRecordType(queryType, MDNS_RECORDTYPE_SRV)) { + answerQuestion(socket, from, addressLength, queryId, queryType, unicast, + questionName, service->srvRecord, instanceRecords(*service)); + } else if (sameName(questionName, service->hostnameQualified) && + hasIpv4Address(*service) && + acceptsRecordType(queryType, MDNS_RECORDTYPE_A)) { + answerQuestion(socket, from, addressLength, queryId, queryType, unicast, + questionName, service->aRecord, + addressRecords(*service, MDNS_RECORDTYPE_A)); + } else if (sameName(questionName, service->hostnameQualified) && + hasIpv6Address(*service) && + acceptsRecordType(queryType, MDNS_RECORDTYPE_AAAA)) { + answerQuestion(socket, from, addressLength, queryId, queryType, unicast, + questionName, service->aaaaRecord, + addressRecords(*service, MDNS_RECORDTYPE_AAAA)); } return 0; @@ -305,22 +259,30 @@ int serviceCallback(int sock, const struct sockaddr* from, size_t addrlen, struct MdnsPublisher::Impl { - std::string serviceName; // e.g. "MyAgent" - std::string serviceType; // e.g. "_RDHMonitor._tcp" + std::string serviceName; + std::string serviceType; unsigned int port; - // String buffers must outlive the mdns_string_t pointers - std::string serviceTypeLocal; // "_RDHMonitor._tcp.local." - std::string serviceInstanceLocal; // "MyAgent._RDHMonitor._tcp.local." - std::string hostnameLocal; // "myhostname.local." + std::string serviceTypeLocal; + std::string serviceInstanceLocal; + std::string hostnameLocal; ServiceData service; + std::vector sockets; - int sockets[32] = {}; - int numSockets = 0; - - std::atomic running{false}; + std::atomic_bool running{false}; std::thread listenThread; + + void listen() + { + char buffer[ResponseBufferSize] = {}; + + while (running) { + const auto readySockets = waitForReadableSockets(sockets, ResponsePollTimeout); + for (int socket : readySockets) + mdns_socket_listen(socket, buffer, sizeof(buffer), serviceCallback, &service); + } + } }; @@ -345,115 +307,26 @@ void MdnsPublisher::start() if (m_impl->running) return; - // Prepare string buffers m_impl->serviceTypeLocal = m_impl->serviceType + ".local."; m_impl->serviceInstanceLocal = m_impl->serviceName + "." + m_impl->serviceType + ".local."; + m_impl->hostnameLocal = getLocalHostname() + ".local."; - char hostnameBuffer[256] = {}; - gethostname(hostnameBuffer, sizeof(hostnameBuffer)); - m_impl->hostnameLocal = std::string(hostnameBuffer) + ".local."; - - // Get local addresses and open service sockets - bool has_ipv4 = false, has_ipv6 = false; - struct sockaddr_in addr_ipv4; - struct sockaddr_in6 addr_ipv6; - - m_impl->numSockets = openServiceSockets(m_impl->sockets, 32, - addr_ipv4, addr_ipv6, has_ipv4, has_ipv6); - - if (m_impl->numSockets <= 0) { + m_impl->sockets = openMdnsServiceSockets(); + if (m_impl->sockets.empty()) { std::cerr << "MdnsPublisher: failed to open mDNS sockets\n"; return; } - // Build service data - auto& svc = m_impl->service; - memset(&svc, 0, sizeof(svc)); - - svc.service = {m_impl->serviceTypeLocal.c_str(), m_impl->serviceTypeLocal.size()}; - svc.hostname = {m_impl->serviceName.c_str(), m_impl->serviceName.size()}; - svc.service_instance = {m_impl->serviceInstanceLocal.c_str(), m_impl->serviceInstanceLocal.size()}; - svc.hostname_qualified = {m_impl->hostnameLocal.c_str(), m_impl->hostnameLocal.size()}; - svc.address_ipv4 = has_ipv4 ? addr_ipv4 : sockaddr_in{}; - svc.address_ipv6 = has_ipv6 ? addr_ipv6 : sockaddr_in6{}; - svc.port = static_cast(m_impl->port); - - // PTR: "_service._tcp.local." -> "MyAgent._service._tcp.local." - svc.record_ptr.name = svc.service; - svc.record_ptr.type = MDNS_RECORDTYPE_PTR; - svc.record_ptr.data.ptr.name = svc.service_instance; - - // SRV: "MyAgent._service._tcp.local." -> "hostname.local." + port - svc.record_srv.name = svc.service_instance; - svc.record_srv.type = MDNS_RECORDTYPE_SRV; - svc.record_srv.data.srv.name = svc.hostname_qualified; - svc.record_srv.data.srv.port = static_cast(svc.port); - svc.record_srv.data.srv.priority = 0; - svc.record_srv.data.srv.weight = 0; - - // A record - svc.record_a.name = svc.hostname_qualified; - svc.record_a.type = MDNS_RECORDTYPE_A; - svc.record_a.data.a.addr = svc.address_ipv4; - - // AAAA record - svc.record_aaaa.name = svc.hostname_qualified; - svc.record_aaaa.type = MDNS_RECORDTYPE_AAAA; - svc.record_aaaa.data.aaaa.addr = svc.address_ipv6; - - // TXT record - svc.txt_record.name = svc.service_instance; - svc.txt_record.type = MDNS_RECORDTYPE_TXT; - svc.txt_record.data.txt.key = {MDNS_STRING_CONST("RDHAgent")}; - svc.txt_record.data.txt.value = {MDNS_STRING_CONST("1")}; - - // Send initial announcement - { - char buffer[2048]; - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - additional[additional_count++] = svc.record_srv; - if (has_ipv4) - additional[additional_count++] = svc.record_a; - if (has_ipv6) - additional[additional_count++] = svc.record_aaaa; - additional[additional_count++] = svc.txt_record; - - for (int i = 0; i < m_impl->numSockets; ++i) - mdns_announce_multicast(m_impl->sockets[i], buffer, sizeof(buffer), - svc.record_ptr, 0, 0, additional, additional_count); - } + m_impl->service = buildServiceData(m_impl->serviceTypeLocal, m_impl->serviceInstanceLocal, + m_impl->hostnameLocal, getLocalAddresses(), m_impl->port); - std::cout << "mDNS service published: " << m_impl->serviceInstanceLocal << " on port " << m_impl->port << "\n"; + announceService(m_impl->sockets, m_impl->service); + + std::cout << "mDNS service published: " << m_impl->serviceInstanceLocal + << " on port " << m_impl->port << "\n"; - // Start listener thread m_impl->running = true; - m_impl->listenThread = std::thread([this] { - char buffer[2048]; - while (m_impl->running) { - int nfds = 0; - fd_set readfs; - FD_ZERO(&readfs); - for (int i = 0; i < m_impl->numSockets; ++i) { - if (m_impl->sockets[i] >= nfds) - nfds = m_impl->sockets[i] + 1; - FD_SET(m_impl->sockets[i], &readfs); - } - - struct timeval timeout; - timeout.tv_sec = 0; - timeout.tv_usec = 200000; // 200ms poll - - if (select(nfds, &readfs, nullptr, nullptr, &timeout) >= 0) { - for (int i = 0; i < m_impl->numSockets; ++i) { - if (FD_ISSET(m_impl->sockets[i], &readfs)) { - mdns_socket_listen(m_impl->sockets[i], buffer, sizeof(buffer), - serviceCallback, &m_impl->service); - } - } - } - } - }); + m_impl->listenThread = std::thread([this] { m_impl->listen(); }); } @@ -466,28 +339,8 @@ void MdnsPublisher::stop() if (m_impl->listenThread.joinable()) m_impl->listenThread.join(); - // Send goodbye - { - auto& svc = m_impl->service; - char buffer[2048]; - mdns_record_t additional[5] = {}; - size_t additional_count = 0; - additional[additional_count++] = svc.record_srv; - if (svc.address_ipv4.sin_family == AF_INET) - additional[additional_count++] = svc.record_a; - if (svc.address_ipv6.sin6_family == AF_INET6) - additional[additional_count++] = svc.record_aaaa; - additional[additional_count++] = svc.txt_record; - - for (int i = 0; i < m_impl->numSockets; ++i) - mdns_goodbye_multicast(m_impl->sockets[i], buffer, sizeof(buffer), - svc.record_ptr, 0, 0, additional, additional_count); - } - - // Close sockets - for (int i = 0; i < m_impl->numSockets; ++i) - mdns_socket_close(m_impl->sockets[i]); - m_impl->numSockets = 0; + sendGoodbye(m_impl->sockets, m_impl->service); + closeSockets(m_impl->sockets); std::cout << "mDNS service unpublished\n"; }