diff --git a/flymyai/core/clients/AsyncClient.py b/flymyai/core/clients/AsyncClient.py index b8664a3..313e4ee 100644 --- a/flymyai/core/clients/AsyncClient.py +++ b/flymyai/core/clients/AsyncClient.py @@ -8,7 +8,14 @@ ) from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo -from flymyai.core.clients.base_client import BaseClient, _predict_timeout +from flymyai.core.clients.base_client import ( + BaseClient, + _predict_timeout, + _http2, + _limits, + _is_reconnectable_error, + _RECONNECT_RETRIES, +) from flymyai.core.exceptions import ( BaseFlyMyAIException, FlyMyAIOpenAPIException, @@ -29,7 +36,8 @@ class BaseAsyncClient(BaseClient[httpx.AsyncClient]): def _construct_client(self): return httpx.AsyncClient( - http2=True, + http2=_http2, + limits=_limits, headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), timeout=_predict_timeout, @@ -44,11 +52,20 @@ async def _reconnect_client(self): self._client = self._construct_client() async def _awith_reconnect(self, fn): - try: - return await fn() - except httpx.RemoteProtocolError: - await self._reconnect_client() - return await fn() + last_exc = None + for attempt in range(1 + _RECONNECT_RETRIES): + try: + return await fn() + except BaseException as e: + last_exc = e + if not _is_reconnectable_error(e): + raise + if attempt < _RECONNECT_RETRIES: + await self._reconnect_client() + continue + raise + assert last_exc is not None + raise last_exc async def __aenter__(self): return self @@ -236,7 +253,9 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict): except BaseFlyMyAIException as e: raise FlyMyAIPredictException.from_base_exception(e) yield response - except httpx.RemoteProtocolError: + except BaseException as e: + if not _is_reconnectable_error(e): + raise await self._reconnect_client() stream_iterator = self._stream_iterator( client_info, payload, is_long_stream=True diff --git a/flymyai/core/clients/SyncClient.py b/flymyai/core/clients/SyncClient.py index 596d60c..db0ec60 100644 --- a/flymyai/core/clients/SyncClient.py +++ b/flymyai/core/clients/SyncClient.py @@ -5,7 +5,14 @@ from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo -from flymyai.core.clients.base_client import BaseClient, _predict_timeout +from flymyai.core.clients.base_client import ( + BaseClient, + _predict_timeout, + _http2, + _limits, + _is_reconnectable_error, + _RECONNECT_RETRIES, +) from flymyai.core.exceptions import ( BaseFlyMyAIException, FlyMyAIOpenAPIException, @@ -29,18 +36,28 @@ class BaseSyncClient(BaseClient[httpx.Client]): def _construct_client(self): return httpx.Client( - http2=True, + http2=_http2, + limits=_limits, headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), timeout=_predict_timeout, ) def _with_reconnect(self, fn): - try: - return fn() - except httpx.RemoteProtocolError: - self._reconnect_client() - return fn() + last_exc = None + for attempt in range(1 + _RECONNECT_RETRIES): + try: + return fn() + except BaseException as e: + last_exc = e + if not _is_reconnectable_error(e): + raise + if attempt < _RECONNECT_RETRIES: + self._reconnect_client() + continue + raise + assert last_exc is not None + raise last_exc def __enter__(self): return self @@ -165,7 +182,9 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict): except BaseFlyMyAIException as e: raise FlyMyAIPredictException.from_base_exception(e) yield response - except httpx.RemoteProtocolError: + except BaseException as e: + if not _is_reconnectable_error(e): + raise self._reconnect_client() response_iterator = self._stream_iterator( client_info, payload, is_long_stream=True diff --git a/flymyai/core/clients/base_client.py b/flymyai/core/clients/base_client.py index 5cad66a..63e339a 100644 --- a/flymyai/core/clients/base_client.py +++ b/flymyai/core/clients/base_client.py @@ -35,6 +35,35 @@ "_PossibleClients", bound=Union[httpx.Client, httpx.AsyncClient] ) +# Connection-style errors that warrant one reconnect + retry (high RPS / HTTP2 drops) +_CONNECT_RECONNECT_EXC = ( + httpx.RemoteProtocolError, + httpx.WriteError, + httpx.ReadError, + httpx.ConnectError, +) + +# How many times to reconnect and retry on connection/stream errors (1 initial + this many retries) +_RECONNECT_RETRIES = int(os.getenv("FMA_RECONNECT_RETRIES", "3")) + + +def _is_reconnectable_error(exc: BaseException) -> bool: + if isinstance(exc, _CONNECT_RECONNECT_EXC): + return True + if isinstance(exc, RuntimeError) and "client has been closed" in str(exc).lower(): + return True + if type(exc).__name__ in ("ClosedResourceError", "BrokenResourceError"): + return True + # HTTP/2 connection closed then reused (e.g. "ConnectionState.CLOSED", "SEND_SETTINGS", "StreamReset") + msg = str(exc) + if ( + "ConnectionState.CLOSED" in msg + or "SEND_SETTINGS" in msg + or "StreamReset" in msg + ): + return True + return False + _predict_timeout = httpx.Timeout( connect=int(os.getenv("FMA_CONNECT_TIMEOUT", 999999)), @@ -43,6 +72,13 @@ pool=int(os.getenv("FMA_POOL_TIMEOUT", 999999)), ) +_http2 = os.getenv("FLYMYAI_HTTP2", "true").lower() in ("1", "true", "yes") +_limits = httpx.Limits( + max_connections=int(os.getenv("FMA_MAX_CONNECTIONS", "100")), + max_keepalive_connections=int(os.getenv("FMA_MAX_KEEPALIVE_CONNECTIONS", "50")), + keepalive_expiry=float(os.getenv("FMA_KEEPALIVE_EXPIRY", "60")), +) + class BaseClient(Generic[_PossibleClients]): """ diff --git a/flymyai/core/clients/m1AsyncClient.py b/flymyai/core/clients/m1AsyncClient.py index 34491b7..46f2e1b 100644 --- a/flymyai/core/clients/m1AsyncClient.py +++ b/flymyai/core/clients/m1AsyncClient.py @@ -93,9 +93,7 @@ async def generation_task_result( ) -> FlyMyAIM1Response: while True: response = await self._awith_reconnect( - lambda: self._client.get( - self._populate_result_path(generation_task) - ) + lambda: self._client.get(self._populate_result_path(generation_task)) ) response.raise_for_status() response_data = response.json()