diff --git a/src/nexa_backtest/engines/backtest.py b/src/nexa_backtest/engines/backtest.py index 06c49a1..ee1f59b 100644 --- a/src/nexa_backtest/engines/backtest.py +++ b/src/nexa_backtest/engines/backtest.py @@ -67,7 +67,6 @@ GateClosureEvent, GateClosureSnapshot, GateClosureWarning, - HistoricalTrade, MarketEvent, Order, OrderBook, @@ -496,19 +495,32 @@ def get_position(self, product_id: str) -> Position: return _zero_position(product_id) net_mw = Decimal("0") - total_cost = Decimal("0") + total_buy_cost = Decimal("0") + total_buy_vol = Decimal("0") + total_sell_cost = Decimal("0") + total_sell_vol = Decimal("0") for f in fills: if f.side == Side.BUY: net_mw += f.volume - total_cost += f.price * f.volume + total_buy_cost += f.price * f.volume + total_buy_vol += f.volume else: net_mw -= f.volume - total_cost -= f.price * f.volume + total_sell_cost += f.price * f.volume + total_sell_vol += f.volume if net_mw == 0: return _zero_position(product_id) - avg_price = abs(total_cost / net_mw) + # Avg entry price is the volume-weighted average of the opens that + # created the current directional position (buys for long, sells for + # short). Using the opposite side's cost would corrupt the price when + # a partial close at a different price has occurred. + if net_mw > 0: + avg_price = total_buy_cost / total_buy_vol + else: + avg_price = total_sell_cost / total_sell_vol + if self._idc_engine is not None: mark = self._idc_engine.get_last_trade_price(product_id) or avg_price else: @@ -1085,10 +1097,6 @@ def _run_idc( gate_closure_snapshots: list[GateClosureSnapshot] = [] cumulative_pnl = Decimal("0") gate_warned: set[str] = set() - # IDC market VWAP accumulators: {product_id: (sum_notional, sum_volume)} - # Accumulated incrementally as HistoricalTrade events flow through the loop. - # O(P) memory — never loads more data than the current sliding window. - idc_vwap_accum: dict[str, tuple[Decimal, Decimal]] = {} current_time = start_dt while current_time < end_dt: @@ -1104,13 +1112,6 @@ def _run_idc( with contextlib.suppress(Exception): self._dispatcher.on_error(context, exc) continue - # Accumulate market-side trade VWAP incrementally. - if isinstance(event, HistoricalTrade): - prev = idc_vwap_accum.get(event.product_id, (Decimal("0"), Decimal("0"))) - idc_vwap_accum[event.product_id] = ( - prev[0] + event.price_eur_mwh * event.volume_mw, - prev[1] + event.volume_mw, - ) # Forward historical event to @algo stream. self._dispatcher.on_market_event(context, event) for fill in fills: @@ -1166,12 +1167,13 @@ def _run_idc( bar_all_fills = [f for f in all_fills if _in_mtu(f.timestamp, current_time, next_time)] if bar_all_fills: # Use running per-product VWAP for equity curve. - # This is a partial VWAP (trades seen so far) — a good approximation - # that is much better than zero and directionally correct. - running_vwaps, _ = compute_idc_vwaps(idc_vwap_accum) + # Falls back to the fill price itself (zero alpha) when no + # market trades have occurred yet for a product — avoids + # corrupting cumulative_pnl with an arbitrary zero benchmark. + running_vwaps, _ = compute_idc_vwaps(matching_engine.get_vwap_accumulator()) bar_pnl = sum( ( - compute_fill_pnl(f, running_vwaps.get(f.product_id, Decimal("0"))) + compute_fill_pnl(f, running_vwaps.get(f.product_id, f.price)) for f in bar_all_fills ), Decimal("0"), @@ -1203,7 +1205,7 @@ def _run_idc( context, clock, gate_closure_snapshots, - idc_vwap_accum, + matching_engine.get_vwap_accumulator(), ) # ------------------------------------------------------------------ diff --git a/src/nexa_backtest/engines/matching.py b/src/nexa_backtest/engines/matching.py index 2f4d79f..9c37b8d 100644 --- a/src/nexa_backtest/engines/matching.py +++ b/src/nexa_backtest/engines/matching.py @@ -502,6 +502,18 @@ def get_session_vwap(self, product_id: str) -> Decimal | None: return None return accum[0] / accum[1] + def get_vwap_accumulator(self) -> dict[str, tuple[Decimal, Decimal]]: + """Return a snapshot of the raw VWAP accumulator for all products. + + Each value is ``(sum_notional, sum_volume)``. Used by the backtest + engine to compute per-product VWAPs for final metrics without + maintaining a separate duplicate accumulator. + + Returns: + A copy of the internal accumulator dict. + """ + return dict(self._vwap_accum) + def get_resting_algo_order_ids(self) -> list[str]: """Return IDs of all resting algo orders. diff --git a/tests/test_delivery_position.py b/tests/test_delivery_position.py index f9ec1ac..cb78144 100644 --- a/tests/test_delivery_position.py +++ b/tests/test_delivery_position.py @@ -144,3 +144,94 @@ def test_get_all_delivery_positions_multiple_periods() -> None: assert len(result) == 2 assert result[ds_0900].net_mw == Decimal("5") assert result[ds_0915].net_mw == Decimal("-3") + + +# --------------------------------------------------------------------------- +# get_position — avg entry price with mixed buy/sell fills +# --------------------------------------------------------------------------- + + +def test_get_position_avg_price_pure_long() -> None: + """Pure long position: avg_price is weighted average of all buys.""" + ctx = _make_ctx() + fill = Fill( + order_id="o1", + product_id="NO1-QH-0900", + price=Decimal("50"), + volume=Decimal("10"), + timestamp=datetime(2026, 3, 1, 8, 0, tzinfo=UTC), + side=Side.BUY, + ) + ctx._record_fill(fill) + pos = ctx.get_position("NO1-QH-0900") + assert pos.net_mw == Decimal("10") + assert pos.avg_entry_price == Decimal("50") + + +def test_get_position_avg_price_partial_close() -> None: + """BUY 100@50 then SELL 50@60: remaining long 50 should still have avg_price=50. + + The sell closes part of the long at a profit, but the remaining open + position was entered at 50, not at some blended value of 50 and 60. + """ + ctx = _make_ctx() + ts = datetime(2026, 3, 1, 8, 0, tzinfo=UTC) + ctx._record_fill( + Fill( + order_id="o1", + product_id="P", + price=Decimal("50"), + volume=Decimal("100"), + timestamp=ts, + side=Side.BUY, + ) + ) + ctx._record_fill( + Fill( + order_id="o2", + product_id="P", + price=Decimal("60"), + volume=Decimal("50"), + timestamp=ts, + side=Side.SELL, + ) + ) + pos = ctx.get_position("P") + assert pos.net_mw == Decimal("50") + assert pos.avg_entry_price == Decimal("50"), ( + f"Expected avg_entry_price=50, got {pos.avg_entry_price}. " + "Partial sell at a different price should not change the entry price of the remaining long." + ) + + +def test_get_position_avg_price_short_after_partial_buy_back() -> None: + """SELL 100@50 then BUY 40@45: remaining short 60 should have avg_price=50.""" + ctx = _make_ctx() + ts = datetime(2026, 3, 1, 8, 0, tzinfo=UTC) + ctx._record_fill( + Fill( + order_id="o1", + product_id="P", + price=Decimal("50"), + volume=Decimal("100"), + timestamp=ts, + side=Side.SELL, + ) + ) + ctx._record_fill( + Fill( + order_id="o2", + product_id="P", + price=Decimal("45"), + volume=Decimal("40"), + timestamp=ts, + side=Side.BUY, + ) + ) + pos = ctx.get_position("P") + assert pos.net_mw == Decimal("-60") + assert pos.avg_entry_price == Decimal("50"), ( + f"Expected avg_entry_price=50, got {pos.avg_entry_price}. " + "Partial buy-back at a different price should not change " + "the entry price of the remaining short." + )