-
Notifications
You must be signed in to change notification settings - Fork 68
fix(ai): wrap streaming responses in AsyncStreamWrapper to support async with #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e5c018a
e27d723
0eabcfe
296d2b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| """Shared async streaming utilities for PostHog AI wrappers.""" | ||
|
|
||
| from typing import Any, AsyncGenerator, TypeVar | ||
|
|
||
| T = TypeVar("T") | ||
|
|
||
|
|
||
| class AsyncStreamWrapper: | ||
| """Wraps an async generator so it also implements the async context manager protocol. | ||
|
|
||
| The OpenAI and Anthropic SDKs return stream objects that support both | ||
| ``async for`` iteration **and** ``async with`` (i.e. they are both async | ||
| iterators and async context managers). PostHog's streaming wrappers | ||
| previously returned a bare async generator, which only supports ``async | ||
| for``. Libraries such as pydantic-ai call ``async with response:`` before | ||
| iterating, causing:: | ||
|
|
||
| TypeError: 'async_generator' object does not support the | ||
| asynchronous context manager protocol | ||
|
|
||
| This class wraps the underlying async generator and adds the missing | ||
| ``__aenter__`` / ``__aexit__`` methods. On ``__aexit__`` the generator is | ||
| closed so that the ``finally`` block inside the generator (which fires the | ||
| PostHog usage event) always executes, even when the caller breaks out of | ||
| the loop early. | ||
| """ | ||
|
|
||
| def __init__(self, generator: AsyncGenerator[T, None]) -> None: | ||
| self._generator = generator | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Async iterator protocol # | ||
| # ------------------------------------------------------------------ # | ||
|
|
||
| def __aiter__(self) -> "AsyncStreamWrapper": | ||
| return self | ||
|
|
||
| async def __anext__(self) -> T: | ||
| return await self._generator.__anext__() | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Async context manager protocol # | ||
| # ------------------------------------------------------------------ # | ||
|
|
||
| async def __aenter__(self) -> "AsyncStreamWrapper": | ||
| return self | ||
|
|
||
| async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: | ||
| # Close the generator so the finally block (PostHog event capture) runs | ||
| # even on early exit. If the generator is already exhausted this is a | ||
| # no-op. | ||
| await self._generator.aclose() | ||
| return False | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Attribute proxy – forward any other attribute access to the # | ||
| # underlying generator (e.g. .response on an Anthropic stream). # | ||
| # ------------------------------------------------------------------ # | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| return getattr(self._generator, name) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| """Regression tests for AsyncStreamWrapper. | ||
|
|
||
| Ensures that PostHog AI streaming wrappers return objects that support both | ||
| the async iterator protocol (``async for``) and the async context manager | ||
| protocol (``async with``), as required by libraries such as pydantic-ai. | ||
|
|
||
| Issue: https://github.com/PostHog/posthog-python/issues/393 | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from posthog.ai.stream import AsyncStreamWrapper | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| async def _make_gen(items): | ||
| """Simple async generator that yields the given items.""" | ||
| for item in items: | ||
| yield item | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_for_iteration(): | ||
| """AsyncStreamWrapper must yield all items when used with ``async for``.""" | ||
| wrapper = AsyncStreamWrapper(_make_gen([1, 2, 3])) | ||
| result = [] | ||
| async for item in wrapper: | ||
| result.append(item) | ||
| assert result == [1, 2, 3] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_context_manager_protocol(): | ||
| """AsyncStreamWrapper must support ``async with`` without raising TypeError.""" | ||
| wrapper = AsyncStreamWrapper(_make_gen(["a", "b"])) | ||
|
|
||
| # This is the call pattern that pydantic-ai uses and that previously raised: | ||
| # TypeError: 'async_generator' object does not support the asynchronous | ||
| # context manager protocol | ||
| async with wrapper as stream: | ||
| result = [] | ||
| async for chunk in stream: | ||
| result.append(chunk) | ||
|
|
||
| assert result == ["a", "b"] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_context_manager_returns_self(): | ||
| """``async with wrapper as w`` should bind the wrapper itself.""" | ||
| wrapper = AsyncStreamWrapper(_make_gen([])) | ||
| async with wrapper as w: | ||
| assert w is wrapper | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_finally_block_runs_on_early_exit(): | ||
| """The underlying generator's finally block must run even when the caller | ||
| breaks out of the loop early (i.e. doesn't fully exhaust the generator).""" | ||
| finally_ran = [] | ||
|
|
||
| async def gen_with_finally(): | ||
| try: | ||
| for i in range(10): | ||
| yield i | ||
| finally: | ||
| finally_ran.append(True) | ||
|
|
||
| wrapper = AsyncStreamWrapper(gen_with_finally()) | ||
| async with wrapper as stream: | ||
| async for chunk in stream: | ||
| if chunk == 2: | ||
| break # early exit | ||
|
|
||
| # __aexit__ must have called aclose(), triggering the finally block | ||
| assert finally_ran == [True], "finally block in generator did not run on early exit" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_finally_block_runs_on_full_exhaustion(): | ||
| """The underlying generator's finally block must also run on normal | ||
| exhaustion (``aclose()`` on an exhausted generator is a no-op).""" | ||
| finally_ran = [] | ||
|
|
||
| async def gen_with_finally(): | ||
| try: | ||
| yield 1 | ||
| yield 2 | ||
| finally: | ||
| finally_ran.append(True) | ||
|
|
||
| wrapper = AsyncStreamWrapper(gen_with_finally()) | ||
| async with wrapper as stream: | ||
| async for _ in stream: | ||
| pass | ||
|
|
||
| assert finally_ran == [True] | ||
|
Comment on lines
+88
to
+106
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Prompt To Fix With AIThis is a comment left during a code review.
Path: posthog/test/test_async_stream_wrapper.py
Line: 88-106
Comment:
**Prefer parameterised tests for similar cases**
`test_finally_block_runs_on_early_exit` and `test_finally_block_runs_on_full_exhaustion` share the same structure (build a generator with a `finally` side-effect, wrap it, drive it with `async with`, assert the flag). The team's preference is to express these as a single `@pytest.mark.parametrize` test so the assertion logic is written OnceAndOnlyOnce and new cases are cheap to add.
How can I resolve this? If you propose a fix, please make it concise.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_attribute_proxy(): | ||
| """Attributes not on AsyncStreamWrapper itself should be forwarded to the | ||
| underlying generator (for provider-specific metadata access).""" | ||
|
|
||
| class FakeStream: | ||
| extra_attr = "hello" | ||
|
|
||
| def __aiter__(self): | ||
| return self | ||
|
|
||
| async def __anext__(self): | ||
| raise StopAsyncIteration | ||
|
|
||
| async def aclose(self): | ||
| pass | ||
|
|
||
| wrapper = AsyncStreamWrapper(FakeStream()) # type: ignore[arg-type] | ||
| assert wrapper.extra_attr == "hello" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AsyncIteratoris imported but never referenced anywhere in the file — it was added in this PR but left unused.Prompt To Fix With AI