From 582aefc0eba991cd33d4f5b77d571500c3643af4 Mon Sep 17 00:00:00 2001 From: Nick Sweeting Date: Wed, 15 Oct 2025 17:36:18 -0700 Subject: [PATCH] add support for middlewares to hook into event bus handler lifecycle --- README.md | 54 ++++++++- bubus/__init__.py | 4 + bubus/middlewares.py | 257 +++++++++++++++++++++++++++++++++++++++++ bubus/service.py | 237 ++++++++++++++++++++++--------------- tests/test_eventbus.py | 179 +++++++++++++++++++++++++++- 5 files changed, 629 insertions(+), 102 deletions(-) create mode 100644 bubus/middlewares.py diff --git a/README.md b/README.md index afd7ed8..df2c09b 100644 --- a/README.md +++ b/README.md @@ -477,11 +477,29 @@ await bus.dispatch(DataEvent()) Persist events automatically to a `jsonl` file for future replay and debugging: ```python +from pathlib import Path + +from bubus import EventBus +from bubus.middlewares import ( + LoggerEventBusMiddleware, + SQLiteEventBusMiddleware, + WALEventBusMiddleware, +) + # Enable WAL event log persistence (optional) -bus = EventBus(name='MyBus', wal_path='./events.jsonl') +bus = EventBus( + name='MyBus', + middlewares=[ + WALEventBusMiddleware('./events.jsonl'), + LoggerEventBusMiddleware('./events.log'), + SQLiteEventBusMiddleware('./events.sqlite'), + ], +) + +# LoggerEventBusMiddleware defaults to stdout-only logging if no file path is provided # All completed events are automatically appended as JSON lines to the end -bus.dispatch(SecondEventAbc(some_key="banana")) +await bus.dispatch(SecondEventAbc(some_key="banana")) ``` `./events.jsonl`: @@ -507,17 +525,43 @@ The main event bus class that manages event processing and handler execution. ```python EventBus( name: str | None = None, - wal_path: Path | str | None = None, parallel_handlers: bool = False, - max_history_size: int | None = 50 + max_history_size: int | None = 50, + middlewares: Sequence[EventBusMiddleware | type[EventBusMiddleware]] | None = None, ) ``` **Parameters:** - `name`: Optional unique name for the bus (auto-generated if not provided) -- `wal_path`: Path for write-ahead logging of events to a `jsonl` file (optional) - `parallel_handlers`: If `True`, handlers run concurrently for each event, otherwise serially if `False` (the default) +- `middlewares`: Optional list of `EventBusMiddleware` subclasses or instances that hook into handler execution for analytics, logging, retries, etc. + +Handler middlewares subclass `EventBusMiddleware` and override whichever lifecycle hooks they need: + +```python +from bubus.middlewares import EventBusMiddleware + +class AnalyticsMiddleware(EventBusMiddleware): + async def before_handler(self, eventbus, event, event_result): + await analytics_bus.dispatch(HandlerStartedAnalyticsEvent(event_id=event_result.event_id)) + + async def after_handler(self, eventbus, event, event_result): + await analytics_bus.dispatch(HandlerCompletedAnalyticsEvent(event_id=event_result.event_id)) + + async def on_handler_error(self, eventbus, event, event_result, error): + await analytics_bus.dispatch(HandlerCompletedAnalyticsEvent(event_id=event_result.event_id, error=error)) +``` + +Middlewares can observe or mutate the `EventResult` at each step, dispatch additional events, or trigger other side effects (metrics, retries, auth checks, etc.). + +The built-in `SQLiteEventBusMiddleware` mirrors every event and handler transition into append-only `events_log` and `event_results_log` tables, making it easy to inspect or audit the bus state: + +```python +from bubus.middlewares import SQLiteEventBusMiddleware + +bus = EventBus(middlewares=[SQLiteEventBusMiddleware('./events.sqlite')]) +``` - `max_history_size`: Maximum number of events to keep in history (default: 50, None = unlimited) #### `EventBus` Properties diff --git a/bubus/__init__.py b/bubus/__init__.py index df6e6e2..871b740 100644 --- a/bubus/__init__.py +++ b/bubus/__init__.py @@ -1,10 +1,14 @@ """Event bus for the browser-use agent.""" +from bubus.middlewares import EventBusMiddleware, LoggerEventBusMiddleware, SQLiteEventBusMiddleware from bubus.models import BaseEvent, EventHandler, EventResult, PythonIdentifierStr, PythonIdStr, UUIDStr from bubus.service import EventBus __all__ = [ 'EventBus', + 'EventBusMiddleware', + 'LoggerEventBusMiddleware', + 'SQLiteEventBusMiddleware', 'BaseEvent', 'EventResult', 'EventHandler', diff --git a/bubus/middlewares.py b/bubus/middlewares.py new file mode 100644 index 0000000..39efff9 --- /dev/null +++ b/bubus/middlewares.py @@ -0,0 +1,257 @@ +"""Reusable EventBus middleware helpers.""" + +from __future__ import annotations + +import asyncio +import logging +import sqlite3 +import threading +from pathlib import Path +from typing import Any + +from bubus.logging import log_eventbus_tree +from bubus.models import BaseEvent +from bubus.service import EventBus, EventBusMiddleware as _EventBusMiddleware + +__all__ = ['EventBusMiddleware', 'WALEventBusMiddleware', 'LoggerEventBusMiddleware', 'SQLiteEventBusMiddleware'] + +logger = logging.getLogger('bubus.middleware') + +EventBusMiddleware = _EventBusMiddleware + + +class WALEventBusMiddleware(EventBusMiddleware): + """Persist completed events to a JSONL write-ahead log.""" + + def __init__(self, wal_path: Path | str): + self.wal_path = Path(wal_path) + self.wal_path.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() + + async def after_event(self, eventbus: EventBus, event: BaseEvent[Any]) -> None: + if getattr(event, '_wal_written', False): + return + + if not self._event_is_complete(event): + return + + try: + await asyncio.to_thread(self._write_event, event) + setattr(event, '_wal_written', True) + except Exception as exc: # pragma: no cover - logging branch + logger.error( + '❌ %s Failed to save event %s to WAL file %s: %s %s', + eventbus, + event.event_id, + self.wal_path, + type(exc).__name__, + exc, + ) + + def _event_is_complete(self, event: BaseEvent[Any]) -> bool: + signal = event.event_completed_signal + if signal is not None and not signal.is_set(): + return False + if any(result.status not in ('completed', 'error') for result in event.event_results.values()): + return False + return event.event_are_all_children_complete() + + def _write_event(self, event: BaseEvent[Any]) -> None: + event_json = event.model_dump_json() # pyright: ignore[reportUnknownMemberType] + with self._lock: + with self.wal_path.open('a', encoding='utf-8') as fp: + fp.write(event_json + '\n') + + +class LoggerEventBusMiddleware(EventBusMiddleware): + """Log completed events using the existing logging helpers and optionally mirror to a text file.""" + + def __init__(self, log_path: Path | str | None = None): + self.log_path = Path(log_path) if log_path is not None else None + if self.log_path is not None: + self.log_path.parent.mkdir(parents=True, exist_ok=True) + + async def after_event(self, eventbus: EventBus, event: BaseEvent[Any]) -> None: + if getattr(event, '_logger_middleware_logged', False): + return + + if not self._event_is_complete(event): + return + + setattr(event, '_logger_middleware_logged', True) + + summary = event.event_log_safe_summary() + logger.info('✅ %s completed event %s', eventbus, summary) + + line = f'[{eventbus.name}] {summary}\n' + await asyncio.to_thread(self._append_line, line) + + if logger.isEnabledFor(logging.DEBUG): + log_eventbus_tree(eventbus) + + def _event_is_complete(self, event: BaseEvent[Any]) -> bool: + signal = event.event_completed_signal + if signal is not None and not signal.is_set(): + return False + if any(result.status not in ('completed', 'error') for result in event.event_results.values()): + return False + return event.event_are_all_children_complete() + + def _append_line(self, line: str) -> None: + if self.log_path is not None: + with self.log_path.open('a', encoding='utf-8') as fp: + fp.write(line) + print(line.rstrip('\n'), flush=True) + + +class SQLiteEventBusMiddleware(EventBusMiddleware): + """Mirror events and handler results into append-only SQLite tables.""" + + def __init__(self, db_path: str | Path): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(self.db_path, check_same_thread=False) + self._conn.execute('PRAGMA journal_mode=WAL') + self._conn.execute('PRAGMA synchronous=NORMAL') + self._setup_schema() + self._lock = asyncio.Lock() + + def __del__(self): + try: + self._conn.close() + except Exception: + pass + + def _setup_schema(self) -> None: + self._conn.execute( + ''' + CREATE TABLE IF NOT EXISTS events_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_id TEXT NOT NULL, + event_type TEXT NOT NULL, + event_status TEXT NOT NULL, + eventbus_name TEXT, + event_json TEXT NOT NULL, + inserted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''' + ) + self._conn.execute( + ''' + CREATE TABLE IF NOT EXISTS event_results_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_id TEXT NOT NULL, + handler_id TEXT NOT NULL, + handler_name TEXT NOT NULL, + eventbus_id TEXT NOT NULL, + eventbus_name TEXT NOT NULL, + status TEXT NOT NULL, + result_repr TEXT, + error_repr TEXT, + inserted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''' + ) + self._conn.commit() + + async def before_handler(self, eventbus: EventBus, event: BaseEvent[Any], event_result) -> None: + await self._insert_event_result(event_result) + + async def after_handler(self, eventbus: EventBus, event: BaseEvent[Any], event_result) -> None: + await self._insert_event_result(event_result) + + async def on_handler_error( + self, + eventbus: EventBus, + event: BaseEvent[Any], + event_result, + error: BaseException, + ) -> None: + await self._insert_event_result(event_result, error_override=error) + + async def after_event(self, eventbus: EventBus, event: BaseEvent[Any]) -> None: + if getattr(event, '_sqlite_logged', False): + return + + if not self._event_is_complete(event): + return + + await self._insert_event(eventbus, event) + setattr(event, '_sqlite_logged', True) + + async def _insert_event_result(self, event_result, error_override: BaseException | None = None) -> None: + error = error_override or event_result.error + error_repr = repr(error) if error is not None else None + result_repr = None + if event_result.result is not None and error is None: + try: + result_repr = repr(event_result.result) + except Exception: + result_repr = '' + + await self._execute( + ''' + INSERT INTO event_results_log ( + event_id, + handler_id, + handler_name, + eventbus_id, + eventbus_name, + status, + result_repr, + error_repr + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ''', + ( + event_result.event_id, + event_result.handler_id, + event_result.handler_name, + event_result.eventbus_id, + event_result.eventbus_name, + event_result.status, + result_repr, + error_repr, + ), + ) + + async def _insert_event(self, eventbus: EventBus, event: BaseEvent[Any]) -> None: + event_json = event.model_dump_json() # pyright: ignore[reportUnknownMemberType] + has_error = any(result.status == 'error' for result in event.event_results.values()) + event_status = 'error' if has_error else event.event_status + + await self._execute( + ''' + INSERT INTO events_log ( + event_id, + event_type, + event_status, + eventbus_name, + event_json + ) + VALUES (?, ?, ?, ?, ?) + ''', + ( + event.event_id, + event.event_type, + event_status, + eventbus.name, + event_json, + ), + ) + + async def _execute(self, sql: str, params: tuple[Any, ...]) -> None: + async with self._lock: + await asyncio.to_thread(self._run_execute, sql, params) + + def _run_execute(self, sql: str, params: tuple[Any, ...]) -> None: + self._conn.execute(sql, params) + self._conn.commit() + + def _event_is_complete(self, event: BaseEvent[Any]) -> bool: + signal = event.event_completed_signal + if signal is not None and not signal.is_set(): + return False + if any(result.status not in ('completed', 'error') for result in event.event_results.values()): + return False + return event.event_are_all_children_complete() diff --git a/bubus/service.py b/bubus/service.py index 72f652e..df19715 100644 --- a/bubus/service.py +++ b/bubus/service.py @@ -6,12 +6,11 @@ import warnings import weakref from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from contextvars import ContextVar from pathlib import Path from typing import Any, Literal, TypeVar, cast, overload -import anyio # pyright: ignore[reportMissingImports] from uuid_extensions import uuid7str # pyright: ignore[reportMissingImports, reportUnknownVariableType] uuid7str: Callable[[], str] = uuid7str # pyright: ignore @@ -34,6 +33,7 @@ UUIDStr, get_handler_id, get_handler_name, + EventResult, ) logger = logging.getLogger('bubus') @@ -52,6 +52,31 @@ class QueueShutDown(Exception): EventPatternType = PythonIdentifierStr | Literal['*'] | type['BaseEvent[Any]'] +class EventBusMiddleware: + """Base class for EventBus middlewares.""" + + async def before_handler( + self, eventbus: 'EventBus', event: 'BaseEvent[Any]', event_result: EventResult[Any] + ) -> None: + return None + + async def after_handler( + self, eventbus: 'EventBus', event: 'BaseEvent[Any]', event_result: EventResult[Any] + ) -> None: + return None + + async def on_handler_error( + self, + eventbus: 'EventBus', + event: 'BaseEvent[Any]', + event_result: EventResult[Any], + error: BaseException, + ) -> None: + return None + + async def after_event(self, eventbus: 'EventBus', event: 'BaseEvent[Any]') -> None: + return None + class CleanShutdownQueue(asyncio.Queue[QueueEntryType]): """asyncio.Queue subclass that handles shutdown cleanly without warnings.""" @@ -263,7 +288,6 @@ class EventBus: # Class Attributes name: PythonIdentifierStr = 'EventBus' parallel_handlers: bool = False - wal_path: Path | None = None # Runtime State id: UUIDStr = '00000000-0000-0000-0000-000000000000' @@ -278,9 +302,9 @@ class EventBus: def __init__( self, name: PythonIdentifierStr | None = None, - wal_path: Path | str | None = None, parallel_handlers: bool = False, max_history_size: int | None = 50, # Keep only 50 events in history + middlewares: Sequence[EventBusMiddleware | type[EventBusMiddleware]] | None = None, ): self.id = uuid7str() self.name = name or f'{self.__class__.__name__}_{self.id[-8:]}' @@ -332,8 +356,9 @@ def __init__( self.event_history = {} self.handlers = defaultdict(list) self.parallel_handlers = parallel_handlers - self.wal_path = Path(wal_path) if wal_path else None self._on_idle = None + self._middlewares: list[EventBusMiddleware] = [] + self.middlewares = list(middlewares or []) # Memory leak prevention settings self.max_history_size = max_history_size @@ -341,11 +366,6 @@ def __init__( # Register this instance EventBus.all_instances.add(self) - # Instead of registering as normal event handlers, - # these special handlers are just called manually at the end of step - # self.on('*', self._default_log_handler) - # self.on('*', self._default_wal_handler) - def __del__(self): """Auto-cleanup on garbage collection""" # Most cleanup should have been done by the event loop close hook @@ -371,6 +391,71 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + @property + def middlewares(self) -> list[EventBusMiddleware]: + return getattr(self, '_middlewares', []) + + @middlewares.setter + def middlewares(self, value: Sequence[EventBusMiddleware | type[EventBusMiddleware]]) -> None: + instances: list[EventBusMiddleware] = [] + for middleware in value: + if isinstance(middleware, EventBusMiddleware): + instances.append(middleware) + elif inspect.isclass(middleware) and issubclass(middleware, EventBusMiddleware): + instances.append(middleware()) + else: + raise TypeError( + f'Invalid middleware {middleware!r}. Expected EventBusMiddleware instance or subclass.' + ) + self._middlewares = instances + + async def _call_middleware_hook( + self, + middleware: EventBusMiddleware, + method_name: str, + *args: Any, + ) -> None: + method = getattr(middleware, method_name, None) + if method is None: + return + result = method(*args) + if inspect.isawaitable(result): + await result + + async def _middlewares_before_handler(self, event: 'BaseEvent[Any]', event_result: EventResult[Any]) -> None: + for middleware in self._middlewares: + await self._call_middleware_hook(middleware, 'before_handler', self, event, event_result) + + async def _middlewares_after_handler(self, event: 'BaseEvent[Any]', event_result: EventResult[Any]) -> None: + for middleware in self._middlewares: + await self._call_middleware_hook(middleware, 'after_handler', self, event, event_result) + + async def _middlewares_on_error( + self, event: 'BaseEvent[Any]', event_result: EventResult[Any], error: BaseException + ) -> None: + for middleware in self._middlewares: + await self._call_middleware_hook(middleware, 'on_handler_error', self, event, event_result, error) + + async def _middleware_after_event(self, event: 'BaseEvent[Any]') -> None: + for middleware in self._middlewares: + await self._call_middleware_hook(middleware, 'after_event', self, event) + + async def _dispatch_after_event_hooks(self, event: 'BaseEvent[Any]') -> None: + if getattr(event, '_after_event_hooks_run', False): + return + + event_completed = False + if event.event_completed_signal is not None and event.event_completed_signal.is_set(): + event_completed = True + elif event.event_results and all(result.status in ('completed', 'error') for result in event.event_results.values()): + event_completed = True + + if not event_completed: + return + + setattr(event, '_after_event_hooks_run', True) + await self._middleware_after_event(event) + @property def events_pending(self) -> list['BaseEvent[Any]']: """Get events that haven't started processing yet (does not include events that have not even finished dispatching yet in self.event_queue)""" @@ -975,12 +1060,11 @@ async def process_event(self, event: 'BaseEvent[Any]', timeout: float | None = N # Execute handlers await self._execute_handlers(event, handlers=applicable_handlers, timeout=timeout) - await self._default_log_handler(event) - await self._default_wal_handler(event) - # Mark event as complete if all handlers are done event.event_mark_complete_if_all_handlers_completed() + await self._dispatch_after_event_hooks(event) + # After processing this event, check if any parent events can now be marked complete # We do this by walking up the parent chain current = event @@ -991,10 +1075,12 @@ async def process_event(self, event: 'BaseEvent[Any]', timeout: float | None = N # Find parent event in any bus's history parent_event = None + parent_bus: EventBus | None = None # Create a list copy to avoid "Set changed size during iteration" error for bus in list(EventBus.all_instances): if bus and current.event_parent_id in bus.event_history: parent_event = bus.event_history[current.event_parent_id] + parent_bus = bus break if not parent_event: @@ -1004,6 +1090,9 @@ async def process_event(self, event: 'BaseEvent[Any]', timeout: float | None = N if parent_event.event_completed_signal and not parent_event.event_completed_signal.is_set(): parent_event.event_mark_complete_if_all_handlers_completed() + if parent_bus: + await parent_bus._dispatch_after_event_hooks(parent_event) + # Move up the chain current = parent_event @@ -1078,35 +1167,39 @@ async def _execute_handlers( # print('FINSIHED EXECUTING ALL HANDLERS') async def execute_handler( - self, event: 'BaseEvent[T_EventResultType]', handler: EventHandler, timeout: float | None = None + self, + event: 'BaseEvent[T_EventResultType]', + handler: EventHandler, + timeout: float | None = None, ) -> Any: - """Safely execute a single handler with deadlock detection""" + """Safely execute a single handler with middleware support.""" - # Check if this handler has already been executed for this event handler_id = get_handler_id(handler, self) - logger.debug(f' ↳ {self}.execute_handler({event}, handler={get_handler_name(handler)}#{handler_id[-4:]})') - if handler_id in event.event_results: - existing_result = event.event_results[handler_id] - if existing_result.started_at is not None: - raise RuntimeError( - f'Handler {get_handler_name(handler)}#{handler_id[-4:]} has already been executed for event {event.event_id}. ' - f'Previous execution started at {existing_result.started_at}' - ) - # Mark handler as started + event_result = event.event_results.get(handler_id) + if event_result is None: + event_result = event.event_result_update( + handler=handler, eventbus=self, status='pending', timeout=timeout or event.event_timeout + ) + elif event_result.started_at is not None: + raise RuntimeError( + f'Handler {get_handler_name(handler)}#{handler_id[-4:]} has already been executed for event {event.event_id}. ' + f'Previous execution started at {event_result.started_at}' + ) + + handler_id = get_handler_id(handler, self) + event_result = event.event_result_update( handler=handler, eventbus=self, status='started', timeout=timeout or event.event_timeout ) - # Set the current event in context so child events can reference it + await self._middlewares_before_handler(event, event_result) + token = _current_event_context.set(event) - # Mark that we're inside a handler handler_token = inside_handler_context.set(True) - # Set the current handler ID so child events can be tracked handler_id_token = _current_handler_id_context.set(handler_id) - # Create a task to monitor for potential deadlock / slow handlers async def deadlock_monitor(): await asyncio.sleep(15.0) logger.warning( @@ -1120,21 +1213,13 @@ async def deadlock_monitor(): ) handler_task = None + final_result: EventResult[Any] | None = None try: if inspect.iscoroutinefunction(handler): - # Create a task for the handler so we can properly cancel it on timeout handler_task = asyncio.create_task(handler(event)) # type: ignore - # This allows us to process child events when the handler awaits them result_value: Any = await asyncio.wait_for(handler_task, timeout=event_result.timeout) elif inspect.isfunction(handler) or inspect.ismethod(handler): - # If handler function is sync function, run it directly in the main thread - # This blocks but ensures we have access to the event loop, dont run it in a subthread! - result_value: Any = handler(event) - - # If the sync handler returned a BaseEvent (from dispatch), DON'T await it - # For forwarding handlers like bus.on('*', other_bus.dispatch), the handler - # has already queued the event on the target bus. The event will be tracked - # as a child event automatically. + result_value = handler(event) if isinstance(result_value, BaseEvent): logger.debug( f'Handler {get_handler_name(handler)} returned BaseEvent, not awaiting to avoid circular dependency' @@ -1145,59 +1230,45 @@ async def deadlock_monitor(): logger.debug( f' ↳ Handler {get_handler_name(handler)}#{handler_id[-4:]} returned: {type(result_value).__name__} {str(result_value)[:26]}...' # pyright: ignore ) - # Cancel the monitor task since handler completed successfully monitor_task.cancel() - # Record successful result - event.event_result_update(handler=handler, eventbus=self, result=result_value) - if handler_id in event.event_results: - # logger.debug( - # f' ↳ Updated result for {get_handler_name(handler)}#{handler_id[-4:]}: {event.event_results[handler_id].status}' - # ) - pass - else: - logger.error(f' ↳ ERROR: Result not found for {get_handler_name(handler)}#{handler_id[-4:]} after update!') - return cast(T_EventResultType, result_value) + final_result = event.event_result_update(handler=handler, eventbus=self, result=result_value) + + await self._middlewares_after_handler(event, final_result) + return cast(T_EventResultType, final_result.result) except asyncio.CancelledError as e: - # Cancel the monitor task on timeout too monitor_task.cancel() - - # Create a RuntimeError for timeout - # TODO: figure out why it breaks when we try to switch to InterruptedError instead of asyncio.CancelledError handler_interrupted_error = asyncio.CancelledError( f'Event handler {get_handler_name(handler)}#{handler_id[-4:]}({event}) was interrupted because of a parent timeout' ) - event.event_result_update(handler=handler, eventbus=self, error=handler_interrupted_error) - - # import ipdb; ipdb.set_trace() + final_result = event.event_result_update(handler=handler, eventbus=self, error=handler_interrupted_error) + await self._middlewares_on_error(event, final_result, handler_interrupted_error) raise handler_interrupted_error from e except TimeoutError as e: - # Cancel the monitor task on timeout too monitor_task.cancel() - - # Create a RuntimeError for timeout children = ( f' and interrupted any processing of {len(event.event_children)} child events' if event.event_children else '' ) handler_timeout_error = TimeoutError( f'Event handler {get_handler_name(handler)}#{handler_id[-4:]}({event}) timed out after {event_result.timeout}s{children}' ) - event.event_result_update(handler=handler, eventbus=self, error=handler_timeout_error) + final_result = event.event_result_update(handler=handler, eventbus=self, error=handler_timeout_error) event.event_cancel_pending_child_processing(handler_timeout_error) from bubus.logging import log_timeout_tree - log_timeout_tree(event, event_result) - # import ipdb; ipdb.set_trace() + if final_result is not None: + log_timeout_tree(event, final_result) + await self._middlewares_on_error(event, final_result, handler_timeout_error) raise handler_timeout_error from e except Exception as e: - # Cancel the monitor task on error too monitor_task.cancel() - # Record error - event.event_result_update(handler=handler, eventbus=self, error=e) + final_result = event.event_result_update(handler=handler, eventbus=self, error=e) + + await self._middlewares_on_error(event, final_result, e) red = '\033[91m' reset = '\033[0m' @@ -1206,29 +1277,28 @@ async def deadlock_monitor(): ) raise finally: - # Reset context _current_event_context.reset(token) inside_handler_context.reset(handler_token) _current_handler_id_context.reset(handler_id_token) - # Ensure handler task is cancelled if it's still running if handler_task and not handler_task.done(): handler_task.cancel() try: await asyncio.wait_for(handler_task, timeout=0.1) except (asyncio.CancelledError, TimeoutError): - pass # Expected when we cancel the task + pass - # Ensure monitor task is cancelled try: if not monitor_task.done(): monitor_task.cancel() await monitor_task except asyncio.CancelledError: - pass # Expected when we cancel the monitor - except Exception as e: - # logger.debug(f"❌ {self} Handler monitor task cleanup error for {get_handler_name(handler)}#{str(id(handler))[-4:]}({event}): {type(e).__name__}: {e}") pass + except Exception: + pass + + assert final_result is not None, 'Handler execution did not produce an EventResult' + return final_result.result def _would_create_loop(self, event: 'BaseEvent[Any]', handler: EventHandler) -> bool: """Check if calling this handler would create a loop""" @@ -1322,27 +1392,6 @@ def _handler_dispatched_ancestor( # Recursively check the parent's ancestry return self._handler_dispatched_ancestor(parent_event, handler_id, visited, depth) - async def _default_log_handler(self, event: 'BaseEvent[Any]') -> None: - """Default handler that logs all events""" - # logger.debug( - # f'✅ {self} completed: {event} -> {list(event.event_results.values()) or ''}' - # ) - pass - - async def _default_wal_handler(self, event: 'BaseEvent[Any]') -> None: - """Persist completed event to WAL file as JSONL""" - - if not self.wal_path: - return None - - try: - event_json = event.model_dump_json() # pyright: ignore[reportUnknownMemberType] - self.wal_path.parent.mkdir(parents=True, exist_ok=True) - async with await anyio.open_file(self.wal_path, 'a', encoding='utf-8') as f: # pyright: ignore[reportUnknownMemberType] - await f.write(event_json + '\n') # pyright: ignore[reportUnknownMemberType] - except Exception as e: - logger.error(f'❌ {self} Failed to save event {event.event_id} to WAL file: {type(e).__name__} {e}\n{event}') - def cleanup_excess_events(self) -> int: """ Clean up excess events from event_history based on max_history_size. diff --git a/tests/test_eventbus.py b/tests/test_eventbus.py index b4cb977..5e86890 100644 --- a/tests/test_eventbus.py +++ b/tests/test_eventbus.py @@ -17,6 +17,7 @@ import asyncio import json import os +import sqlite3 import time from datetime import datetime, timezone from typing import Any @@ -25,6 +26,12 @@ from pydantic import Field from bubus import BaseEvent, EventBus +from bubus.middlewares import ( + EventBusMiddleware, + LoggerEventBusMiddleware, + SQLiteEventBusMiddleware, + WALEventBusMiddleware, +) class CreateAgentTaskEvent(BaseEvent): @@ -694,7 +701,7 @@ async def test_wal_persistence_handler(self, tmp_path): """Test that events are automatically persisted to WAL file""" # Create event bus with WAL path wal_path = tmp_path / 'test_events.jsonl' - bus = EventBus(name='TestBus', wal_path=wal_path) + bus = EventBus(name='TestBus', middlewares=[WALEventBusMiddleware(wal_path)]) try: # Emit some events @@ -734,7 +741,7 @@ async def test_wal_persistence_creates_parent_dir(self, tmp_path): assert not wal_path.parent.exists() # Create event bus - bus = EventBus(name='TestBus', wal_path=wal_path) + bus = EventBus(name='TestBus', middlewares=[WALEventBusMiddleware(wal_path)]) try: # Emit an event @@ -755,7 +762,7 @@ async def test_wal_persistence_creates_parent_dir(self, tmp_path): async def test_wal_persistence_skips_incomplete_events(self, tmp_path): """Test that WAL persistence only writes completed events""" wal_path = tmp_path / 'incomplete_events.jsonl' - bus = EventBus(name='TestBus', wal_path=wal_path) + bus = EventBus(name='TestBus', middlewares=[WALEventBusMiddleware(wal_path)]) try: # Add a slow handler that will delay completion @@ -789,6 +796,172 @@ async def slow_handler(event: BaseEvent) -> str: await bus.stop() +class TestHandlerMiddleware: + """Tests for the handler middleware pipeline.""" + + async def test_middleware_wraps_successful_handler(self): + calls: list[tuple[str, str]] = [] + + class TrackingMiddleware(EventBusMiddleware): + def __init__(self, call_log: list[tuple[str, str]]): + self.call_log = call_log + + async def before_handler(self, eventbus: EventBus, event: BaseEvent, event_result): + self.call_log.append(('before', event_result.status)) + + async def after_handler(self, eventbus: EventBus, event: BaseEvent, event_result): + self.call_log.append(('after', event_result.status)) + + bus = EventBus(middlewares=[TrackingMiddleware(calls)]) + bus.on('UserActionEvent', lambda event: 'ok') + + try: + completed = await bus.dispatch(UserActionEvent(action='test', user_id='user1')) + await bus.wait_until_idle() + + assert completed.event_results + result = next(iter(completed.event_results.values())) + assert result.status == 'completed' + assert result.result == 'ok' + assert calls == [('before', 'started'), ('after', 'completed')] + finally: + await bus.stop() + + async def test_middleware_observes_handler_errors(self): + observations: list[tuple[str, str]] = [] + + class ErrorMiddleware(EventBusMiddleware): + def __init__(self, log: list[tuple[str, str]]): + self.log = log + + async def before_handler(self, eventbus: EventBus, event: BaseEvent, event_result): + self.log.append(('before', event_result.status)) + + async def on_handler_error( + self, + eventbus: EventBus, + event: BaseEvent, + event_result, + error: BaseException, + ): + self.log.append(('error', type(error).__name__)) + + async def failing_handler(event: BaseEvent) -> None: + raise ValueError('boom') + + bus = EventBus(middlewares=[ErrorMiddleware(observations)]) + bus.on('UserActionEvent', failing_handler) + + try: + event = await bus.dispatch(UserActionEvent(action='fail', user_id='user2')) + await bus.wait_until_idle() + + result = next(iter(event.event_results.values())) + assert result.status == 'error' + assert isinstance(result.error, ValueError) + assert observations == [('before', 'started'), ('error', 'ValueError')] + finally: + await bus.stop() + + +class TestSQLiteMiddleware: + async def test_sqlite_middleware_persists_events_and_results(self, tmp_path): + db_path = tmp_path / 'events.sqlite' + middleware = SQLiteEventBusMiddleware(db_path) + bus = EventBus(middlewares=[middleware]) + + async def handler(event: BaseEvent) -> str: + return 'ok' + + bus.on('UserActionEvent', handler) + + try: + await bus.dispatch(UserActionEvent(action='ping', user_id='u-1')) + await bus.wait_until_idle() + + conn = sqlite3.connect(db_path) + events = conn.execute('SELECT event_id, event_type, event_status, event_json FROM events_log').fetchall() + assert len(events) == 1 + assert events[0][1] == 'UserActionEvent' + assert events[0][2] == 'completed' + + result_rows = conn.execute( + 'SELECT status, result_repr, error_repr FROM event_results_log ORDER BY id' + ).fetchall() + conn.close() + + assert [status for status, *_ in result_rows] == ['started', 'completed'] + assert result_rows[-1][1] == "'ok'" + assert result_rows[-1][2] is None + finally: + await bus.stop() + + +class TestLoggerMiddleware: + async def test_logger_middleware_writes_file(self, tmp_path): + log_path = tmp_path / 'events.log' + bus = EventBus(middlewares=[LoggerEventBusMiddleware(log_path)]) + + async def handler(event: BaseEvent) -> str: + return 'logged' + + bus.on('UserActionEvent', handler) + + try: + await bus.dispatch(UserActionEvent(action='log', user_id='user')) + await bus.wait_until_idle() + + assert log_path.exists() + contents = log_path.read_text().strip().splitlines() + assert contents + assert 'UserActionEvent' in contents[-1] + finally: + await bus.stop() + + async def test_logger_middleware_stdout_only(self, capsys): + bus = EventBus(middlewares=[LoggerEventBusMiddleware()]) + + async def handler(event: BaseEvent) -> str: + return 'stdout' + + bus.on('UserActionEvent', handler) + + try: + await bus.dispatch(UserActionEvent(action='log', user_id='user')) + await bus.wait_until_idle() + + captured = capsys.readouterr() + assert 'UserActionEvent' in captured.out + assert 'stdout' not in captured.err + finally: + await bus.stop() + async def test_sqlite_middleware_records_errors(self, tmp_path): + db_path = tmp_path / 'events.sqlite' + middleware = SQLiteEventBusMiddleware(db_path) + bus = EventBus(middlewares=[middleware]) + + async def failing_handler(event: BaseEvent) -> None: + raise RuntimeError('handler boom') + + bus.on('UserActionEvent', failing_handler) + + try: + await bus.dispatch(UserActionEvent(action='boom', user_id='u-2')) + await bus.wait_until_idle() + + conn = sqlite3.connect(db_path) + result_rows = conn.execute( + 'SELECT status, error_repr FROM event_results_log ORDER BY id' + ).fetchall() + events = conn.execute('SELECT event_status FROM events_log').fetchall() + conn.close() + + assert [status for status, _ in result_rows] == ['started', 'error'] + assert 'RuntimeError' in result_rows[-1][1] + assert events[0][0] == 'error' + finally: + await bus.stop() + class TestEventBusHierarchy: """Test hierarchical EventBus subscription patterns"""