diff --git a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py index 4a0efce9e1bfe..40c672a3c4285 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py +++ b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py @@ -11,10 +11,12 @@ import asyncio import inspect import os +import weakref from collections.abc import Awaitable, Callable from functools import lru_cache from typing import Any, cast +import httpx import openai from pydantic import SecretStr @@ -46,6 +48,44 @@ def __del__(self) -> None: pass +class _LoopAwareAsyncHttpxClientProxy: + """A proxy for async httpx clients that maintains per-event-loop instances. + + httpx.AsyncClient connections are bound to the event loop they were created + on. This proxy is cached at the (base_url, timeout) level via @lru_cache, + preserving fast O(1) init-time lookup. Internally it stores one + _AsyncHttpxClientWrapper per event loop in a WeakValueDictionary so that + clients are automatically garbage-collected when their loop closes. + + This avoids the 'Event loop is closed' RuntimeError that occurs when a + process-global @lru_cache'd client is reused across different event loops + (e.g. in multi-threaded environments, Celery workers, or sequential + asyncio.run() calls). + """ + + def __init__(self, base_url: str | None, timeout: Any) -> None: + self._base_url = base_url + self._timeout = timeout + self._loop_clients: weakref.WeakValueDictionary[ + int, _AsyncHttpxClientWrapper + ] = weakref.WeakValueDictionary() + + def get_client(self) -> _AsyncHttpxClientWrapper: + """Return the cached client for the current event loop, creating one if needed.""" + try: + loop = asyncio.get_running_loop() + loop_id = id(loop) + except RuntimeError: + # No running event loop; return a fresh client (won't be cached). + return _build_async_httpx_client(self._base_url, self._timeout) + + client = self._loop_clients.get(loop_id) + if client is None: + client = _build_async_httpx_client(self._base_url, self._timeout) + self._loop_clients[loop_id] = client + return client + + def _build_sync_httpx_client( base_url: str | None, timeout: Any ) -> _SyncHttpxClientWrapper: @@ -76,10 +116,16 @@ def _cached_sync_httpx_client( @lru_cache -def _cached_async_httpx_client( +def _cached_async_httpx_client_proxy( base_url: str | None, timeout: Any -) -> _AsyncHttpxClientWrapper: - return _build_async_httpx_client(base_url, timeout) +) -> _LoopAwareAsyncHttpxClientProxy: + """Return a per-(base_url, timeout) proxy cached by @lru_cache. + + The proxy itself is lightweight and @lru_cache'd, so init-time overhead is + identical to the original implementation. The proxy internally dispatches + to a per-event-loop client, preventing cross-loop connection reuse. + """ + return _LoopAwareAsyncHttpxClientProxy(base_url, timeout) def _get_default_httpx_client( @@ -100,16 +146,22 @@ def _get_default_httpx_client( def _get_default_async_httpx_client( base_url: str | None, timeout: Any ) -> _AsyncHttpxClientWrapper: - """Get default httpx client. + """Get default async httpx client, scoped to the current event loop. - Uses cached client unless timeout is `httpx.Timeout`, which is not hashable. + Async httpx clients are bound to the event loop they were created on, so + they cannot be safely shared across different event loops. This function + returns a client that is cached per-loop to avoid 'Event loop is closed' + errors in multi-threaded or multi-loop environments (e.g. Celery, sequential + asyncio.run() calls, or multi-threaded FastAPI handlers). + + Uses a fresh (uncached) client when timeout is not hashable. """ try: hash(timeout) except TypeError: return _build_async_httpx_client(base_url, timeout) else: - return _cached_async_httpx_client(base_url, timeout) + return _cached_async_httpx_client_proxy(base_url, timeout).get_client() def _resolve_sync_and_async_api_keys( @@ -117,8 +169,8 @@ def _resolve_sync_and_async_api_keys( ) -> tuple[str | None | Callable[[], str], str | Callable[[], Awaitable[str]]]: """Resolve sync and async API key values. - Because OpenAI and AsyncOpenAI clients support either sync or async callables for - the API key, we need to resolve separate values here. + Because OpenAI and AsyncOpenAI clients support either sync or async callables + for the API key, we need to resolve separate values here. """ if isinstance(api_key, SecretStr): sync_api_key_value: str | None | Callable[[], str] = api_key.get_secret_value() @@ -138,5 +190,4 @@ async def async_api_key_wrapper() -> str: ) async_api_key_value = async_api_key_wrapper - return sync_api_key_value, async_api_key_value