Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions src/nexa_backtest/engines/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
GateClosureEvent,
GateClosureSnapshot,
GateClosureWarning,
HistoricalTrade,
MarketEvent,
Order,
OrderBook,
Expand Down Expand Up @@ -496,19 +495,32 @@ def get_position(self, product_id: str) -> Position:
return _zero_position(product_id)

net_mw = Decimal("0")
total_cost = Decimal("0")
total_buy_cost = Decimal("0")
total_buy_vol = Decimal("0")
total_sell_cost = Decimal("0")
total_sell_vol = Decimal("0")
for f in fills:
if f.side == Side.BUY:
net_mw += f.volume
total_cost += f.price * f.volume
total_buy_cost += f.price * f.volume
total_buy_vol += f.volume
else:
net_mw -= f.volume
total_cost -= f.price * f.volume
total_sell_cost += f.price * f.volume
total_sell_vol += f.volume

if net_mw == 0:
return _zero_position(product_id)

avg_price = abs(total_cost / net_mw)
# Avg entry price is the volume-weighted average of the opens that
# created the current directional position (buys for long, sells for
# short). Using the opposite side's cost would corrupt the price when
# a partial close at a different price has occurred.
if net_mw > 0:
avg_price = total_buy_cost / total_buy_vol
else:
avg_price = total_sell_cost / total_sell_vol

if self._idc_engine is not None:
mark = self._idc_engine.get_last_trade_price(product_id) or avg_price
else:
Expand Down Expand Up @@ -1085,10 +1097,6 @@ def _run_idc(
gate_closure_snapshots: list[GateClosureSnapshot] = []
cumulative_pnl = Decimal("0")
gate_warned: set[str] = set()
# IDC market VWAP accumulators: {product_id: (sum_notional, sum_volume)}
# Accumulated incrementally as HistoricalTrade events flow through the loop.
# O(P) memory — never loads more data than the current sliding window.
idc_vwap_accum: dict[str, tuple[Decimal, Decimal]] = {}

current_time = start_dt
while current_time < end_dt:
Expand All @@ -1104,13 +1112,6 @@ def _run_idc(
with contextlib.suppress(Exception):
self._dispatcher.on_error(context, exc)
continue
# Accumulate market-side trade VWAP incrementally.
if isinstance(event, HistoricalTrade):
prev = idc_vwap_accum.get(event.product_id, (Decimal("0"), Decimal("0")))
idc_vwap_accum[event.product_id] = (
prev[0] + event.price_eur_mwh * event.volume_mw,
prev[1] + event.volume_mw,
)
# Forward historical event to @algo stream.
self._dispatcher.on_market_event(context, event)
for fill in fills:
Expand Down Expand Up @@ -1166,12 +1167,13 @@ def _run_idc(
bar_all_fills = [f for f in all_fills if _in_mtu(f.timestamp, current_time, next_time)]
if bar_all_fills:
# Use running per-product VWAP for equity curve.
# This is a partial VWAP (trades seen so far) — a good approximation
# that is much better than zero and directionally correct.
running_vwaps, _ = compute_idc_vwaps(idc_vwap_accum)
# Falls back to the fill price itself (zero alpha) when no
# market trades have occurred yet for a product — avoids
# corrupting cumulative_pnl with an arbitrary zero benchmark.
running_vwaps, _ = compute_idc_vwaps(matching_engine.get_vwap_accumulator())
bar_pnl = sum(
(
compute_fill_pnl(f, running_vwaps.get(f.product_id, Decimal("0")))
compute_fill_pnl(f, running_vwaps.get(f.product_id, f.price))
for f in bar_all_fills
),
Decimal("0"),
Expand Down Expand Up @@ -1203,7 +1205,7 @@ def _run_idc(
context,
clock,
gate_closure_snapshots,
idc_vwap_accum,
matching_engine.get_vwap_accumulator(),
)

# ------------------------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions src/nexa_backtest/engines/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,18 @@ def get_session_vwap(self, product_id: str) -> Decimal | None:
return None
return accum[0] / accum[1]

def get_vwap_accumulator(self) -> dict[str, tuple[Decimal, Decimal]]:
"""Return a snapshot of the raw VWAP accumulator for all products.

Each value is ``(sum_notional, sum_volume)``. Used by the backtest
engine to compute per-product VWAPs for final metrics without
maintaining a separate duplicate accumulator.

Returns:
A copy of the internal accumulator dict.
"""
return dict(self._vwap_accum)

def get_resting_algo_order_ids(self) -> list[str]:
"""Return IDs of all resting algo orders.

Expand Down
91 changes: 91 additions & 0 deletions tests/test_delivery_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,94 @@ def test_get_all_delivery_positions_multiple_periods() -> None:
assert len(result) == 2
assert result[ds_0900].net_mw == Decimal("5")
assert result[ds_0915].net_mw == Decimal("-3")


# ---------------------------------------------------------------------------
# get_position — avg entry price with mixed buy/sell fills
# ---------------------------------------------------------------------------


def test_get_position_avg_price_pure_long() -> None:
"""Pure long position: avg_price is weighted average of all buys."""
ctx = _make_ctx()
fill = Fill(
order_id="o1",
product_id="NO1-QH-0900",
price=Decimal("50"),
volume=Decimal("10"),
timestamp=datetime(2026, 3, 1, 8, 0, tzinfo=UTC),
side=Side.BUY,
)
ctx._record_fill(fill)
pos = ctx.get_position("NO1-QH-0900")
assert pos.net_mw == Decimal("10")
assert pos.avg_entry_price == Decimal("50")


def test_get_position_avg_price_partial_close() -> None:
"""BUY 100@50 then SELL 50@60: remaining long 50 should still have avg_price=50.

The sell closes part of the long at a profit, but the remaining open
position was entered at 50, not at some blended value of 50 and 60.
"""
ctx = _make_ctx()
ts = datetime(2026, 3, 1, 8, 0, tzinfo=UTC)
ctx._record_fill(
Fill(
order_id="o1",
product_id="P",
price=Decimal("50"),
volume=Decimal("100"),
timestamp=ts,
side=Side.BUY,
)
)
ctx._record_fill(
Fill(
order_id="o2",
product_id="P",
price=Decimal("60"),
volume=Decimal("50"),
timestamp=ts,
side=Side.SELL,
)
)
pos = ctx.get_position("P")
assert pos.net_mw == Decimal("50")
assert pos.avg_entry_price == Decimal("50"), (
f"Expected avg_entry_price=50, got {pos.avg_entry_price}. "
"Partial sell at a different price should not change the entry price of the remaining long."
)


def test_get_position_avg_price_short_after_partial_buy_back() -> None:
"""SELL 100@50 then BUY 40@45: remaining short 60 should have avg_price=50."""
ctx = _make_ctx()
ts = datetime(2026, 3, 1, 8, 0, tzinfo=UTC)
ctx._record_fill(
Fill(
order_id="o1",
product_id="P",
price=Decimal("50"),
volume=Decimal("100"),
timestamp=ts,
side=Side.SELL,
)
)
ctx._record_fill(
Fill(
order_id="o2",
product_id="P",
price=Decimal("45"),
volume=Decimal("40"),
timestamp=ts,
side=Side.BUY,
)
)
pos = ctx.get_position("P")
assert pos.net_mw == Decimal("-60")
assert pos.avg_entry_price == Decimal("50"), (
f"Expected avg_entry_price=50, got {pos.avg_entry_price}. "
"Partial buy-back at a different price should not change "
"the entry price of the remaining short."
)
Loading