From 3c3f4a358e9e49ff44be2b8f9fec5e5216ab2c61 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 19:51:18 +0000 Subject: [PATCH] Fix critical and major issues identified in codebase review Critical: - get_vwap() for IDC now returns true session VWAP (volume-weighted average of all historical trades) instead of last trade price. Added _vwap_accum accumulator to ContinuousMatchingEngine and a get_session_vwap() method. - Auto-discovered signals in _discover_signals() now emit an explicit WARNING about look-ahead bias before creating the CsvSignalProvider, making the risk visible at engine level in addition to the provider-level warning. Major: - Subscribed signals that are never dispatched (due to data range mismatches or incorrect publication_offset) now trigger a WARNING at the end of run() instead of silently producing no signal updates for the entire backtest. - Test files no longer construct _BacktestContext directly. Added tests/testing_utils.py with a make_minimal_backtest_context() factory that centralises the internal import. Both test_delivery_position.py and test_algo_decorator.py now import from tests.testing_utils. https://claude.ai/code/session_01XqYtc2t1oCV8eTmXN8aaiY --- src/nexa_backtest/engines/backtest.py | 55 +++++++++++++++++++++++++-- src/nexa_backtest/engines/matching.py | 25 ++++++++++++ tests/test_algo_decorator.py | 13 ++----- tests/test_delivery_position.py | 11 +----- tests/testing_utils.py | 40 +++++++++++++++++++ 5 files changed, 122 insertions(+), 22 deletions(-) create mode 100644 tests/testing_utils.py diff --git a/src/nexa_backtest/engines/backtest.py b/src/nexa_backtest/engines/backtest.py index 68ee18e..06c49a1 100644 --- a/src/nexa_backtest/engines/backtest.py +++ b/src/nexa_backtest/engines/backtest.py @@ -409,9 +409,14 @@ def get_last_price(self, product_id: str) -> Decimal | None: return self._clearing_prices.get(product_id) def get_vwap(self, product_id: str) -> Decimal | None: - """Return the session VWAP for ``product_id``.""" + """Return the session VWAP for ``product_id``. + + For IDC products returns the volume-weighted average of all historical + trades seen so far in the session. For DA products returns the + clearing price (which is also the uniform-price VWAP). + """ if self._idc_engine is not None: - return self._idc_engine.get_last_trade_price(product_id) + return self._idc_engine.get_session_vwap(product_id) return self._clearing_prices.get(product_id) # ------------------------------------------------------------------ @@ -770,6 +775,10 @@ def __init__( self._signals: list[SignalProvider] = signals or [] self._models = models + # Tracks which subscribed signals were successfully dispatched at least once. + # Checked at the end of run() to warn about signals that never had data. + self._signals_ever_dispatched: set[str] = set() + # Resolve dispatcher based on whether this is a SimpleAlgo or @algo fn. if isinstance(algo, SimpleAlgo): self._dispatcher: SimpleAlgoDispatcher | AsyncAlgoDispatcher = SimpleAlgoDispatcher( @@ -801,6 +810,9 @@ def run(self) -> BacktestResult: if not self._products: raise DataError("No products specified.") + # Reset per-run dispatch tracking. + self._signals_ever_dispatched.clear() + da_products = [p for p in self._products if _is_da_product(p)] idc_products = [p for p in self._products if _is_idc_product(p)] @@ -892,6 +904,8 @@ def run(self) -> BacktestResult: ) max_gate_nop = max(nop_values, default=Decimal("0")) + self._warn_undispatched_signals() + return BacktestResult( algo_name=self._algo_name, exchange=self._exchange, @@ -1230,6 +1244,7 @@ def _dispatch_signals(self, context: _BacktestContext, registry: SignalRegistry) try: value = context.get_signal(signal_name) self._dispatcher.on_signal(context, signal_name, value) + self._signals_ever_dispatched.add(signal_name) except SignalError: logger.debug( "No value yet for signal '%s' at %s — skipping.", @@ -1237,8 +1252,34 @@ def _dispatch_signals(self, context: _BacktestContext, registry: SignalRegistry) context.now().isoformat(), ) + def _warn_undispatched_signals(self) -> None: + """Warn about subscribed signals that were never dispatched during the run. + + A signal that is subscribed but never dispatches a value likely has a + data range mismatch or an incorrect ``publication_offset``, causing the + algo to silently receive no signal updates for the entire backtest. + """ + subscribed = set(self._dispatcher.subscribed_signals) + never_dispatched = subscribed - self._signals_ever_dispatched + for name in sorted(never_dispatched): + logger.warning( + "Signal '%s' was subscribed but no value was ever dispatched during " + "this backtest. Check that the signal data covers the backtest period " + "and that publication_offset is set correctly.", + name, + ) + def _discover_signals(self, registry: SignalRegistry) -> None: - """Auto-register CSV providers for subscribed but unregistered signals.""" + """Auto-register CSV providers for subscribed but unregistered signals. + + Auto-discovered signals cannot have a ``publication_offset`` inferred + automatically. They default to ``None``, meaning values are visible at + their exact timestamp. This is correct for actuals but introduces + **look-ahead bias** for forecast data. Pass an explicit + :class:`~nexa_backtest.signals.csv_loader.CsvSignalProvider` with + ``publication_offset`` set via the ``signals`` argument to + :class:`BacktestEngine` to suppress this warning. + """ signals_dir = self._data_dir / "signals" for signal_name in self._dispatcher.subscribed_signals: if registry.has(signal_name): @@ -1249,6 +1290,14 @@ def _discover_signals(self, registry: SignalRegistry) -> None: f"Signal '{signal_name}' is subscribed by the algo but no CSV " f"was found at '{csv_path}'." ) + logger.warning( + "Auto-loading signal '%s' without a publication_offset. " + "If '%s' is forecast data this will introduce look-ahead bias. " + "Pass an explicit CsvSignalProvider(publication_offset=...) via the " + "signals= argument to BacktestEngine to fix this.", + signal_name, + signal_name, + ) provider = CsvSignalProvider( name=signal_name, path=csv_path, diff --git a/src/nexa_backtest/engines/matching.py b/src/nexa_backtest/engines/matching.py index 9def8e6..2f4d79f 100644 --- a/src/nexa_backtest/engines/matching.py +++ b/src/nexa_backtest/engines/matching.py @@ -232,6 +232,9 @@ def __init__(self) -> None: # Last trade price per product (for get_last_price on context) self._last_trade_price: dict[str, Decimal] = {} + # Session VWAP accumulators: {product_id: (sum_notional, sum_volume)} + self._vwap_accum: dict[str, tuple[Decimal, Decimal]] = {} + # Gate closure times per product (set by the engine before each MTU) self._gate_closures: dict[str, datetime] = {} @@ -299,6 +302,11 @@ def _process_market_data_update(self, event: MarketDataUpdate) -> list[Fill]: def _process_historical_trade(self, event: HistoricalTrade) -> list[Fill]: self._last_trade_price[event.product_id] = event.price_eur_mwh + prev = self._vwap_accum.get(event.product_id, (Decimal("0"), Decimal("0"))) + self._vwap_accum[event.product_id] = ( + prev[0] + event.price_eur_mwh * event.volume_mw, + prev[1] + event.volume_mw, + ) return self._check_algo_orders_vs_trade(event) # ------------------------------------------------------------------ @@ -477,6 +485,23 @@ def get_last_trade_price(self, product_id: str) -> Decimal | None: """ return self._last_trade_price.get(product_id) + def get_session_vwap(self, product_id: str) -> Decimal | None: + """Return the session VWAP for a product based on historical trades. + + Computed as the running volume-weighted average of all historical trade + prices seen so far for the product. + + Args: + product_id: Exchange product identifier. + + Returns: + Session VWAP in EUR/MWh, or ``None`` if no trades have occurred. + """ + accum = self._vwap_accum.get(product_id) + if accum is None or accum[1] == Decimal("0"): + return None + return accum[0] / accum[1] + def get_resting_algo_order_ids(self) -> list[str]: """Return IDs of all resting algo orders. diff --git a/tests/test_algo_decorator.py b/tests/test_algo_decorator.py index 7ba0184..be79ab8 100644 --- a/tests/test_algo_decorator.py +++ b/tests/test_algo_decorator.py @@ -27,10 +27,9 @@ SimpleAlgoDispatcher, _BacktestContext, ) -from nexa_backtest.engines.clock import SimulatedClock from nexa_backtest.exceptions import AlgoError -from nexa_backtest.signals.registry import SignalRegistry from nexa_backtest.types import GateClosureEvent, GateClosureWarning, MarketEvent, SignalValue +from tests.testing_utils import make_minimal_backtest_context # --------------------------------------------------------------------------- # Tests: @algo decorator validation @@ -135,12 +134,7 @@ async def run(ctx: TradingContext) -> None: def test_events_raises_in_simple_algo_mode(tmp_path: Path) -> None: """ctx.events() must raise AlgoError when called from SimpleAlgo context.""" - from nexa_backtest.engines.backtest import _BacktestContext - from nexa_backtest.engines.clock import SimulatedClock - from nexa_backtest.signals.registry import SignalRegistry - - clock = SimulatedClock(initial_time=datetime(2026, 3, 1, tzinfo=UTC)) - ctx = _BacktestContext(clock=clock, signal_registry=SignalRegistry()) + ctx = make_minimal_backtest_context(initial_time=datetime(2026, 3, 1, tzinfo=UTC)) # _event_queue is None → should raise AlgoError. with pytest.raises(AlgoError, match="events\\(\\)"): @@ -390,8 +384,7 @@ async def run(ctx: TradingContext) -> None: def _make_ctx() -> _BacktestContext: """Return a minimal _BacktestContext for dispatcher unit tests.""" - clock = SimulatedClock(initial_time=datetime(2026, 3, 1, 12, 0, tzinfo=UTC)) - return _BacktestContext(clock=clock, signal_registry=SignalRegistry()) + return make_minimal_backtest_context(initial_time=datetime(2026, 3, 1, 12, 0, tzinfo=UTC)) def _make_passthrough_algo() -> object: diff --git a/tests/test_delivery_position.py b/tests/test_delivery_position.py index 1b957b0..f9ec1ac 100644 --- a/tests/test_delivery_position.py +++ b/tests/test_delivery_position.py @@ -6,22 +6,15 @@ from decimal import Decimal from nexa_backtest.engines.backtest import _BacktestContext -from nexa_backtest.engines.clock import SimulatedClock -from nexa_backtest.signals.registry import SignalRegistry from nexa_backtest.types import Fill, Side +from tests.testing_utils import make_minimal_backtest_context def _make_ctx( products: dict[str, datetime] | None = None, ) -> _BacktestContext: """Create a minimal _BacktestContext with optional product delivery starts.""" - clock = SimulatedClock( - initial_time=datetime(2026, 3, 1, 9, 0, tzinfo=UTC), - ) - ctx = _BacktestContext(clock=clock, signal_registry=SignalRegistry()) - if products: - ctx._product_delivery_starts = products - return ctx + return make_minimal_backtest_context(products=products) def _record_fill(ctx: _BacktestContext, product_id: str, side: Side, volume: Decimal) -> None: diff --git a/tests/testing_utils.py b/tests/testing_utils.py new file mode 100644 index 0000000..dda8c71 --- /dev/null +++ b/tests/testing_utils.py @@ -0,0 +1,40 @@ +"""Internal test utilities for nexa-backtest. + +These helpers intentionally access private engine implementation details to +support unit tests that need to construct engine state without running a full +backtest. They must not be imported from algo code. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from nexa_backtest.engines.backtest import _BacktestContext +from nexa_backtest.engines.clock import SimulatedClock +from nexa_backtest.signals.registry import SignalRegistry + + +def make_minimal_backtest_context( + products: dict[str, datetime] | None = None, + initial_time: datetime | None = None, +) -> _BacktestContext: + """Create a minimal ``_BacktestContext`` for engine unit tests. + + Args: + products: Optional mapping of ``product_id → delivery_start`` to + pre-populate ``_product_delivery_starts``. + initial_time: Simulated clock start time. Defaults to + ``2026-03-01T09:00:00Z``. + + Returns: + A ``_BacktestContext`` with no fills, no IDC engine, and no model + registry. Callers may call ``ctx._record_fill(fill)`` and access + ``ctx._product_delivery_starts`` directly for state setup. + """ + clock = SimulatedClock( + initial_time=initial_time or datetime(2026, 3, 1, 9, 0, tzinfo=UTC), + ) + ctx = _BacktestContext(clock=clock, signal_registry=SignalRegistry()) + if products: + ctx._product_delivery_starts = products + return ctx