diff --git a/src/network/host_pairing.cpp b/src/network/host_pairing.cpp index ac4a2b4..b971626 100644 --- a/src/network/host_pairing.cpp +++ b/src/network/host_pairing.cpp @@ -116,6 +116,7 @@ namespace { constexpr std::size_t CLIENT_CHALLENGE_BYTE_COUNT = 16; constexpr std::size_t CLIENT_SECRET_BYTE_COUNT = 16; constexpr int SOCKET_TIMEOUT_MILLISECONDS = 5000; + constexpr int PIN_ENTRY_SOCKET_TIMEOUT_MILLISECONDS = 90000; constexpr uint16_t DEFAULT_SERVERINFO_HTTP_PORT = 47989; constexpr uint16_t FALLBACK_SERVERINFO_HTTP_PORT = 47984; constexpr uint16_t DEFAULT_SERVERINFO_HTTPS_PORT = 47990; @@ -217,7 +218,7 @@ namespace { bool is_timeout_error(int errorCode) { #if defined(NXDK) || !defined(_WIN32) - return errorCode == ETIMEDOUT; + return errorCode == ETIMEDOUT || errorCode == EWOULDBLOCK || errorCode == EAGAIN; #else return errorCode == WSAETIMEDOUT; #endif @@ -249,18 +250,18 @@ namespace { return true; } - void set_socket_timeouts(SOCKET socketHandle) { + void set_socket_timeouts(SOCKET socketHandle, int timeoutMilliseconds) { #if defined(NXDK) || !defined(_WIN32) timeval timeout { - SOCKET_TIMEOUT_MILLISECONDS / 1000, - (SOCKET_TIMEOUT_MILLISECONDS % 1000) * 1000, + timeoutMilliseconds / 1000, + (timeoutMilliseconds % 1000) * 1000, }; setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); #else - const DWORD timeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS; - setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeoutMilliseconds), sizeof(timeoutMilliseconds)); - setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeoutMilliseconds), sizeof(timeoutMilliseconds)); + const DWORD platformTimeoutMilliseconds = static_cast(timeoutMilliseconds); + setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&platformTimeoutMilliseconds), sizeof(platformTimeoutMilliseconds)); + setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&platformTimeoutMilliseconds), sizeof(platformTimeoutMilliseconds)); #endif } @@ -351,7 +352,8 @@ namespace { std::string_view expectedTlsCertificatePem, HttpResponse *response, std::string *errorMessage, - const std::atomic *cancelRequested = nullptr + const std::atomic *cancelRequested = nullptr, + int socketIoTimeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS ); std::string summarize_http_payload_preview(std::string_view text) { @@ -1496,18 +1498,25 @@ namespace { return true; } - bool finalize_connected_socket(SOCKET socketHandle, std::string *errorMessage) { + bool finalize_connected_socket(SOCKET socketHandle, int socketIoTimeoutMilliseconds, std::string *errorMessage) { trace_pairing_phase("restoring blocking mode after connect"); if (!set_socket_non_blocking(socketHandle, false, errorMessage)) { return false; } - set_socket_timeouts(socketHandle); + set_socket_timeouts(socketHandle, socketIoTimeoutMilliseconds); trace_pairing_phase("socket connected"); return true; } - bool connect_socket(const std::string &address, uint16_t port, SocketGuard *socketGuard, std::string *errorMessage, const std::atomic *cancelRequested = nullptr) { + bool connect_socket( + const std::string &address, + uint16_t port, + SocketGuard *socketGuard, + int socketIoTimeoutMilliseconds, + std::string *errorMessage, + const std::atomic *cancelRequested = nullptr + ) { if (socketGuard == nullptr) { return append_error(errorMessage, "Internal pairing error while preparing the host connection"); } @@ -1547,7 +1556,7 @@ namespace { } } - return finalize_connected_socket(socketGuard->handle, errorMessage); + return finalize_connected_socket(socketGuard->handle, socketIoTimeoutMilliseconds, errorMessage); } bool recv_all_plain(SOCKET socketHandle, std::string *response, std::string *errorMessage, const std::atomic *cancelRequested = nullptr) { @@ -1855,7 +1864,8 @@ namespace { std::string_view expectedTlsCertificatePem, HttpResponse *response, std::string *errorMessage, - const std::atomic *cancelRequested + const std::atomic *cancelRequested, + int socketIoTimeoutMilliseconds ) { if (pairing_cancel_requested(cancelRequested)) { return append_cancelled_pairing_error(errorMessage); @@ -1869,6 +1879,7 @@ namespace { useTls, tlsClientIdentity, std::string(expectedTlsCertificatePem), + socketIoTimeoutMilliseconds, }; network::testing::HostPairingHttpTestResponse testResponse {}; if (std::string testError; !testHandler(testRequest, &testResponse, &testError, cancelRequested)) { @@ -1891,7 +1902,7 @@ namespace { SocketGuard socketGuard; trace_pairing_phase("http_get: connect_socket"); - if (!connect_socket(address, port, &socketGuard, errorMessage, cancelRequested)) { + if (!connect_socket(address, port, &socketGuard, socketIoTimeoutMilliseconds, errorMessage, cancelRequested)) { return false; } @@ -2149,12 +2160,18 @@ namespace { return true; } - bool execute_pairing_phase_request(PairingSessionState *session, const std::string &path, bool useTls, std::string_view expectedTlsCertificatePem = {}) { + bool execute_pairing_phase_request( + PairingSessionState *session, + const std::string &path, + bool useTls, + std::string_view expectedTlsCertificatePem = {}, + int socketIoTimeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS + ) { if (session == nullptr) { return false; } - if (!http_get(session->request.address, useTls ? session->serverInfo.httpsPort : session->serverInfo.httpPort, path, useTls, useTls ? &session->request.identity : nullptr, expectedTlsCertificatePem, &session->response, &session->errorMessage, session->cancelRequested)) { + if (!http_get(session->request.address, useTls ? session->serverInfo.httpsPort : session->serverInfo.httpPort, path, useTls, useTls ? &session->request.identity : nullptr, expectedTlsCertificatePem, &session->response, &session->errorMessage, session->cancelRequested, socketIoTimeoutMilliseconds)) { return false; } return parse_pairing_tag(session->response, "paired", &session->phaseValue, &session->errorMessage); @@ -2187,7 +2204,7 @@ namespace { const std::string phasePath = "/pair?uniqueid=" + session->uniqueId + "&uuid=" + session->requestUuid + "&devicename=" + session->deviceName + "&updateState=1&phrase=getservercert&salt=" + session->saltHex + "&clientcert=" + certHex; trace_pairing_phase("phase 1 getservercert request"); - if (!execute_pairing_phase_request(session, phasePath, false)) { + if (!execute_pairing_phase_request(session, phasePath, false, {}, PIN_ENTRY_SOCKET_TIMEOUT_MILLISECONDS)) { return false; } if (session->phaseValue != "1") { diff --git a/src/network/host_pairing.h b/src/network/host_pairing.h index 2bb6013..cb9f392 100644 --- a/src/network/host_pairing.h +++ b/src/network/host_pairing.h @@ -277,6 +277,7 @@ namespace network { bool useTls = false; ///< True when the request would normally use TLS. const PairingIdentity *tlsClientIdentity = nullptr; ///< Optional client identity attached to TLS requests. std::string expectedTlsCertificatePem; ///< Optional pinned host certificate expected by the request. + int socketIoTimeoutMilliseconds = 0; ///< Socket read/write timeout that the real transport would apply. }; /** diff --git a/tests/unit/network/host_pairing_test.cpp b/tests/unit/network/host_pairing_test.cpp index d23ad90..b9c879f 100644 --- a/tests/unit/network/host_pairing_test.cpp +++ b/tests/unit/network/host_pairing_test.cpp @@ -33,6 +33,8 @@ namespace { using network::testing::HostPairingHttpTestRequest; using network::testing::HostPairingHttpTestResponse; + constexpr int kDefaultPairingSocketTimeoutMilliseconds = 5000; + constexpr int kPinEntrySocketTimeoutMilliseconds = 90000; constexpr std::string_view kUnpairedClientErrorMessage = "The host reports that this client is no longer paired. Pair the host again."; class ScopedHostPairingHttpTestHandler { @@ -1309,6 +1311,39 @@ namespace { EXPECT_EQ(result.message, "Pairing failed during phase 1 (getservercert): The host rejected the initial pairing request"); } + TEST(HostPairingTest, PairHostUsesExtendedTimeoutForPinEntryResponse) { + const network::PairingIdentity identity = network::create_pairing_identity(); + ASSERT_TRUE(network::is_valid_pairing_identity(identity)); + + std::size_t callCount = 0U; + ScopedHostPairingHttpTestHandler guard([&callCount](const HostPairingHttpTestRequest &request, HostPairingHttpTestResponse *response, std::string *, const std::atomic *) { + if (callCount++ == 0U) { + EXPECT_EQ(request.socketIoTimeoutMilliseconds, kDefaultPairingSocketTimeoutMilliseconds); + response->statusCode = 200; + response->body = make_server_info_xml(false, 47989U, 47990U, "Pair Host", "pair-host"); + return true; + } + + EXPECT_NE(request.pathAndQuery.find("phrase=getservercert"), std::string::npos); + EXPECT_EQ(request.socketIoTimeoutMilliseconds, kPinEntrySocketTimeoutMilliseconds); + response->statusCode = 200; + response->body = make_pair_phase_response("0"); + return true; + }); + + const network::HostPairingResult result = network::pair_host({ + test_support::kTestIpv4Addresses[test_support::kIpLivingRoom], + 47989U, + "1234", + "MoonlightXboxOG", + identity, + }); + + EXPECT_EQ(callCount, 2U); + EXPECT_FALSE(result.success); + EXPECT_EQ(result.message, "Pairing failed during phase 1 (getservercert): The host rejected the initial pairing request"); + } + TEST(HostPairingTest, PairHostFailsWhenTheChallengeResponseIsTooShort) { const network::PairingIdentity clientIdentity = network::create_pairing_identity(); const network::PairingIdentity serverIdentity = network::create_pairing_identity();