diff --git a/src/hyperion_client.c b/src/hyperion_client.c index f732a0b..6ab17c4 100644 --- a/src/hyperion_client.c +++ b/src/hyperion_client.c @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "hyperion_reply_reader.h" @@ -26,7 +27,63 @@ static bool _registered = false; static int _priority = 0; static const char* _origin = NULL; static bool _connected = false; -unsigned char recvBuff[1024]; +static unsigned char recvBuff[1024]; + +#define RX_STALE_SECS 120 + +enum rx_phase { RX_HEADER, + RX_BODY }; + +static struct { + enum rx_phase phase; + uint8_t header[4]; + uint32_t body_len; + uint32_t received; +} rx; + +static time_t rx_last_data; + +static void _rx_reset(void) +{ + memset(&rx, 0, sizeof(rx)); + rx_last_data = time(NULL); +} + +/** + * Try to read more bytes into buf (at offset *received, total needed: len). + * Returns: >0 bytes read this call + * 0 EAGAIN/EWOULDBLOCK (non-fatal, try again later) + * -1 fatal (EOF or real error) + * EINTR is retried internally. + */ +static int _read_exact(int fd, void* buf, size_t len, uint32_t* received) +{ + for (;;) { + ssize_t n = read(fd, (uint8_t*)buf + *received, len - *received); + if (n > 0) { + *received += (uint32_t)n; + return (int)n; + } + if (n == 0) + return 0; /* timeout or EOF — treat as transient */ + /* n < 0 */ + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) + return 0; + return -1; + } +} + +static int _check_stale(void) +{ + if (difftime(time(NULL), rx_last_data) > RX_STALE_SECS) { + WARN("No data received for %d seconds, assuming stale connection", RX_STALE_SECS); + _rx_reset(); + return -1; + } + return 0; +} int hyperion_client(const char* origin, const char* hostname, int port, bool unix_socket, int priority) { @@ -36,26 +93,58 @@ int hyperion_client(const char* origin, const char* hostname, int port, bool uni _registered = false; sockfd = 0; + int ret; if (unix_socket) { - return _connect_unix_socket(hostname); + ret = _connect_unix_socket(hostname); } else { - return _connect_inet_socket(hostname, port); + ret = _connect_inet_socket(hostname, port); } + _rx_reset(); + return ret; } int hyperion_read() { if (!sockfd) return -1; - uint8_t headbuff[4]; - int n = read(sockfd, headbuff, 4); - uint32_t messageSize = ((headbuff[0] << 24) & 0xFF000000) | ((headbuff[1] << 16) & 0x00FF0000) | ((headbuff[2] << 8) & 0x0000FF00) | ((headbuff[3]) & 0x000000FF); - if (n < 0 || messageSize >= sizeof(recvBuff)) - return -1; - n = read(sockfd, recvBuff, messageSize); - if (n < 0) + + int ret; + + /* Phase 1: accumulate 4-byte header */ + if (rx.phase == RX_HEADER) { + ret = _read_exact(sockfd, rx.header, 4, &rx.received); + if (ret < 0) + return -1; + if (ret > 0) + rx_last_data = time(NULL); + if (rx.received < 4) + return _check_stale(); + + rx.body_len = ((uint32_t)rx.header[0] << 24) + | ((uint32_t)rx.header[1] << 16) + | ((uint32_t)rx.header[2] << 8) + | (uint32_t)rx.header[3]; + + if (rx.body_len == 0 || rx.body_len >= sizeof(recvBuff)) { + _rx_reset(); + return -1; + } + + rx.phase = RX_BODY; + rx.received = 0; + } + + /* Phase 2: accumulate body */ + ret = _read_exact(sockfd, recvBuff, rx.body_len, &rx.received); + if (ret < 0) return -1; + if (ret > 0) + rx_last_data = time(NULL); + if (rx.received < rx.body_len) + return _check_stale(); + _parse_reply(hyperionnet_Reply_as_root(recvBuff)); + _rx_reset(); return 0; } @@ -65,6 +154,7 @@ int hyperion_destroy() return 0; close(sockfd); sockfd = 0; + _rx_reset(); return 0; }