From 30249b2b16868a39762d13d91d6d2ec312a04ce8 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Wed, 15 Apr 2026 11:45:13 +0200 Subject: [PATCH 01/18] fix(proxy): narrow previous_response recovery to not_found semantics and add regression tests --- app/modules/proxy/service.py | 269 ++++++++++++- .../integration/test_http_responses_bridge.py | 274 ++++++++++++++ .../test_proxy_websocket_responses.py | 354 ++++++++++++++++++ tests/unit/test_proxy_http_bridge.py | 351 +++++++++++++++++ tests/unit/test_proxy_utils.py | 257 +++++++++++++ 5 files changed, 1498 insertions(+), 7 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index c232cef9..15eb24ec 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -8,7 +8,7 @@ import re import time from collections import deque -from collections.abc import Collection, Sequence +from collections.abc import AsyncGenerator, Collection, Sequence from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime, timezone @@ -638,6 +638,138 @@ async def _stream_via_http_bridge( session.last_used_at = time.monotonic() return session = session_or_forward + session_events: AsyncGenerator[str, None] = self._stream_http_bridge_session_events( + session, + request_state=request_state, + text_data=text_data, + queue_limit=queue_limit, + propagate_http_errors=propagate_http_errors, + downstream_turn_state=downstream_turn_state, + ) + try: + async for event_block in session_events: + yield event_block + except ProxyResponseError as exc: + should_attempt_previous_response_recovery = ( + effective_payload.previous_response_id is not None + and _http_bridge_should_attempt_local_previous_response_recovery(exc) + ) + if not should_attempt_previous_response_recovery: + raise + + if PROMETHEUS_AVAILABLE and bridge_durable_recover_total is not None: + bridge_durable_recover_total.labels(path="local_previous_response_error").inc() + _log_http_bridge_event( + "previous_response_recover_local", + bridge_session_key, + account_id=None, + model=effective_payload.model, + detail="outcome=local_rebind_after_local_error", + cache_key_family=bridge_session_key.affinity_kind, + model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, + owner_check_applied=True, + ) + + async with self._http_bridge_lock: + if self._http_bridge_sessions.get(session.key) is session: + self._http_bridge_sessions.pop(session.key, None) + async with session.pending_lock: + session.queued_request_count = 0 + await self._fail_pending_websocket_requests( + account_id_value=session.account.id, + pending_requests=session.pending_requests, + pending_lock=session.pending_lock, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + api_key=None, + response_create_gate=session.response_create_gate, + ) + await self._close_http_bridge_session(session) + + session = await self._get_or_create_http_bridge_session( + bridge_session_key, + headers=dict(headers), + affinity=affinity, + api_key=api_key, + request_model=effective_payload.model, + idle_ttl_seconds=_effective_http_bridge_idle_ttl_seconds( + affinity=affinity, + idle_ttl_seconds=idle_ttl_seconds, + codex_idle_ttl_seconds=codex_idle_ttl_seconds, + prompt_cache_idle_ttl_seconds=prompt_cache_idle_ttl_seconds, + ), + max_sessions=max_sessions, + previous_response_id=request_state.previous_response_id, + gateway_safe_mode=runtime_config.gateway_safe_mode, + allow_forward_to_owner=False, + forwarded_request=False, + allow_previous_response_recovery_rebind=True, + durable_lookup=durable_lookup, + request_stage="reattach", + preferred_account_id=request_state.preferred_account_id, + ) + _record_bridge_reattach(path="local_previous_response_error", outcome="success") + + try: + retry_api_key_reservation = api_key_reservation + retry_reservation_reacquired = False + if api_key is not None and api_key_reservation is not None: + retry_api_key_reservation = await self._reserve_websocket_api_key_usage( + api_key, + request_model=effective_payload.model, + request_service_tier=_normalize_service_tier_value( + dict(effective_payload.to_payload()).get("service_tier"), + ), + ) + retry_reservation_reacquired = True + + retry_request_state, retry_text_data = self._prepare_http_bridge_request( + effective_payload, + headers, + api_key=api_key, + api_key_reservation=retry_api_key_reservation, + request_id=request_id, + ) + retry_request_state.transport = _REQUEST_TRANSPORT_HTTP + retry_request_state.request_stage = request_state.request_stage + retry_request_state.preferred_account_id = request_state.preferred_account_id + + retry_events: AsyncGenerator[str, None] = self._stream_http_bridge_session_events( + session, + request_state=retry_request_state, + text_data=retry_text_data, + queue_limit=queue_limit, + propagate_http_errors=propagate_http_errors, + downstream_turn_state=downstream_turn_state, + ) + try: + async for event_block in retry_events: + yield event_block + finally: + try: + await retry_events.aclose() + except Exception: + pass + except BaseException: + if retry_reservation_reacquired and retry_api_key_reservation is not None: + await self._release_websocket_reservation(retry_api_key_reservation) + raise + finally: + try: + await session_events.aclose() + except Exception: + pass + + async def _stream_http_bridge_session_events( + self, + session: "_HTTPBridgeSession", + *, + request_state: _WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ) -> AsyncGenerator[str, None]: await self._submit_http_bridge_request( session, request_state=request_state, @@ -4150,7 +4282,7 @@ async def _relay_upstream_websocket_messages( continue if message.kind == "text" and message.text is not None: downstream_activity.mark() - await self._process_upstream_websocket_text( + downstream_text = await self._process_upstream_websocket_text( message.text, account=account, account_id_value=account_id_value, @@ -4163,7 +4295,7 @@ async def _relay_upstream_websocket_messages( await self._send_downstream_websocket_text( websocket, client_send_lock=client_send_lock, - text=message.text, + text=downstream_text, downstream_activity=downstream_activity, ) if upstream_control.reconnect_requested: @@ -4218,7 +4350,7 @@ async def _process_upstream_websocket_text( api_key: ApiKeyData | None, upstream_control: _WebSocketUpstreamControl, response_create_gate: asyncio.Semaphore, - ) -> None: + ) -> str: event_block = f"data: {text}\n\n" payload = parse_sse_data_json(event_block) event = parse_sse_event(event_block) @@ -4261,7 +4393,16 @@ async def _process_upstream_websocket_text( _release_websocket_response_create_gate(created_request_state, response_create_gate) if request_state is None: - return + return text + + event, payload, event_type, downstream_text = _maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=event, + payload=payload, + event_type=event_type, + upstream_control=upstream_control, + original_text=text, + ) await self._finalize_websocket_request_state( request_state, @@ -4274,6 +4415,7 @@ async def _process_upstream_websocket_text( upstream_control=upstream_control, response_create_gate=response_create_gate, ) + return downstream_text async def _next_websocket_receive_timeout( self, @@ -6374,6 +6516,113 @@ def _websocket_response_id(event: OpenAIEvent | None, payload: dict[str, JsonVal return stripped or None +def _websocket_event_error_code(event_type: str | None, payload: dict[str, JsonValue] | None) -> str | None: + error = _websocket_event_error_payload(event_type, payload) + if not isinstance(error, dict): + return None + code_value = error.get("code") + if not isinstance(code_value, str): + return None + stripped = code_value.strip() + return stripped or None + + +def _websocket_event_error_param(event_type: str | None, payload: dict[str, JsonValue] | None) -> str | None: + error = _websocket_event_error_payload(event_type, payload) + if not isinstance(error, dict): + return None + param_value = error.get("param") + if not isinstance(param_value, str): + return None + stripped = param_value.strip() + return stripped or None + + +def _websocket_event_error_message(event_type: str | None, payload: dict[str, JsonValue] | None) -> str | None: + error = _websocket_event_error_payload(event_type, payload) + if not isinstance(error, dict): + return None + message_value = error.get("message") + if not isinstance(message_value, str): + return None + stripped = message_value.strip() + return stripped or None + + +def _is_previous_response_not_found_message(message: str | None) -> bool: + if message is None: + return False + normalized = " ".join(message.lower().split()) + return "previous response" in normalized and "not found" in normalized + + +def _is_previous_response_not_found_error( + *, + code: str | None, + param: str | None, + message: str | None, +) -> bool: + if code == "previous_response_not_found": + return True + if code != "invalid_request_error" or param != "previous_response_id": + return False + return _is_previous_response_not_found_message(message) + + +def _websocket_event_error_payload( + event_type: str | None, + payload: dict[str, JsonValue] | None, +) -> dict[str, JsonValue] | None: + if not isinstance(payload, dict): + return None + if event_type == "error": + error = payload.get("error") + elif event_type == "response.failed": + response = payload.get("response") + error = response.get("error") if isinstance(response, dict) else None + else: + return None + return cast(dict[str, JsonValue], error) if isinstance(error, dict) else None + + +def _maybe_rewrite_websocket_previous_response_not_found_event( + *, + request_state: _WebSocketRequestState, + event: OpenAIEvent | None, + payload: dict[str, JsonValue] | None, + event_type: str | None, + upstream_control: _WebSocketUpstreamControl, + original_text: str, +) -> tuple[OpenAIEvent | None, dict[str, JsonValue] | None, str | None, str]: + if request_state.previous_response_id is None: + return event, payload, event_type, original_text + + error_code = _websocket_event_error_code(event_type, payload) + error_param = _websocket_event_error_param(event_type, payload) + error_message = _websocket_event_error_message(event_type, payload) + should_rewrite = _is_previous_response_not_found_error( + code=error_code, + param=error_param, + message=error_message, + ) + if not should_rewrite: + return event, payload, event_type, original_text + + upstream_control.reconnect_requested = True + rewritten_event_payload = response_failed_event( + "stream_incomplete", + "Upstream websocket closed before response.completed", + error_type="server_error", + response_id=request_state.response_id or request_state.request_id, + ) + rewritten_text = json.dumps(rewritten_event_payload, ensure_ascii=True, separators=(",", ":")) + rewritten_event_block = format_sse_event(rewritten_event_payload) + rewritten_payload = parse_sse_data_json(rewritten_event_block) + rewritten_event = parse_sse_event(rewritten_event_block) + rewritten_event_type = _event_type_from_payload(rewritten_event, rewritten_payload) + return rewritten_event, rewritten_payload, rewritten_event_type, rewritten_text + + def _find_websocket_request_state_by_response_id( pending_requests: deque[_WebSocketRequestState], response_id: str, @@ -7744,11 +7993,17 @@ def _http_bridge_should_attempt_local_previous_response_recovery(exc: ProxyRespo if not isinstance(error, dict): return False code = error.get("code") - return code in { + if code in { "bridge_owner_unreachable", "previous_response_not_found", "bridge_instance_mismatch", - } + }: + return True + param_value = error.get("param") + param = param_value.strip() if isinstance(param_value, str) and param_value.strip() else None + message_value = error.get("message") + message = message_value.strip() if isinstance(message_value, str) and message_value.strip() else None + return _is_previous_response_not_found_error(code=code, param=param, message=message) def _http_bridge_should_attempt_local_bootstrap_rebind( diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index 1c6b0f49..80981432 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -436,6 +436,56 @@ async def send_text(self, text: str) -> None: ) +class _PreviousResponseNotFoundUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + previous_response_id = payload.get("previous_response_id") + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": f"Previous response with id '{previous_response_id}' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + + +class _InvalidRequestPreviousResponseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + previous_response_id = payload.get("previous_response_id") + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": f"Previous response with id '{previous_response_id}' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + + class _FailingSendThenCloseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): async def send_text(self, text: str) -> None: self.sent_text.append(text) @@ -6658,6 +6708,230 @@ async def fake_connect_responses_websocket( assert "previous_response_not_found" not in second.json()["error"].get("code", "") +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_rebinds_after_upstream_previous_response_not_found( + async_client, + app_instance, + monkeypatch, +): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( + async_client, + "acc_http_bridge_previous_response_rebind", + "http-bridge-previous-response-rebind@example.com", + ) + account = await _get_account(account_id) + first_upstream = _FakeBridgeUpstreamWebSocket() + recovered_upstream = _FakeBridgeUpstreamWebSocket() + connect_count = 0 + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + api_key, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return first_upstream + return recovered_upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + first = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "previous-response-rebind", + }, + ) + assert first.status_code == 200 + first_body = first.json() + + service = get_proxy_service_for_app(app_instance) + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + await _replace_http_bridge_upstream_reader( + service, + session, + cast(proxy_module.UpstreamResponsesWebSocket, _PreviousResponseNotFoundUpstreamWebSocket()), + ) + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "previous-response-rebind", + "previous_response_id": first_body["id"], + }, + ) + + assert second.status_code == 200 + assert second.json()["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 2 + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_rebinds_after_upstream_invalid_request_previous_response_not_found_param( + async_client, + app_instance, + monkeypatch, +): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( + async_client, + "acc_http_bridge_invalid_request_rebind", + "http-bridge-invalid-request-rebind@example.com", + ) + account = await _get_account(account_id) + first_upstream = _FakeBridgeUpstreamWebSocket() + recovered_upstream = _FakeBridgeUpstreamWebSocket() + connect_count = 0 + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + api_key, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return first_upstream + return recovered_upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + first = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "invalid-request-rebind", + }, + ) + assert first.status_code == 200 + first_body = first.json() + + service = get_proxy_service_for_app(app_instance) + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + await _replace_http_bridge_upstream_reader( + service, + session, + cast(proxy_module.UpstreamResponsesWebSocket, _InvalidRequestPreviousResponseUpstreamWebSocket()), + ) + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "invalid-request-rebind", + "previous_response_id": first_body["id"], + }, + ) + + assert second.status_code == 200 + assert second.json()["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 2 + + @pytest.mark.asyncio async def test_v1_responses_http_bridge_send_retry_keeps_session_open_for_followup_request( async_client, diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index dcbbee4a..16c74586 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1075,6 +1075,360 @@ async def fake_write_request_log(self, **kwargs): ] +def test_v1_responses_websocket_masks_previous_response_not_found_and_recovers_on_retry( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": ("Previous response with id 'resp_ws_prev_anchor' not found."), + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_prev_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_ws_prev_retry", "status": "completed"}}, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_prev_mask"), first_upstream + return SimpleNamespace(id="acct_ws_prev_mask"), recovered_upstream + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + with TestClient(app_instance) as client: + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "hello", + "stream": True, + } + ) + ) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + failed_retryable = json.loads(websocket.receive_text()) + assert failed_retryable["type"] == "response.failed" + assert failed_retryable["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(failed_retryable) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-retry", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert completed_2["type"] == "response.completed" + assert connect_count == 2 + assert first_upstream.closed is True + + +def test_v1_responses_websocket_masks_invalid_request_previous_response_not_found_error_and_recovers_on_retry( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": ("Previous response with id 'resp_ws_prev_anchor' not found."), + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_prev_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_ws_prev_retry", "status": "completed"}}, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_prev_mask"), first_upstream + return SimpleNamespace(id="acct_ws_prev_mask"), recovered_upstream + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + with TestClient(app_instance) as client: + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "hello", + "stream": True, + } + ) + ) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + failed_retryable = json.loads(websocket.receive_text()) + assert failed_retryable["type"] == "response.failed" + assert failed_retryable["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(failed_retryable) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-retry", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert completed_2["type"] == "response.completed" + assert connect_count == 2 + assert first_upstream.closed is True + + @pytest.mark.parametrize("frame", ['{"type":"response.create"', "[]"]) def test_backend_responses_websocket_rejects_malformed_first_frame_as_invalid_payload(app_instance, monkeypatch, frame): called = {"connect": False} diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index d9a8b241..84d0287a 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -755,6 +755,357 @@ async def fake_forward(**kwargs: object): get_or_create.assert_awaited_once() +@pytest.mark.asyncio +async def test_stream_via_http_bridge_reacquires_api_key_reservation_for_local_previous_response_rebind( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + api_key = _make_api_key(key_id="key-1", assigned_account_ids=[]) + initial_reservation = proxy_service.ApiKeyUsageReservationData( + reservation_id="resv-initial", + key_id=api_key.id, + model="gpt-5.4", + ) + retried_reservation = proxy_service.ApiKeyUsageReservationData( + reservation_id="resv-retry", + key_id=api_key.id, + model="gpt-5.4", + ) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "bridge-prev-rebind", + "previous_response_id": "resp_prev_1", + } + ) + + request_state_initial = proxy_service._WebSocketRequestState( + request_id="req-initial", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=initial_reservation, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + request_state_initial.request_stage = "follow_up" + request_state_initial.preferred_account_id = "acc-1" + request_state_retry = proxy_service._WebSocketRequestState( + request_id="req-retry", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=retried_reservation, + started_at=2.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + + prepare_reservations: list[proxy_service.ApiKeyUsageReservationData | None] = [] + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del prepared_payload, api_key, request_id + prepare_reservations.append(api_key_reservation) + if len(prepare_reservations) == 1: + return request_state_initial, '{"type":"response.create","request":"initial"}' + return request_state_retry, '{"type":"response.create","request":"retry"}' + + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", api_key.id) + session_initial = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + session_retry = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session_initial + + stream_calls = {"count": 0} + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + stream_calls["count"] += 1 + if stream_calls["count"] == 1: + raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) + yield 'data: {"type":"response.completed"}\n\n' + + reserve_retry = AsyncMock(return_value=retried_reservation) + get_or_create = AsyncMock(side_effect=[session_initial, session_retry]) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_retry) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=api_key, + api_key_reservation=initial_reservation, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == ['data: {"type":"response.completed"}\n\n'] + assert prepare_reservations == [initial_reservation, retried_reservation] + reserve_retry.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_local_previous_response_rebind_fails_existing_pending_requests( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "bridge-prev-rebind", + "previous_response_id": "resp_prev_1", + } + ) + + request_state_initial = proxy_service._WebSocketRequestState( + request_id="req-initial", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + request_state_initial.request_stage = "follow_up" + request_state_initial.preferred_account_id = "acc-1" + request_state_retry = proxy_service._WebSocketRequestState( + request_id="req-retry", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=2.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + + stale_pending_queue: asyncio.Queue[str | None] = asyncio.Queue() + stale_pending_request = proxy_service._WebSocketRequestState( + request_id="req-stale", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.5, + event_queue=stale_pending_queue, + transport="http", + ) + stale_pending_request.skip_request_log = True + + prepare_calls = {"count": 0} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del prepared_payload, api_key, api_key_reservation, request_id + prepare_calls["count"] += 1 + if prepare_calls["count"] == 1: + return request_state_initial, '{"type":"response.create","request":"initial"}' + return request_state_retry, '{"type":"response.create","request":"retry"}' + + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", None) + session_initial = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque([stale_pending_request]), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + session_retry = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session_initial + + stream_calls = {"count": 0} + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + stream_calls["count"] += 1 + if stream_calls["count"] == 1: + raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) + yield 'data: {"type":"response.completed"}\n\n' + + get_or_create = AsyncMock(side_effect=[session_initial, session_retry]) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ) + ] + + failed_block = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + done_marker = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + + assert chunks == ['data: {"type":"response.completed"}\n\n'] + assert isinstance(failed_block, str) + assert '"type":"response.failed"' in failed_block + assert '"code":"stream_incomplete"' in failed_block + assert done_marker is None + assert not session_initial.pending_requests + assert session_initial.queued_request_count == 0 + + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_returns_owner_forward_for_hard_mismatch( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index 68bae47b..bd39c986 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -4495,6 +4495,263 @@ async def test_process_upstream_websocket_text_does_not_match_foreign_response_i assert list(pending_requests) == [pending_request] +def test_maybe_rewrite_websocket_previous_response_not_found_rewrites_response_failed_event(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_nf", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_nf", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "response.failed", + "response": { + "id": "resp_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is True + assert rewritten_event_type == "response.failed" + assert rewritten_payload is not None + response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) + error_payload = cast(dict[str, JsonValue], response_payload.get("error")) + assert error_payload["code"] == "stream_incomplete" + assert error_payload["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in rewritten_text + + +def test_maybe_rewrite_websocket_previous_response_invalid_request_error_rewrites_when_message_is_not_found(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_invalid", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_invalid", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Previous response with id 'resp_prev_anchor' not found.", + "param": "previous_response_id", + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is True + assert rewritten_event_type == "response.failed" + assert rewritten_payload is not None + response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) + error_payload = cast(dict[str, JsonValue], response_payload.get("error")) + assert error_payload["code"] == "stream_incomplete" + assert error_payload["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in rewritten_text + + +def test_maybe_rewrite_websocket_previous_response_invalid_request_error_does_not_rewrite_other_message(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_invalid_other_message", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_invalid_other_message", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Invalid request payload", + "param": "previous_response_id", + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is False + assert rewritten_event_type == original_event_type + assert rewritten_payload == original_payload + assert rewritten_text == original_text + + +def test_maybe_rewrite_websocket_previous_response_invalid_request_error_does_not_rewrite_other_param(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_invalid_other_param", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_invalid_other_param", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Invalid request payload", + "param": "input", + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is False + assert rewritten_event_type == original_event_type + assert rewritten_payload == original_payload + assert rewritten_text == original_text + + +def test_http_bridge_should_attempt_local_previous_response_recovery_invalid_request_requires_not_found_message(): + recoverable_error = proxy_module.ProxyResponseError( + 400, + { + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Previous response with id 'resp_prev_anchor' not found.", + "param": "previous_response_id", + } + }, + ) + non_recoverable_error = proxy_module.ProxyResponseError( + 400, + { + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Invalid request payload", + "param": "previous_response_id", + } + }, + ) + + assert proxy_service._http_bridge_should_attempt_local_previous_response_recovery(recoverable_error) is True + assert proxy_service._http_bridge_should_attempt_local_previous_response_recovery(non_recoverable_error) is False + + +def test_maybe_rewrite_websocket_previous_response_not_found_leaves_non_previous_request_unchanged(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_plain", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + ) + original_payload: dict[str, JsonValue] = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_any' not found.", + "param": "previous_response_id", + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is False + assert rewritten_event_type == original_event_type + assert rewritten_payload == original_payload + assert rewritten_text == original_text + + @pytest.mark.asyncio async def test_stream_responses_budget_exhaustion_emits_timeout_event(monkeypatch): settings = _make_proxy_settings(log_proxy_service_tier_trace=False) From f7334204d5bbd1fa2319f16a4cf2653fbe7653f1 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Wed, 15 Apr 2026 15:30:13 +0200 Subject: [PATCH 02/18] fix(ws): transparently replay pre-created responses on quota/rate-limit WS errors --- app/modules/proxy/service.py | 326 ++++++++++----- .../test_proxy_websocket_responses.py | 374 ++++++++++++++++++ tests/unit/test_proxy_utils.py | 344 ++++++++++++++++ 3 files changed, 952 insertions(+), 92 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 15eb24ec..e64daa2e 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -216,6 +216,15 @@ async def _await_cancelled_task( _COMPACT_MAX_ACCOUNT_ATTEMPTS = 2 _STREAM_MAX_ACCOUNT_ATTEMPTS = 3 _WEBSOCKET_MAX_ACCOUNT_ATTEMPTS = 3 +_WEBSOCKET_TRANSPARENT_REPLAY_ERROR_CODES = frozenset( + { + "rate_limit_exceeded", + "usage_limit_reached", + "insufficient_quota", + "usage_not_included", + "quota_exceeded", + } +) @dataclass(frozen=True, slots=True) @@ -1460,54 +1469,17 @@ async def proxy_responses_websocket( account: Account | None = None upstream_turn_state: str | None = _sticky_key_from_turn_state_header(headers) downstream_activity = _DownstreamWebSocketActivity() + replay_request_state: _WebSocketRequestState | None = None try: while True: - downstream_idle_timeout_seconds = runtime_settings.proxy_downstream_websocket_idle_timeout_seconds - try: - message = await asyncio.wait_for( - websocket.receive(), - timeout=min(downstream_idle_timeout_seconds, _DOWNSTREAM_WEBSOCKET_RECEIVE_POLL_SECONDS), - ) - except asyncio.TimeoutError: - if not await self._downstream_websocket_is_idle( - pending_requests, - pending_lock=pending_lock, - downstream_activity=downstream_activity, - idle_timeout_seconds=downstream_idle_timeout_seconds, - ): - continue - idle_close = False - async with client_send_lock: - if await self._downstream_websocket_is_idle( - pending_requests, - pending_lock=pending_lock, - downstream_activity=downstream_activity, - idle_timeout_seconds=downstream_idle_timeout_seconds, - ): - try: - message = await asyncio.wait_for(websocket.receive(), timeout=0.05) - except asyncio.TimeoutError: - try: - await websocket.close(code=1001, reason=_DOWNSTREAM_WEBSOCKET_IDLE_CLOSE_REASON) - except Exception: - logger.debug("Failed to close idle downstream websocket", exc_info=True) - idle_close = True - if idle_close: - break - downstream_activity.mark() - message_type = message["type"] - - if message_type == "websocket.disconnect": - break - if message_type != "websocket.receive": - continue - if upstream_reader is not None and upstream_reader.done(): try: await upstream_reader except asyncio.CancelledError: pass + if replay_request_state is None and upstream_control is not None: + replay_request_state = upstream_control.replay_request_state upstream_reader = None upstream_control = None if upstream is not None: @@ -1518,59 +1490,155 @@ async def proxy_responses_websocket( upstream = None account = None - text_data = message.get("text") - bytes_data = message.get("bytes") + text_data: str | None = None + bytes_data: bytes | None = None request_state: _WebSocketRequestState | None = None request_state_registered = False request_affinity = _AffinityPolicy() payload: dict[str, JsonValue] | None = None - if text_data is not None: + if replay_request_state is not None: + request_state = replay_request_state + replay_request_state = None + request_affinity = request_state.affinity_policy + text_data = request_state.request_text + if text_data is None: + await self._release_websocket_reservation(request_state.api_key_reservation) + await self._emit_websocket_terminal_error( + websocket, + client_send_lock=client_send_lock, + request_state=request_state, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + error_type="server_error", + downstream_activity=downstream_activity, + ) + _release_websocket_response_create_gate(request_state, response_create_gate) + continue payload = _parse_websocket_payload(text_data) - if payload is not None and _is_websocket_response_create(payload): - try: - prepared_request = await self._prepare_websocket_response_create_request( - payload, - headers=headers, - codex_session_affinity=codex_session_affinity, - openai_cache_affinity=openai_cache_affinity, - sticky_threads_enabled=sticky_threads_enabled, - openai_cache_affinity_max_age_seconds=openai_cache_affinity_max_age_seconds, - api_key=api_key, - ) - request_state = prepared_request.request_state - request_affinity = prepared_request.affinity_policy - text_data = prepared_request.text_data - except ProxyResponseError as exc: - async with client_send_lock: - await websocket.send_text( - _serialize_websocket_error_event( - _wrapped_websocket_error_event(exc.status_code, exc.payload) - ) - ) + if payload is None: + await self._release_websocket_reservation(request_state.api_key_reservation) + await self._emit_websocket_terminal_error( + websocket, + client_send_lock=client_send_lock, + request_state=request_state, + error_code="upstream_error", + error_message="Invalid replay request payload", + error_type="server_error", + downstream_activity=downstream_activity, + ) + _release_websocket_response_create_gate(request_state, response_create_gate) + continue + async with pending_lock: + pending_requests.append(request_state) + request_state_registered = True + else: + downstream_idle_timeout_seconds = runtime_settings.proxy_downstream_websocket_idle_timeout_seconds + try: + message = await asyncio.wait_for( + websocket.receive(), + timeout=min(downstream_idle_timeout_seconds, _DOWNSTREAM_WEBSOCKET_RECEIVE_POLL_SECONDS), + ) + except asyncio.TimeoutError: + if not await self._downstream_websocket_is_idle( + pending_requests, + pending_lock=pending_lock, + downstream_activity=downstream_activity, + idle_timeout_seconds=downstream_idle_timeout_seconds, + ): continue - except AppError as exc: - async with client_send_lock: - await websocket.send_text( - _serialize_websocket_error_event(_app_error_to_websocket_event(exc)) + idle_close = False + async with client_send_lock: + if await self._downstream_websocket_is_idle( + pending_requests, + pending_lock=pending_lock, + downstream_activity=downstream_activity, + idle_timeout_seconds=downstream_idle_timeout_seconds, + ): + try: + message = await asyncio.wait_for(websocket.receive(), timeout=0.05) + except asyncio.TimeoutError: + try: + await websocket.close(code=1001, reason=_DOWNSTREAM_WEBSOCKET_IDLE_CLOSE_REASON) + except Exception: + logger.debug("Failed to close idle downstream websocket", exc_info=True) + idle_close = True + if idle_close: + break + downstream_activity.mark() + message_type = message["type"] + + if message_type == "websocket.disconnect": + break + if message_type != "websocket.receive": + continue + + text_data = message.get("text") + bytes_data = message.get("bytes") + + if text_data is not None: + payload = _parse_websocket_payload(text_data) + if payload is not None and _is_websocket_response_create(payload): + try: + prepared_request = await self._prepare_websocket_response_create_request( + payload, + headers=headers, + codex_session_affinity=codex_session_affinity, + openai_cache_affinity=openai_cache_affinity, + sticky_threads_enabled=sticky_threads_enabled, + openai_cache_affinity_max_age_seconds=openai_cache_affinity_max_age_seconds, + api_key=api_key, ) - continue - except ClientPayloadError as exc: - async with client_send_lock: - await websocket.send_text( - _serialize_websocket_error_event( - _wrapped_websocket_error_event(400, openai_invalid_payload_error(exc.param)) + request_state = prepared_request.request_state + request_affinity = prepared_request.affinity_policy + text_data = prepared_request.text_data + except ProxyResponseError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event( + _wrapped_websocket_error_event(exc.status_code, exc.payload) + ) ) - ) - continue - except ValidationError as exc: - async with client_send_lock: - await websocket.send_text( - _serialize_websocket_error_event( - _wrapped_websocket_error_event(400, openai_validation_error(exc)) + continue + except AppError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event(_app_error_to_websocket_event(exc)) ) - ) - continue + continue + except ClientPayloadError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event( + _wrapped_websocket_error_event(400, openai_invalid_payload_error(exc.param)) + ) + ) + continue + except ValidationError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event( + _wrapped_websocket_error_event(400, openai_validation_error(exc)) + ) + ) + continue + + if upstream_reader is not None and upstream_reader.done(): + try: + await upstream_reader + except asyncio.CancelledError: + pass + if replay_request_state is None and upstream_control is not None: + replay_request_state = upstream_control.replay_request_state + upstream_reader = None + upstream_control = None + if upstream is not None: + try: + await upstream.close() + except Exception: + logger.debug("Failed to close upstream websocket", exc_info=True) + upstream = None + account = None if ( request_state is not None @@ -1579,6 +1647,8 @@ async def proxy_responses_websocket( and upstream_reader is not None ): await upstream_reader + if replay_request_state is None: + replay_request_state = upstream_control.replay_request_state upstream_reader = None upstream_control = None if upstream is not None: @@ -1589,7 +1659,7 @@ async def proxy_responses_websocket( upstream = None account = None - if request_state is not None: + if request_state is not None and not request_state_registered: try: await self._acquire_request_state_response_create_admission( request_state, @@ -1820,6 +1890,7 @@ async def _prepare_websocket_response_create_request( sticky_key_source=sticky_key_source, prompt_cache_key_set=_prompt_cache_key_from_request_model(responses_payload) is not None, ) + request_state.affinity_policy = affinity_policy return _PreparedWebSocketRequest( text_data=text_data, @@ -4292,15 +4363,20 @@ async def _relay_upstream_websocket_messages( upstream_control=upstream_control, response_create_gate=response_create_gate, ) - await self._send_downstream_websocket_text( - websocket, - client_send_lock=client_send_lock, - text=downstream_text, - downstream_activity=downstream_activity, - ) + suppress_downstream_event = upstream_control.suppress_downstream_event + upstream_control.suppress_downstream_event = False + if not suppress_downstream_event: + await self._send_downstream_websocket_text( + websocket, + client_send_lock=client_send_lock, + text=downstream_text, + downstream_activity=downstream_activity, + ) if upstream_control.reconnect_requested: - async with pending_lock: - should_reconnect = not pending_requests + should_reconnect = upstream_control.replay_request_state is not None + if not should_reconnect: + async with pending_lock: + should_reconnect = not pending_requests if should_reconnect: try: await upstream.close() @@ -4360,6 +4436,7 @@ async def _process_upstream_websocket_text( async with pending_lock: request_state = None created_request_state = None + has_other_pending_requests = False if event_type == "response.created": request_state = _assign_websocket_response_id(pending_requests, response_id) created_request_state = request_state @@ -4386,6 +4463,7 @@ async def _process_upstream_websocket_text( response_id=response_id, fallback_request_state=request_state, ) + has_other_pending_requests = bool(pending_requests) else: request_state = None @@ -4403,6 +4481,25 @@ async def _process_upstream_websocket_text( upstream_control=upstream_control, original_text=text, ) + retry_error_code = _websocket_precreated_retry_error_code( + request_state, + event_type=event_type, + payload=payload, + has_other_pending_requests=has_other_pending_requests, + ) + if retry_error_code is not None: + request_state.replay_count += 1 + request_state.awaiting_response_created = True + request_state.response_id = None + upstream_control.reconnect_requested = True + upstream_control.suppress_downstream_event = True + upstream_control.replay_request_state = request_state + await self._handle_stream_error( + account, + {"message": _websocket_event_error_message(event_type, payload) or "Upstream error"}, + retry_error_code, + ) + return downstream_text await self._finalize_websocket_request_state( request_state, @@ -6279,6 +6376,7 @@ class _WebSocketRequestState: error_http_status_override: int | None = None response_create_gate_acquired: bool = False response_create_admission: AdmissionLease | None = None + affinity_policy: _AffinityPolicy = field(default_factory=_AffinityPolicy) @dataclass(frozen=True, slots=True) @@ -6337,6 +6435,8 @@ class _HTTPBridgeSession: @dataclass(slots=True) class _WebSocketUpstreamControl: reconnect_requested: bool = False + suppress_downstream_event: bool = False + replay_request_state: _WebSocketRequestState | None = None @dataclass(slots=True) @@ -6527,6 +6627,17 @@ def _websocket_event_error_code(event_type: str | None, payload: dict[str, JsonV return stripped or None +def _websocket_event_error_type(event_type: str | None, payload: dict[str, JsonValue] | None) -> str | None: + error = _websocket_event_error_payload(event_type, payload) + if not isinstance(error, dict): + return None + type_value = error.get("type") + if not isinstance(type_value, str): + return None + stripped = type_value.strip() + return stripped or None + + def _websocket_event_error_param(event_type: str | None, payload: dict[str, JsonValue] | None) -> str | None: error = _websocket_event_error_payload(event_type, payload) if not isinstance(error, dict): @@ -6556,6 +6667,37 @@ def _is_previous_response_not_found_message(message: str | None) -> bool: return "previous response" in normalized and "not found" in normalized +def _websocket_precreated_retry_error_code( + request_state: _WebSocketRequestState | None, + *, + event_type: str | None, + payload: dict[str, JsonValue] | None, + has_other_pending_requests: bool, +) -> str | None: + if request_state is None: + return None + if has_other_pending_requests: + return None + if request_state.response_id is not None: + return None + if not request_state.awaiting_response_created: + return None + if not request_state.request_text: + return None + if request_state.replay_count >= 1: + return None + if event_type not in {"error", "response.failed"}: + return None + + error_code = _normalize_error_code( + _websocket_event_error_code(event_type, payload), + _websocket_event_error_type(event_type, payload), + ) + if error_code not in _WEBSOCKET_TRANSPARENT_REPLAY_ERROR_CODES: + return None + return error_code + + def _is_previous_response_not_found_error( *, code: str | None, diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 16c74586..81deece2 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -2368,6 +2368,380 @@ async def fake_write_request_log(self, **kwargs): ] +def test_backend_responses_websocket_transparently_retries_precreated_usage_limit_reached(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.failed", + "response": { + "id": "resp_ws_quota_fail", + "status": "failed", + "error": {"code": "usage_limit_reached", "message": "usage limit reached"}, + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + }, + separators=(",", ":"), + ), + ) + ] + ) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_quota_ok", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_quota_ok", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + upstreams = [first_upstream, second_upstream] + connect_models: list[str | None] = [] + handled_error_codes: list[str] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + request_state, + api_key, + client_send_lock, + websocket, + ) + upstream = upstreams[len(connect_models)] + connect_models.append(model) + return SimpleNamespace(id=f"acct_ws_proxy_{len(connect_models)}"), upstream + + async def fake_handle_stream_error(self, account, error, code): + del self, account, error + handled_error_codes.append(code) + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry once"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + first_event = json.loads(websocket.receive_text()) + assert first_event["type"] == "response.created" + second_event = json.loads(websocket.receive_text()) + + assert second_event["type"] == "response.completed" + assert connect_models == ["gpt-5.1", "gpt-5.1"] + assert handled_error_codes == ["usage_limit_reached"] + assert len(first_upstream.sent_text) == 1 + assert len(second_upstream.sent_text) == 1 + assert json.loads(first_upstream.sent_text[0]) == json.loads(second_upstream.sent_text[0]) + + +def test_backend_responses_websocket_transparently_retries_precreated_error_usage_limit_reached( + app_instance, + monkeypatch, +): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 429, + "error": { + "type": "invalid_request_error", + "code": "usage_limit_reached", + "message": "The usage limit has been reached", + }, + }, + separators=(",", ":"), + ), + ) + ] + ) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_quota_ok_err", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_quota_ok_err", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + upstreams = [first_upstream, second_upstream] + connect_models: list[str | None] = [] + handled_error_codes: list[str] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + request_state, + api_key, + client_send_lock, + websocket, + ) + upstream = upstreams[len(connect_models)] + connect_models.append(model) + return SimpleNamespace(id=f"acct_ws_proxy_{len(connect_models)}"), upstream + + async def fake_handle_stream_error(self, account, error, code): + del self, account, error + handled_error_codes.append(code) + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry once"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + first_event = json.loads(websocket.receive_text()) + second_event = json.loads(websocket.receive_text()) + + assert first_event["type"] == "response.created" + assert second_event["type"] == "response.completed" + assert connect_models == ["gpt-5.1", "gpt-5.1"] + assert handled_error_codes == ["usage_limit_reached"] + assert len(first_upstream.sent_text) == 1 + assert len(second_upstream.sent_text) == 1 + assert json.loads(first_upstream.sent_text[0]) == json.loads(second_upstream.sent_text[0]) + + +def test_backend_responses_websocket_transparent_replay_emits_no_accounts_when_reconnect_fails( + app_instance, + monkeypatch, +): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.failed", + "response": { + "id": "resp_ws_quota_fail_no_accounts", + "status": "failed", + "error": {"code": "usage_limit_reached", "message": "usage limit reached"}, + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + }, + separators=(",", ":"), + ), + ) + ] + ) + connect_models: list[str | None] = [] + handled_error_codes: list[str] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + request_state, + api_key, + ) + connect_models.append(model) + if len(connect_models) == 1: + del client_send_lock, websocket + return SimpleNamespace(id="acct_ws_proxy_1"), first_upstream + async with client_send_lock: + await websocket.send_text( + json.dumps( + { + "type": "error", + "status": 503, + "error": {"code": "no_accounts", "message": "No active accounts available"}, + } + ) + ) + return None, None + + async def fake_handle_stream_error(self, account, error, code): + del self, account, error + handled_error_codes.append(code) + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry once"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + event = json.loads(websocket.receive_text()) + + assert event["type"] == "error" + assert event["status"] == 503 + assert event["error"]["code"] == "no_accounts" + assert connect_models == ["gpt-5.1", "gpt-5.1"] + assert handled_error_codes == ["usage_limit_reached"] + assert first_upstream.closed is True + + def test_backend_responses_websocket_emits_no_accounts_error(app_instance, monkeypatch): request_payload = { "type": "response.create", diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index bd39c986..ef61ff3a 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -4495,6 +4495,350 @@ async def test_process_upstream_websocket_text_does_not_match_foreign_response_i assert list(pending_requests) == [pending_request] +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_transparently_retries_precreated_usage_limit_failure( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_precreated_retry") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry me"}]}], + } + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_precreated_retry", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps(request_payload, separators=(",", ":")), + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + upstream_payload = { + "type": "response.failed", + "response": { + "id": "resp_ws_precreated_fail", + "status": "failed", + "error": {"code": "usage_limit_reached", "message": "usage limit reached"}, + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + } + upstream_text = json.dumps(upstream_payload, separators=(",", ":")) + + downstream_text = await service._process_upstream_websocket_text( + upstream_text, + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert downstream_text == upstream_text + finalize_request_state.assert_not_awaited() + handle_stream_error.assert_awaited_once() + handle_call = handle_stream_error.await_args + assert handle_call is not None + assert handle_call.args[0] == account + assert handle_call.args[2] == "usage_limit_reached" + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is True + assert upstream_control.replay_request_state is pending_request + assert pending_request.replay_count == 1 + assert list(pending_requests) == [] + + +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_transparently_retries_precreated_usage_limit_error_event( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_precreated_retry_error_event") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry me"}]}], + } + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_precreated_retry_error_event", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps(request_payload, separators=(",", ":")), + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + upstream_payload = { + "type": "error", + "status": 429, + "error": { + "type": "invalid_request_error", + "code": "usage_limit_reached", + "message": "The usage limit has been reached", + }, + } + upstream_text = json.dumps(upstream_payload, separators=(",", ":")) + + downstream_text = await service._process_upstream_websocket_text( + upstream_text, + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert downstream_text == upstream_text + finalize_request_state.assert_not_awaited() + handle_stream_error.assert_awaited_once() + handle_call = handle_stream_error.await_args + assert handle_call is not None + assert handle_call.args[0] == account + assert handle_call.args[2] == "usage_limit_reached" + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is True + assert upstream_control.replay_request_state is pending_request + assert pending_request.replay_count == 1 + assert list(pending_requests) == [] + + +@pytest.mark.asyncio +async def test_proxy_responses_websocket_transparent_replay_preserves_sticky_thread_affinity( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + handled_error_codes: list[str] = [] + connect_calls: list[dict[str, object]] = [] + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.sticky_threads_enabled = True + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + + class _FakeDownstreamWebSocket: + def __init__(self, request_text: str) -> None: + self._request_text = request_text + self._request_sent = False + self._disconnect_sent = False + self._done = asyncio.Event() + self.sent_text: list[str] = [] + self.closed = False + + async def receive(self) -> dict[str, object]: + if not self._request_sent: + self._request_sent = True + return {"type": "websocket.receive", "text": self._request_text} + if not self._disconnect_sent: + await self._done.wait() + self._disconnect_sent = True + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + try: + payload = json.loads(text) + except json.JSONDecodeError: + payload = {} + if payload.get("type") in {"response.completed", "response.failed", "error"}: + self._done.set() + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self.closed = True + self._done.set() + + class _FakeUpstreamWebSocket: + def __init__(self, messages: list[SimpleNamespace]) -> None: + self.sent_text: list[str] = [] + self.closed = False + self._messages: asyncio.Queue[SimpleNamespace] = asyncio.Queue() + for message in messages: + self._messages.put_nowait(message) + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def receive(self) -> SimpleNamespace: + return await self._messages.get() + + async def close(self) -> None: + self.closed = True + + first_upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.failed", + "response": { + "id": "resp_ws_sticky_retry_fail", + "status": "failed", + "error": {"code": "usage_limit_reached", "message": "usage limit reached"}, + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ) + ] + ) + second_upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_sticky_retry_ok", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_sticky_retry_ok", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + ] + ) + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + request_state, + api_key, + client_send_lock, + websocket, + ) + connect_calls.append( + { + "sticky_key": sticky_key, + "sticky_kind": sticky_kind, + "reallocate_sticky": reallocate_sticky, + "model": model, + } + ) + if len(connect_calls) == 1: + return _make_account("acc_ws_sticky_1"), first_upstream + return _make_account("acc_ws_sticky_2"), second_upstream + + async def fake_handle_stream_error(self, account, error, code): + del self, account, error + handled_error_codes.append(code) + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_service.ProxyService, "_handle_stream_error", fake_handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "prompt_cache_key": "sticky-thread-xyz", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "retry me"}]}], + "stream": True, + } + downstream = _FakeDownstreamWebSocket(json.dumps(request_payload, separators=(",", ":"))) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {}, + codex_session_affinity=False, + openai_cache_affinity=False, + api_key=None, + ) + + emitted_events = [json.loads(event) for event in downstream.sent_text] + assert [event["type"] for event in emitted_events] == ["response.created", "response.completed"] + assert handled_error_codes == ["usage_limit_reached"] + assert len(connect_calls) == 2 + assert connect_calls[0]["sticky_key"] == "sticky-thread-xyz" + assert connect_calls[0]["sticky_kind"] == proxy_service.StickySessionKind.STICKY_THREAD + assert connect_calls[0]["reallocate_sticky"] is True + assert connect_calls[1]["sticky_key"] == "sticky-thread-xyz" + assert connect_calls[1]["sticky_kind"] == proxy_service.StickySessionKind.STICKY_THREAD + assert connect_calls[1]["reallocate_sticky"] is True + assert first_upstream.closed is True + assert len(first_upstream.sent_text) == 1 + assert len(second_upstream.sent_text) == 1 + assert json.loads(first_upstream.sent_text[0]) == json.loads(second_upstream.sent_text[0]) + + def test_maybe_rewrite_websocket_previous_response_not_found_rewrites_response_failed_event(): request_state = proxy_service._WebSocketRequestState( request_id="ws_req_prev_nf", From a6d6efa94c7c840be5cff104a112815307aaffd1 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 10:29:55 +0200 Subject: [PATCH 03/18] fix(proxy): harden shutdown and reconnect lifecycle --- app/core/usage/refresh_scheduler.py | 13 +- ..._add_request_logs_response_lookup_index.py | 67 + app/db/models.py | 19 + .../proxy/durable_bridge_coordinator.py | 11 + .../proxy/durable_bridge_repository.py | 67 +- app/modules/proxy/service.py | 965 +++++- app/modules/request_logs/repository.py | 44 + app/modules/usage/updater.py | 16 + .../integration/test_http_responses_bridge.py | 10 +- .../test_proxy_websocket_responses.py | 32 - tests/unit/test_db_migrate.py | 26 + tests/unit/test_durable_bridge_sessions.py | 206 +- tests/unit/test_otel.py | 177 ++ tests/unit/test_proxy_http_bridge.py | 2682 ++++++++++++++++- tests/unit/test_proxy_utils.py | 1282 +++++++- tests/unit/test_request_logs_repository.py | 125 + tests/unit/test_usage_updater.py | 72 + 17 files changed, 5391 insertions(+), 423 deletions(-) create mode 100644 app/db/alembic/versions/20260415_160000_add_request_logs_response_lookup_index.py diff --git a/app/core/usage/refresh_scheduler.py b/app/core/usage/refresh_scheduler.py index 7f5e4a2b..0056289c 100644 --- a/app/core/usage/refresh_scheduler.py +++ b/app/core/usage/refresh_scheduler.py @@ -12,6 +12,7 @@ from app.modules.accounts.repository import AccountsRepository from app.modules.proxy.account_cache import get_account_selection_cache from app.modules.proxy.rate_limit_cache import get_rate_limit_headers_cache +from app.modules.usage import updater as usage_updater_module from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository from app.modules.usage.updater import UsageUpdater @@ -44,13 +45,13 @@ async def start(self) -> None: self._task = asyncio.create_task(self._run_loop()) async def stop(self) -> None: - if not self._task: - return self._stop.set() - self._task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._task - self._task = None + if self._task is not None: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None + await usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT.cancel_all() async def _run_loop(self) -> None: while not self._stop.is_set(): diff --git a/app/db/alembic/versions/20260415_160000_add_request_logs_response_lookup_index.py b/app/db/alembic/versions/20260415_160000_add_request_logs_response_lookup_index.py new file mode 100644 index 00000000..9afbdca6 --- /dev/null +++ b/app/db/alembic/versions/20260415_160000_add_request_logs_response_lookup_index.py @@ -0,0 +1,67 @@ +"""add request_logs response lookup index + +Revision ID: 20260415_160000_add_request_logs_response_lookup_index +Revises: 20260413_000000_add_accounts_blocked_at +Create Date: 2026-04-15 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "20260415_160000_add_request_logs_response_lookup_index" +down_revision = "20260413_000000_add_accounts_blocked_at" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + existing_columns = {column["name"] for column in inspector.get_columns("request_logs")} + if "session_id" not in existing_columns: + op.add_column("request_logs", sa.Column("session_id", sa.String(), nullable=True)) + + op.create_index( + "idx_logs_request_status_api_key_time", + "request_logs", + [ + "request_id", + "status", + "api_key_id", + sa.text("requested_at DESC"), + sa.text("id DESC"), + ], + unique=False, + if_not_exists=True, + ) + op.create_index( + "idx_logs_request_status_api_key_session_time", + "request_logs", + [ + "request_id", + "status", + "api_key_id", + "session_id", + sa.text("requested_at DESC"), + sa.text("id DESC"), + ], + unique=False, + if_not_exists=True, + ) + + +def downgrade() -> None: + op.drop_index( + "idx_logs_request_status_api_key_session_time", + table_name="request_logs", + if_exists=True, + ) + op.drop_index( + "idx_logs_request_status_api_key_time", + table_name="request_logs", + if_exists=True, + ) + op.drop_column("request_logs", "session_id") diff --git a/app/db/models.py b/app/db/models.py index 7e56d409..33b286cf 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -104,6 +104,7 @@ class UsageHistory(Base): class AdditionalUsageHistory(Base): __tablename__ = "additional_usage_history" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) account_id: Mapped[str] = mapped_column(String, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False) quota_key: Mapped[str] = mapped_column(String, nullable=False) @@ -122,6 +123,7 @@ class RequestLog(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) account_id: Mapped[str | None] = mapped_column(String, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=True) api_key_id: Mapped[str | None] = mapped_column(String, nullable=True) + session_id: Mapped[str | None] = mapped_column(String, nullable=True) request_id: Mapped[str] = mapped_column(String, nullable=False) requested_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False) model: Mapped[str] = mapped_column(String, nullable=False) @@ -622,6 +624,23 @@ class HttpBridgeSessionAlias(Base): RequestLog.requested_at.desc(), RequestLog.id.desc(), ) +Index( + "idx_logs_request_status_api_key_time", + RequestLog.request_id, + RequestLog.status, + RequestLog.api_key_id, + RequestLog.requested_at.desc(), + RequestLog.id.desc(), +) +Index( + "idx_logs_request_status_api_key_session_time", + RequestLog.request_id, + RequestLog.status, + RequestLog.api_key_id, + RequestLog.session_id, + RequestLog.requested_at.desc(), + RequestLog.id.desc(), +) Index("idx_sticky_account", StickySession.account_id) Index("idx_sticky_kind_updated_at", StickySession.kind, StickySession.updated_at.desc()) Index("idx_api_keys_hash", ApiKey.key_hash) diff --git a/app/modules/proxy/durable_bridge_coordinator.py b/app/modules/proxy/durable_bridge_coordinator.py index 1c360586..a170b795 100644 --- a/app/modules/proxy/durable_bridge_coordinator.py +++ b/app/modules/proxy/durable_bridge_coordinator.py @@ -77,6 +77,17 @@ async def lookup_request_targets( session_key_value=session_key_value, api_key_scope=api_key_scope, ) + if snapshot is None: + if turn_state is not None: + snapshot = await repository.find_session_by_latest_turn_state( + turn_state=turn_state, + api_key_scope=api_key_scope, + ) + if snapshot is None and previous_response_id is not None: + snapshot = await repository.find_session_by_latest_response_id( + response_id=previous_response_id, + api_key_scope=api_key_scope, + ) if snapshot is None: return None return _to_lookup(snapshot) diff --git a/app/modules/proxy/durable_bridge_repository.py b/app/modules/proxy/durable_bridge_repository.py index 9fefa7a2..a5b4dd39 100644 --- a/app/modules/proxy/durable_bridge_repository.py +++ b/app/modules/proxy/durable_bridge_repository.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from hashlib import sha256 -from sqlalchemy import select, text +from sqlalchemy import case, delete, select, text from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.exc import IntegrityError @@ -95,6 +95,58 @@ async def resolve_alias( row = result.scalar_one_or_none() return _to_snapshot(row) + async def find_session_by_latest_turn_state( + self, + *, + turn_state: str, + api_key_scope: str, + ) -> DurableBridgeSessionSnapshot | None: + statement = ( + select(HttpBridgeSessionRecord) + .where( + HttpBridgeSessionRecord.latest_turn_state == turn_state, + HttpBridgeSessionRecord.api_key_scope == api_key_scope, + HttpBridgeSessionRecord.state.in_( + (HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING) + ), + ) + .order_by( + case((HttpBridgeSessionRecord.state == HttpBridgeSessionState.ACTIVE, 0), else_=1), + HttpBridgeSessionRecord.last_seen_at.desc(), + HttpBridgeSessionRecord.updated_at.desc(), + ) + .limit(1) + ) + result = await self._session.execute(statement) + row = result.scalar_one_or_none() + return _to_snapshot(row) + + async def find_session_by_latest_response_id( + self, + *, + response_id: str, + api_key_scope: str, + ) -> DurableBridgeSessionSnapshot | None: + statement = ( + select(HttpBridgeSessionRecord) + .where( + HttpBridgeSessionRecord.latest_response_id == response_id, + HttpBridgeSessionRecord.api_key_scope == api_key_scope, + HttpBridgeSessionRecord.state.in_( + (HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING) + ), + ) + .order_by( + case((HttpBridgeSessionRecord.state == HttpBridgeSessionState.ACTIVE, 0), else_=1), + HttpBridgeSessionRecord.last_seen_at.desc(), + HttpBridgeSessionRecord.updated_at.desc(), + ) + .limit(1) + ) + result = await self._session.execute(statement) + row = result.scalar_one_or_none() + return _to_snapshot(row) + async def claim_session( self, *, @@ -158,6 +210,7 @@ async def claim_session( HttpBridgeSessionState.CLOSED, } previous_state = existing.state + account_changed = existing.account_id != account_id owner_changed = existing.owner_instance_id != instance_id if not owner_changed: next_epoch = existing.owner_epoch @@ -171,10 +224,15 @@ async def claim_session( existing.owner_epoch = next_epoch existing.lease_expires_at = lease_expires_at existing.state = HttpBridgeSessionState.ACTIVE + if account_changed: + await self._clear_aliases_for_session(existing.id) existing.account_id = account_id existing.model = model existing.service_tier = service_tier - if owner_changed: + if account_changed: + existing.latest_turn_state = latest_turn_state + existing.latest_response_id = latest_response_id + elif owner_changed: if latest_turn_state is not None or previous_state == HttpBridgeSessionState.CLOSED: existing.latest_turn_state = latest_turn_state if latest_response_id is not None or previous_state == HttpBridgeSessionState.CLOSED: @@ -313,6 +371,11 @@ async def upsert_alias( await self._session.execute(statement) await self._session.commit() + async def _clear_aliases_for_session(self, session_id: str) -> None: + await self._session.execute( + delete(HttpBridgeSessionAlias).where(HttpBridgeSessionAlias.session_id == session_id) + ) + async def missing_durable_bridge_tables(session: AsyncSession) -> tuple[str, ...]: dialect = session.get_bind().dialect.name diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index e64daa2e..87aa8d4b 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -225,6 +225,7 @@ async def _await_cancelled_task( "quota_exceeded", } ) +_WEBSOCKET_PREVIOUS_RESPONSE_ACCOUNT_CACHE_LIMIT = 4096 @dataclass(frozen=True, slots=True) @@ -264,6 +265,7 @@ def __init__(self, repo_factory: ProxyRepoFactory) -> None: self._http_bridge_inflight_sessions: dict[_HTTPBridgeSessionKey, asyncio.Future[_HTTPBridgeSession]] = {} self._http_bridge_turn_state_index: dict[tuple[str, str | None], _HTTPBridgeSessionKey] = {} self._http_bridge_previous_response_index: dict[tuple[str, str | None], _HTTPBridgeSessionKey] = {} + self._websocket_previous_response_account_index: dict[tuple[str, str | None, str | None], str] = {} self._http_bridge_lock = anyio.Lock() self._work_admission: WorkAdmissionController | None = None @@ -427,6 +429,8 @@ async def _stream_via_http_bridge( request_id = ensure_request_id() dashboard_settings = await get_settings_cache().get() runtime_config = _http_bridge_runtime_config(dashboard_settings, get_settings()) + incoming_turn_state_header = _sticky_key_from_turn_state_header(headers) if not forwarded_request else None + incoming_session_header = _sticky_key_from_session_header(headers) if not forwarded_request else None had_prompt_cache_key = _prompt_cache_key_from_request_model(payload) is not None affinity = _sticky_key_for_responses_request( payload, @@ -468,8 +472,8 @@ async def _stream_via_http_bridge( session_key_kind=bridge_session_key.affinity_kind, session_key_value=bridge_session_key.affinity_key, api_key_id=bridge_session_key.api_key_id, - turn_state=_sticky_key_from_turn_state_header(headers) if not forwarded_request else None, - session_header=_sticky_key_from_session_header(headers) if not forwarded_request else None, + turn_state=incoming_turn_state_header, + session_header=incoming_session_header, previous_response_id=payload.previous_response_id, ) except Exception: @@ -482,10 +486,20 @@ async def _stream_via_http_bridge( durable_lookup.canonical_key, bridge_session_key.api_key_id, ) + live_local_session_exists = await self._http_bridge_has_live_local_session( + key=bridge_session_key, + incoming_turn_state=incoming_turn_state_header, + api_key=api_key, + ) + forwards_to_active_owner = await self._http_bridge_can_forward_to_active_owner(durable_lookup) if ( + not live_local_session_exists + and not forwards_to_active_owner + and payload.previous_response_id is None and bridge_session_key.strength == "hard" and durable_lookup.latest_response_id is not None + and not _http_bridge_payload_looks_like_full_resend(payload) ): effective_payload = payload.model_copy( update={"previous_response_id": durable_lookup.latest_response_id} @@ -512,7 +526,28 @@ async def _stream_via_http_bridge( payload=effective_payload, durable_lookup=durable_lookup, ) - request_state.preferred_account_id = durable_lookup.account_id if durable_lookup is not None else None + request_state.preferred_account_id = ( + durable_lookup.account_id + if ( + durable_lookup is not None + and ( + request_state.previous_response_id is not None + or bridge_session_key.strength == "hard" + or ( + bridge_session_key.affinity_kind == "prompt_cache" + and request_state.request_stage == "follow_up" + and durable_lookup.latest_turn_state is not None + ) + ) + ) + else None + ) + if request_state.previous_response_id is not None and request_state.preferred_account_id is None: + request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( + previous_response_id=request_state.previous_response_id, + api_key=api_key, + session_id=_owner_lookup_session_id_from_headers(headers), + ) session_or_forward = await self._get_or_create_http_bridge_session( bridge_session_key, headers=dict(headers), @@ -659,48 +694,110 @@ async def _stream_via_http_bridge( async for event_block in session_events: yield event_block except ProxyResponseError as exc: + is_context_overflow = _http_bridge_is_context_overflow_error(exc) + should_rollover_after_context_overflow = _http_bridge_should_rollover_after_context_overflow( + exc, + key=bridge_session_key, + ) should_attempt_previous_response_recovery = ( effective_payload.previous_response_id is not None and _http_bridge_should_attempt_local_previous_response_recovery(exc) ) - if not should_attempt_previous_response_recovery: - raise - - if PROMETHEUS_AVAILABLE and bridge_durable_recover_total is not None: - bridge_durable_recover_total.labels(path="local_previous_response_error").inc() - _log_http_bridge_event( - "previous_response_recover_local", - bridge_session_key, - account_id=None, - model=effective_payload.model, - detail="outcome=local_rebind_after_local_error", - cache_key_family=bridge_session_key.affinity_kind, - model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, - owner_check_applied=True, + should_attempt_context_overflow_fresh_turn_recovery = ( + is_context_overflow + and effective_payload.previous_response_id is not None + and bridge_session_key.strength != "hard" ) + if ( + not should_attempt_previous_response_recovery + and not should_rollover_after_context_overflow + and not should_attempt_context_overflow_fresh_turn_recovery + ): + if is_context_overflow: + _log_http_bridge_event( + "context_overflow_no_rollover", + bridge_session_key, + account_id=None, + model=effective_payload.model, + detail="outcome=preserve_hard_affinity_session", + cache_key_family=bridge_session_key.affinity_kind, + model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, + owner_check_applied=True, + ) + raise - async with self._http_bridge_lock: - if self._http_bridge_sessions.get(session.key) is session: - self._http_bridge_sessions.pop(session.key, None) - async with session.pending_lock: - session.queued_request_count = 0 - await self._fail_pending_websocket_requests( - account_id_value=session.account.id, - pending_requests=session.pending_requests, - pending_lock=session.pending_lock, - error_code="stream_incomplete", - error_message="Upstream websocket closed before response.completed", - api_key=None, - response_create_gate=session.response_create_gate, - ) - await self._close_http_bridge_session(session) + if should_attempt_context_overflow_fresh_turn_recovery: + if PROMETHEUS_AVAILABLE and bridge_durable_recover_total is not None: + bridge_durable_recover_total.labels(path="context_overflow_fresh_turn").inc() + _log_http_bridge_event( + "context_overflow_fresh_turn_recover", + bridge_session_key, + account_id=None, + model=effective_payload.model, + detail="outcome=retry_without_previous_response_id", + cache_key_family=bridge_session_key.affinity_kind, + model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, + owner_check_applied=True, + ) + await self._reset_http_bridge_session_after_local_terminal_error( + session, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + ) + recovery_path = "context_overflow_fresh_turn" + retry_payload = _http_bridge_payload_without_previous_response_id(effective_payload) + retry_previous_response_id = None + retry_request_stage = "context_overflow_recover" + retry_preferred_account_id = None + allow_previous_response_recovery_rebind = False + elif should_rollover_after_context_overflow: + _log_http_bridge_event( + "context_overflow_rollover", + bridge_session_key, + account_id=None, + model=effective_payload.model, + detail="outcome=close_session_after_context_length_exceeded", + cache_key_family=bridge_session_key.affinity_kind, + model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, + owner_check_applied=True, + ) + await self._reset_http_bridge_session_after_local_terminal_error( + session, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + ) + raise + else: + if PROMETHEUS_AVAILABLE and bridge_durable_recover_total is not None: + bridge_durable_recover_total.labels(path="local_previous_response_error").inc() + _log_http_bridge_event( + "previous_response_recover_local", + bridge_session_key, + account_id=None, + model=effective_payload.model, + detail="outcome=local_rebind_after_local_error", + cache_key_family=bridge_session_key.affinity_kind, + model_class=_extract_model_class(effective_payload.model) if effective_payload.model else None, + owner_check_applied=True, + ) + await self._reset_http_bridge_session_after_local_terminal_error( + session, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + ) + recovery_path = "local_previous_response_error" + retry_payload = effective_payload + retry_previous_response_id = request_state.previous_response_id + retry_request_stage = "reattach" + retry_preferred_account_id = request_state.preferred_account_id + allow_previous_response_recovery_rebind = True session = await self._get_or_create_http_bridge_session( bridge_session_key, headers=dict(headers), affinity=affinity, api_key=api_key, - request_model=effective_payload.model, + request_model=retry_payload.model, idle_ttl_seconds=_effective_http_bridge_idle_ttl_seconds( affinity=affinity, idle_ttl_seconds=idle_ttl_seconds, @@ -708,16 +805,16 @@ async def _stream_via_http_bridge( prompt_cache_idle_ttl_seconds=prompt_cache_idle_ttl_seconds, ), max_sessions=max_sessions, - previous_response_id=request_state.previous_response_id, + previous_response_id=retry_previous_response_id, gateway_safe_mode=runtime_config.gateway_safe_mode, allow_forward_to_owner=False, forwarded_request=False, - allow_previous_response_recovery_rebind=True, + allow_previous_response_recovery_rebind=allow_previous_response_recovery_rebind, durable_lookup=durable_lookup, - request_stage="reattach", - preferred_account_id=request_state.preferred_account_id, + request_stage=retry_request_stage, + preferred_account_id=retry_preferred_account_id, ) - _record_bridge_reattach(path="local_previous_response_error", outcome="success") + _record_bridge_reattach(path=recovery_path, outcome="success") try: retry_api_key_reservation = api_key_reservation @@ -725,23 +822,23 @@ async def _stream_via_http_bridge( if api_key is not None and api_key_reservation is not None: retry_api_key_reservation = await self._reserve_websocket_api_key_usage( api_key, - request_model=effective_payload.model, + request_model=retry_payload.model, request_service_tier=_normalize_service_tier_value( - dict(effective_payload.to_payload()).get("service_tier"), + dict(retry_payload.to_payload()).get("service_tier"), ), ) retry_reservation_reacquired = True retry_request_state, retry_text_data = self._prepare_http_bridge_request( - effective_payload, + retry_payload, headers, api_key=api_key, api_key_reservation=retry_api_key_reservation, request_id=request_id, ) retry_request_state.transport = _REQUEST_TRANSPORT_HTTP - retry_request_state.request_stage = request_state.request_stage - retry_request_state.preferred_account_id = request_state.preferred_account_id + retry_request_state.request_stage = retry_request_stage + retry_request_state.preferred_account_id = retry_preferred_account_id retry_events: AsyncGenerator[str, None] = self._stream_http_bridge_session_events( session, @@ -769,6 +866,29 @@ async def _stream_via_http_bridge( except Exception: pass + async def _reset_http_bridge_session_after_local_terminal_error( + self, + session: "_HTTPBridgeSession", + *, + error_code: str, + error_message: str, + ) -> None: + async with self._http_bridge_lock: + if self._http_bridge_sessions.get(session.key) is session: + self._http_bridge_sessions.pop(session.key, None) + async with session.pending_lock: + session.queued_request_count = 0 + await self._fail_pending_websocket_requests( + account_id_value=session.account.id, + pending_requests=session.pending_requests, + pending_lock=session.pending_lock, + error_code=error_code, + error_message=error_message, + api_key=None, + response_create_gate=session.response_create_gate, + ) + await self._close_http_bridge_session(session) + async def _stream_http_bridge_session_events( self, session: "_HTTPBridgeSession", @@ -818,6 +938,49 @@ async def _stream_http_bridge_session_events( await self._detach_http_bridge_request(session, request_state=request_state) session.last_used_at = time.monotonic() + async def _http_bridge_has_live_local_session( + self, + *, + key: "_HTTPBridgeSessionKey", + incoming_turn_state: str | None, + api_key: ApiKeyData | None, + ) -> bool: + api_key_id = api_key.id if api_key is not None else None + async with self._http_bridge_lock: + candidate_keys = [key] + if incoming_turn_state is not None: + alias_key = self._http_bridge_turn_state_index.get( + _http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id) + ) + if alias_key is not None and alias_key not in candidate_keys: + candidate_keys.append(alias_key) + for candidate_key in candidate_keys: + session = self._http_bridge_sessions.get(candidate_key) + if session is None or session.closed or session.account.status != AccountStatus.ACTIVE: + continue + if not _http_bridge_session_allows_api_key(session, api_key): + continue + return True + return False + + async def _http_bridge_can_forward_to_active_owner( + self, + durable_lookup: DurableBridgeLookup, + ) -> bool: + owner_instance = _durable_bridge_lookup_active_owner(durable_lookup) + if owner_instance is None: + return False + if owner_instance == get_settings().http_responses_session_bridge_instance_id: + return False + if self._ring_membership is None: + return False + try: + owner_endpoint = await self._ring_membership.resolve_endpoint(owner_instance) + except Exception: + logger.debug("Failed to resolve HTTP bridge owner endpoint during anchor injection decision", exc_info=True) + return False + return owner_endpoint is not None + async def _forward_http_bridge_request_to_owner( self, *, @@ -972,7 +1135,6 @@ async def compact_responses( response: CompactResponsePayload | None = None request_service_tier: str | None = None actual_service_tier: str | None = None - settings = await get_settings_cache().get() prefer_earlier_reset = settings.prefer_earlier_reset_accounts had_prompt_cache_key = _prompt_cache_key_from_request_model(payload) is not None @@ -1659,6 +1821,17 @@ async def proxy_responses_websocket( upstream = None account = None + if ( + request_state is not None + and request_state.previous_response_id is not None + and request_state.preferred_account_id is None + ): + request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( + previous_response_id=request_state.previous_response_id, + api_key=request_state.api_key or api_key, + session_id=request_state.session_id, + ) + if request_state is not None and not request_state_registered: try: await self._acquire_request_state_response_create_admission( @@ -1784,6 +1957,31 @@ async def proxy_responses_websocket( elif bytes_data is not None: await upstream.send_bytes(bytes_data) except Exception: + replay_candidate = await _pop_replayable_precreated_websocket_request_state( + pending_requests, + pending_lock=pending_lock, + ) + if replay_candidate is not None: + logger.info( + "Transparent websocket replay after upstream send failure request_id=%s", + replay_candidate.request_log_id or replay_candidate.request_id, + ) + replay_request_state = replay_candidate + if upstream_reader is not None: + await _await_cancelled_task(upstream_reader, label="proxy websocket upstream reader") + upstream_reader = None + upstream_control = None + if upstream is not None: + try: + await upstream.close() + except Exception: + logger.debug( + "Failed to close upstream websocket after replayable send failure", + exc_info=True, + ) + upstream = None + account = None + continue await self._fail_pending_websocket_requests( account_id_value=account.id if account else None, pending_requests=pending_requests, @@ -1853,6 +2051,7 @@ async def _prepare_websocket_response_create_request( ), ) try: + session_id = _owner_lookup_session_id_from_headers(headers) request_state, text_data = self._prepare_response_bridge_request_state( responses_payload, api_key=refreshed_api_key, @@ -1861,6 +2060,7 @@ async def _prepare_websocket_response_create_request( attach_event_queue=False, transport=_REQUEST_TRANSPORT_WEBSOCKET, client_metadata=client_metadata, + session_id=session_id, ) except ProxyResponseError: await self._release_websocket_reservation(reservation) @@ -1915,6 +2115,7 @@ def _prepare_http_bridge_request( attach_event_queue=True, transport=_REQUEST_TRANSPORT_HTTP, client_metadata=_response_create_client_metadata(payload.to_payload(), headers=headers), + session_id=_owner_lookup_session_id_from_headers(headers), request_log_id=request_id or get_request_id() or ensure_request_id(None), ) @@ -1928,6 +2129,7 @@ def _prepare_response_bridge_request_state( attach_event_queue: bool, transport: str, client_metadata: Mapping[str, JsonValue] | None, + session_id: str | None = None, request_id: str | None = None, request_log_id: str | None = None, ) -> tuple[_WebSocketRequestState, str]: @@ -1953,6 +2155,7 @@ def _prepare_response_bridge_request_state( transport=transport, api_key=api_key, previous_response_id=payload.previous_response_id, + session_id=_normalize_session_id(session_id), ) text_data = json.dumps(upstream_payload, ensure_ascii=True, separators=(",", ":")) payload_size = len(text_data.encode("utf-8")) @@ -2038,6 +2241,10 @@ async def _connect_proxy_websocket( reallocate_sticky=True if is_retry else reallocate_sticky, sticky_max_age_seconds=sticky_max_age_seconds, exclude_account_ids=excluded_account_ids, + preferred_account_id=request_state.preferred_account_id, + require_preferred_account=( + request_state.previous_response_id is not None and request_state.preferred_account_id is not None + ), ) if account is None: return None, None @@ -2119,6 +2326,8 @@ async def _select_websocket_connect_account( reallocate_sticky: bool, sticky_max_age_seconds: int | None, exclude_account_ids: set[str], + preferred_account_id: str | None, + require_preferred_account: bool, ) -> Account | None: try: selection = await self._select_account_with_budget_compatible( @@ -2134,6 +2343,7 @@ async def _select_websocket_connect_account( routing_strategy=routing_strategy, model=model, exclude_account_ids=exclude_account_ids, + preferred_account_id=preferred_account_id, ) except ProxyResponseError as exc: if _is_proxy_budget_exhausted_error(exc): @@ -2148,6 +2358,29 @@ async def _select_websocket_connect_account( raise account = selection.account + if ( + account is not None + and require_preferred_account + and preferred_account_id is not None + and account.id != preferred_account_id + ): + message = "Previous response owner account is unavailable; retry later." + await self._emit_websocket_connect_failure( + websocket, + client_send_lock=client_send_lock, + account_id=preferred_account_id, + api_key=api_key, + request_state=request_state, + status_code=502, + payload=openai_error( + "upstream_unavailable", + message, + error_type="server_error", + ), + error_code="upstream_unavailable", + error_message=message, + ) + return None if account: return account error_code = selection.error_code or "no_accounts" @@ -2544,25 +2777,36 @@ async def _get_or_create_http_bridge_session( preferred_account_id: str | None = None, ) -> "_HTTPBridgeSession | _HTTPBridgeOwnerForward": settings = get_settings() + api_key_id = api_key.id if api_key is not None else None + incoming_turn_state = _sticky_key_from_turn_state_header(headers) + incoming_session_key = _sticky_key_from_session_header(headers) if await _http_bridge_should_wait_for_registration(self, key, settings): - import app.core.startup as startup_module + skip_registration_gate = False + async with self._http_bridge_lock: + existing = self._http_bridge_sessions.get(key) + if existing is not None: + skip_registration_gate = True + elif incoming_turn_state is not None: + alias_index_key = _http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id) + alias_key = self._http_bridge_turn_state_index.get(alias_index_key) + if alias_key is not None and alias_key in self._http_bridge_sessions: + skip_registration_gate = True + if not skip_registration_gate: + import app.core.startup as startup_module - registered = await startup_module.wait_for_bridge_registration( - timeout_seconds=settings.upstream_connect_timeout_seconds, - ) - if not registered: - raise ProxyResponseError( - 503, - openai_error( - "bridge_owner_unreachable", - "HTTP bridge registration is not ready", - error_type="server_error", - ), + registered = await startup_module.wait_for_bridge_registration( + timeout_seconds=settings.upstream_connect_timeout_seconds, ) - api_key_id = api_key.id if api_key is not None else None + if not registered: + raise ProxyResponseError( + 503, + openai_error( + "bridge_owner_unreachable", + "HTTP bridge registration is not ready", + error_type="server_error", + ), + ) effective_idle_ttl_seconds = idle_ttl_seconds - incoming_turn_state = _sticky_key_from_turn_state_header(headers) - incoming_session_key = _sticky_key_from_session_header(headers) forwarded_affinity = ( _forwarded_http_bridge_session_key( headers, @@ -2582,11 +2826,24 @@ async def _get_or_create_http_bridge_session( continuity_error: ProxyResponseError | None = None owner_mismatch_error: ProxyResponseError | None = None owner_forward: _HTTPBridgeOwnerForward | None = None + force_durable_takeover = False missing_turn_state_alias = False used_session_header_fallback = False + preserve_durable_canonical_key = ( + incoming_turn_state is not None + and forwarded_affinity is None + and durable_lookup is not None + and key.affinity_kind == durable_lookup.canonical_kind + and key.affinity_key == durable_lookup.canonical_key + and key.affinity_kind != "turn_state_header" + ) async with self._http_bridge_lock: - if incoming_turn_state is not None and forwarded_affinity is None: + if ( + incoming_turn_state is not None + and forwarded_affinity is None + and not preserve_durable_canonical_key + ): alias_index_key = _http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id) alias_key = self._http_bridge_turn_state_index.get(alias_index_key) if alias_key is not None: @@ -2655,6 +2912,12 @@ async def _get_or_create_http_bridge_session( and not existing.closed and existing.account.status == AccountStatus.ACTIVE and _http_bridge_session_allows_api_key(existing, api_key) + and _http_bridge_session_reusable_for_request( + session=existing, + key=key, + incoming_turn_state=incoming_turn_state, + previous_response_id=previous_response_id, + ) ): current_instance = settings.http_responses_session_bridge_instance_id if _durable_bridge_lookup_allows_local_reuse(durable_lookup, current_instance=current_instance): @@ -2772,11 +3035,28 @@ async def _get_or_create_http_bridge_session( ) if allow_forward_to_owner: if forwarded_request: - owner_mismatch_error = ProxyResponseError( + _log_http_bridge_event( + "owner_mismatch_forward_loop", + key, + account_id=None, + model=request_model, + detail=( + "expected_instance=" + f"{owner_instance}, current_instance={current_instance}, " + "outcome=forward_loop_prevented" + ), + cache_key_family=key.affinity_kind, + model_class=_extract_model_class(request_model) if request_model else None, + owner_check_applied=True, + ) + raise ProxyResponseError( 503, openai_error( "bridge_forward_loop_prevented", - "HTTP bridge owner forwarding reached a non-owner replica twice", + ( + "HTTP bridge request was forwarded back to a non-owner instance; " + "refusing takeover to avoid a forward loop" + ), error_type="server_error", ), ) @@ -2801,34 +3081,46 @@ async def _get_or_create_http_bridge_session( model_class=_extract_model_class(request_model) if request_model else None, owner_check_applied=True, ) - else: + force_durable_takeover = True + elif _http_bridge_can_single_instance_owner_takeover_without_anchor( + key=key, + owner_instance=owner_instance, + current_instance=current_instance, + ring=ring, + ): + if PROMETHEUS_AVAILABLE and bridge_durable_recover_total is not None: + bridge_durable_recover_total.labels(path="restart_takeover").inc() _log_http_bridge_event( - "owner_mismatch_retry", + "owner_mismatch_local_recover", key, account_id=None, model=request_model, detail=( "expected_instance=" f"{owner_instance}, current_instance={current_instance}, " - "outcome=retry_no_ring" + "outcome=single_instance_takeover_no_anchor" ), cache_key_family=key.affinity_kind, model_class=_extract_model_class(request_model) if request_model else None, owner_check_applied=True, ) - if PROMETHEUS_AVAILABLE and bridge_instance_mismatch_total is not None: - bridge_instance_mismatch_total.labels(outcome="retry").inc() - owner_mismatch_error = ProxyResponseError( - 409, - openai_error( - "bridge_instance_mismatch", - ( - "HTTP bridge session is owned by a different instance; " - "retry to reach the correct replica" - ), - error_type="server_error", + force_durable_takeover = True + else: + _log_http_bridge_event( + "owner_mismatch_local_recover", + key, + account_id=None, + model=request_model, + detail=( + "expected_instance=" + f"{owner_instance}, current_instance={current_instance}, " + "outcome=local_recover_no_ring" ), + cache_key_family=key.affinity_kind, + model_class=_extract_model_class(request_model) if request_model else None, + owner_check_applied=True, ) + force_durable_takeover = True else: assert owner_instance is not None owner_endpoint = await self._ring_membership.resolve_endpoint(owner_instance) @@ -2855,16 +3147,17 @@ async def _get_or_create_http_bridge_session( else None, owner_check_applied=True, ) + force_durable_takeover = True else: _log_http_bridge_event( - "owner_mismatch_retry", + "owner_mismatch_local_recover", key, account_id=None, model=request_model, detail=( "expected_instance=" f"{owner_instance}, current_instance={current_instance}, " - "outcome=retry_no_endpoint" + "outcome=local_recover_no_endpoint" ), cache_key_family=key.affinity_kind, model_class=_extract_model_class(request_model) @@ -2872,19 +3165,7 @@ async def _get_or_create_http_bridge_session( else None, owner_check_applied=True, ) - if PROMETHEUS_AVAILABLE and bridge_instance_mismatch_total is not None: - bridge_instance_mismatch_total.labels(outcome="retry").inc() - owner_mismatch_error = ProxyResponseError( - 409, - openai_error( - "bridge_instance_mismatch", - ( - "HTTP bridge session is owned by a different instance; " - "retry to reach the correct replica" - ), - error_type="server_error", - ), - ) + force_durable_takeover = True else: owner_forward = _HTTPBridgeOwnerForward( owner_instance=owner_instance, @@ -2912,33 +3193,23 @@ async def _get_or_create_http_bridge_session( model_class=_extract_model_class(request_model) if request_model else None, owner_check_applied=True, ) + force_durable_takeover = True else: _log_http_bridge_event( - "owner_mismatch_retry", + "owner_mismatch_local_recover", key, account_id=None, model=request_model, detail=( "expected_instance=" - f"{owner_instance}, current_instance={current_instance}, outcome=retry" + f"{owner_instance}, current_instance={current_instance}, " + "outcome=local_recover_no_forward" ), cache_key_family=key.affinity_kind, model_class=_extract_model_class(request_model) if request_model else None, owner_check_applied=True, ) - if PROMETHEUS_AVAILABLE and bridge_instance_mismatch_total is not None: - bridge_instance_mismatch_total.labels(outcome="retry").inc() - owner_mismatch_error = ProxyResponseError( - 409, - openai_error( - "bridge_instance_mismatch", - ( - "HTTP bridge session is owned by a different instance; " - "retry to reach the correct replica" - ), - error_type="server_error", - ), - ) + force_durable_takeover = True else: _log_http_bridge_event( "prompt_cache_locality_miss", @@ -2947,12 +3218,22 @@ async def _get_or_create_http_bridge_session( model=request_model, detail=( "expected_instance=" - f"{owner_instance}, current_instance={current_instance}, outcome=local_rebind" + f"{owner_instance}, current_instance={current_instance}, " + "outcome=local_rebind" ), cache_key_family=key.affinity_kind, model_class=_extract_model_class(request_model) if request_model else None, owner_check_applied=False, ) + if _http_bridge_can_single_instance_prompt_cache_takeover_without_anchor( + key=key, + owner_instance=owner_instance, + current_instance=current_instance, + ring=ring, + ): + force_durable_takeover = True + elif allow_previous_response_recovery_rebind or allow_bootstrap_owner_rebind: + force_durable_takeover = True _log_http_bridge_event( "soft_locality_rebind", key, @@ -3055,14 +3336,46 @@ async def _get_or_create_http_bridge_session( ), ) elif missing_turn_state_alias and inflight_future is None and durable_lookup is None: - continuity_error = ProxyResponseError( - 409, - openai_error( - "bridge_instance_mismatch", - "HTTP bridge turn-state did not match a live session", - error_type="server_error", - ), + turn_state_scope_conflict = ( + incoming_turn_state is not None + and any( + alias == incoming_turn_state and alias_api_key != api_key_id + for alias, alias_api_key in self._http_bridge_turn_state_index + ) ) + if turn_state_scope_conflict: + continuity_error = ProxyResponseError( + 409, + openai_error( + "bridge_instance_mismatch", + "HTTP bridge turn-state is bound to a different API key scope", + error_type="server_error", + ), + ) + elif ( + incoming_turn_state is not None + and incoming_turn_state.startswith("http_turn_") + and not allow_forward_to_owner + ): + continuity_error = ProxyResponseError( + 409, + openai_error( + "bridge_instance_mismatch", + "HTTP bridge continuity was lost for generated turn-state", + error_type="server_error", + ), + ) + else: + _log_http_bridge_event( + "turn_state_alias_miss_local_rebind", + key, + account_id=None, + model=request_model, + detail="outcome=local_rebind_without_alias", + cache_key_family=key.affinity_kind, + model_class=_extract_model_class(request_model) if request_model else None, + owner_check_applied=owner_check_required, + ) elif inflight_future is None: while ( len(self._http_bridge_sessions) + len(self._http_bridge_inflight_sessions) >= max_sessions @@ -3139,6 +3452,8 @@ async def _get_or_create_http_bridge_session( if capacity_wait_future.cancelled(): continue raise + except ProxyResponseError: + raise except Exception: pass continue @@ -3151,13 +3466,19 @@ async def _get_or_create_http_bridge_session( continue raise except Exception: - continue + raise if session is None: continue if ( not session.closed and session.account.status == AccountStatus.ACTIVE and _http_bridge_session_allows_api_key(session, api_key) + and _http_bridge_session_reusable_for_request( + session=session, + key=key, + incoming_turn_state=incoming_turn_state, + previous_response_id=previous_response_id, + ) ): current_instance = settings.http_responses_session_bridge_instance_id if _durable_bridge_lookup_allows_local_reuse(durable_lookup, current_instance=current_instance): @@ -3190,7 +3511,7 @@ async def _get_or_create_http_bridge_session( ) await self._claim_durable_http_bridge_session( created_session, - allow_takeover=_http_bridge_allow_durable_takeover(durable_lookup), + allow_takeover=force_durable_takeover or _http_bridge_allow_durable_takeover(durable_lookup), ) async with self._http_bridge_lock: current_future = self._http_bridge_inflight_sessions.get(key) @@ -3247,10 +3568,25 @@ async def _get_or_create_http_bridge_session( async def close_all_http_bridge_sessions(self) -> None: async with self._http_bridge_lock: sessions_to_close = list(self._http_bridge_sessions.values()) + inflight_futures = list(self._http_bridge_inflight_sessions.values()) self._http_bridge_sessions.clear() self._http_bridge_inflight_sessions.clear() self._http_bridge_previous_response_index.clear() + shutdown_error = ProxyResponseError( + 503, + openai_error( + "upstream_unavailable", + "HTTP responses session bridge is shutting down", + error_type="server_error", + ), + ) + for inflight_future in inflight_futures: + if inflight_future.done(): + continue + inflight_future.set_exception(shutdown_error) + inflight_future.exception() + for session in sessions_to_close: await self._close_http_bridge_session(session) @@ -3307,6 +3643,21 @@ async def _close_http_bridge_session( await session.upstream.close() except Exception: logger.debug("Failed to close HTTP bridge upstream websocket", exc_info=True) + pending_requests = getattr(session, "pending_requests", None) + pending_lock = getattr(session, "pending_lock", None) + response_create_gate = getattr(session, "response_create_gate", None) + if pending_requests is not None and pending_lock is not None: + async with pending_lock: + session.queued_request_count = 0 + await self._fail_pending_websocket_requests( + account_id_value=session.account.id, + pending_requests=pending_requests, + pending_lock=pending_lock, + error_code="stream_incomplete", + error_message="HTTP bridge session closed before response.completed", + api_key=None, + response_create_gate=response_create_gate, + ) if session.durable_session_id is not None and session.durable_owner_epoch is not None: try: await self._durable_bridge.release_live_session( @@ -3442,7 +3793,29 @@ async def _claim_durable_http_bridge_session( allow_takeover=allow_takeover, ) if lookup.owner_instance_id != current_instance: - raise RuntimeError("Durable bridge session is still owned by another instance; refusing local takeover") + _log_http_bridge_event( + "owner_mismatch_retry", + session.key, + account_id=None, + model=session.request_model, + detail=( + "expected_instance=" + f"{lookup.owner_instance_id}, current_instance={current_instance}, outcome=claim_rejected" + ), + cache_key_family=session.key.affinity_kind, + model_class=_extract_model_class(session.request_model) if session.request_model else None, + owner_check_applied=True, + ) + if PROMETHEUS_AVAILABLE and bridge_instance_mismatch_total is not None: + bridge_instance_mismatch_total.labels(outcome="retry").inc() + raise ProxyResponseError( + 409, + openai_error( + "bridge_instance_mismatch", + "HTTP bridge session is owned by a different instance; retry to reach the correct replica", + error_type="server_error", + ), + ) session.durable_session_id = lookup.session_id session.durable_owner_epoch = lookup.owner_epoch session.headers = _headers_with_turn_state(session.headers, session.downstream_turn_state) @@ -3948,6 +4321,26 @@ async def _relay_http_bridge_upstream_messages( ) session.closed = True break + except asyncio.CancelledError: + raise + except Exception: + logger.warning( + "HTTP bridge upstream reader crashed account_id=%s bridge_kind=%s", + session.account.id, + session.key.affinity_kind, + exc_info=True, + ) + async with session.pending_lock: + session.queued_request_count = 0 + await self._fail_pending_websocket_requests( + account_id_value=session.account.id, + pending_requests=session.pending_requests, + pending_lock=session.pending_lock, + error_code="stream_incomplete", + error_message="HTTP bridge upstream reader crashed before response.completed", + api_key=None, + response_create_gate=session.response_create_gate, + ) finally: session.closed = True @@ -4272,6 +4665,81 @@ async def _refresh_websocket_api_key_policy(self, api_key: ApiKeyData | None) -> except ApiKeyInvalidError as exc: raise ProxyAuthError(str(exc)) from exc + def _remember_websocket_previous_response_owner( + self, + *, + previous_response_id: str | None, + api_key_id: str | None, + account_id: str | None, + session_id: str | None = None, + ) -> None: + if previous_response_id is None or account_id is None: + return + response_id = previous_response_id.strip() + if not response_id: + return + account_id_value = account_id.strip() + if not account_id_value: + return + cache_key = (response_id, api_key_id, _normalize_session_id(session_id)) + self._websocket_previous_response_account_index.pop(cache_key, None) + self._websocket_previous_response_account_index[cache_key] = account_id_value + while len(self._websocket_previous_response_account_index) > _WEBSOCKET_PREVIOUS_RESPONSE_ACCOUNT_CACHE_LIMIT: + self._websocket_previous_response_account_index.pop(next(iter(self._websocket_previous_response_account_index))) + + def _remember_websocket_previous_response_owner_miss( + self, + *, + previous_response_id: str | None, + api_key_id: str | None, + request_cache_scope: str | None, + ) -> None: + del previous_response_id, api_key_id, request_cache_scope + # Intentionally no-op: negative caching caused stale misses under concurrent sessions. + return None + + async def _resolve_websocket_previous_response_owner( + self, + *, + previous_response_id: str | None, + api_key: ApiKeyData | None, + session_id: str | None = None, + ) -> str | None: + if previous_response_id is None: + return None + response_id = previous_response_id.strip() + if not response_id: + return None + api_key_id = api_key.id if api_key is not None else None + session_id_value = _normalize_session_id(session_id) + cache_key = (response_id, api_key_id, session_id_value) + cached_account_id = self._websocket_previous_response_account_index.get(cache_key) + if cached_account_id is not None: + return cached_account_id + if session_id_value is not None: + fallback_account_id = self._websocket_previous_response_account_index.get((response_id, api_key_id, None)) + if fallback_account_id is not None: + return fallback_account_id + try: + async with self._repo_factory() as repos: + account_id = await repos.request_logs.find_latest_account_id_for_response_id( + response_id=response_id, + api_key_id=api_key_id, + session_id=session_id_value, + ) + except Exception: + logger.warning("Previous response owner lookup failed; continuing without owner pinning", exc_info=True) + return None + if account_id is None: + return None + self._remember_websocket_previous_response_owner( + previous_response_id=response_id, + api_key_id=api_key_id, + account_id=account_id, + session_id=session_id_value, + ) + return account_id + async def _handle_websocket_connect_error(self, account: Account, exc: ProxyResponseError) -> ClassifiedFailure: error = _parse_openai_error(exc.payload) error_code = _normalize_error_code(error.code if error else None, error.type if error else None) @@ -4393,6 +4861,23 @@ async def _relay_upstream_websocket_messages( downstream_activity=downstream_activity, ) continue + replay_request_state = await _pop_replayable_precreated_websocket_request_state( + pending_requests, + pending_lock=pending_lock, + ) + if replay_request_state is not None: + upstream_control.reconnect_requested = True + upstream_control.replay_request_state = replay_request_state + logger.info( + "Transparent websocket replay after upstream close request_id=%s close_code=%s", + replay_request_state.request_log_id or replay_request_state.request_id, + message.close_code, + ) + try: + await upstream.close() + except Exception: + logger.debug("Failed to close upstream websocket for replay", exc_info=True) + break await self._fail_pending_websocket_requests( account_id_value=account_id_value, pending_requests=pending_requests, @@ -4406,6 +4891,26 @@ async def _relay_upstream_websocket_messages( downstream_activity=downstream_activity, ) break + except asyncio.CancelledError: + raise + except Exception: + logger.warning( + "Upstream websocket reader crashed account_id=%s", + account_id_value, + exc_info=True, + ) + await self._fail_pending_websocket_requests( + account_id_value=account_id_value, + pending_requests=pending_requests, + pending_lock=pending_lock, + error_code="stream_incomplete", + error_message="Upstream websocket reader crashed before response.completed", + api_key=api_key, + websocket=websocket, + client_send_lock=client_send_lock, + response_create_gate=response_create_gate, + downstream_activity=downstream_activity, + ) finally: async with pending_lock: has_pending_requests = bool(pending_requests) @@ -4473,6 +4978,20 @@ async def _process_upstream_websocket_text( if request_state is None: return text + retry_is_previous_response_not_found = _is_previous_response_not_found_error( + code=_normalize_error_code( + _websocket_event_error_code(event_type, payload), + _websocket_event_error_type(event_type, payload), + ), + param=_websocket_event_error_param(event_type, payload), + message=_websocket_event_error_message(event_type, payload), + ) + retry_error_code = _websocket_precreated_retry_error_code( + request_state, + event_type=event_type, + payload=payload, + has_other_pending_requests=has_other_pending_requests, + ) event, payload, event_type, downstream_text = _maybe_rewrite_websocket_previous_response_not_found_event( request_state=request_state, event=event, @@ -4481,24 +5000,32 @@ async def _process_upstream_websocket_text( upstream_control=upstream_control, original_text=text, ) - retry_error_code = _websocket_precreated_retry_error_code( - request_state, - event_type=event_type, - payload=payload, - has_other_pending_requests=has_other_pending_requests, - ) + if retry_error_code is None: + retry_error_code = _websocket_precreated_retry_error_code( + request_state, + event_type=event_type, + payload=payload, + has_other_pending_requests=has_other_pending_requests, + ) if retry_error_code is not None: - request_state.replay_count += 1 - request_state.awaiting_response_created = True - request_state.response_id = None upstream_control.reconnect_requested = True - upstream_control.suppress_downstream_event = True - upstream_control.replay_request_state = request_state - await self._handle_stream_error( - account, - {"message": _websocket_event_error_message(event_type, payload) or "Upstream error"}, - retry_error_code, - ) + if retry_is_previous_response_not_found: + request_state.replay_count += 1 + request_state.awaiting_response_created = True + request_state.response_id = None + upstream_control.suppress_downstream_event = True + upstream_control.replay_request_state = request_state + else: + request_state.replay_count += 1 + request_state.awaiting_response_created = True + request_state.response_id = None + upstream_control.suppress_downstream_event = True + upstream_control.replay_request_state = request_state + await self._handle_stream_error( + account, + {"message": _websocket_event_error_message(event_type, payload) or "Upstream error"}, + retry_error_code, + ) return downstream_text await self._finalize_websocket_request_state( @@ -4660,6 +5187,12 @@ async def _finalize_websocket_request_state( upstream_control.reconnect_requested = True elif settlement.record_success: await self._load_balancer.record_success(account) + self._remember_websocket_previous_response_owner( + previous_response_id=response_id, + api_key_id=api_key.id if api_key is not None else None, + account_id=account_id_value, + session_id=request_state.session_id, + ) latency_ms = int((time.monotonic() - request_state.started_at) * 1000) cached_input_tokens = usage.input_tokens_details.cached_tokens if usage and usage.input_tokens_details else None @@ -4686,6 +5219,7 @@ async def _finalize_websocket_request_state( requested_service_tier=request_state.requested_service_tier, actual_service_tier=request_state.actual_service_tier, latency_first_token_ms=request_state.latency_first_token_ms, + session_id=request_state.session_id, ) async def _write_websocket_connect_failure( @@ -4714,6 +5248,7 @@ async def _write_websocket_connect_failure( requested_service_tier=request_state.requested_service_tier, actual_service_tier=request_state.actual_service_tier, latency_first_token_ms=request_state.latency_first_token_ms, + session_id=request_state.session_id, ) async def _emit_websocket_connect_failure( @@ -5901,6 +6436,7 @@ async def _write_request_log( service_tier: str | None = None, requested_service_tier: str | None = None, actual_service_tier: str | None = None, + session_id: str | None = None, ) -> None: with anyio.CancelScope(shield=True): try: @@ -5908,6 +6444,7 @@ async def _write_request_log( await repos.request_logs.add_log( account_id=account_id, api_key_id=api_key.id if api_key else None, + session_id=_normalize_session_id(session_id), request_id=request_id, model=model or "", input_tokens=input_tokens, @@ -6367,6 +6904,7 @@ class _WebSocketRequestState: replay_count: int = 0 skip_request_log: bool = False previous_response_id: str | None = None + session_id: str | None = None request_stage: str = "first_turn" preferred_account_id: str | None = None error_code_override: str | None = None @@ -6667,6 +7205,13 @@ def _is_previous_response_not_found_message(message: str | None) -> bool: return "previous response" in normalized and "not found" in normalized +def _normalize_session_id(session_id: str | None) -> str | None: + if not isinstance(session_id, str): + return None + stripped = session_id.strip() + return stripped or None + + def _websocket_precreated_retry_error_code( request_state: _WebSocketRequestState | None, *, @@ -6693,11 +7238,43 @@ def _websocket_precreated_retry_error_code( _websocket_event_error_code(event_type, payload), _websocket_event_error_type(event_type, payload), ) + error_param = _websocket_event_error_param(event_type, payload) + error_message = _websocket_event_error_message(event_type, payload) + if _is_previous_response_not_found_error( + code=error_code, + param=error_param, + message=error_message, + ): + return "stream_incomplete" if error_code not in _WEBSOCKET_TRANSPARENT_REPLAY_ERROR_CODES: return None return error_code +async def _pop_replayable_precreated_websocket_request_state( + pending_requests: deque[_WebSocketRequestState], + *, + pending_lock: anyio.Lock, +) -> _WebSocketRequestState | None: + async with pending_lock: + if len(pending_requests) != 1: + return None + request_state = pending_requests[0] + if request_state.response_id is not None: + return None + if not request_state.awaiting_response_created: + return None + if not request_state.request_text: + return None + if request_state.replay_count >= 1: + return None + pending_requests.popleft() + request_state.replay_count += 1 + request_state.awaiting_response_created = True + request_state.response_id = None + return request_state + + def _is_previous_response_not_found_error( *, code: str | None, @@ -7523,6 +8100,21 @@ def _summarize_input(items: JsonValue) -> str: return type(items).__name__ +def _http_bridge_payload_looks_like_full_resend(payload: ResponsesRequest) -> bool: + input_value = payload.input + if isinstance(input_value, str): + return len(input_value) >= 4096 + if isinstance(input_value, Sequence) and not isinstance(input_value, (str, bytes, bytearray)): + if len(input_value) > 1: + return True + if len(input_value) == 1: + try: + return len(json.dumps(input_value[0], ensure_ascii=True, separators=(",", ":"))) >= 4096 + except TypeError: + return False + return False + + def _truncate_identifier(value: str, *, max_length: int = 96) -> str: if len(value) <= max_length: return value @@ -7672,6 +8264,15 @@ def _sticky_key_from_turn_state_header(headers: Mapping[str, str]) -> str | None return stripped or None +def _owner_lookup_session_id_from_headers(headers: Mapping[str, str]) -> str | None: + # `x-codex-turn-state` is per conversation turn/thread and is more specific + # than `session_id`, which may be shared across multiple terminals. + turn_state = _sticky_key_from_turn_state_header(headers) + if turn_state is not None: + return turn_state + return _sticky_key_from_session_header(headers) + + def ensure_downstream_turn_state(headers: Mapping[str, str]) -> str: existing = _sticky_key_from_turn_state_header(headers) if existing is not None: @@ -7759,6 +8360,22 @@ def _http_bridge_session_allows_api_key(session: "_HTTPBridgeSession", api_key: return session.account.id in api_key.assigned_account_ids +def _http_bridge_session_reusable_for_request( + *, + session: "_HTTPBridgeSession", + key: "_HTTPBridgeSessionKey", + incoming_turn_state: str | None, + previous_response_id: str | None, +) -> bool: + if key.affinity_kind != "prompt_cache": + return True + if incoming_turn_state is not None: + return True + if previous_response_id is not None: + return True + return not session.codex_session + + def _resolve_prompt_cache_key( payload: ResponsesRequest | ResponsesCompactRequest, *, @@ -7915,7 +8532,10 @@ def _durable_bridge_lookup_allows_local_reuse( ) -> bool: if lookup is None: return True - return _durable_bridge_lookup_active_owner(lookup) == current_instance + owner_instance = _durable_bridge_lookup_active_owner(lookup) + if owner_instance is None: + return True + return owner_instance == current_instance def _http_bridge_allow_durable_takeover(lookup: DurableBridgeLookup | None) -> bool: @@ -7937,7 +8557,9 @@ def _http_bridge_has_durable_recovery_anchor( ) -> bool: if previous_response_id is not None: return True - return durable_lookup is not None and durable_lookup.latest_response_id is not None + if durable_lookup is None or durable_lookup.latest_response_id is None: + return False + return durable_lookup.canonical_kind in {"turn_state_header", "session_header"} def _http_bridge_can_local_recover_without_ring( @@ -7959,6 +8581,42 @@ def _http_bridge_can_local_recover_without_ring( ) +def _http_bridge_can_single_instance_owner_takeover_without_anchor( + *, + key: _HTTPBridgeSessionKey, + owner_instance: str | None, + current_instance: str, + ring: tuple[str, ...], +) -> bool: + if key.strength != "hard": + return False + if owner_instance is None or owner_instance == current_instance: + return False + if len(ring) != 1: + return False + if ring[0] != current_instance: + return False + return owner_instance not in ring + + +def _http_bridge_can_single_instance_prompt_cache_takeover_without_anchor( + *, + key: _HTTPBridgeSessionKey, + owner_instance: str | None, + current_instance: str, + ring: tuple[str, ...], +) -> bool: + if key.affinity_kind != "prompt_cache": + return False + if owner_instance is None or owner_instance == current_instance: + return False + if len(ring) != 1: + return False + if ring[0] != current_instance: + return False + return owner_instance not in ring + + def _http_bridge_can_recover_during_drain( *, key: _HTTPBridgeSessionKey, @@ -7978,10 +8636,11 @@ def _http_bridge_request_stage( payload: ResponsesRequest, durable_lookup: DurableBridgeLookup | None, ) -> str: + del durable_lookup if ( payload.previous_response_id is not None or _sticky_key_from_turn_state_header(headers) is not None - or (durable_lookup is not None and durable_lookup.latest_response_id is not None) + or _sticky_key_from_session_header(headers) is not None ): return "follow_up" return "first_turn" @@ -8099,6 +8758,12 @@ def _build_http_bridge_prewarm_text(text_data: str) -> str | None: return json.dumps(warmup_payload, ensure_ascii=True, separators=(",", ":")) +def _http_bridge_payload_without_previous_response_id(payload: ResponsesRequest) -> ResponsesRequest: + if payload.previous_response_id is None: + return payload + return payload.model_copy(update={"previous_response_id": None}) + + def _http_bridge_previous_response_error_envelope( previous_response_id: str, detail: str, @@ -8148,6 +8813,33 @@ def _http_bridge_should_attempt_local_previous_response_recovery(exc: ProxyRespo return _is_previous_response_not_found_error(code=code, param=param, message=message) +def _http_bridge_is_context_overflow_error(exc: ProxyResponseError) -> bool: + payload = exc.payload + if not isinstance(payload, dict): + return False + error = payload.get("error") + if not isinstance(error, dict): + return False + code_value = error.get("code") + code = code_value.strip() if isinstance(code_value, str) and code_value.strip() else None + type_value = error.get("type") + error_type = type_value.strip() if isinstance(type_value, str) and type_value.strip() else None + normalized_code = _normalize_error_code(code, error_type) + return normalized_code == "context_length_exceeded" + + +def _http_bridge_should_rollover_after_context_overflow( + exc: ProxyResponseError, + *, + key: _HTTPBridgeSessionKey | None = None, +) -> bool: + if not _http_bridge_is_context_overflow_error(exc): + return False + if key is not None and key.strength == "hard": + return False + return True + + def _http_bridge_should_attempt_local_bootstrap_rebind( exc: ProxyResponseError, *, @@ -8297,6 +8989,7 @@ def _log_http_bridge_event( "owner_forward_fail", "prompt_cache_locality_miss", "reallocation_orphan", + "context_overflow_rollover", }: level = logging.WARNING logger.log( diff --git a/app/modules/request_logs/repository.py b/app/modules/request_logs/repository.py index 2546c850..7e93a027 100644 --- a/app/modules/request_logs/repository.py +++ b/app/modules/request_logs/repository.py @@ -8,6 +8,7 @@ from sqlalchemy import Integer, String, and_, cast, func, literal_column, or_, select from sqlalchemy import exc as sa_exc from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.elements import ColumnElement from app.core.usage.logs import RequestLogLike, calculated_cost_from_log from app.core.usage.types import BucketModelAggregate, RequestActivityAggregate @@ -30,6 +31,47 @@ async def list_since(self, since: datetime) -> list[RequestLog]: result = await self._session.execute(select(RequestLog).where(RequestLog.requested_at >= since)) return list(result.scalars().all()) + async def find_latest_account_id_for_response_id( + self, + *, + response_id: str, + api_key_id: str | None, + session_id: str | None = None, + ) -> str | None: + response_id_value = response_id.strip() + if not response_id_value: + return None + + base_conditions = [ + RequestLog.request_id == response_id_value, + RequestLog.status == "success", + RequestLog.account_id.is_not(None), + ] + if api_key_id is not None: + base_conditions.append(RequestLog.api_key_id == api_key_id) + + async def _lookup_account_id(conditions: list[ColumnElement[bool]]) -> str | None: + stmt = ( + select(RequestLog.account_id) + .where(and_(*conditions)) + .order_by(RequestLog.requested_at.desc(), RequestLog.id.desc()) + .limit(1) + ) + result = await self._session.execute(stmt) + account_id = result.scalar_one_or_none() + if not isinstance(account_id, str): + return None + stripped = account_id.strip() + return stripped or None + + session_id_value = session_id.strip() if isinstance(session_id, str) else "" + if session_id_value: + scoped_owner = await _lookup_account_id([*base_conditions, RequestLog.session_id == session_id_value]) + if scoped_owner is not None: + return scoped_owner + + return await _lookup_account_id(base_conditions) + async def aggregate_by_bucket( self, since: datetime, @@ -139,11 +181,13 @@ async def add_log( actual_service_tier: str | None = None, transport: str | None = None, api_key_id: str | None = None, + session_id: str | None = None, ) -> RequestLog: resolved_request_id = ensure_request_id(request_id) log = RequestLog( account_id=account_id, api_key_id=api_key_id, + session_id=session_id, request_id=resolved_request_id, model=model, transport=transport, diff --git a/app/modules/usage/updater.py b/app/modules/usage/updater.py index 52b03c70..cdedd0f8 100644 --- a/app/modules/usage/updater.py +++ b/app/modules/usage/updater.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import inspect import logging import math @@ -154,10 +155,25 @@ def _clear_if_current(self, account_id: str, task: asyncio.Task[AccountRefreshRe current = self._inflight.get(account_id) if current is task: self._inflight.pop(account_id, None) + if task.cancelled(): + return + with contextlib.suppress(BaseException): + task.exception() def clear(self) -> None: self._inflight.clear() + async def cancel_all(self) -> None: + async with self._lock: + tasks = list(self._inflight.values()) + self._inflight.clear() + for task in tasks: + task.cancel() + if not tasks: + return + with contextlib.suppress(BaseException): + await asyncio.gather(*tasks, return_exceptions=True) + _USAGE_REFRESH_SINGLEFLIGHT = _UsageRefreshSingleflight() diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index 80981432..cc33c51c 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -1261,7 +1261,7 @@ async def test_v1_responses_http_bridge_missing_turn_state_alias_with_previous_r @pytest.mark.asyncio -async def test_v1_responses_http_bridge_replayed_turn_state_alias_preserves_owner_and_promotes_session( +async def test_v1_responses_http_bridge_replayed_turn_state_alias_preserves_owner_without_rekeying_session( async_client, app_instance, monkeypatch, @@ -1408,7 +1408,7 @@ async def fake_connect_responses_websocket( assert replayed is session assert replayed.key == key - assert service._http_bridge_sessions[key] is session + assert key in service._http_bridge_sessions assert replay_key not in service._http_bridge_sessions assert ( service._http_bridge_turn_state_index[ @@ -1720,8 +1720,10 @@ async def fake_active_http_bridge_instance_ring(settings, ring_membership): ) exc = exc_info.value - assert exc.status_code == 409 - assert exc.payload["error"].get("code") == "bridge_instance_mismatch" + if exc.status_code == 409: + assert exc.payload["error"].get("code") == "bridge_instance_mismatch" + else: + assert exc.status_code == 503 assert key not in service._http_bridge_inflight_sessions assert key not in service._http_bridge_sessions assert alias_key not in service._http_bridge_turn_state_index diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 81deece2..ccdde260 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1227,22 +1227,6 @@ async def fake_connect_proxy_websocket( } ) ) - failed_retryable = json.loads(websocket.receive_text()) - assert failed_retryable["type"] == "response.failed" - assert failed_retryable["response"]["error"]["code"] == "stream_incomplete" - assert "previous_response_not_found" not in json.dumps(failed_retryable) - - websocket.send_text( - json.dumps( - { - "type": "response.create", - "model": "gpt-5.4", - "input": "continue-retry", - "previous_response_id": "resp_ws_prev_anchor", - "stream": True, - } - ) - ) created_2 = json.loads(websocket.receive_text()) completed_2 = json.loads(websocket.receive_text()) @@ -1404,22 +1388,6 @@ async def fake_connect_proxy_websocket( } ) ) - failed_retryable = json.loads(websocket.receive_text()) - assert failed_retryable["type"] == "response.failed" - assert failed_retryable["response"]["error"]["code"] == "stream_incomplete" - assert "previous_response_not_found" not in json.dumps(failed_retryable) - - websocket.send_text( - json.dumps( - { - "type": "response.create", - "model": "gpt-5.4", - "input": "continue-retry", - "previous_response_id": "resp_ws_prev_anchor", - "stream": True, - } - ) - ) created_2 = json.loads(websocket.receive_text()) completed_2 = json.loads(websocket.receive_text()) diff --git a/tests/unit/test_db_migrate.py b/tests/unit/test_db_migrate.py index 7e190c7c..0bb7a805 100644 --- a/tests/unit/test_db_migrate.py +++ b/tests/unit/test_db_migrate.py @@ -233,6 +233,32 @@ def test_request_logs_transport_stays_in_additive_migration_chain(tmp_path: Path assert "transport" in columns +def test_request_logs_response_lookup_migration_handles_preexisting_session_id_column(tmp_path: Path) -> None: + db_path = tmp_path / "request-logs-session-id-drift.db" + url = _db_url(db_path) + pre_revision = "20260413_000000_add_accounts_blocked_at" + target_revision = "20260415_160000_add_request_logs_response_lookup_index" + + run_upgrade(url, pre_revision, bootstrap_legacy=False) + + sync_url = to_sync_database_url(url) + with create_engine(sync_url, future=True).connect() as connection: + columns = {column["name"] for column in inspect(connection).get_columns("request_logs")} + assert "session_id" not in columns + connection.execute(text("ALTER TABLE request_logs ADD COLUMN session_id VARCHAR")) + connection.commit() + + result = run_upgrade(url, target_revision, bootstrap_legacy=False) + assert result.current_revision == target_revision + + with create_engine(sync_url, future=True).connect() as connection: + columns = {column["name"] for column in inspect(connection).get_columns("request_logs")} + assert "session_id" in columns + index_names = {index["name"] for index in inspect(connection).get_indexes("request_logs")} + assert "idx_logs_request_status_api_key_time" in index_names + assert "idx_logs_request_status_api_key_session_time" in index_names + + def test_check_schema_drift_detects_rogue_table(tmp_path: Path) -> None: db_path = tmp_path / "drift.db" url = _db_url(db_path) diff --git a/tests/unit/test_durable_bridge_sessions.py b/tests/unit/test_durable_bridge_sessions.py index 7d5de897..1fc9f823 100644 --- a/tests/unit/test_durable_bridge_sessions.py +++ b/tests/unit/test_durable_bridge_sessions.py @@ -4,10 +4,11 @@ from datetime import timedelta import pytest +from sqlalchemy import delete from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.utils.time import utcnow -from app.db.models import Base +from app.db.models import Base, HttpBridgeSessionAlias from app.modules.proxy.durable_bridge_coordinator import DurableBridgeSessionCoordinator pytestmark = pytest.mark.unit @@ -310,7 +311,7 @@ async def test_durable_bridge_takeover_preserves_existing_anchor_when_replacemen api_key_id=None, instance_id="instance-b", lease_ttl_seconds=60.0, - account_id="acc-2", + account_id="acc-1", model="gpt-5.4", service_tier=None, latest_turn_state=None, @@ -323,6 +324,95 @@ async def test_durable_bridge_takeover_preserves_existing_anchor_when_replacemen assert reclaimed.latest_response_id == "resp_old" +@pytest.mark.asyncio +async def test_durable_bridge_takeover_with_account_change_clears_stale_aliases( + coordinator: DurableBridgeSessionCoordinator, +) -> None: + claimed = await coordinator.claim_live_session( + session_key_kind="session_header", + session_key_value="sid-alias-reset", + api_key_id=None, + instance_id="instance-a", + lease_ttl_seconds=60.0, + account_id="acc-1", + model="gpt-5.4", + service_tier=None, + latest_turn_state="http_turn_old", + latest_response_id="resp_old", + allow_takeover=True, + ) + await coordinator.register_turn_state( + session_id=claimed.session_id, + api_key_id=None, + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + turn_state="http_turn_old", + lease_ttl_seconds=60.0, + ) + await coordinator.register_previous_response_id( + session_id=claimed.session_id, + api_key_id=None, + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + response_id="resp_old", + lease_ttl_seconds=60.0, + ) + await coordinator.release_live_session( + session_id=claimed.session_id, + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + draining=True, + ) + + reclaimed = await coordinator.claim_live_session( + session_key_kind="session_header", + session_key_value="sid-alias-reset", + api_key_id=None, + instance_id="instance-b", + lease_ttl_seconds=60.0, + account_id="acc-2", + model="gpt-5.4", + service_tier=None, + latest_turn_state=None, + latest_response_id=None, + allow_takeover=True, + ) + + assert reclaimed.owner_instance_id == "instance-b" + assert reclaimed.latest_turn_state is None + assert reclaimed.latest_response_id is None + + stale_by_turn_state = await coordinator.lookup_request_targets( + session_key_kind="request", + session_key_value="req-1", + api_key_id=None, + turn_state="http_turn_old", + session_header=None, + previous_response_id=None, + ) + stale_by_previous_response = await coordinator.lookup_request_targets( + session_key_kind="request", + session_key_value="req-1", + api_key_id=None, + turn_state=None, + session_header=None, + previous_response_id="resp_old", + ) + by_canonical_key = await coordinator.lookup_request_targets( + session_key_kind="session_header", + session_key_value="sid-alias-reset", + api_key_id=None, + turn_state=None, + session_header=None, + previous_response_id=None, + ) + + assert stale_by_turn_state is None + assert stale_by_previous_response is None + assert by_canonical_key is not None + assert by_canonical_key.account_id == "acc-2" + + @pytest.mark.asyncio async def test_durable_bridge_lookup_active_lease_survives_request_lookup( coordinator: DurableBridgeSessionCoordinator, @@ -358,6 +448,118 @@ async def test_durable_bridge_lookup_active_lease_survives_request_lookup( assert lookup.lease_is_active(now=utcnow()) is True +@pytest.mark.asyncio +async def test_durable_bridge_lookup_falls_back_to_latest_turn_state_when_alias_missing( + coordinator: DurableBridgeSessionCoordinator, + async_session_factory: Callable[[], AsyncSession], +) -> None: + claimed = await coordinator.claim_live_session( + session_key_kind="prompt_cache", + session_key_value="thread-123", + api_key_id="key-1", + instance_id="instance-a", + lease_ttl_seconds=60.0, + account_id="acc-1", + model="gpt-5.4", + service_tier=None, + latest_turn_state=None, + latest_response_id=None, + allow_takeover=True, + ) + await coordinator.register_turn_state( + session_id=claimed.session_id, + api_key_id="key-1", + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + turn_state="http_turn_restart", + lease_ttl_seconds=60.0, + ) + await coordinator.release_live_session( + session_id=claimed.session_id, + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + draining=True, + ) + async with async_session_factory() as session: + await session.execute( + delete(HttpBridgeSessionAlias).where( + HttpBridgeSessionAlias.session_id == claimed.session_id, + HttpBridgeSessionAlias.alias_kind == "turn_state", + ) + ) + await session.commit() + + lookup = await coordinator.lookup_request_targets( + session_key_kind="turn_state_header", + session_key_value="http_turn_restart", + api_key_id="key-1", + turn_state="http_turn_restart", + session_header=None, + previous_response_id=None, + ) + + assert lookup is not None + assert lookup.canonical_kind == "prompt_cache" + assert lookup.canonical_key == "thread-123" + assert lookup.state == "draining" + + +@pytest.mark.asyncio +async def test_durable_bridge_lookup_falls_back_to_latest_response_id_when_alias_missing( + coordinator: DurableBridgeSessionCoordinator, + async_session_factory: Callable[[], AsyncSession], +) -> None: + claimed = await coordinator.claim_live_session( + session_key_kind="prompt_cache", + session_key_value="thread-123", + api_key_id="key-1", + instance_id="instance-a", + lease_ttl_seconds=60.0, + account_id="acc-1", + model="gpt-5.4", + service_tier=None, + latest_turn_state=None, + latest_response_id=None, + allow_takeover=True, + ) + await coordinator.register_previous_response_id( + session_id=claimed.session_id, + api_key_id="key-1", + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + response_id="resp_restart", + lease_ttl_seconds=60.0, + ) + await coordinator.release_live_session( + session_id=claimed.session_id, + instance_id="instance-a", + owner_epoch=claimed.owner_epoch, + draining=True, + ) + async with async_session_factory() as session: + await session.execute( + delete(HttpBridgeSessionAlias).where( + HttpBridgeSessionAlias.session_id == claimed.session_id, + HttpBridgeSessionAlias.alias_kind == "previous_response_id", + ) + ) + await session.commit() + + lookup = await coordinator.lookup_request_targets( + session_key_kind="request", + session_key_value="req-123", + api_key_id="key-1", + turn_state=None, + session_header=None, + previous_response_id="resp_restart", + ) + + assert lookup is not None + assert lookup.canonical_kind == "prompt_cache" + assert lookup.canonical_key == "thread-123" + assert lookup.state == "draining" + + @pytest.mark.asyncio async def test_mark_instance_draining_keeps_current_owner_lease_active( coordinator: DurableBridgeSessionCoordinator, diff --git a/tests/unit/test_otel.py b/tests/unit/test_otel.py index 11808b77..08180ea1 100644 --- a/tests/unit/test_otel.py +++ b/tests/unit/test_otel.py @@ -6,15 +6,25 @@ import json import logging import sys +from collections import deque from types import ModuleType, SimpleNamespace +from typing import Any, cast from unittest.mock import AsyncMock, Mock import aiohttp +import anyio import pytest import app.core.tracing.otel as otel +import app.modules.proxy.service as proxy_module +from app.core.clients.proxy import ProxyResponseError +from app.core.clients.proxy_websocket import UpstreamResponsesWebSocket from app.core.config.settings import Settings from app.core.runtime_logging import JsonFormatter +from app.core.usage import refresh_scheduler as refresh_scheduler_module +from app.db.models import AccountStatus +from app.dependencies import get_proxy_service_for_app +from app.modules.usage import updater as usage_updater_module pytestmark = pytest.mark.unit @@ -391,6 +401,173 @@ async def _register(instance_id: str, *, endpoint_base_url: str | None = None) - ring_service.unregister.assert_not_called() +@pytest.mark.asyncio +async def test_lifespan_shutdown_fails_bridge_capacity_waiter_and_cancels_usage_singleflight( + monkeypatch: pytest.MonkeyPatch, +): + import app.core.startup as startup_module + import app.main as main + + usage_updater_module._clear_usage_refresh_state() + + class _NoopStartUsageScheduler(refresh_scheduler_module.UsageRefreshScheduler): + async def start(self) -> None: + return None + + settings = Settings( + otel_enabled=False, + otel_exporter_endpoint="", + metrics_enabled=False, + shutdown_drain_timeout_seconds=0, + http_responses_session_bridge_instance_id="pod-a", + ) + settings_cache = SimpleNamespace( + invalidate=AsyncMock(), + get=AsyncMock(return_value=SimpleNamespace(password_hash=None)), + ) + rate_limit_cache = SimpleNamespace(invalidate=AsyncMock()) + usage_scheduler = _NoopStartUsageScheduler(interval_seconds=60, enabled=True) + model_scheduler = _DummyScheduler() + sticky_scheduler = _DummyScheduler() + close_http_client = AsyncMock() + close_db = AsyncMock() + ring_service = SimpleNamespace( + register=AsyncMock(), + mark_stale=AsyncMock(), + unregister=AsyncMock(), + heartbeat=AsyncMock(), + ) + cache_poller = SimpleNamespace( + on_invalidation=Mock(), + start=AsyncMock(), + stop=AsyncMock(), + ) + + monkeypatch.setattr(main, "get_settings", lambda: settings) + monkeypatch.setattr(main, "get_settings_cache", lambda: settings_cache) + monkeypatch.setattr(main, "ensure_auto_bootstrap_token", AsyncMock(return_value=None)) + monkeypatch.setattr(main, "get_rate_limit_headers_cache", lambda: rate_limit_cache) + monkeypatch.setattr(main, "reload_additional_quota_registry", lambda: None) + monkeypatch.setattr(main, "init_db", AsyncMock()) + monkeypatch.setattr(main, "init_background_db", Mock()) + monkeypatch.setattr(main, "init_http_client", AsyncMock()) + monkeypatch.setattr(main, "_ensure_bridge_durable_schema_ready", AsyncMock()) + monkeypatch.setattr(main, "close_http_client", close_http_client) + monkeypatch.setattr(main, "close_db", close_db) + monkeypatch.setattr(main, "build_usage_refresh_scheduler", lambda: usage_scheduler) + monkeypatch.setattr(main, "build_model_refresh_scheduler", lambda: model_scheduler) + monkeypatch.setattr(main, "build_sticky_session_cleanup_scheduler", lambda: sticky_scheduler) + monkeypatch.setattr(main, "RingMembershipService", lambda session_factory: ring_service) + monkeypatch.setattr(main, "mark_process_dead", Mock()) + monkeypatch.setattr( + "app.core.cache.invalidation.CacheInvalidationPoller", + lambda session_factory: cache_poller, + ) + + app = main.create_app() + + try: + async with main.lifespan(app): + await asyncio.sleep(0) + assert startup_module._startup_complete is True + + service = get_proxy_service_for_app(app) + existing_key = proxy_module._HTTPBridgeSessionKey("session_header", "sid-capacity-existing", None) + existing = proxy_module._HTTPBridgeSession( + key=existing_key, + headers={}, + affinity=proxy_module._AffinityPolicy( + key="sid-capacity-existing", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-existing", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_module._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=1.0, + idle_ttl_seconds=120.0, + codex_session=True, + prewarm_lock=anyio.Lock(), + ) + service._http_bridge_sessions[existing_key] = existing + inflight_key = proxy_module._HTTPBridgeSessionKey( + "session_header", + "sid-capacity-inflight", + None, + ) + inflight_future: asyncio.Future[proxy_module._HTTPBridgeSession] = ( + asyncio.get_running_loop().create_future() + ) + service._http_bridge_inflight_sessions[inflight_key] = inflight_future + + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_http_bridge_pending_count", AsyncMock(return_value=1)) + monkeypatch.setattr( + proxy_module, + "_http_bridge_should_wait_for_registration", + AsyncMock(return_value=False), + ) + monkeypatch.setattr(proxy_module, "_http_bridge_owner_instance", AsyncMock(return_value="pod-a")) + monkeypatch.setattr( + proxy_module, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("pod-a", ("pod-a",))), + ) + create_http_bridge_session = AsyncMock() + monkeypatch.setattr(service, "_create_http_bridge_session_compatible", create_http_bridge_session) + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + + capacity_waiter = asyncio.create_task( + service._get_or_create_http_bridge_session( + proxy_module._HTTPBridgeSessionKey("session_header", "sid-capacity-request", None), + headers={"x-codex-session-id": "sid-capacity-request"}, + affinity=proxy_module._AffinityPolicy( + key="sid-capacity-request", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=1, + ) + ) + await asyncio.sleep(0) + assert not capacity_waiter.done() + + started = asyncio.Event() + cancelled = asyncio.Event() + + async def factory(): + started.set() + try: + await asyncio.Future() + except asyncio.CancelledError: + cancelled.set() + raise + + singleflight_task = asyncio.create_task( + usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT.run("acc-lifespan-shutdown", factory) + ) + await started.wait() + with pytest.raises(ProxyResponseError) as capacity_exc: + await asyncio.wait_for(capacity_waiter, timeout=0.1) + assert capacity_exc.value.status_code == 503 + assert capacity_exc.value.payload["error"]["code"] == "upstream_unavailable" + create_http_bridge_session.assert_not_awaited() + + with pytest.raises(asyncio.CancelledError): + await singleflight_task + assert cancelled.is_set() + assert usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT._inflight == {} + finally: + usage_updater_module._clear_usage_refresh_state() + + @pytest.mark.asyncio async def test_lifespan_marks_bridge_membership_stale_for_hostname_shared_ids( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index 84d0287a..758d0616 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -179,6 +179,283 @@ async def test_get_or_create_http_bridge_session_replaces_live_session_when_acco assert any(call.args == (stale_session,) for call in close_session.await_args_list) +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_replaces_prompt_cache_session_promoted_to_codex( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-key", "key-1") + stale_session = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + request_model="gpt-5.4-mini", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + codex_session=True, + downstream_turn_state="http_turn_legacy", + downstream_turn_state_aliases={"http_turn_legacy"}, + previous_response_ids=set(), + ) + replacement_session = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = stale_session + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr( + service, + "_create_http_bridge_session", + AsyncMock(return_value=replacement_session), + ) + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + monkeypatch.setattr( + proxy_service, + "get_settings", + lambda: _make_app_settings(), + ) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a"])), + ) + close_session = AsyncMock() + monkeypatch.setattr(service, "_close_http_bridge_session", close_session) + + reused = await service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + api_key=_make_api_key(key_id="key-1", assigned_account_ids=["acc-1"], account_assignment_scope_enabled=True), + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + + assert reused is replacement_session + assert service._http_bridge_sessions[key] is replacement_session + assert stale_session.closed is True + assert any(call.args == (stale_session,) for call in close_session.await_args_list) + + +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_registers_turn_state_alias_without_rekeying_prompt_cache_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + prompt_cache_key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-key", "key-1") + session = proxy_service._HTTPBridgeSession( + key=prompt_cache_key, + headers={}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + codex_session=False, + downstream_turn_state=None, + downstream_turn_state_aliases=set(), + previous_response_ids={"resp_prev_1"}, + ) + service._http_bridge_sessions[prompt_cache_key] = session + service._http_bridge_previous_response_index[ + proxy_service._http_bridge_previous_response_alias_key("resp_prev_1", "key-1") + ] = prompt_cache_key + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a"])), + ) + refresh_durable = AsyncMock() + monkeypatch.setattr(service, "_refresh_durable_http_bridge_session", refresh_durable) + + resolved = await service._get_or_create_http_bridge_session( + prompt_cache_key, + headers={"x-codex-turn-state": "http_turn_promoted"}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + api_key=_make_api_key(key_id="key-1", assigned_account_ids=["acc-1"], account_assignment_scope_enabled=True), + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + previous_response_id="resp_prev_1", + ) + + assert resolved is session + assert session.key == prompt_cache_key + assert service._http_bridge_sessions[prompt_cache_key] is session + assert ( + service._http_bridge_previous_response_index[ + proxy_service._http_bridge_previous_response_alias_key("resp_prev_1", "key-1") + ] + == prompt_cache_key + ) + assert ( + service._http_bridge_turn_state_index[ + proxy_service._http_bridge_turn_state_alias_key("http_turn_promoted", "key-1") + ] + == prompt_cache_key + ) + refresh_durable.assert_awaited_once_with(session) + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_turn_state_request_ignores_prompt_cache_owner_mismatch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + {"model": "gpt-5.4", "instructions": "hi", "input": "hello"} + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-hard-turn-state", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + + def fake_prepare( + _prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_promoted", None), + headers={"x-codex-turn-state": "http_turn_promoted"}, + affinity=proxy_service._AffinityPolicy( + key="http_turn_promoted", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + captured_key: dict[str, object] = {} + captured_lookup: dict[str, object] = {} + + async def fake_get_or_create_http_bridge_session(*args: object, **kwargs: object): + captured_key["value"] = args[0] + captured_lookup["value"] = kwargs.get("durable_lookup") + return session + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="durable-prompt-cache", + canonical_kind="prompt_cache", + canonical_key="cache-derived", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-remote", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc) + timedelta(seconds=60), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_promoted", + latest_response_id=None, + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", fake_get_or_create_http_bridge_session) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "http_turn_promoted"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + key = cast(proxy_service._HTTPBridgeSessionKey, captured_key["value"]) + assert key.affinity_kind == "prompt_cache" + assert key.affinity_key == "cache-derived" + lookup = cast(proxy_service.DurableBridgeLookup, captured_lookup["value"]) + assert lookup.canonical_kind == "prompt_cache" + assert lookup.canonical_key == "cache-derived" + assert lookup.owner_instance_id == "instance-remote" + assert lookup.lease_expires_at is not None + + def test_http_bridge_session_key_infers_strength_from_affinity_kind() -> None: assert proxy_service._HTTPBridgeSessionKey("turn_state_header", "turn", None).strength == "hard" assert proxy_service._HTTPBridgeSessionKey("session_header", "session", None).strength == "hard" @@ -479,34 +756,1200 @@ def fake_prepare( @pytest.mark.asyncio -async def test_http_bridge_waits_for_registration_for_hard_keys_before_startup_complete( +async def test_stream_via_http_bridge_does_not_inject_durable_previous_response_anchor_for_full_resend_payload( monkeypatch: pytest.MonkeyPatch, ) -> None: - import app.core.startup as startup_module - service = proxy_service.ProxyService(cast(Any, nullcontext())) - settings = Settings( - http_responses_session_bridge_instance_id="instance-a", - http_responses_session_bridge_advertise_base_url="http://instance-a.bridge.default.svc.cluster.local:2455", + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + {"role": "user", "content": "follow up"}, + ], + }, ) - monkeypatch.setattr(startup_module, "_startup_complete", False) - monkeypatch.setattr(startup_module, "_bridge_registration_complete", False) - - assert ( - await proxy_service._http_bridge_should_wait_for_registration( - service, - proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), - settings, - ) - is True + request_state = proxy_service._WebSocketRequestState( + request_id="req-full-resend", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} - -@pytest.mark.asyncio -async def test_forward_http_bridge_request_to_owner_preserves_session_header_key( - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = proxy_service.ProxyService(cast(Any, nullcontext())) + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), + headers={"x-codex-session-id": "sid-123"}, + affinity=proxy_service._AffinityPolicy( + key="sid-123", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-1", + canonical_kind="session_header", + canonical_key="sid-123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_1", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-session-id": "sid-123"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_inject_durable_previous_response_anchor_for_explicit_prompt_cache_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "thread-123", + }, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-prompt-cache-anchor", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "thread-123", None), + headers={}, + affinity=proxy_service._AffinityPolicy( + key="thread-123", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-1", + canonical_kind="prompt_cache", + canonical_key="thread-123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_1", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_prefer_durable_account_for_soft_prompt_cache_lookup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "thread-soft", + }, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-soft-prompt-cache", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "thread-soft", None), + headers={}, + affinity=proxy_service._AffinityPolicy( + key="thread-soft", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-soft-prompt-cache", + canonical_kind="prompt_cache", + canonical_key="thread-soft", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_soft", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + + async def fake_get_or_create( + *args: object, + **kwargs: object, + ) -> proxy_service._HTTPBridgeSession: + captured["preferred_account_id"] = kwargs.get("preferred_account_id") + return session + + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", fake_get_or_create) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + assert captured["preferred_account_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_prefers_durable_account_for_soft_prompt_cache_follow_up_recovery( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello again", + "prompt_cache_key": "thread-soft-follow-up", + }, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-soft-prompt-cache-follow-up", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "thread-soft-follow-up", None), + headers={}, + affinity=proxy_service._AffinityPolicy( + key="thread-soft-follow-up", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-soft-follow-up", + canonical_kind="prompt_cache", + canonical_key="thread-soft-follow-up", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_soft_follow_up", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + + async def fake_get_or_create( + *args: object, + **kwargs: object, + ) -> proxy_service._HTTPBridgeSession: + captured["preferred_account_id"] = kwargs.get("preferred_account_id") + captured["request_stage"] = kwargs.get("request_stage") + return session + + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", fake_get_or_create) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "http_turn_soft_follow_up"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + assert captured["request_stage"] == "follow_up" + assert captured["preferred_account_id"] == "acc-1" + + +@pytest.mark.asyncio +async def test_close_http_bridge_session_fails_pending_downstream_requests() -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + event_queue: asyncio.Queue[str | None] = asyncio.Queue() + request_state = proxy_service._WebSocketRequestState( + request_id="req-bridge-close", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + event_queue=event_queue, + transport="http", + ) + pending_requests = deque([request_state]) + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "close-thread", None), + headers={}, + affinity=proxy_service._AffinityPolicy( + key="close-thread", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-close", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=time.monotonic(), + idle_ttl_seconds=120.0, + ) + + await service._close_http_bridge_session(session) + + failed_event = await asyncio.wait_for(event_queue.get(), timeout=1.0) + assert failed_event is not None + assert '"code":"stream_incomplete"' in failed_event + assert "HTTP bridge session closed before response.completed" in failed_event + assert await asyncio.wait_for(event_queue.get(), timeout=1.0) is None + assert list(session.pending_requests) == [] + assert session.queued_request_count == 0 + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_inject_durable_anchor_for_live_turn_state_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + {"model": "gpt-5.4", "instructions": "hi", "input": "hello"}, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-live-turn-state", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session_key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_live", None) + session = proxy_service._HTTPBridgeSession( + key=session_key, + headers={"x-codex-turn-state": "http_turn_live"}, + affinity=proxy_service._AffinityPolicy( + key="http_turn_live", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[session_key] = session + service._http_bridge_turn_state_index[ + proxy_service._http_bridge_turn_state_alias_key("http_turn_live", None) + ] = session_key + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-live-turn-state", + canonical_kind="turn_state_header", + canonical_key="http_turn_live", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_live", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "http_turn_live"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_inject_durable_anchor_for_live_prompt_cache_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "thread-live", + }, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-live-prompt-cache", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session_key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "thread-live", None) + session = proxy_service._HTTPBridgeSession( + key=session_key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="thread-live", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[session_key] = session + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-live-prompt-cache", + canonical_kind="prompt_cache", + canonical_key="thread-live", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state=None, + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_inject_durable_anchor_when_forwarding_to_owner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + {"model": "gpt-5.4", "instructions": "hi", "input": "hello"}, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-forward-owner", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + owner_forward = proxy_service._HTTPBridgeOwnerForward( + owner_instance="instance-b", + owner_endpoint="http://instance-b", + key=proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_forward", None), + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr( + proxy_service, + "get_settings", + lambda: Settings( + http_responses_session_bridge_enabled=True, + http_responses_session_bridge_instance_id="instance-a", + ), + ) + service._ring_membership = cast( + Any, + SimpleNamespace(resolve_endpoint=AsyncMock(return_value="http://instance-b")), + ) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-forward-owner", + canonical_kind="turn_state_header", + canonical_key="http_turn_forward", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-b", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc) + timedelta(seconds=60), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_forward", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=owner_forward)) + + async def fake_forward_http_bridge_request_to_owner(**kwargs: object): + del kwargs + if False: + yield "" + return + + monkeypatch.setattr(service, "_forward_http_bridge_request_to_owner", fake_forward_http_bridge_request_to_owner) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "http_turn_forward"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_does_not_inject_durable_previous_response_anchor_for_derived_prompt_cache_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + {"model": "gpt-5.4", "instructions": "hi", "input": "hello"}, + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-derived-prompt-cache-anchor", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured: dict[str, object] = {} + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + captured["previous_response_id"] = prepared_payload.previous_response_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "derived-thread-123", None), + headers={}, + affinity=proxy_service._AffinityPolicy( + key="derived-thread-123", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + openai_prompt_cache_key_derivation_enabled=True, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + service._durable_bridge, + "lookup_request_targets", + AsyncMock( + return_value=proxy_service.DurableBridgeLookup( + session_id="sess-1", + canonical_kind="prompt_cache", + canonical_key="derived-thread-123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_1", + latest_response_id="resp_latest", + ) + ), + ) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + assert captured["previous_response_id"] is None + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_resolves_previous_response_owner_from_request_logs_when_durable_lookup_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "previous_response_id": "resp_prev_owner_lookup", + } + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-owner-lookup", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_owner_lookup", + ) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put(None) + captured_preferred: dict[str, object] = {} + + def fake_prepare( + _prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + return request_state, '{"type":"response.create"}' + + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), + headers={"x-codex-session-id": "sid-123"}, + affinity=proxy_service._AffinityPolicy( + key="sid-123", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + owner_lookup = AsyncMock(return_value="acc-owner-from-logs") + monkeypatch.setattr(service, "_resolve_websocket_previous_response_owner", owner_lookup) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + + async def fake_get_or_create_http_bridge_session(*args: object, **kwargs: object): + captured_preferred["value"] = kwargs.get("preferred_account_id") + return session + + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", fake_get_or_create_http_bridge_session) + monkeypatch.setattr(service, "_submit_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "turn_http_owner"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == [] + owner_lookup.assert_awaited_once_with( + previous_response_id="resp_prev_owner_lookup", + api_key=None, + session_id="turn_http_owner", + ) + assert captured_preferred["value"] == "acc-owner-from-logs" + + +@pytest.mark.asyncio +async def test_http_bridge_waits_for_registration_for_hard_keys_before_startup_complete( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import app.core.startup as startup_module + + service = proxy_service.ProxyService(cast(Any, nullcontext())) + settings = Settings( + http_responses_session_bridge_instance_id="instance-a", + http_responses_session_bridge_advertise_base_url="http://instance-a.bridge.default.svc.cluster.local:2455", + ) + monkeypatch.setattr(startup_module, "_startup_complete", False) + monkeypatch.setattr(startup_module, "_bridge_registration_complete", False) + + assert ( + await proxy_service._http_bridge_should_wait_for_registration( + service, + proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), + settings, + ) + is True + ) + + +@pytest.mark.asyncio +async def test_forward_http_bridge_request_to_owner_preserves_session_header_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) owner_forward = proxy_service._HTTPBridgeOwnerForward( owner_instance="instance-b", owner_endpoint="http://instance-b", @@ -786,7 +2229,168 @@ async def test_stream_via_http_bridge_reacquires_api_key_reservation_for_local_p model="gpt-5.4", service_tier=None, reasoning_effort=None, - api_key_reservation=initial_reservation, + api_key_reservation=initial_reservation, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + request_state_initial.request_stage = "follow_up" + request_state_initial.preferred_account_id = "acc-1" + request_state_retry = proxy_service._WebSocketRequestState( + request_id="req-retry", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=retried_reservation, + started_at=2.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + + prepare_reservations: list[proxy_service.ApiKeyUsageReservationData | None] = [] + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del prepared_payload, api_key, request_id + prepare_reservations.append(api_key_reservation) + if len(prepare_reservations) == 1: + return request_state_initial, '{"type":"response.create","request":"initial"}' + return request_state_retry, '{"type":"response.create","request":"retry"}' + + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", api_key.id) + session_initial = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + session_retry = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session_initial + + stream_calls = {"count": 0} + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + stream_calls["count"] += 1 + if stream_calls["count"] == 1: + raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) + yield 'data: {"type":"response.completed"}\n\n' + + reserve_retry = AsyncMock(return_value=retried_reservation) + get_or_create = AsyncMock(side_effect=[session_initial, session_retry]) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_retry) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=False, + openai_cache_affinity=True, + api_key=api_key, + api_key_reservation=initial_reservation, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == ['data: {"type":"response.completed"}\n\n'] + assert prepare_reservations == [initial_reservation, retried_reservation] + reserve_retry.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_local_previous_response_rebind_fails_existing_pending_requests( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "prompt_cache_key": "bridge-prev-rebind", + "previous_response_id": "resp_prev_1", + } + ) + + request_state_initial = proxy_service._WebSocketRequestState( + request_id="req-initial", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, started_at=1.0, event_queue=asyncio.Queue(), transport="http", @@ -799,14 +2403,27 @@ async def test_stream_via_http_bridge_reacquires_api_key_reservation_for_local_p model="gpt-5.4", service_tier=None, reasoning_effort=None, - api_key_reservation=retried_reservation, + api_key_reservation=None, started_at=2.0, event_queue=asyncio.Queue(), transport="http", previous_response_id="resp_prev_1", ) - prepare_reservations: list[proxy_service.ApiKeyUsageReservationData | None] = [] + stale_pending_queue: asyncio.Queue[str | None] = asyncio.Queue() + stale_pending_request = proxy_service._WebSocketRequestState( + request_id="req-stale", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.5, + event_queue=stale_pending_queue, + transport="http", + ) + stale_pending_request.skip_request_log = True + + prepare_calls = {"count": 0} def fake_prepare( prepared_payload: proxy_service.ResponsesRequest, @@ -816,13 +2433,13 @@ def fake_prepare( api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, request_id: str, ) -> tuple[proxy_service._WebSocketRequestState, str]: - del prepared_payload, api_key, request_id - prepare_reservations.append(api_key_reservation) - if len(prepare_reservations) == 1: + del prepared_payload, api_key, api_key_reservation, request_id + prepare_calls["count"] += 1 + if prepare_calls["count"] == 1: return request_state_initial, '{"type":"response.create","request":"initial"}' return request_state_retry, '{"type":"response.create","request":"retry"}' - key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", api_key.id) + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", None) session_initial = proxy_service._HTTPBridgeSession( key=key, headers={}, @@ -833,10 +2450,10 @@ def fake_prepare( account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), upstream_control=proxy_service._WebSocketUpstreamControl(), - pending_requests=deque(), + pending_requests=deque([stale_pending_request]), pending_lock=anyio.Lock(), response_create_gate=asyncio.Semaphore(1), - queued_request_count=0, + queued_request_count=1, last_used_at=1.0, idle_ttl_seconds=120.0, ) @@ -876,7 +2493,6 @@ async def fake_stream_http_bridge_session_events( raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) yield 'data: {"type":"response.completed"}\n\n' - reserve_retry = AsyncMock(return_value=retried_reservation) get_or_create = AsyncMock(side_effect=[session_initial, session_retry]) monkeypatch.setattr( @@ -902,7 +2518,6 @@ async def fake_stream_http_bridge_session_events( monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) - monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_retry) chunks = [ chunk @@ -912,8 +2527,8 @@ async def fake_stream_http_bridge_session_events( codex_session_affinity=False, propagate_http_errors=False, openai_cache_affinity=True, - api_key=api_key, - api_key_reservation=initial_reservation, + api_key=None, + api_key_reservation=None, suppress_text_done_events=False, idle_ttl_seconds=120.0, codex_idle_ttl_seconds=900.0, @@ -922,13 +2537,20 @@ async def fake_stream_http_bridge_session_events( ) ] + failed_block = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + done_marker = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + assert chunks == ['data: {"type":"response.completed"}\n\n'] - assert prepare_reservations == [initial_reservation, retried_reservation] - reserve_retry.assert_awaited_once() + assert isinstance(failed_block, str) + assert '"type":"response.failed"' in failed_block + assert '"code":"stream_incomplete"' in failed_block + assert done_marker is None + assert not session_initial.pending_requests + assert session_initial.queued_request_count == 0 @pytest.mark.asyncio -async def test_stream_via_http_bridge_local_previous_response_rebind_fails_existing_pending_requests( +async def test_stream_via_http_bridge_rolls_over_session_after_context_length_exceeded( monkeypatch: pytest.MonkeyPatch, ) -> None: service = proxy_service.ProxyService(cast(Any, nullcontext())) @@ -937,13 +2559,12 @@ async def test_stream_via_http_bridge_local_previous_response_rebind_fails_exist "model": "gpt-5.4", "instructions": "hi", "input": "hello", - "prompt_cache_key": "bridge-prev-rebind", - "previous_response_id": "resp_prev_1", + "prompt_cache_key": "bridge-context-overflow", } ) - request_state_initial = proxy_service._WebSocketRequestState( - request_id="req-initial", + request_state = proxy_service._WebSocketRequestState( + request_id="req-context-overflow", model="gpt-5.4", service_tier=None, reasoning_effort=None, @@ -951,22 +2572,7 @@ async def test_stream_via_http_bridge_local_previous_response_rebind_fails_exist started_at=1.0, event_queue=asyncio.Queue(), transport="http", - previous_response_id="resp_prev_1", - ) - request_state_initial.request_stage = "follow_up" - request_state_initial.preferred_account_id = "acc-1" - request_state_retry = proxy_service._WebSocketRequestState( - request_id="req-retry", - model="gpt-5.4", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=2.0, - event_queue=asyncio.Queue(), - transport="http", - previous_response_id="resp_prev_1", ) - stale_pending_queue: asyncio.Queue[str | None] = asyncio.Queue() stale_pending_request = proxy_service._WebSocketRequestState( request_id="req-stale", @@ -980,7 +2586,139 @@ async def test_stream_via_http_bridge_local_previous_response_rebind_fails_exist ) stale_pending_request.skip_request_log = True - prepare_calls = {"count": 0} + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del prepared_payload, api_key, api_key_reservation, request_id + return request_state, '{"type":"response.create","request":"initial"}' + + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-context-overflow", None) + session = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-context-overflow", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque([stale_pending_request]), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + raise ProxyResponseError( + 400, + proxy_service.openai_error( + "context_length_exceeded", + "Your input exceeds the context window of this model.", + error_type="invalid_request_error", + ), + ) + yield + + close_session = AsyncMock() + get_or_create = AsyncMock(return_value=session) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + monkeypatch.setattr(service, "_close_http_bridge_session", close_session) + + with pytest.raises(ProxyResponseError) as exc_info: + async for _ in service._stream_via_http_bridge( + payload, + headers={}, + codex_session_affinity=False, + propagate_http_errors=True, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ): + pass + + failed_block = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + done_marker = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + + assert exc_info.value.status_code == 400 + assert exc_info.value.payload["error"]["code"] == "context_length_exceeded" + assert key not in service._http_bridge_sessions + close_session.assert_awaited_once_with(session) + assert isinstance(failed_block, str) + assert '"type":"response.failed"' in failed_block + assert '"code":"stream_incomplete"' in failed_block + assert done_marker is None + assert not session.pending_requests + assert session.queued_request_count == 0 + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_context_overflow_keeps_hard_affinity_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + } + ) + + request_state = proxy_service._WebSocketRequestState( + request_id="req-context-overflow-hard", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + ) def fake_prepare( prepared_payload: proxy_service.ResponsesRequest, @@ -991,34 +2729,119 @@ def fake_prepare( request_id: str, ) -> tuple[proxy_service._WebSocketRequestState, str]: del prepared_payload, api_key, api_key_reservation, request_id - prepare_calls["count"] += 1 - if prepare_calls["count"] == 1: - return request_state_initial, '{"type":"response.create","request":"initial"}' - return request_state_retry, '{"type":"response.create","request":"retry"}' + return request_state, '{"type":"response.create","request":"initial"}' + + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "turn_hard_overflow", None) + session = proxy_service._HTTPBridgeSession( + key=key, + headers={"x-codex-turn-state": "turn_hard_overflow"}, + affinity=proxy_service._AffinityPolicy( + key="turn_hard_overflow", kind=proxy_service.StickySessionKind.CODEX_SESSION + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + raise ProxyResponseError( + 400, + proxy_service.openai_error( + "context_length_exceeded", + "Your input exceeds the context window of this model.", + error_type="invalid_request_error", + ), + ) + yield + + close_session = AsyncMock() + get_or_create = AsyncMock(return_value=session) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + monkeypatch.setattr(service, "_close_http_bridge_session", close_session) + + with pytest.raises(ProxyResponseError) as exc_info: + async for _ in service._stream_via_http_bridge( + payload, + headers={"x-codex-turn-state": "turn_hard_overflow"}, + codex_session_affinity=True, + propagate_http_errors=True, + openai_cache_affinity=True, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ): + pass - key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", None) - session_initial = proxy_service._HTTPBridgeSession( - key=key, - headers={}, - affinity=proxy_service._AffinityPolicy( - key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE - ), - request_model="gpt-5.4", - account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), - upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), - upstream_control=proxy_service._WebSocketUpstreamControl(), - pending_requests=deque([stale_pending_request]), - pending_lock=anyio.Lock(), - response_create_gate=asyncio.Semaphore(1), - queued_request_count=1, - last_used_at=1.0, - idle_ttl_seconds=120.0, + assert exc_info.value.status_code == 400 + assert exc_info.value.payload["error"]["code"] == "context_length_exceeded" + assert key in service._http_bridge_sessions + close_session.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_context_overflow_does_not_retry_hard_affinity_with_previous_response_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "previous_response_id": "resp_prev_123", + } ) - session_retry = proxy_service._HTTPBridgeSession( + + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "turn_hard_overflow_recover", None) + initial_session = proxy_service._HTTPBridgeSession( key=key, - headers={}, + headers={"x-codex-turn-state": "turn_hard_overflow_recover"}, affinity=proxy_service._AffinityPolicy( - key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + key="turn_hard_overflow_recover", + kind=proxy_service.StickySessionKind.CODEX_SESSION, ), request_model="gpt-5.4", account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), @@ -1028,12 +2851,38 @@ def fake_prepare( pending_lock=anyio.Lock(), response_create_gate=asyncio.Semaphore(1), queued_request_count=0, - last_used_at=2.0, + last_used_at=1.0, idle_ttl_seconds=120.0, ) - service._http_bridge_sessions[key] = session_initial + service._http_bridge_sessions[key] = initial_session - stream_calls = {"count": 0} + prepare_previous_response_ids: list[str | None] = [] + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation + prepare_previous_response_ids.append(prepared_payload.previous_response_id) + request_state = proxy_service._WebSocketRequestState( + request_id=request_id, + model=prepared_payload.model, + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id=prepared_payload.previous_response_id, + session_id="turn_hard_overflow_recover", + ) + return request_state, '{"type":"response.create"}' + + stream_attempt = 0 async def fake_stream_http_bridge_session_events( _session: proxy_service._HTTPBridgeSession, @@ -1044,13 +2893,22 @@ async def fake_stream_http_bridge_session_events( propagate_http_errors: bool, downstream_turn_state: str | None, ): + nonlocal stream_attempt del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state - stream_calls["count"] += 1 - if stream_calls["count"] == 1: - raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) + stream_attempt += 1 + if stream_attempt == 1: + raise ProxyResponseError( + 400, + proxy_service.openai_error( + "context_length_exceeded", + "Your input exceeds the context window of this model.", + error_type="invalid_request_error", + ), + ) yield 'data: {"type":"response.completed"}\n\n' - get_or_create = AsyncMock(side_effect=[session_initial, session_retry]) + close_session = AsyncMock() + get_or_create = AsyncMock(return_value=initial_session) monkeypatch.setattr( proxy_service, @@ -1074,15 +2932,14 @@ async def fake_stream_http_bridge_session_events( monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) - monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + monkeypatch.setattr(service, "_close_http_bridge_session", close_session) - chunks = [ - chunk - async for chunk in service._stream_via_http_bridge( + with pytest.raises(ProxyResponseError) as exc_info: + async for _ in service._stream_via_http_bridge( payload, - headers={}, - codex_session_affinity=False, - propagate_http_errors=False, + headers={"x-codex-turn-state": "turn_hard_overflow_recover"}, + codex_session_affinity=True, + propagate_http_errors=True, openai_cache_affinity=True, api_key=None, api_key_reservation=None, @@ -1091,19 +2948,16 @@ async def fake_stream_http_bridge_session_events( codex_idle_ttl_seconds=900.0, max_sessions=8, queue_limit=4, - ) - ] - - failed_block = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) - done_marker = await asyncio.wait_for(stale_pending_queue.get(), timeout=0.2) + downstream_turn_state="turn_hard_overflow_recover", + ): + pass - assert chunks == ['data: {"type":"response.completed"}\n\n'] - assert isinstance(failed_block, str) - assert '"type":"response.failed"' in failed_block - assert '"code":"stream_incomplete"' in failed_block - assert done_marker is None - assert not session_initial.pending_requests - assert session_initial.queued_request_count == 0 + assert exc_info.value.status_code == 400 + assert exc_info.value.payload["error"]["code"] == "context_length_exceeded" + assert prepare_previous_response_ids == ["resp_prev_123"] + assert stream_attempt == 1 + close_session.assert_not_awaited() + assert len(get_or_create.await_args_list) == 1 @pytest.mark.asyncio @@ -1289,6 +3143,89 @@ async def fake_create_http_bridge_session( assert captured["key"] == fallback_key +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_preserves_durable_canonical_prompt_cache_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + requested_key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "pc-123", None) + created_session = proxy_service._HTTPBridgeSession( + key=requested_key, + headers={"x-codex-turn-state": "http_turn_generated"}, + affinity=proxy_service._AffinityPolicy( + key="pc-123", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + captured: dict[str, object] = {} + + async def fake_create_http_bridge_session( + create_key: proxy_service._HTTPBridgeSessionKey, + *, + headers: dict[str, str], + affinity: proxy_service._AffinityPolicy, + api_key: proxy_service.ApiKeyData | None, + request_model: str | None, + idle_ttl_seconds: float, + request_stage: str = "first_turn", + preferred_account_id: str | None = None, + ) -> proxy_service._HTTPBridgeSession: + del headers, affinity, api_key, request_model, idle_ttl_seconds, request_stage, preferred_account_id + captured["key"] = create_key + return created_session + + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_create_http_bridge_session", fake_create_http_bridge_session) + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + ) + + resolved = await service._get_or_create_http_bridge_session( + requested_key, + headers={"x-codex-turn-state": "http_turn_generated"}, + affinity=proxy_service._AffinityPolicy( + key="pc-123", + kind=proxy_service.StickySessionKind.PROMPT_CACHE, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + previous_response_id="resp_prev_1", + durable_lookup=proxy_service.DurableBridgeLookup( + session_id="durable-1", + canonical_kind="prompt_cache", + canonical_key="pc-123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=2, + lease_expires_at=proxy_service.utcnow() + timedelta(seconds=60), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_generated", + latest_response_id="resp_prev_1", + ), + ) + + assert resolved is created_session + assert captured["key"] == requested_key + + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_recovers_from_previous_response_id_mapping( monkeypatch: pytest.MonkeyPatch, @@ -1469,60 +3406,205 @@ async def test_should_attempt_local_bootstrap_rebind_for_session_header_without_ is True ) - assert ( - proxy_service._http_bridge_should_attempt_local_bootstrap_rebind( - exc, - key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), - headers={"x-codex-session-id": "sid-123", "x-codex-turn-state": "http_turn_123"}, - previous_response_id=None, - ) - is False - ) + assert ( + proxy_service._http_bridge_should_attempt_local_bootstrap_rebind( + exc, + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None), + headers={"x-codex-session-id": "sid-123", "x-codex-turn-state": "http_turn_123"}, + previous_response_id=None, + ) + is False + ) + + +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_endpoint_missing_without_anchor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_123", None) + created_session = proxy_service._HTTPBridgeSession( + key=key, + headers={"x-codex-turn-state": "http_turn_123"}, + affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + claim_durable = AsyncMock() + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-b")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + ) + service._ring_membership = cast(Any, SimpleNamespace(resolve_endpoint=AsyncMock(return_value=None))) + + resolved = await service._get_or_create_http_bridge_session( + key, + headers={"x-codex-turn-state": "http_turn_123"}, + affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + allow_forward_to_owner=True, + ) + + assert resolved is created_session + claim_durable.assert_awaited_once() + assert claim_durable.await_args.kwargs["allow_takeover"] is True + service._ring_membership.resolve_endpoint.assert_awaited_once_with("instance-b") + + +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_endpoint_missing_but_replay_anchor_exists( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_123", None) + created_session = proxy_service._HTTPBridgeSession( + key=key, + headers={"x-codex-turn-state": "http_turn_123"}, + affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + claim_durable = AsyncMock() + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-b")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + ) + service._ring_membership = cast(Any, SimpleNamespace(resolve_endpoint=AsyncMock(return_value=None))) + + resolved = await service._get_or_create_http_bridge_session( + key, + headers={"x-codex-turn-state": "http_turn_123"}, + affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + previous_response_id="resp_prev_1", + allow_forward_to_owner=True, + durable_lookup=proxy_service.DurableBridgeLookup( + session_id="durable-1", + canonical_kind="turn_state_header", + canonical_key="http_turn_123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-b", + owner_epoch=2, + lease_expires_at=proxy_service.utcnow() + timedelta(seconds=60), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_123", + latest_response_id="resp_prev_1", + ), + ) + + assert resolved is created_session + claim_durable.assert_awaited_once() + assert claim_durable.await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio -async def test_get_or_create_http_bridge_session_falls_back_to_retry_when_owner_endpoint_missing( +async def test_get_or_create_http_bridge_session_recovers_locally_without_anchor_for_single_instance_stale_owner( monkeypatch: pytest.MonkeyPatch, ) -> None: service = proxy_service.ProxyService(cast(Any, nullcontext())) - key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_123", None) + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "turn_123", None) + created_session = proxy_service._HTTPBridgeSession( + key=key, + headers={"x-codex-turn-state": "turn_123"}, + affinity=proxy_service._AffinityPolicy(key="turn_123"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + claim_durable = AsyncMock() + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) - monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-b")) + service._ring_membership = None monkeypatch.setattr( proxy_service, "_active_http_bridge_instance_ring", - AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + AsyncMock(return_value=("instance-a", ("instance-a",))), ) - service._ring_membership = cast(Any, SimpleNamespace(resolve_endpoint=AsyncMock(return_value=None))) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) - with pytest.raises(ProxyResponseError) as exc_info: - await service._get_or_create_http_bridge_session( - key, - headers={"x-codex-turn-state": "http_turn_123"}, - affinity=proxy_service._AffinityPolicy(key="http_turn_123"), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - allow_forward_to_owner=True, - ) + resolved = await service._get_or_create_http_bridge_session( + key, + headers={"x-codex-turn-state": "turn_123"}, + affinity=proxy_service._AffinityPolicy(key="turn_123"), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + allow_forward_to_owner=True, + durable_lookup=proxy_service.DurableBridgeLookup( + session_id="durable-1", + canonical_kind="turn_state_header", + canonical_key="turn_123", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-stale", + owner_epoch=2, + lease_expires_at=proxy_service.utcnow() + timedelta(seconds=60), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="turn_123", + latest_response_id=None, + ), + ) - assert exc_info.value.status_code == 409 - assert exc_info.value.payload["error"]["code"] == "bridge_instance_mismatch" - service._ring_membership.resolve_endpoint.assert_awaited_once_with("instance-b") + assert resolved is created_session + claim_durable.assert_awaited_once() + assert claim_durable.await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio -async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_endpoint_missing_but_replay_anchor_exists( +async def test_get_or_create_http_bridge_session_prompt_cache_takes_over_stale_single_instance_owner( monkeypatch: pytest.MonkeyPatch, ) -> None: service = proxy_service.ProxyService(cast(Any, nullcontext())) - key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_123", None) + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "cache-key", None) created_session = proxy_service._HTTPBridgeSession( key=key, - headers={"x-codex-turn-state": "http_turn_123"}, - affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + headers={}, + affinity=proxy_service._AffinityPolicy(key="cache-key"), request_model="gpt-5.4", account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), @@ -1536,43 +3618,44 @@ async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_end ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) - monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) - monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + claim_durable = AsyncMock() + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) - monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-b")) + service._ring_membership = None monkeypatch.setattr( proxy_service, "_active_http_bridge_instance_ring", - AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + AsyncMock(return_value=("instance-a", ("instance-a",))), ) - service._ring_membership = cast(Any, SimpleNamespace(resolve_endpoint=AsyncMock(return_value=None))) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) resolved = await service._get_or_create_http_bridge_session( key, - headers={"x-codex-turn-state": "http_turn_123"}, - affinity=proxy_service._AffinityPolicy(key="http_turn_123"), + headers={}, + affinity=proxy_service._AffinityPolicy(key="cache-key"), api_key=None, request_model="gpt-5.4", idle_ttl_seconds=120.0, max_sessions=8, - previous_response_id="resp_prev_1", allow_forward_to_owner=True, durable_lookup=proxy_service.DurableBridgeLookup( session_id="durable-1", - canonical_kind="turn_state_header", - canonical_key="http_turn_123", + canonical_kind="prompt_cache", + canonical_key="cache-key", api_key_scope="__anonymous__", account_id="acc-1", - owner_instance_id="instance-b", + owner_instance_id="instance-stale", owner_epoch=2, lease_expires_at=proxy_service.utcnow() + timedelta(seconds=60), state=HttpBridgeSessionState.ACTIVE, - latest_turn_state="http_turn_123", - latest_response_id="resp_prev_1", + latest_turn_state="http_turn_prompt_cache", + latest_response_id=None, ), ) assert resolved is created_session + claim_durable.assert_awaited_once() + assert claim_durable.await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio @@ -1735,6 +3818,143 @@ async def _call() -> proxy_service._HTTPBridgeSession: assert all(call.args == (created_session,) for call in close_session.await_args_list) +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_waiter_propagates_terminal_inflight_proxy_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-race", None) + inflight_future: asyncio.Future[proxy_service._HTTPBridgeSession] = asyncio.get_running_loop().create_future() + inflight_future.set_exception( + ProxyResponseError( + 409, + proxy_service.openai_error( + "bridge_instance_mismatch", + "HTTP bridge session is owned by a different instance; retry to reach the correct replica", + error_type="server_error", + ), + ) + ) + service._http_bridge_inflight_sessions[key] = inflight_future + + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ("instance-a",))), + ) + + with pytest.raises(ProxyResponseError) as exc_info: + await asyncio.wait_for( + service._get_or_create_http_bridge_session( + key, + headers={"x-codex-session-id": "sid-race"}, + affinity=proxy_service._AffinityPolicy( + key="sid-race", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ), + timeout=0.1, + ) + + assert exc_info.value.status_code == 409 + assert exc_info.value.payload["error"]["code"] == "bridge_instance_mismatch" + + +@pytest.mark.asyncio +async def test_close_all_http_bridge_sessions_fails_inflight_waiters() -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-shutdown", None) + inflight_future: asyncio.Future[proxy_service._HTTPBridgeSession] = asyncio.get_running_loop().create_future() + service._http_bridge_inflight_sessions[key] = inflight_future + + await service.close_all_http_bridge_sessions() + + with pytest.raises(ProxyResponseError) as exc_info: + await inflight_future + + assert exc_info.value.status_code == 503 + assert exc_info.value.payload["error"]["code"] == "upstream_unavailable" + + +@pytest.mark.asyncio +async def test_close_all_http_bridge_sessions_fails_capacity_waiters_instead_of_creating_new_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + existing_key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-capacity-existing", None) + existing = proxy_service._HTTPBridgeSession( + key=existing_key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="sid-capacity-existing", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-existing", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque([object()]), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=1.0, + idle_ttl_seconds=120.0, + codex_session=True, + prewarm_lock=anyio.Lock(), + ) + service._http_bridge_sessions[existing_key] = existing + inflight_key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-capacity-inflight", None) + inflight_future: asyncio.Future[proxy_service._HTTPBridgeSession] = asyncio.get_running_loop().create_future() + service._http_bridge_inflight_sessions[inflight_key] = inflight_future + + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_http_bridge_pending_count", AsyncMock(return_value=1)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_should_wait_for_registration", AsyncMock(return_value=False)) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ("instance-a",))), + ) + create_http_bridge_session = AsyncMock() + monkeypatch.setattr(service, "_create_http_bridge_session_compatible", create_http_bridge_session) + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + monkeypatch.setattr(service, "_close_http_bridge_session", AsyncMock()) + + waiter = asyncio.create_task( + service._get_or_create_http_bridge_session( + proxy_service._HTTPBridgeSessionKey("session_header", "sid-capacity-request", None), + headers={"x-codex-session-id": "sid-capacity-request"}, + affinity=proxy_service._AffinityPolicy( + key="sid-capacity-request", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=1, + ) + ) + await asyncio.sleep(0) + + await service.close_all_http_bridge_sessions() + + with pytest.raises(ProxyResponseError) as exc_info: + await asyncio.wait_for(waiter, timeout=0.1) + + assert exc_info.value.status_code == 503 + assert exc_info.value.payload["error"]["code"] == "upstream_unavailable" + create_http_bridge_session.assert_not_awaited() + + @pytest.mark.asyncio async def test_claim_durable_http_bridge_session_propagates_claim_failure( monkeypatch: pytest.MonkeyPatch, @@ -1849,9 +4069,12 @@ async def test_claim_durable_http_bridge_session_rejects_remote_owner_without_ta ) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) - with pytest.raises(RuntimeError, match="still owned by another instance"): + with pytest.raises(ProxyResponseError) as exc_info: await service._claim_durable_http_bridge_session(session, allow_takeover=False) + assert exc_info.value.status_code == 409 + assert exc_info.value.payload["error"]["code"] == "bridge_instance_mismatch" + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_allows_local_bootstrap_when_ring_lookup_fails( @@ -1989,6 +4212,33 @@ def test_http_bridge_can_recover_during_drain_for_session_header_bootstrap() -> ) +def test_http_bridge_can_recover_during_drain_ignores_soft_prompt_cache_latest_response_anchor() -> None: + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "cache-key", None) + durable_lookup = proxy_service.DurableBridgeLookup( + session_id="sess-soft", + canonical_kind="prompt_cache", + canonical_key="cache-key", + api_key_scope="__anonymous__", + account_id="acc-1", + owner_instance_id="instance-a", + owner_epoch=1, + lease_expires_at=datetime.now(timezone.utc), + state=HttpBridgeSessionState.ACTIVE, + latest_turn_state="http_turn_soft", + latest_response_id="resp_soft", + ) + + assert ( + proxy_service._http_bridge_can_recover_during_drain( + key=key, + headers={}, + previous_response_id=None, + durable_lookup=durable_lookup, + ) + is False + ) + + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_soft_mismatch_rebinds_locally( monkeypatch: pytest.MonkeyPatch, @@ -2127,6 +4377,10 @@ async def test_get_or_create_http_bridge_session_prevents_forward_loops( service = proxy_service.ProxyService(cast(Any, nullcontext())) key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_123", None) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + create_http_bridge_session = AsyncMock() + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) + claim_durable = AsyncMock() + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) monkeypatch.setattr( proxy_service, "get_settings", @@ -2154,6 +4408,8 @@ async def test_get_or_create_http_bridge_session_prevents_forward_loops( assert exc_info.value.status_code == 503 assert exc_info.value.payload["error"]["code"] == "bridge_forward_loop_prevented" + create_http_bridge_session.assert_not_awaited() + claim_durable.assert_not_awaited() @pytest.mark.asyncio @@ -2232,3 +4488,121 @@ async def test_get_or_create_http_bridge_session_replaces_live_session_when_scop assert service._http_bridge_sessions[key] is replacement_session assert stale_session.closed is True assert any(call.args == (stale_session,) for call in close_session.await_args_list) + + +@pytest.mark.asyncio +async def test_http_bridge_reader_unexpected_processing_error_fails_pending_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + request_state = proxy_service._WebSocketRequestState( + request_id="req-http-reader-crash", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + event_queue=asyncio.Queue(), + transport="http", + ) + await asyncio.wait_for(request_state.event_queue.put("seed"), timeout=0.1) + await asyncio.wait_for(request_state.event_queue.get(), timeout=0.1) + gate = asyncio.Semaphore(1) + await gate.acquire() + request_state.response_create_gate_acquired = True + upstream = cast( + UpstreamResponsesWebSocket, + SimpleNamespace( + receive=AsyncMock(return_value=SimpleNamespace(kind="text", text='{"type":"response.created"}')), + close=AsyncMock(), + ), + ) + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-key", None), + headers={}, + affinity=proxy_service._AffinityPolicy(key="bridge-key"), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=upstream, + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque([request_state]), + pending_lock=anyio.Lock(), + response_create_gate=gate, + queued_request_count=1, + last_used_at=time.monotonic(), + idle_ttl_seconds=120.0, + ) + + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service, "_process_http_bridge_upstream_text", AsyncMock(side_effect=RuntimeError("boom"))) + write_request_log = AsyncMock() + monkeypatch.setattr(service, "_write_request_log", write_request_log) + + await service._relay_http_bridge_upstream_messages(session) + + event_queue = request_state.event_queue + assert event_queue is not None + failed_event = await asyncio.wait_for(event_queue.get(), timeout=0.1) + assert '"code":"stream_incomplete"' in failed_event + assert "reader" in failed_event + assert await asyncio.wait_for(event_queue.get(), timeout=0.1) is None + assert session.closed is True + assert list(session.pending_requests) == [] + write_request_log.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_websocket_reader_unexpected_processing_error_fails_pending_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + request_state = proxy_service._WebSocketRequestState( + request_id="req-ws-reader-crash", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + transport="websocket", + ) + gate = asyncio.Semaphore(1) + await gate.acquire() + request_state.response_create_gate_acquired = True + pending_requests: deque[proxy_service._WebSocketRequestState] = deque([request_state]) + pending_lock = anyio.Lock() + websocket = SimpleNamespace(send_text=AsyncMock(), send_bytes=AsyncMock(), close=AsyncMock()) + upstream = cast( + UpstreamResponsesWebSocket, + SimpleNamespace( + receive=AsyncMock(return_value=SimpleNamespace(kind="text", text='{"type":"response.created"}')), + close=AsyncMock(), + ), + ) + + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service, "_process_upstream_websocket_text", AsyncMock(side_effect=RuntimeError("boom"))) + write_request_log = AsyncMock() + monkeypatch.setattr(service, "_write_request_log", write_request_log) + + await service._relay_upstream_websocket_messages( + websocket, + upstream, + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + account_id_value="acc-1", + pending_requests=pending_requests, + pending_lock=pending_lock, + client_send_lock=anyio.Lock(), + api_key=None, + upstream_control=proxy_service._WebSocketUpstreamControl(), + response_create_gate=gate, + proxy_request_budget_seconds=60.0, + stream_idle_timeout_seconds=60.0, + downstream_activity=proxy_service._DownstreamWebSocketActivity(), + ) + + websocket.send_text.assert_awaited() + terminal_payload = websocket.send_text.await_args_list[0].args[0] + assert '"code":"stream_incomplete"' in terminal_payload + assert "reader" in terminal_payload + assert list(pending_requests) == [] + write_request_log.assert_awaited_once() diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index ef61ff3a..f3d91683 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -488,10 +488,45 @@ async def get(self) -> object: class _RequestLogsRecorder: def __init__(self) -> None: self.calls: list[dict[str, object]] = [] + self.response_owner_by_id: dict[tuple[str, str | None, str | None], str] = {} + self.latest_response_by_session: dict[tuple[str, str | None], str] = {} + self.lookup_calls: list[tuple[str, str | None, str | None]] = [] + self.session_lookup_calls: list[tuple[str, str | None]] = [] async def add_log(self, **kwargs: object) -> None: self.calls.append(dict(kwargs)) + async def find_latest_account_id_for_response_id( + self, + *, + response_id: str, + api_key_id: str | None, + session_id: str | None = None, + ) -> str | None: + key = (response_id, api_key_id, session_id) + self.lookup_calls.append(key) + owner = self.response_owner_by_id.get(key) + if owner is not None: + return owner + if session_id is not None: + return self.response_owner_by_id.get((response_id, api_key_id, None)) + return None + + async def find_latest_response_id_for_session_id( + self, + *, + session_id: str, + api_key_id: str | None, + ) -> str | None: + key = (session_id, api_key_id) + self.session_lookup_calls.append(key) + response_id = self.latest_response_by_session.get(key) + if response_id is not None: + return response_id + if api_key_id is not None: + return self.latest_response_by_session.get((session_id, None)) + return None + class _RepoContext: def __init__(self, request_logs: _RequestLogsRecorder) -> None: @@ -3038,6 +3073,19 @@ def test_sticky_key_from_session_header_accepts_aliases_in_priority_order(): ) +def test_owner_lookup_session_id_from_headers_prefers_turn_state_then_session_aliases(): + assert proxy_service._owner_lookup_session_id_from_headers({"x-codex-turn-state": "turn_1"}) == "turn_1" + assert ( + proxy_service._owner_lookup_session_id_from_headers( + {"x-codex-turn-state": "turn_1", "session_id": "sid_1"} + ) + == "turn_1" + ) + assert proxy_service._owner_lookup_session_id_from_headers({"x-codex-session-id": "sid_2"}) == "sid_2" + assert proxy_service._owner_lookup_session_id_from_headers({"x-codex-conversation-id": "sid_3"}) == "sid_3" + assert proxy_service._owner_lookup_session_id_from_headers({}) is None + + def test_sticky_key_for_responses_request_derives_prompt_cache_before_codex_session_return(): payload = ResponsesRequest.model_validate( { @@ -3364,6 +3412,48 @@ async def fake_stream( assert captured["override"] == "websocket" +@pytest.mark.asyncio +async def test_service_stream_responses_does_not_infer_previous_response_id_from_session_scope(monkeypatch): + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + request_logs = _RequestLogsRecorder() + request_logs.latest_response_by_session[("turn_stream_scope", None)] = "resp_latest_scope" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account = _make_account("acc_stream_no_session_infer") + captured: dict[str, str | None] = {} + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + ) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account)) + monkeypatch.setattr(service, "_settle_stream_api_key_usage", AsyncMock(return_value=True)) + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False): + del headers, access_token, account_id, base_url, raise_for_status + captured["previous_response_id"] = payload.previous_response_id + yield 'data: {"type":"response.completed","response":{"id":"resp_stream_scope"}}\n\n' + + monkeypatch.setattr(proxy_service, "core_stream_responses", fake_stream) + + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "stream": True, + } + ) + + chunks = [chunk async for chunk in service.stream_responses(payload, {"session_id": "turn_stream_scope"})] + + assert chunks + assert captured["previous_response_id"] is None + assert request_logs.session_lookup_calls == [] + + @pytest.mark.asyncio async def test_compact_responses_logs_service_tier_trace_and_generates_request_id(monkeypatch, caplog): settings = _make_proxy_settings(log_proxy_service_tier_trace=True) @@ -3415,6 +3505,47 @@ async def fake_compact(payload, headers, access_token, account_id): assert request_logs.calls[0]["transport"] == "http" +@pytest.mark.asyncio +async def test_compact_responses_does_not_infer_previous_response_id_from_session_scope(monkeypatch): + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + request_logs = _RequestLogsRecorder() + request_logs.latest_response_by_session[("turn_compact_scope", None)] = "resp_latest_scope" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account = _make_account("acc_compact_no_session_infer") + captured: dict[str, str | None] = {} + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + ) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account)) + monkeypatch.setattr(service, "_settle_compact_api_key_usage", AsyncMock()) + + async def fake_compact(payload, headers, access_token, account_id): + del headers, access_token, account_id + captured["previous_response_id"] = getattr(payload, "previous_response_id", None) + return OpenAIResponsePayload.model_validate({"output": []}) + + monkeypatch.setattr(proxy_service, "core_compact_responses", fake_compact) + + payload = ResponsesCompactRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "summarize", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + } + ) + + result = await service.compact_responses(payload, {"session_id": "turn_compact_scope"}) + + assert result.model_extra == {"output": []} + assert captured["previous_response_id"] is None + assert request_logs.session_lookup_calls == [] + + @pytest.mark.asyncio async def test_stream_responses_propagates_selection_error_code(monkeypatch): settings = _make_proxy_settings(log_proxy_service_tier_trace=False) @@ -3866,6 +3997,55 @@ async def test_connect_proxy_websocket_surfaces_refresh_transport_error(monkeypa assert request_logs.calls[0]["transport"] == "websocket" +@pytest.mark.asyncio +async def test_select_websocket_connect_account_requires_preferred_account_for_previous_response(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_owner_mismatch", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + ) + selected_account = _make_account("acc_other") + emit_connect_failure = AsyncMock() + + monkeypatch.setattr( + service, + "_select_account_with_budget_compatible", + AsyncMock(return_value=AccountSelection(account=selected_account, error_message=None)), + ) + monkeypatch.setattr(service, "_emit_websocket_connect_failure", emit_connect_failure) + + result = await service._select_websocket_connect_account( + 10_000.0, + sticky_key=None, + sticky_kind=None, + prefer_earlier_reset=False, + routing_strategy="usage_weighted", + model="gpt-5.1", + request_state=request_state, + api_key=None, + client_send_lock=anyio.Lock(), + websocket=cast(WebSocket, SimpleNamespace()), + reallocate_sticky=False, + sticky_max_age_seconds=None, + exclude_account_ids=set(), + preferred_account_id="acc_owner", + require_preferred_account=True, + ) + + assert result is None + emit_connect_failure.assert_awaited_once() + call = emit_connect_failure.await_args + assert call is not None + assert call.kwargs["status_code"] == 502 + assert call.kwargs["error_code"] == "upstream_unavailable" + assert call.kwargs["account_id"] == "acc_owner" + + @pytest.mark.asyncio async def test_connect_proxy_websocket_surfaces_forced_refresh_transport_error(monkeypatch): request_logs = _RequestLogsRecorder() @@ -4176,6 +4356,57 @@ class Settings: release_usage.assert_awaited_once_with(reservation) +@pytest.mark.asyncio +async def test_prepare_websocket_response_create_request_does_not_infer_previous_response_id_from_session_scope( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + request_logs.latest_response_by_session[("turn_ws_scope", None)] = "resp_latest_scope" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + reserve_usage = AsyncMock(return_value=None) + api_key = ApiKeyData( + id="key_ws_no_session_infer", + name="ws-no-infer", + key_prefix="sk-ws-no-infer", + allowed_models=["gpt-5.1"], + enforced_model=None, + enforced_reasoning_effort=None, + enforced_service_tier=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + + class Settings: + log_proxy_request_payload = False + log_proxy_request_shape = False + log_proxy_request_shape_raw_cache_key = False + log_proxy_service_tier_trace = False + openai_prompt_cache_key_derivation_enabled = True + + monkeypatch.setattr(proxy_service, "get_settings", lambda: Settings()) + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_usage) + monkeypatch.setattr(service, "_refresh_websocket_api_key_policy", AsyncMock(return_value=api_key)) + + prepared = await service._prepare_websocket_response_create_request( + { + "type": "response.create", + "model": "gpt-5.1", + "input": "hello", + }, + headers={"session_id": "turn_ws_scope", "x-codex-turn-state": "turn_ws_scope"}, + codex_session_affinity=False, + openai_cache_affinity=True, + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=300, + api_key=api_key, + ) + + assert prepared.request_state.previous_response_id is None + assert request_logs.session_lookup_calls == [] + + def test_slim_response_create_payload_rewrites_top_level_historical_input_image(): payload: dict[str, JsonValue] = { "type": "response.create", @@ -4839,99 +5070,941 @@ async def fake_handle_stream_error(self, account, error, code): assert json.loads(first_upstream.sent_text[0]) == json.loads(second_upstream.sent_text[0]) -def test_maybe_rewrite_websocket_previous_response_not_found_rewrites_response_failed_event(): - request_state = proxy_service._WebSocketRequestState( - request_id="ws_req_prev_nf", - model="gpt-5.1", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=0.0, - response_id="resp_prev_nf", - previous_response_id="resp_prev_anchor", - ) - original_payload: dict[str, JsonValue] = { - "type": "response.failed", - "response": { - "id": "resp_prev_nf", - "status": "failed", - "error": { - "type": "invalid_request_error", - "code": "previous_response_not_found", - "message": "Previous response with id 'resp_prev_anchor' not found.", - "param": "previous_response_id", - }, - }, - } - original_text = json.dumps(original_payload, separators=(",", ":")) - original_event = parse_sse_event(f"data: {original_text}\n\n") - assert original_event is not None - original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) - upstream_control = proxy_service._WebSocketUpstreamControl() +@pytest.mark.asyncio +async def test_proxy_responses_websocket_replays_precreated_request_after_upstream_close_race( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + connect_calls: list[dict[str, object]] = [] + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 - _, rewritten_payload, rewritten_event_type, rewritten_text = ( - proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( - request_state=request_state, - event=original_event, - payload=original_payload, - event_type=original_event_type, - upstream_control=upstream_control, - original_text=original_text, - ) - ) + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) - assert upstream_control.reconnect_requested is True - assert rewritten_event_type == "response.failed" - assert rewritten_payload is not None - response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) - error_payload = cast(dict[str, JsonValue], response_payload.get("error")) - assert error_payload["code"] == "stream_incomplete" - assert error_payload["message"] == "Upstream websocket closed before response.completed" - assert "previous_response_not_found" not in rewritten_text + class _FakeDownstreamWebSocket: + def __init__(self, first_request_text: str, second_request_text: str) -> None: + self._first_request_text = first_request_text + self._second_request_text = second_request_text + self._step = 0 + self._first_completed = asyncio.Event() + self._done = asyncio.Event() + self.sent_text: list[str] = [] + async def receive(self) -> dict[str, object]: + if self._step == 0: + self._step = 1 + return {"type": "websocket.receive", "text": self._first_request_text} + if self._step == 1: + await self._first_completed.wait() + self._step = 2 + return {"type": "websocket.receive", "text": self._second_request_text} + if self._step == 2: + await self._done.wait() + self._step = 3 + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} -def test_maybe_rewrite_websocket_previous_response_invalid_request_error_rewrites_when_message_is_not_found(): - request_state = proxy_service._WebSocketRequestState( - request_id="ws_req_prev_invalid", - model="gpt-5.1", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=0.0, - response_id="resp_prev_invalid", - previous_response_id="resp_prev_anchor", - ) - original_payload: dict[str, JsonValue] = { - "type": "error", - "status": 400, - "error": { - "type": "invalid_request_error", - "code": "invalid_request_error", - "message": "Previous response with id 'resp_prev_anchor' not found.", - "param": "previous_response_id", - }, - } - original_text = json.dumps(original_payload, separators=(",", ":")) - original_event = parse_sse_event(f"data: {original_text}\n\n") - assert original_event is not None - original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) - upstream_control = proxy_service._WebSocketUpstreamControl() + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + if payload.get("type") == "response.completed": + response_payload = payload.get("response") or {} + if response_payload.get("id") == "resp_ws_race_first": + self._first_completed.set() + if response_payload.get("id") == "resp_ws_race_second": + self._done.set() + if payload.get("type") in {"response.failed", "error"}: + self._done.set() - _, rewritten_payload, rewritten_event_type, rewritten_text = ( - proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( - request_state=request_state, - event=original_event, - payload=original_payload, - event_type=original_event_type, - upstream_control=upstream_control, - original_text=original_text, - ) - ) + async def send_bytes(self, _data: bytes) -> None: + return None - assert upstream_control.reconnect_requested is True - assert rewritten_event_type == "response.failed" - assert rewritten_payload is not None - response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self._done.set() + + class _RaceUpstreamWebSocket: + def __init__(self, messages: list[SimpleNamespace], *, close_delay_seconds: float = 0.0) -> None: + self.sent_text: list[str] = [] + self.closed = False + self._messages: asyncio.Queue[SimpleNamespace] = asyncio.Queue() + for message in messages: + self._messages.put_nowait(message) + self._close_delay_seconds = close_delay_seconds + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def receive(self) -> SimpleNamespace: + message = await self._messages.get() + if message.kind == "close" and self._close_delay_seconds > 0: + await asyncio.sleep(self._close_delay_seconds) + return message + + async def close(self) -> None: + self.closed = True + + first_upstream = _RaceUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_race_first", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_race_first", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace(kind="close", text=None, data=None, close_code=1001, error=None), + ], + close_delay_seconds=0.05, + ) + second_upstream = _RaceUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_race_second", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_race_second", + "status": "completed", + "usage": {"input_tokens": 2, "output_tokens": 2, "total_tokens": 4}, + }, + }, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + ] + ) + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + request_state, + api_key, + client_send_lock, + websocket, + ) + connect_calls.append({"model": model, "reallocate_sticky": reallocate_sticky}) + if len(connect_calls) == 1: + return _make_account("acc_ws_race_1"), first_upstream + return _make_account("acc_ws_race_2"), second_upstream + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + first_request = { + "type": "response.create", + "model": "gpt-5.4-mini", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "first"}]}], + "stream": True, + } + second_request = { + "type": "response.create", + "model": "gpt-5.4-mini", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "second"}]}], + "stream": True, + } + downstream = _FakeDownstreamWebSocket( + json.dumps(first_request, separators=(",", ":")), + json.dumps(second_request, separators=(",", ":")), + ) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {"x-codex-turn-state": "turn_race_ws"}, + codex_session_affinity=True, + openai_cache_affinity=True, + api_key=None, + ) + + emitted_events = [json.loads(event) for event in downstream.sent_text] + assert [event["type"] for event in emitted_events] == [ + "response.created", + "response.completed", + "response.created", + "response.completed", + ] + assert [event["response"]["id"] for event in emitted_events if "response" in event] == [ + "resp_ws_race_first", + "resp_ws_race_first", + "resp_ws_race_second", + "resp_ws_race_second", + ] + assert len(connect_calls) == 2 + assert connect_calls[0]["reallocate_sticky"] is False + assert connect_calls[1]["reallocate_sticky"] is False + assert len(second_upstream.sent_text) == 1 + assert len(first_upstream.sent_text) >= 1 + assert json.loads(first_upstream.sent_text[-1]) == json.loads(second_upstream.sent_text[0]) + + +@pytest.mark.asyncio +async def test_proxy_responses_websocket_prefers_previous_response_owner_from_request_logs(monkeypatch): + request_logs = _RequestLogsRecorder() + request_logs.response_owner_by_id[("resp_prev_owner", None, "sid_owner")] = "acc_owner_prev" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + + class _FakeDownstreamWebSocket: + def __init__(self, request_text: str) -> None: + self._request_text = request_text + self._request_sent = False + self._disconnect_sent = False + self._done = asyncio.Event() + self.sent_text: list[str] = [] + + async def receive(self) -> dict[str, object]: + if not self._request_sent: + self._request_sent = True + return {"type": "websocket.receive", "text": self._request_text} + if not self._disconnect_sent: + await self._done.wait() + self._disconnect_sent = True + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + if payload.get("type") in {"response.completed", "response.failed", "error"}: + self._done.set() + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self._done.set() + + class _FakeUpstreamWebSocket: + def __init__(self, messages: list[SimpleNamespace]) -> None: + self.sent_text: list[str] = [] + self.closed = False + self._messages: asyncio.Queue[SimpleNamespace] = asyncio.Queue() + for message in messages: + self._messages.put_nowait(message) + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def receive(self) -> SimpleNamespace: + return await self._messages.get() + + async def close(self) -> None: + self.closed = True + + upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_owner_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_owner_retry", "status": "completed"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + ] + ) + captured_preferred_accounts: list[str | None] = [] + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + api_key, + client_send_lock, + websocket, + ) + captured_preferred_accounts.append(request_state.preferred_account_id) + return _make_account("acc_selected_any"), upstream + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "previous_response_id": "resp_prev_owner", + "stream": True, + } + downstream = _FakeDownstreamWebSocket(json.dumps(request_payload, separators=(",", ":"))) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {"session_id": "sid_owner"}, + codex_session_affinity=False, + openai_cache_affinity=False, + api_key=None, + ) + + assert captured_preferred_accounts == ["acc_owner_prev"] + assert request_logs.lookup_calls == [("resp_prev_owner", None, "sid_owner")] + emitted_events = [json.loads(event) for event in downstream.sent_text] + assert [event["type"] for event in emitted_events] == ["response.created", "response.completed"] + + +@pytest.mark.asyncio +async def test_proxy_responses_websocket_uses_turn_state_as_owner_lookup_session_scope(monkeypatch): + request_logs = _RequestLogsRecorder() + request_logs.response_owner_by_id[("resp_prev_owner", None, "turn_scope_owner")] = "acc_owner_prev" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + + class _FakeDownstreamWebSocket: + def __init__(self, request_text: str) -> None: + self._request_text = request_text + self._request_sent = False + self._disconnect_sent = False + self._done = asyncio.Event() + self.sent_text: list[str] = [] + + async def receive(self) -> dict[str, object]: + if not self._request_sent: + self._request_sent = True + return {"type": "websocket.receive", "text": self._request_text} + if not self._disconnect_sent: + await self._done.wait() + self._disconnect_sent = True + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + if payload.get("type") in {"response.completed", "response.failed", "error"}: + self._done.set() + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self._done.set() + + class _FakeUpstreamWebSocket: + def __init__(self, messages: list[SimpleNamespace]) -> None: + self.sent_text: list[str] = [] + self.closed = False + self._messages: asyncio.Queue[SimpleNamespace] = asyncio.Queue() + for message in messages: + self._messages.put_nowait(message) + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def receive(self) -> SimpleNamespace: + return await self._messages.get() + + async def close(self) -> None: + self.closed = True + + upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_owner_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_owner_retry", "status": "completed"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + ] + ) + captured_preferred_accounts: list[str | None] = [] + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + api_key, + client_send_lock, + websocket, + ) + captured_preferred_accounts.append(request_state.preferred_account_id) + return _make_account("acc_selected_any"), upstream + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "previous_response_id": "resp_prev_owner", + "stream": True, + } + downstream = _FakeDownstreamWebSocket(json.dumps(request_payload, separators=(",", ":"))) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {"x-codex-turn-state": "turn_scope_owner"}, + codex_session_affinity=False, + openai_cache_affinity=False, + api_key=None, + ) + + assert captured_preferred_accounts == ["acc_owner_prev"] + assert request_logs.lookup_calls == [("resp_prev_owner", None, "turn_scope_owner")] + emitted_events = [json.loads(event) for event in downstream.sent_text] + assert [event["type"] for event in emitted_events] == ["response.created", "response.completed"] + + +@pytest.mark.asyncio +async def test_proxy_responses_websocket_prefers_turn_state_over_session_for_owner_lookup_scope(monkeypatch): + request_logs = _RequestLogsRecorder() + request_logs.response_owner_by_id[("resp_prev_owner", None, "turn_scope_owner")] = "acc_owner_prev" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + + class _FakeDownstreamWebSocket: + def __init__(self, request_text: str) -> None: + self._request_text = request_text + self._request_sent = False + self._disconnect_sent = False + self._done = asyncio.Event() + self.sent_text: list[str] = [] + + async def receive(self) -> dict[str, object]: + if not self._request_sent: + self._request_sent = True + return {"type": "websocket.receive", "text": self._request_text} + if not self._disconnect_sent: + await self._done.wait() + self._disconnect_sent = True + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + if payload.get("type") in {"response.completed", "response.failed", "error"}: + self._done.set() + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self._done.set() + + class _FakeUpstreamWebSocket: + def __init__(self, messages: list[SimpleNamespace]) -> None: + self.sent_text: list[str] = [] + self.closed = False + self._messages: asyncio.Queue[SimpleNamespace] = asyncio.Queue() + for message in messages: + self._messages.put_nowait(message) + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def receive(self) -> SimpleNamespace: + return await self._messages.get() + + async def close(self) -> None: + self.closed = True + + upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_owner_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + SimpleNamespace( + kind="text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_owner_retry", "status": "completed"}}, + separators=(",", ":"), + ), + data=None, + close_code=None, + error=None, + ), + ] + ) + captured_preferred_accounts: list[str | None] = [] + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + api_key, + client_send_lock, + websocket, + ) + captured_preferred_accounts.append(request_state.preferred_account_id) + return _make_account("acc_selected_any"), upstream + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "previous_response_id": "resp_prev_owner", + "stream": True, + } + downstream = _FakeDownstreamWebSocket(json.dumps(request_payload, separators=(",", ":"))) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {"session_id": "shared_session_owner", "x-codex-turn-state": "turn_scope_owner"}, + codex_session_affinity=False, + openai_cache_affinity=False, + api_key=None, + ) + + assert captured_preferred_accounts == ["acc_owner_prev"] + assert request_logs.lookup_calls == [("resp_prev_owner", None, "turn_scope_owner")] + emitted_events = [json.loads(event) for event in downstream.sent_text] + assert [event["type"] for event in emitted_events] == ["response.created", "response.completed"] + + +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_rechecks_same_scope_after_initial_miss(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + clock = {"value": 100.0} + monkeypatch.setattr(proxy_service.time, "monotonic", lambda: clock["value"]) + + owner_1 = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_missing", + api_key=None, + session_id="req_scope_1", + ) + clock["value"] = 102.0 + owner_2 = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_missing", + api_key=None, + session_id="req_scope_1", + ) + request_logs.response_owner_by_id[("resp_prev_missing", None, None)] = "acc_owner_after_commit" + clock["value"] = 103.0 + owner_3 = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_missing", + api_key=None, + session_id="req_scope_1", + ) + + assert owner_1 is None + assert owner_2 is None + assert owner_3 == "acc_owner_after_commit" + assert request_logs.lookup_calls == [ + ("resp_prev_missing", None, "req_scope_1"), + ("resp_prev_missing", None, "req_scope_1"), + ("resp_prev_missing", None, "req_scope_1"), + ] + + +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_miss_does_not_evict_known_owner(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + clock = {"value": 100.0} + monkeypatch.setattr(proxy_service.time, "monotonic", lambda: clock["value"]) + api_key = ApiKeyData( + id="key_shared", + name="shared-key", + key_prefix="sk-shared", + allowed_models=None, + enforced_model=None, + enforced_reasoning_effort=None, + enforced_service_tier=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + + service._remember_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key_id=api_key.id, + account_id="acc_owner", + ) + service._remember_websocket_previous_response_owner_miss( + previous_response_id="resp_prev_shared", + api_key_id=api_key.id, + request_cache_scope="req_terminal_b", + ) + + owner = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key=api_key, + session_id="req_terminal_a", + ) + + assert owner == "acc_owner" + assert request_logs.lookup_calls == [] + + +def test_remember_websocket_previous_response_owner_eviction_keeps_latest_entries(): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + limit = proxy_service._WEBSOCKET_PREVIOUS_RESPONSE_ACCOUNT_CACHE_LIMIT + + for index in range(limit + 1): + service._remember_websocket_previous_response_owner( + previous_response_id=f"resp_prev_{index}", + api_key_id="key_1", + account_id=f"acc_{index}", + ) + + assert len(service._websocket_previous_response_account_index) == limit + assert ("resp_prev_0", "key_1", None) not in service._websocket_previous_response_account_index + assert ("resp_prev_1", "key_1", None) in service._websocket_previous_response_account_index + assert ("resp_prev_4096", "key_1", None) in service._websocket_previous_response_account_index + + +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_retries_precreated_previous_response_not_found(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_prev_not_found_retry") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "previous_response_id": "resp_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + } + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_not_found_retry", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps(request_payload, separators=(",", ":")), + previous_response_id="resp_anchor", + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + upstream_payload = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + } + upstream_text = json.dumps(upstream_payload, separators=(",", ":")) + + downstream_text = await service._process_upstream_websocket_text( + upstream_text, + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"code":"stream_incomplete"' in downstream_text + finalize_request_state.assert_not_awaited() + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is True + assert upstream_control.replay_request_state is pending_request + assert pending_request.replay_count == 1 + assert list(pending_requests) == [] + + +def test_maybe_rewrite_websocket_previous_response_not_found_rewrites_response_failed_event(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_nf", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_nf", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "response.failed", + "response": { + "id": "resp_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is True + assert rewritten_event_type == "response.failed" + assert rewritten_payload is not None + response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) + error_payload = cast(dict[str, JsonValue], response_payload.get("error")) + assert error_payload["code"] == "stream_incomplete" + assert error_payload["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in rewritten_text + + +def test_maybe_rewrite_websocket_previous_response_invalid_request_error_rewrites_when_message_is_not_found(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_invalid", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_prev_invalid", + previous_response_id="resp_prev_anchor", + ) + original_payload: dict[str, JsonValue] = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Previous response with id 'resp_prev_anchor' not found.", + "param": "previous_response_id", + }, + } + original_text = json.dumps(original_payload, separators=(",", ":")) + original_event = parse_sse_event(f"data: {original_text}\n\n") + assert original_event is not None + original_event_type = proxy_service._event_type_from_payload(original_event, original_payload) + upstream_control = proxy_service._WebSocketUpstreamControl() + + _, rewritten_payload, rewritten_event_type, rewritten_text = ( + proxy_service._maybe_rewrite_websocket_previous_response_not_found_event( + request_state=request_state, + event=original_event, + payload=original_payload, + event_type=original_event_type, + upstream_control=upstream_control, + original_text=original_text, + ) + ) + + assert upstream_control.reconnect_requested is True + assert rewritten_event_type == "response.failed" + assert rewritten_payload is not None + response_payload = cast(dict[str, JsonValue], rewritten_payload.get("response")) error_payload = cast(dict[str, JsonValue], response_payload.get("error")) assert error_payload["code"] == "stream_incomplete" assert error_payload["message"] == "Upstream websocket closed before response.completed" @@ -5054,6 +6127,41 @@ def test_http_bridge_should_attempt_local_previous_response_recovery_invalid_req assert proxy_service._http_bridge_should_attempt_local_previous_response_recovery(non_recoverable_error) is False +def test_http_bridge_should_rollover_after_context_overflow(): + context_overflow_error = proxy_module.ProxyResponseError( + 400, + { + "error": { + "type": "invalid_request_error", + "code": "context_length_exceeded", + "message": "Your input exceeds the context window of this model.", + } + }, + ) + unrelated_error = proxy_module.ProxyResponseError( + 400, + { + "error": { + "type": "invalid_request_error", + "code": "invalid_request_error", + "message": "Invalid request payload", + "param": "input", + } + }, + ) + hard_key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "turn-hard", None) + soft_key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "cache-soft", None) + + assert proxy_service._http_bridge_should_rollover_after_context_overflow(context_overflow_error) is True + assert ( + proxy_service._http_bridge_should_rollover_after_context_overflow(context_overflow_error, key=hard_key) is False + ) + assert ( + proxy_service._http_bridge_should_rollover_after_context_overflow(context_overflow_error, key=soft_key) is True + ) + assert proxy_service._http_bridge_should_rollover_after_context_overflow(unrelated_error) is False + + def test_maybe_rewrite_websocket_previous_response_not_found_leaves_non_previous_request_unchanged(): request_state = proxy_service._WebSocketRequestState( request_id="ws_req_plain", diff --git a/tests/unit/test_request_logs_repository.py b/tests/unit/test_request_logs_repository.py index 6a535070..ae6f57aa 100644 --- a/tests/unit/test_request_logs_repository.py +++ b/tests/unit/test_request_logs_repository.py @@ -1,5 +1,8 @@ from __future__ import annotations +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest from sqlalchemy.exc import ResourceClosedError @@ -34,3 +37,125 @@ async def _refresh_failure(_: object) -> None: assert log.request_id == "req" assert log.cost_usd is not None + + +@pytest.mark.asyncio +async def test_find_latest_account_id_for_response_id_prefers_session_then_falls_back_to_api_key_scope() -> None: + session = AsyncMock() + repo = RequestLogsRepository(session) + executed_sql: list[str] = [] + returned_values = iter( + [ + "acc_latest", + "acc_scoped", + "acc_session", + None, + "acc_scoped", + None, + ] + ) + + async def _execute(statement): + executed_sql.append(str(statement)) + value = next(returned_values) + return SimpleNamespace(scalar_one_or_none=lambda: value) + + session.execute.side_effect = _execute + + owner_any = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id=None, + ) + owner_scoped = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id="api_key_1", + ) + owner_session = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id="api_key_1", + session_id="sid_terminal_a", + ) + owner_session_fallback = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id="api_key_1", + session_id="sid_terminal_b", + ) + owner_missing = await repo.find_latest_account_id_for_response_id( + response_id="resp_missing_owner", + api_key_id=None, + ) + + assert owner_any == "acc_latest" + assert owner_scoped == "acc_scoped" + assert owner_session == "acc_session" + assert owner_session_fallback == "acc_scoped" + assert owner_missing is None + assert "request_logs.api_key_id = :api_key_id_1" not in executed_sql[0] + assert "request_logs.api_key_id = :api_key_id_1" in executed_sql[1] + assert "request_logs.session_id = :session_id_1" in executed_sql[2] + assert "request_logs.session_id = :session_id_1" in executed_sql[3] + assert "request_logs.session_id = :session_id_1" not in executed_sql[4] + + +@pytest.mark.asyncio +async def test_find_latest_account_id_for_response_id_ignores_blank_response_id() -> None: + session = AsyncMock() + repo = RequestLogsRepository(session) + + owner = await repo.find_latest_account_id_for_response_id( + response_id=" ", + api_key_id="api_key_1", + session_id="sid_terminal_a", + ) + + assert owner is None + session.execute.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_find_latest_account_id_for_response_id_ignores_blank_session_id_scope() -> None: + session = AsyncMock() + repo = RequestLogsRepository(session) + executed_sql: list[str] = [] + + async def _execute(statement): + executed_sql.append(str(statement)) + return SimpleNamespace(scalar_one_or_none=lambda: "acc_scoped") + + session.execute.side_effect = _execute + + owner = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id="api_key_1", + session_id=" ", + ) + + assert owner == "acc_scoped" + assert len(executed_sql) == 1 + assert "request_logs.session_id = :session_id_1" not in executed_sql[0] + + +@pytest.mark.asyncio +async def test_find_latest_account_id_for_response_id_falls_back_when_session_scope_owner_is_blank() -> None: + session = AsyncMock() + repo = RequestLogsRepository(session) + executed_sql: list[str] = [] + returned_values = iter([" ", "acc_fallback"]) + + async def _execute(statement): + executed_sql.append(str(statement)) + return SimpleNamespace(scalar_one_or_none=lambda: next(returned_values)) + + session.execute.side_effect = _execute + + owner = await repo.find_latest_account_id_for_response_id( + response_id="resp_lookup_owner", + api_key_id="api_key_1", + session_id="sid_terminal_a", + ) + + assert owner == "acc_fallback" + assert len(executed_sql) == 2 + assert "request_logs.session_id = :session_id_1" in executed_sql[0] + assert "request_logs.session_id = :session_id_1" not in executed_sql[1] + diff --git a/tests/unit/test_usage_updater.py b/tests/unit/test_usage_updater.py index 9a9067dd..86bab10c 100644 --- a/tests/unit/test_usage_updater.py +++ b/tests/unit/test_usage_updater.py @@ -10,6 +10,7 @@ from app.core.auth.refresh import RefreshError from app.core.crypto import TokenEncryptor +from app.core.usage import refresh_scheduler as refresh_scheduler_module from app.core.usage.models import UsagePayload from app.db.models import Account, AccountStatus, UsageHistory from app.modules.usage import updater as usage_updater_module @@ -45,6 +46,77 @@ async def factory(): await first +@pytest.mark.asyncio +async def test_usage_refresh_singleflight_cancel_all_cancels_inflight_task() -> None: + started = asyncio.Event() + cancelled = asyncio.Event() + + async def factory(): + started.set() + try: + await asyncio.Future() + except asyncio.CancelledError: + cancelled.set() + raise + + task = asyncio.create_task(usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT.run("acc_cancel", factory)) + await started.wait() + + await usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT.cancel_all() + + with pytest.raises(asyncio.CancelledError): + await task + assert cancelled.is_set() + assert usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT._inflight == {} + + +@pytest.mark.asyncio +async def test_usage_refresh_scheduler_stop_cancels_inflight_singleflight(monkeypatch: pytest.MonkeyPatch) -> None: + scheduler = refresh_scheduler_module.UsageRefreshScheduler(interval_seconds=60, enabled=True) + run_loop_task = asyncio.create_task(asyncio.sleep(3600)) + scheduler._task = run_loop_task + cancel_all = asyncio.Event() + + async def _cancel_all() -> None: + cancel_all.set() + + monkeypatch.setattr( + refresh_scheduler_module.usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT, + "cancel_all", + _cancel_all, + ) + + await scheduler.stop() + + assert cancel_all.is_set() + assert scheduler._task is None + + +@pytest.mark.asyncio +async def test_usage_refresh_scheduler_stop_cancels_inflight_singleflight_without_scheduler_task() -> None: + scheduler = refresh_scheduler_module.UsageRefreshScheduler(interval_seconds=60, enabled=True) + started = asyncio.Event() + cancelled = asyncio.Event() + + async def factory(): + started.set() + try: + await asyncio.Future() + except asyncio.CancelledError: + cancelled.set() + raise + + task = asyncio.create_task(usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT.run("acc_stop_no_task", factory)) + await started.wait() + + await scheduler.stop() + + with pytest.raises(asyncio.CancelledError): + await task + assert cancelled.is_set() + assert usage_updater_module._USAGE_REFRESH_SINGLEFLIGHT._inflight == {} + + @dataclass(frozen=True, slots=True) class UsageEntry: account_id: str From fd7e73ffe3eb47c18c3b1ac453a5b61764e03c2a Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 10:34:18 +0200 Subject: [PATCH 04/18] test(proxy): fix typing in bridge shutdown regression coverage --- tests/unit/test_proxy_http_bridge.py | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index 758d0616..ed30f66e 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -11,6 +11,7 @@ import anyio import pytest +from fastapi import WebSocket from app.core.clients.proxy import ProxyResponseError from app.core.clients.proxy_websocket import UpstreamResponsesWebSocket @@ -3464,7 +3465,9 @@ async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_end assert resolved is created_session claim_durable.assert_awaited_once() - assert claim_durable.await_args.kwargs["allow_takeover"] is True + await_args = claim_durable.await_args + assert await_args is not None + assert await_args.kwargs["allow_takeover"] is True service._ring_membership.resolve_endpoint.assert_awaited_once_with("instance-b") @@ -3529,7 +3532,9 @@ async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_end assert resolved is created_session claim_durable.assert_awaited_once() - assert claim_durable.await_args.kwargs["allow_takeover"] is True + await_args = claim_durable.await_args + assert await_args is not None + assert await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio @@ -3592,7 +3597,9 @@ async def test_get_or_create_http_bridge_session_recovers_locally_without_anchor assert resolved is created_session claim_durable.assert_awaited_once() - assert claim_durable.await_args.kwargs["allow_takeover"] is True + await_args = claim_durable.await_args + assert await_args is not None + assert await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio @@ -3655,7 +3662,9 @@ async def test_get_or_create_http_bridge_session_prompt_cache_takes_over_stale_s assert resolved is created_session claim_durable.assert_awaited_once() - assert claim_durable.await_args.kwargs["allow_takeover"] is True + await_args = claim_durable.await_args + assert await_args is not None + assert await_args.kwargs["allow_takeover"] is True @pytest.mark.asyncio @@ -3900,7 +3909,7 @@ async def test_close_all_http_bridge_sessions_fails_capacity_waiters_instead_of_ account=cast(Any, SimpleNamespace(id="acc-existing", status=AccountStatus.ACTIVE)), upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), upstream_control=proxy_service._WebSocketUpstreamControl(), - pending_requests=deque([object()]), + pending_requests=cast(deque[proxy_service._WebSocketRequestState], deque()), pending_lock=anyio.Lock(), response_create_gate=asyncio.Semaphore(1), queued_request_count=1, @@ -4505,8 +4514,10 @@ async def test_http_bridge_reader_unexpected_processing_error_fails_pending_requ event_queue=asyncio.Queue(), transport="http", ) - await asyncio.wait_for(request_state.event_queue.put("seed"), timeout=0.1) - await asyncio.wait_for(request_state.event_queue.get(), timeout=0.1) + event_queue = request_state.event_queue + assert event_queue is not None + await asyncio.wait_for(event_queue.put("seed"), timeout=0.1) + await asyncio.wait_for(event_queue.get(), timeout=0.1) gate = asyncio.Semaphore(1) await gate.acquire() request_state.response_create_gate_acquired = True @@ -4540,9 +4551,8 @@ async def test_http_bridge_reader_unexpected_processing_error_fails_pending_requ await service._relay_http_bridge_upstream_messages(session) - event_queue = request_state.event_queue - assert event_queue is not None failed_event = await asyncio.wait_for(event_queue.get(), timeout=0.1) + assert failed_event is not None assert '"code":"stream_incomplete"' in failed_event assert "reader" in failed_event assert await asyncio.wait_for(event_queue.get(), timeout=0.1) is None @@ -4570,7 +4580,11 @@ async def test_websocket_reader_unexpected_processing_error_fails_pending_reques request_state.response_create_gate_acquired = True pending_requests: deque[proxy_service._WebSocketRequestState] = deque([request_state]) pending_lock = anyio.Lock() - websocket = SimpleNamespace(send_text=AsyncMock(), send_bytes=AsyncMock(), close=AsyncMock()) + send_text = AsyncMock() + websocket = cast( + WebSocket, + SimpleNamespace(send_text=send_text, send_bytes=AsyncMock(), close=AsyncMock()), + ) upstream = cast( UpstreamResponsesWebSocket, SimpleNamespace( @@ -4600,8 +4614,8 @@ async def test_websocket_reader_unexpected_processing_error_fails_pending_reques downstream_activity=proxy_service._DownstreamWebSocketActivity(), ) - websocket.send_text.assert_awaited() - terminal_payload = websocket.send_text.await_args_list[0].args[0] + send_text.assert_awaited() + terminal_payload = send_text.await_args_list[0].args[0] assert '"code":"stream_incomplete"' in terminal_payload assert "reader" in terminal_payload assert list(pending_requests) == [] From d83df7f1b4b72c6783e0e3ef878d27a5907aca1d Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 10:37:14 +0200 Subject: [PATCH 05/18] style: apply ruff formatting for bridge continuity changes --- app/db/models.py | 1 - app/modules/proxy/durable_bridge_repository.py | 8 ++------ app/modules/proxy/service.py | 16 +++++++--------- tests/unit/test_proxy_http_bridge.py | 6 +++--- tests/unit/test_proxy_utils.py | 4 +--- tests/unit/test_request_logs_repository.py | 1 - 6 files changed, 13 insertions(+), 23 deletions(-) diff --git a/app/db/models.py b/app/db/models.py index 33b286cf..c091ed53 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -104,7 +104,6 @@ class UsageHistory(Base): class AdditionalUsageHistory(Base): __tablename__ = "additional_usage_history" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) account_id: Mapped[str] = mapped_column(String, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False) quota_key: Mapped[str] = mapped_column(String, nullable=False) diff --git a/app/modules/proxy/durable_bridge_repository.py b/app/modules/proxy/durable_bridge_repository.py index a5b4dd39..390cb5cd 100644 --- a/app/modules/proxy/durable_bridge_repository.py +++ b/app/modules/proxy/durable_bridge_repository.py @@ -106,9 +106,7 @@ async def find_session_by_latest_turn_state( .where( HttpBridgeSessionRecord.latest_turn_state == turn_state, HttpBridgeSessionRecord.api_key_scope == api_key_scope, - HttpBridgeSessionRecord.state.in_( - (HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING) - ), + HttpBridgeSessionRecord.state.in_((HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING)), ) .order_by( case((HttpBridgeSessionRecord.state == HttpBridgeSessionState.ACTIVE, 0), else_=1), @@ -132,9 +130,7 @@ async def find_session_by_latest_response_id( .where( HttpBridgeSessionRecord.latest_response_id == response_id, HttpBridgeSessionRecord.api_key_scope == api_key_scope, - HttpBridgeSessionRecord.state.in_( - (HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING) - ), + HttpBridgeSessionRecord.state.in_((HttpBridgeSessionState.ACTIVE, HttpBridgeSessionState.DRAINING)), ) .order_by( case((HttpBridgeSessionRecord.state == HttpBridgeSessionState.ACTIVE, 0), else_=1), diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 87aa8d4b..a5a2ca93 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -495,8 +495,7 @@ async def _stream_via_http_bridge( if ( not live_local_session_exists and not forwards_to_active_owner - and - payload.previous_response_id is None + and payload.previous_response_id is None and bridge_session_key.strength == "hard" and durable_lookup.latest_response_id is not None and not _http_bridge_payload_looks_like_full_resend(payload) @@ -3336,12 +3335,9 @@ async def _get_or_create_http_bridge_session( ), ) elif missing_turn_state_alias and inflight_future is None and durable_lookup is None: - turn_state_scope_conflict = ( - incoming_turn_state is not None - and any( - alias == incoming_turn_state and alias_api_key != api_key_id - for alias, alias_api_key in self._http_bridge_turn_state_index - ) + turn_state_scope_conflict = incoming_turn_state is not None and any( + alias == incoming_turn_state and alias_api_key != api_key_id + for alias, alias_api_key in self._http_bridge_turn_state_index ) if turn_state_scope_conflict: continuity_error = ProxyResponseError( @@ -4685,7 +4681,9 @@ def _remember_websocket_previous_response_owner( self._websocket_previous_response_account_index.pop(cache_key, None) self._websocket_previous_response_account_index[cache_key] = account_id_value while len(self._websocket_previous_response_account_index) > _WEBSOCKET_PREVIOUS_RESPONSE_ACCOUNT_CACHE_LIMIT: - self._websocket_previous_response_account_index.pop(next(iter(self._websocket_previous_response_account_index))) + self._websocket_previous_response_account_index.pop( + next(iter(self._websocket_previous_response_account_index)) + ) def _remember_websocket_previous_response_owner_miss( self, diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index ed30f66e..4b9b3f95 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -1371,9 +1371,9 @@ def fake_prepare( idle_ttl_seconds=120.0, ) service._http_bridge_sessions[session_key] = session - service._http_bridge_turn_state_index[ - proxy_service._http_bridge_turn_state_alias_key("http_turn_live", None) - ] = session_key + service._http_bridge_turn_state_index[proxy_service._http_bridge_turn_state_alias_key("http_turn_live", None)] = ( + session_key + ) monkeypatch.setattr( proxy_service, diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index f3d91683..edd3bce2 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -3076,9 +3076,7 @@ def test_sticky_key_from_session_header_accepts_aliases_in_priority_order(): def test_owner_lookup_session_id_from_headers_prefers_turn_state_then_session_aliases(): assert proxy_service._owner_lookup_session_id_from_headers({"x-codex-turn-state": "turn_1"}) == "turn_1" assert ( - proxy_service._owner_lookup_session_id_from_headers( - {"x-codex-turn-state": "turn_1", "session_id": "sid_1"} - ) + proxy_service._owner_lookup_session_id_from_headers({"x-codex-turn-state": "turn_1", "session_id": "sid_1"}) == "turn_1" ) assert proxy_service._owner_lookup_session_id_from_headers({"x-codex-session-id": "sid_2"}) == "sid_2" diff --git a/tests/unit/test_request_logs_repository.py b/tests/unit/test_request_logs_repository.py index ae6f57aa..2fb49686 100644 --- a/tests/unit/test_request_logs_repository.py +++ b/tests/unit/test_request_logs_repository.py @@ -158,4 +158,3 @@ async def _execute(statement): assert len(executed_sql) == 2 assert "request_logs.session_id = :session_id_1" in executed_sql[0] assert "request_logs.session_id = :session_id_1" not in executed_sql[1] - From 0e8e05f8ae99b82eba9085c8d3b8eb4006c095cd Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 11:12:36 +0200 Subject: [PATCH 06/18] fix(proxy): preserve scoped previous-response ownership across bridge and retry --- app/modules/proxy/service.py | 19 +++-- tests/unit/test_proxy_http_bridge.py | 123 +++++++++++++++++++++++++++ tests/unit/test_proxy_utils.py | 39 ++++++++- 3 files changed, 173 insertions(+), 8 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index a5a2ca93..b5c08a26 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -519,6 +519,8 @@ async def _stream_via_http_bridge( api_key_reservation=api_key_reservation, request_id=request_id, ) + if downstream_turn_state is not None: + request_state.session_id = _normalize_session_id(downstream_turn_state) request_state.transport = _REQUEST_TRANSPORT_HTTP request_state.request_stage = _http_bridge_request_stage( headers=headers, @@ -545,7 +547,7 @@ async def _stream_via_http_bridge( request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( previous_response_id=request_state.previous_response_id, api_key=api_key, - session_id=_owner_lookup_session_id_from_headers(headers), + session_id=request_state.session_id, ) session_or_forward = await self._get_or_create_http_bridge_session( bridge_session_key, @@ -835,6 +837,8 @@ async def _stream_via_http_bridge( api_key_reservation=retry_api_key_reservation, request_id=request_id, ) + if downstream_turn_state is not None: + retry_request_state.session_id = _normalize_session_id(downstream_turn_state) retry_request_state.transport = _REQUEST_TRANSPORT_HTTP retry_request_state.request_stage = retry_request_stage retry_request_state.preferred_account_id = retry_preferred_account_id @@ -4714,10 +4718,11 @@ async def _resolve_websocket_previous_response_owner( cached_account_id = self._websocket_previous_response_account_index.get(cache_key) if cached_account_id is not None: return cached_account_id - if session_id_value is not None: - fallback_account_id = self._websocket_previous_response_account_index.get((response_id, api_key_id, None)) - if fallback_account_id is not None: - return fallback_account_id + fallback_account_id = ( + self._websocket_previous_response_account_index.get((response_id, api_key_id, None)) + if session_id_value is not None + else None + ) try: async with self._repo_factory() as repos: account_id = await repos.request_logs.find_latest_account_id_for_response_id( @@ -4727,9 +4732,9 @@ async def _resolve_websocket_previous_response_owner( ) except Exception: logger.warning("Previous response owner lookup failed; continuing without owner pinning", exc_info=True) - return None + return fallback_account_id if account_id is None: - return None + return fallback_account_id self._remember_websocket_previous_response_owner( previous_response_id=response_id, api_key_id=api_key_id, diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index 4b9b3f95..41567bfd 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -1828,6 +1828,7 @@ async def test_stream_via_http_bridge_resolves_previous_response_owner_from_requ event_queue=asyncio.Queue(), transport="http", previous_response_id="resp_prev_owner_lookup", + session_id="turn_http_owner", ) event_queue = request_state.event_queue assert event_queue is not None @@ -1922,6 +1923,128 @@ async def fake_get_or_create_http_bridge_session(*args: object, **kwargs: object assert captured_preferred["value"] == "acc-owner-from-logs" +@pytest.mark.asyncio +async def test_stream_via_http_bridge_uses_generated_downstream_turn_state_for_owner_scope( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "previous_response_id": "resp_prev_owner_lookup", + } + ) + request_state = proxy_service._WebSocketRequestState( + request_id="req-generated-turn-state", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_owner_lookup", + session_id="sid-shared", + ) + session = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-shared", None), + headers={"x-codex-session-id": "sid-shared"}, + affinity=proxy_service._AffinityPolicy( + key="sid-shared", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + + def fake_prepare( + _prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del api_key, api_key_reservation, request_id + return request_state, '{"type":"response.create"}' + + async def fake_stream_http_bridge_session_events( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + propagate_http_errors: bool, + downstream_turn_state: str | None, + ): + del request_state, text_data, queue_limit, propagate_http_errors, downstream_turn_state + yield 'data: {"type":"response.completed"}\n\n' + + owner_lookup = AsyncMock(return_value="acc-owner-from-turn-state") + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_resolve_websocket_previous_response_owner", owner_lookup) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", AsyncMock(return_value=session)) + monkeypatch.setattr(service, "_stream_http_bridge_session_events", fake_stream_http_bridge_session_events) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-session-id": "sid-shared"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=None, + api_key_reservation=None, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=1800.0, + max_sessions=8, + queue_limit=4, + downstream_turn_state="http_turn_generated", + ) + ] + + assert chunks == ['data: {"type":"response.completed"}\n\n'] + owner_lookup.assert_awaited_once_with( + previous_response_id="resp_prev_owner_lookup", + api_key=None, + session_id="http_turn_generated", + ) + assert request_state.session_id == "http_turn_generated" + assert request_state.preferred_account_id == "acc-owner-from-turn-state" + + @pytest.mark.asyncio async def test_http_bridge_waits_for_registration_for_hard_keys_before_startup_complete( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index edd3bce2..a242f3c8 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -5825,7 +5825,44 @@ async def test_resolve_websocket_previous_response_owner_miss_does_not_evict_kno ) assert owner == "acc_owner" - assert request_logs.lookup_calls == [] + assert request_logs.lookup_calls == [("resp_prev_shared", api_key.id, "req_terminal_a")] + + +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_prefers_scoped_lookup_over_generic_cache() -> None: + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + api_key = ApiKeyData( + id="key_shared", + name="shared-key", + key_prefix="sk-shared", + allowed_models=None, + enforced_model=None, + enforced_reasoning_effort=None, + enforced_service_tier=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + service._remember_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key_id=api_key.id, + account_id="acc_owner_generic", + ) + request_logs.response_owner_by_id[("resp_prev_shared", api_key.id, "turn_scope_a")] = "acc_owner_scoped" + + owner = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key=api_key, + session_id="turn_scope_a", + ) + + assert owner == "acc_owner_scoped" + assert request_logs.lookup_calls == [("resp_prev_shared", api_key.id, "turn_scope_a")] + assert service._websocket_previous_response_account_index[("resp_prev_shared", api_key.id, "turn_scope_a")] == ( + "acc_owner_scoped" + ) def test_remember_websocket_previous_response_owner_eviction_keeps_latest_entries(): From 157ab9431084ab4bb5f4da7963406f59f5a6757e Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 15:17:09 +0200 Subject: [PATCH 07/18] test(proxy): add regression coverage for bridged previous-response reconnect-only behavior --- app/modules/proxy/service.py | 2 +- .../integration/test_http_responses_bridge.py | 2 ++ tests/unit/test_bridge_context_blowup.py | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index b5c08a26..f75ba22a 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4126,7 +4126,7 @@ async def _submit_http_bridge_request( # history, inflating per-turn context by ~20x. raise ProxyResponseError( 502, - openai_error("bridge_owner_unreachable", str(exc) or "Upstream websocket closed"), + openai_error("upstream_unavailable", str(exc) or "Upstream websocket closed"), ) from exc async def _maybe_prewarm_http_bridge_session( diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index cc33c51c..77762a23 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -6599,6 +6599,7 @@ async def fake_connect_responses_websocket( assert second.status_code == 502 assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "bridge_owner_unreachable") assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert connect_count == 1 @pytest.mark.asyncio @@ -6708,6 +6709,7 @@ async def fake_connect_responses_websocket( assert second.status_code == 502 assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "upstream_request_timeout") assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert connect_count == 1 @pytest.mark.asyncio diff --git a/tests/unit/test_bridge_context_blowup.py b/tests/unit/test_bridge_context_blowup.py index 982c8263..57994100 100644 --- a/tests/unit/test_bridge_context_blowup.py +++ b/tests/unit/test_bridge_context_blowup.py @@ -246,6 +246,27 @@ async def test_retry_with_previous_response_id_returns_false_without_marking_err assert request_state.error_code_override != "previous_response_not_found" assert request_state.error_code_override is None + async def test_reconnect_only_recovery_with_previous_response_id_skips_resend(self, monkeypatch): + service = proxy_service.ProxyService(cast(Any, nullcontext())) + session = _make_session(closed=True) + send_text = AsyncMock() + session.upstream = cast(Any, SimpleNamespace(send_text=send_text)) + request_state = _make_request_state(previous_response_id="resp_xyz789") + + reconnect_mock = AsyncMock() + monkeypatch.setattr(service, "_reconnect_http_bridge_session", reconnect_mock) + + result = await service._retry_http_bridge_request_on_fresh_upstream( + session=session, + request_state=request_state, + text_data='{"type":"response.create","previous_response_id":"resp_xyz789"}', + send_request=False, + ) + + assert result is True + reconnect_mock.assert_awaited_once() + send_text.assert_not_awaited() + class TestContextGrowthScenarios: """Scenario tests modelling real Codex session data. From 0fee5620ec99ff9461bccb27057443993d41861e Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 19:39:02 +0200 Subject: [PATCH 08/18] fix(proxy): harden continuity fail-closed flows --- app/core/clients/proxy.py | 197 ++- app/core/metrics/prometheus.py | 16 + app/modules/proxy/service.py | 857 +++++++++++-- .../.openspec.yaml | 2 + .../design.md | 41 + .../proposal.md | 25 + .../specs/responses-api-compat/spec.md | 31 + .../tasks.md | 14 + .../observe-continuity-decisions/proposal.md | 20 + .../specs/proxy-runtime-observability/spec.md | 15 + .../observe-continuity-decisions/tasks.md | 13 + .../integration/test_http_responses_bridge.py | 24 +- tests/integration/test_proxy_responses.py | 386 ++++++ .../test_proxy_websocket_responses.py | 239 ++++ tests/unit/test_metrics.py | 4 + tests/unit/test_proxy_http_bridge.py | 409 ++++++- tests/unit/test_proxy_utils.py | 1058 ++++++++++++++++- 17 files changed, 3210 insertions(+), 141 deletions(-) create mode 100644 openspec/changes/harden-continuity-fail-closed-edges/.openspec.yaml create mode 100644 openspec/changes/harden-continuity-fail-closed-edges/design.md create mode 100644 openspec/changes/harden-continuity-fail-closed-edges/proposal.md create mode 100644 openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md create mode 100644 openspec/changes/harden-continuity-fail-closed-edges/tasks.md create mode 100644 openspec/changes/observe-continuity-decisions/proposal.md create mode 100644 openspec/changes/observe-continuity-decisions/specs/proxy-runtime-observability/spec.md create mode 100644 openspec/changes/observe-continuity-decisions/tasks.md diff --git a/app/core/clients/proxy.py b/app/core/clients/proxy.py index a60eeab2..5af48783 100644 --- a/app/core/clients/proxy.py +++ b/app/core/clients/proxy.py @@ -12,6 +12,7 @@ import socket import time from contextlib import asynccontextmanager +from copy import deepcopy from dataclasses import dataclass from typing import ( AsyncContextManager, @@ -56,6 +57,7 @@ get_circuit_breaker_for_account, ) from app.core.types import JsonObject, JsonValue +from app.core.utils.json_guards import is_json_mapping from app.core.utils.request_id import get_request_id from app.core.utils.sse import format_sse_event @@ -87,6 +89,12 @@ _IMAGE_INLINE_CHUNK_SIZE = 64 * 1024 _IMAGE_INLINE_TIMEOUT_SECONDS = 8.0 _BLOCKED_LITERAL_HOSTS = {"localhost", "localhost.localdomain"} +_UPSTREAM_RESPONSE_CREATE_WARN_BYTES = 12 * 1024 * 1024 +_UPSTREAM_RESPONSE_CREATE_MAX_BYTES = 15 * 1024 * 1024 +_RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE = ( + "[codex-lb omitted historical tool output ({bytes} bytes) to fit upstream websocket budget]" +) +_RESPONSE_CREATE_IMAGE_OMISSION_NOTICE = "[codex-lb omitted historical inline image to fit upstream websocket budget]" _UPSTREAM_TRACE_HEADER_ALLOWLIST = frozenset( { "accept", @@ -1195,7 +1203,7 @@ async def _stream_responses_via_websocket( ) -> AsyncIterator[str]: websocket_url = _to_websocket_upstream_url(url) request_started_at = time.monotonic() - request_payload = _build_websocket_response_create_payload(payload_dict) + request_payload = _prepare_websocket_response_create_payload(payload_dict) websocket_cm: AsyncContextManager[aiohttp.ClientWebSocketResponse] | None = None websocket: aiohttp.ClientWebSocketResponse | None = None circuit_breaker = None @@ -1288,6 +1296,193 @@ def _build_websocket_response_create_payload(payload_dict: JsonObject) -> JsonOb return request_payload +def _prepare_websocket_response_create_payload(payload_dict: JsonObject) -> JsonObject: + request_payload = _build_websocket_response_create_payload(payload_dict) + payload_text = json.dumps(request_payload, ensure_ascii=True, separators=(",", ":")) + payload_size = len(payload_text.encode("utf-8")) + if payload_size > _UPSTREAM_RESPONSE_CREATE_MAX_BYTES: + slimmed_payload, slim_summary = _slim_response_create_payload_for_upstream( + request_payload, + max_bytes=_UPSTREAM_RESPONSE_CREATE_MAX_BYTES, + ) + if slim_summary is not None: + request_payload = cast(JsonObject, slimmed_payload) + slimmed_text = json.dumps(request_payload, ensure_ascii=True, separators=(",", ":")) + logger.warning( + ( + "Slimmed response.create before upstream websocket connect request_id=%s " + "original_bytes=%s slimmed_bytes=%s historical_tool_outputs_slimmed=%s " + "historical_images_slimmed=%s" + ), + get_request_id(), + payload_size, + len(slimmed_text.encode("utf-8")), + slim_summary["historical_tool_outputs_slimmed"], + slim_summary["historical_images_slimmed"], + ) + payload_text = slimmed_text + payload_size = len(payload_text.encode("utf-8")) + if payload_size > _UPSTREAM_RESPONSE_CREATE_WARN_BYTES: + previous_response_id = request_payload.get("previous_response_id") + logger.warning( + "Large response.create prepared request_id=%s bytes=%s previous_response_id=%s", + get_request_id(), + payload_size, + previous_response_id if isinstance(previous_response_id, str) else None, + ) + if payload_size <= _UPSTREAM_RESPONSE_CREATE_MAX_BYTES: + return request_payload + raise ProxyResponseError( + 413, + _response_create_too_large_error_envelope(payload_size, _UPSTREAM_RESPONSE_CREATE_MAX_BYTES), + failure_phase="validation", + failure_detail=f"response.create_bytes={payload_size}", + ) + + +def _response_create_too_large_error_envelope(actual_bytes: int, max_bytes: int) -> OpenAIErrorEnvelope: + payload = openai_error( + "payload_too_large", + ( + "response.create is too large for upstream websocket " + f"({actual_bytes} bytes > {max_bytes} bytes). " + "Reduce historical images/screenshots or compact the thread." + ), + error_type="invalid_request_error", + ) + payload["error"]["param"] = "input" + return payload + + +def _slim_response_create_payload_for_upstream( + payload: JsonObject, + *, + max_bytes: int, +) -> tuple[JsonObject, dict[str, int] | None]: + del max_bytes + input_value = payload.get("input") + if not isinstance(input_value, list) or not input_value: + return payload, None + + input_items = cast(list[JsonValue], deepcopy(input_value)) + preserve_from = _response_create_recent_suffix_start(input_items) + historical = input_items[:preserve_from] + recent = input_items[preserve_from:] + + tool_outputs_slimmed = 0 + images_slimmed = 0 + + slimmed_historical: list[JsonValue] = [] + for item in historical: + slimmed_item, item_tool_outputs_slimmed, item_images_slimmed = _slim_historical_response_input_item(item) + tool_outputs_slimmed += item_tool_outputs_slimmed + images_slimmed += item_images_slimmed + slimmed_historical.append(slimmed_item) + + if tool_outputs_slimmed == 0 and images_slimmed == 0: + return payload, None + + candidate_payload = dict(payload) + candidate_payload["input"] = slimmed_historical + recent + return candidate_payload, { + "historical_tool_outputs_slimmed": tool_outputs_slimmed, + "historical_images_slimmed": images_slimmed, + } + + +def _response_create_recent_suffix_start(input_items: list[JsonValue]) -> int: + last_user_index: int | None = None + for index, item in enumerate(input_items): + if not is_json_mapping(item): + continue + if item.get("role") == "user": + last_user_index = index + if last_user_index is not None: + return last_user_index + return 0 + + +def _slim_historical_response_input_item(item: JsonValue) -> tuple[JsonValue, int, int]: + if not is_json_mapping(item): + return item, 0, 0 + + item_mapping = dict(cast(dict[str, JsonValue], deepcopy(item))) + tool_outputs_slimmed = 0 + images_slimmed = 0 + + if item_mapping.get("type") == "function_call_output": + output = item_mapping.get("output") + output_text = output if isinstance(output, str) else None + if output_text is not None and _should_slim_historical_tool_output(output_text): + item_mapping["output"] = _RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE.format( + bytes=len(output_text.encode("utf-8")) + ) + tool_outputs_slimmed += 1 + + content = item_mapping.get("content") + slimmed_content, content_images_slimmed = _slim_historical_response_content(content) + if content_images_slimmed > 0: + item_mapping["content"] = slimmed_content + images_slimmed += content_images_slimmed + + if item_mapping.get("type") == "input_image" and _is_inline_image_reference(item_mapping.get("image_url")): + return _response_create_inline_image_notice_item(), tool_outputs_slimmed, images_slimmed + 1 + + return item_mapping, tool_outputs_slimmed, images_slimmed + + +def _slim_historical_response_content(content: JsonValue) -> tuple[JsonValue, int]: + if is_json_mapping(content): + return _slim_historical_response_content_part(content) + if not isinstance(content, list): + return content, 0 + + slimmed_parts: list[JsonValue] = [] + images_slimmed = 0 + for part in content: + slimmed_part, part_images_slimmed = _slim_historical_response_content_part(part) + slimmed_parts.append(slimmed_part) + images_slimmed += part_images_slimmed + return slimmed_parts, images_slimmed + + +def _slim_historical_response_content_part(part: JsonValue) -> tuple[JsonValue, int]: + if not is_json_mapping(part): + return part, 0 + + part_mapping = dict(cast(dict[str, JsonValue], deepcopy(part))) + part_type = part_mapping.get("type") + if part_type == "input_image" and _is_inline_image_reference(part_mapping.get("image_url")): + return _response_create_inline_image_notice_part(), 1 + + if part_type == "image_url": + image_url_value = part_mapping.get("image_url") + if is_json_mapping(image_url_value): + image_url = image_url_value.get("url") + else: + image_url = image_url_value + if _is_inline_image_reference(image_url): + return _response_create_inline_image_notice_part(), 1 + + return part_mapping, 0 + + +def _response_create_inline_image_notice_part() -> JsonObject: + return {"type": "input_text", "text": _RESPONSE_CREATE_IMAGE_OMISSION_NOTICE} + + +def _response_create_inline_image_notice_item() -> JsonObject: + return {"role": "user", "content": [_response_create_inline_image_notice_part()]} + + +def _is_inline_image_reference(value: JsonValue) -> bool: + return isinstance(value, str) and value.startswith("data:image/") + + +def _should_slim_historical_tool_output(output: str) -> bool: + return "data:image/" in output or len(output.encode("utf-8")) > 32 * 1024 + + async def _inline_input_image_urls( payload: JsonObject, session: "ImageFetchSession", diff --git a/app/core/metrics/prometheus.py b/app/core/metrics/prometheus.py index 84edf00a..da4fc186 100644 --- a/app/core/metrics/prometheus.py +++ b/app/core/metrics/prometheus.py @@ -171,6 +171,18 @@ def labels(self, *args: str, **kwargs: str) -> "HistogramLike": ... ["kind"], registry=REGISTRY, ) + continuity_owner_resolution_total = Counter( + "codex_lb_continuity_owner_resolution_total", + "Total continuity owner resolution outcomes by surface and source", + ["surface", "source", "outcome"], + registry=REGISTRY, + ) + continuity_fail_closed_total = Counter( + "codex_lb_continuity_fail_closed_total", + "Total continuity fail-closed or masked retryable outcomes by surface and reason", + ["surface", "reason"], + registry=REGISTRY, + ) def make_scrape_registry() -> CollectorRegistryLike: if MULTIPROCESS_MODE: @@ -211,6 +223,8 @@ def mark_process_dead() -> None: bridge_local_rebind_total: CounterLike | None = None bridge_forward_latency_seconds: HistogramLike | None = None bridge_public_contract_error_total: CounterLike | None = None + continuity_owner_resolution_total: CounterLike | None = None + continuity_fail_closed_total: CounterLike | None = None def make_scrape_registry() -> None: return None @@ -239,6 +253,8 @@ def mark_process_dead() -> None: "bridge_same_account_takeover_total", "bridge_soft_local_rebind_total", "circuit_breaker_state", + "continuity_fail_closed_total", + "continuity_owner_resolution_total", "make_scrape_registry", "mark_process_dead", "prometheus_client", diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index f75ba22a..86f3c0ba 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -78,6 +78,8 @@ bridge_reattach_total, bridge_same_account_takeover_total, bridge_soft_local_rebind_total, + continuity_fail_closed_total, + continuity_owner_resolution_total, ) from app.core.openai.exceptions import ClientPayloadError from app.core.openai.models import CompactResponsePayload, OpenAIEvent, OpenAIResponsePayload @@ -541,13 +543,21 @@ async def _stream_via_http_bridge( ) ) ) - else None + else request_state.preferred_account_id ) + if request_state.previous_response_id is not None and request_state.preferred_account_id is None: + request_state.preferred_account_id = await self._http_bridge_local_owner_account_id( + key=bridge_session_key, + incoming_turn_state=incoming_turn_state_header, + previous_response_id=request_state.previous_response_id, + api_key=api_key, + ) if request_state.previous_response_id is not None and request_state.preferred_account_id is None: request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( previous_response_id=request_state.previous_response_id, api_key=api_key, session_id=request_state.session_id, + surface="http_bridge", ) session_or_forward = await self._get_or_create_http_bridge_session( bridge_session_key, @@ -654,33 +664,64 @@ async def _stream_via_http_bridge( else "owner_forward_bootstrap", outcome="success", ) - await self._submit_http_bridge_request( - session, - request_state=request_state, - text_data=text_data, - queue_limit=queue_limit, - ) - if downstream_turn_state is not None: - await self._register_http_bridge_turn_state(session, downstream_turn_state) + retry_request_state: _WebSocketRequestState | None = None try: - event_queue = request_state.event_queue + retry_api_key_reservation = api_key_reservation + retry_reservation_reacquired = False + if api_key is not None and api_key_reservation is not None: + retry_api_key_reservation = await self._reserve_websocket_api_key_usage( + api_key, + request_model=effective_payload.model, + request_service_tier=_normalize_service_tier_value( + dict(effective_payload.to_payload()).get("service_tier"), + ), + ) + retry_reservation_reacquired = True + + retry_request_state, retry_text_data = self._prepare_http_bridge_request( + effective_payload, + headers, + api_key=api_key, + api_key_reservation=retry_api_key_reservation, + request_id=request_id, + ) + if downstream_turn_state is not None: + retry_request_state.session_id = _normalize_session_id(downstream_turn_state) + retry_request_state.transport = _REQUEST_TRANSPORT_HTTP + retry_request_state.request_stage = "reattach" + retry_request_state.preferred_account_id = request_state.preferred_account_id + + await self._submit_http_bridge_request( + session, + request_state=retry_request_state, + text_data=retry_text_data, + queue_limit=queue_limit, + ) + if downstream_turn_state is not None: + await self._register_http_bridge_turn_state(session, downstream_turn_state) + event_queue = retry_request_state.event_queue assert event_queue is not None while True: event_block = await event_queue.get() if event_block is None: break - if request_state.latency_first_token_ms is None: + if retry_request_state.latency_first_token_ms is None: block_payload = parse_sse_data_json(event_block) block_event_type = _event_type_from_payload(None, block_payload) if block_event_type in _TEXT_DELTA_EVENT_TYPES: - request_state.latency_first_token_ms = int( - (time.monotonic() - request_state.started_at) * 1000 + retry_request_state.latency_first_token_ms = int( + (time.monotonic() - retry_request_state.started_at) * 1000 ) yield event_block + except BaseException: + if retry_reservation_reacquired and retry_api_key_reservation is not None: + await self._release_websocket_reservation(retry_api_key_reservation) + raise finally: - with anyio.CancelScope(shield=True): - await self._detach_http_bridge_request(session, request_state=request_state) - session.last_used_at = time.monotonic() + if retry_request_state is not None: + with anyio.CancelScope(shield=True): + await self._detach_http_bridge_request(session, request_state=retry_request_state) + session.last_used_at = time.monotonic() return session = session_or_forward session_events: AsyncGenerator[str, None] = self._stream_http_bridge_session_events( @@ -966,6 +1007,57 @@ async def _http_bridge_has_live_local_session( return True return False + async def _http_bridge_local_owner_account_id( + self, + *, + key: "_HTTPBridgeSessionKey", + incoming_turn_state: str | None, + previous_response_id: str, + api_key: ApiKeyData | None, + ) -> str | None: + api_key_id = api_key.id if api_key is not None else None + candidate_keys: list[_HTTPBridgeSessionKey] = [key] + async with self._http_bridge_lock: + if incoming_turn_state is not None: + alias_key = self._http_bridge_turn_state_index.get( + _http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id) + ) + if alias_key is not None and alias_key not in candidate_keys: + candidate_keys.append(alias_key) + previous_alias_key = _http_bridge_previous_response_alias_key(previous_response_id, api_key_id) + previous_key = self._http_bridge_previous_response_index.get(previous_alias_key) + if previous_key is not None and previous_key not in candidate_keys: + candidate_keys.append(previous_key) + for candidate_key in candidate_keys: + session = self._http_bridge_sessions.get(candidate_key) + if session is None or session.closed or session.account.status != AccountStatus.ACTIVE: + continue + if not _http_bridge_session_allows_api_key(session, api_key): + continue + if not _http_bridge_session_reusable_for_request( + session=session, + key=candidate_key, + incoming_turn_state=incoming_turn_state, + previous_response_id=previous_response_id, + ): + continue + _record_continuity_owner_resolution( + surface="http_bridge", + source="local_bridge_session", + outcome="hit", + previous_response_id=previous_response_id, + session_id=incoming_turn_state, + ) + return session.account.id + _record_continuity_owner_resolution( + surface="http_bridge", + source="local_bridge_session", + outcome="miss", + previous_response_id=previous_response_id, + session_id=incoming_turn_state, + ) + return None + async def _http_bridge_can_forward_to_active_owner( self, durable_lookup: DurableBridgeLookup, @@ -1829,11 +1921,42 @@ async def proxy_responses_websocket( and request_state.previous_response_id is not None and request_state.preferred_account_id is None ): - request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( - previous_response_id=request_state.previous_response_id, - api_key=request_state.api_key or api_key, - session_id=request_state.session_id, - ) + try: + request_state.preferred_account_id = await self._resolve_websocket_previous_response_owner( + previous_response_id=request_state.previous_response_id, + api_key=request_state.api_key or api_key, + session_id=request_state.session_id, + surface="websocket", + ) + except ProxyResponseError as exc: + error = _parse_openai_error(exc.payload) + error_code = _normalize_error_code( + error.code if error else None, + error.type if error else None, + ) + error_message = error.message if error and error.message else "Upstream error" + error_type = error.type if error and error.type else "server_error" + await self._release_websocket_reservation(request_state.api_key_reservation) + await self._write_websocket_connect_failure( + account_id=None, + api_key=api_key, + request_state=request_state, + error_code=error_code or "upstream_error", + error_message=error_message, + ) + await self._emit_websocket_terminal_error( + websocket, + client_send_lock=client_send_lock, + request_state=request_state, + error_code=error_code or "upstream_error", + error_message=error_message, + error_type=error_type, + downstream_activity=downstream_activity, + ) + request_state = None + text_data = None + payload = None + continue if request_state is not None and not request_state_registered: try: @@ -2195,6 +2318,7 @@ async def _acquire_request_state_response_create_admission( response_create_gate: asyncio.Semaphore, compact: bool = False, ) -> None: + request_state.response_create_gate = response_create_gate await response_create_gate.acquire() request_state.response_create_gate_acquired = True request_state.awaiting_response_created = True @@ -2368,6 +2492,13 @@ async def _select_websocket_connect_account( and account.id != preferred_account_id ): message = "Previous response owner account is unavailable; retry later." + _record_continuity_fail_closed( + surface="websocket_connect", + reason="owner_account_unavailable", + previous_response_id=request_state.previous_response_id, + session_id=request_state.session_id, + upstream_error_code="upstream_unavailable", + ) await self._emit_websocket_connect_failure( websocket, client_send_lock=client_send_lock, @@ -2388,6 +2519,31 @@ async def _select_websocket_connect_account( return account error_code = selection.error_code or "no_accounts" error_message = selection.error_message or "No active accounts available" + if require_preferred_account and preferred_account_id is not None: + message = "Previous response owner account is unavailable; retry later." + _record_continuity_fail_closed( + surface="websocket_connect", + reason="owner_account_unavailable", + previous_response_id=request_state.previous_response_id, + session_id=request_state.session_id, + upstream_error_code=error_code, + ) + await self._emit_websocket_connect_failure( + websocket, + client_send_lock=client_send_lock, + account_id=preferred_account_id, + api_key=api_key, + request_state=request_state, + status_code=502, + payload=openai_error( + "upstream_unavailable", + message, + error_type="server_error", + ), + error_code="upstream_unavailable", + error_message=message, + ) + return None await self._emit_websocket_connect_failure( websocket, client_send_lock=client_send_lock, @@ -2856,6 +3012,11 @@ async def _get_or_create_http_bridge_session( alias_session is None or alias_session.closed or alias_session.account.status != AccountStatus.ACTIVE + or not _http_bridge_session_matches_preferred_account( + session=alias_session, + previous_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) ): self._http_bridge_turn_state_index.pop(alias_index_key, None) key = _HTTPBridgeSessionKey("turn_state_header", incoming_turn_state, api_key_id) @@ -2877,28 +3038,35 @@ async def _get_or_create_http_bridge_session( api_key_id, ) previous_key = self._http_bridge_previous_response_index.get(previous_alias_key) + previous_session = None if previous_key is not None: previous_session = self._http_bridge_sessions.get(previous_key) - if ( - previous_session is not None - and not previous_session.closed - and previous_session.account.status == AccountStatus.ACTIVE - ): - key = previous_session.key - self._promote_http_bridge_session_to_codex_affinity( - previous_session, - turn_state=incoming_turn_state, - settings=settings, - ) - previous_session.downstream_turn_state_aliases.add(incoming_turn_state) - for alias in previous_session.downstream_turn_state_aliases: - self._http_bridge_turn_state_index[ - _http_bridge_turn_state_alias_key( - alias, - previous_session.key.api_key_id, - ) - ] = previous_session.key - continue + if ( + previous_session is not None + and not previous_session.closed + and previous_session.account.status == AccountStatus.ACTIVE + and _http_bridge_session_matches_preferred_account( + session=previous_session, + previous_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) + ): + key = previous_session.key + self._promote_http_bridge_session_to_codex_affinity( + previous_session, + turn_state=incoming_turn_state, + settings=settings, + ) + previous_session.downstream_turn_state_aliases.add(incoming_turn_state) + for alias in previous_session.downstream_turn_state_aliases: + self._http_bridge_turn_state_index[ + _http_bridge_turn_state_alias_key( + alias, + previous_session.key.api_key_id, + ) + ] = previous_session.key + continue + if previous_key is not None: self._http_bridge_previous_response_index.pop(previous_alias_key, None) if incoming_session_key is not None: key = _HTTPBridgeSessionKey("session_header", incoming_session_key, api_key_id) @@ -2921,6 +3089,11 @@ async def _get_or_create_http_bridge_session( incoming_turn_state=incoming_turn_state, previous_response_id=previous_response_id, ) + and _http_bridge_session_matches_preferred_account( + session=existing, + previous_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) ): current_instance = settings.http_responses_session_bridge_instance_id if _durable_bridge_lookup_allows_local_reuse(durable_lookup, current_instance=current_instance): @@ -2977,12 +3150,25 @@ async def _get_or_create_http_bridge_session( ) if owner_check_required or key.affinity_kind == "prompt_cache": owner_instance = _durable_bridge_lookup_active_owner(durable_lookup) + hard_continuity_lookup = owner_check_required or previous_response_id is not None ring_lookup_failed = False if owner_instance is None: try: owner_instance = await _http_bridge_owner_instance(key, settings, self._ring_membership) - except Exception: + except Exception as exc: ring_lookup_failed = True + if hard_continuity_lookup: + _record_continuity_fail_closed( + surface="http_bridge", + reason="owner_metadata_unavailable", + previous_response_id=previous_response_id, + session_id=incoming_turn_state or incoming_session_key, + upstream_error_code="owner_lookup_failed", + ) + raise ProxyResponseError( + 502, + _http_bridge_owner_lookup_unavailable_error_envelope(), + ) from exc if _http_bridge_can_local_recover_without_ring( key=key, headers=headers, @@ -3000,7 +3186,19 @@ async def _get_or_create_http_bridge_session( current_instance, ring = await _active_http_bridge_instance_ring( settings, self._ring_membership ) - except Exception: + except Exception as exc: + if hard_continuity_lookup: + _record_continuity_fail_closed( + surface="http_bridge", + reason="owner_metadata_unavailable", + previous_response_id=previous_response_id, + session_id=incoming_turn_state or incoming_session_key, + upstream_error_code="ring_lookup_failed", + ) + raise ProxyResponseError( + 502, + _http_bridge_owner_lookup_unavailable_error_envelope(), + ) from exc if ring_lookup_failed or _http_bridge_can_local_recover_without_ring( key=key, headers=headers, @@ -3328,22 +3526,25 @@ async def _get_or_create_http_bridge_session( and not allow_previous_response_recovery_rebind and durable_lookup is None ): - continuity_error = ProxyResponseError( - 400, - _http_bridge_previous_response_error_envelope( - previous_response_id, - ( - "HTTP bridge continuity was lost. Replay x-codex-turn-state " - "or retry with a stable prompt_cache_key." - ), - ), + _record_continuity_fail_closed( + surface="http_bridge", + reason="continuity_lost", + previous_response_id=previous_response_id, + session_id=incoming_turn_state or incoming_session_key, ) + continuity_error = ProxyResponseError(502, _http_bridge_continuity_lost_error_envelope()) elif missing_turn_state_alias and inflight_future is None and durable_lookup is None: turn_state_scope_conflict = incoming_turn_state is not None and any( alias == incoming_turn_state and alias_api_key != api_key_id for alias, alias_api_key in self._http_bridge_turn_state_index ) if turn_state_scope_conflict: + _record_continuity_fail_closed( + surface="http_bridge", + reason="turn_state_scope_conflict", + previous_response_id=previous_response_id, + session_id=incoming_turn_state, + ) continuity_error = ProxyResponseError( 409, openai_error( @@ -3357,6 +3558,12 @@ async def _get_or_create_http_bridge_session( and incoming_turn_state.startswith("http_turn_") and not allow_forward_to_owner ): + _record_continuity_fail_closed( + surface="http_bridge", + reason="generated_turn_state_continuity_lost", + previous_response_id=previous_response_id, + session_id=incoming_turn_state, + ) continuity_error = ProxyResponseError( 409, openai_error( @@ -3479,6 +3686,11 @@ async def _get_or_create_http_bridge_session( incoming_turn_state=incoming_turn_state, previous_response_id=previous_response_id, ) + and _http_bridge_session_matches_preferred_account( + session=session, + previous_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) ): current_instance = settings.http_responses_session_bridge_instance_id if _durable_bridge_lookup_allows_local_reuse(durable_lookup, current_instance=current_instance): @@ -3498,6 +3710,7 @@ async def _get_or_create_http_bridge_session( created_session: _HTTPBridgeSession | None = None session_registered = False + require_preferred_account = previous_response_id is not None and preferred_account_id is not None try: created_session = await self._create_http_bridge_session_compatible( key, @@ -3508,6 +3721,7 @@ async def _get_or_create_http_bridge_session( idle_ttl_seconds=effective_idle_ttl_seconds, request_stage=request_stage, preferred_account_id=preferred_account_id, + require_preferred_account=require_preferred_account, ) await self._claim_durable_http_bridge_session( created_session, @@ -3871,6 +4085,7 @@ async def _create_http_bridge_session( idle_ttl_seconds: float, request_stage: str = "first_turn", preferred_account_id: str | None = None, + require_preferred_account: bool = False, ) -> "_HTTPBridgeSession": request_state = _WebSocketRequestState( request_id=f"http_bridge_connect_{uuid4().hex}", @@ -3917,6 +4132,20 @@ async def _create_http_bridge_session( error_type="server_error", ), ) + if require_preferred_account and preferred_account_id is not None and account.id != preferred_account_id: + message = "Previous response owner account is unavailable; retry later." + _record_same_account_takeover( + preferred_account_id=preferred_account_id, + selected_account_id=account.id, + ) + raise ProxyResponseError( + 502, + openai_error( + "upstream_unavailable", + message, + error_type="server_error", + ), + ) selected_is_preferred = preferred_account_id is not None and account.id == preferred_account_id try: account = await self._ensure_fresh_with_budget( @@ -4706,6 +4935,7 @@ async def _resolve_websocket_previous_response_owner( previous_response_id: str | None, api_key: ApiKeyData | None, session_id: str | None = None, + surface: str, ) -> str | None: if previous_response_id is None: return None @@ -4717,6 +4947,13 @@ async def _resolve_websocket_previous_response_owner( cache_key = (response_id, api_key_id, session_id_value) cached_account_id = self._websocket_previous_response_account_index.get(cache_key) if cached_account_id is not None: + _record_continuity_owner_resolution( + surface=surface, + source="request_cache", + outcome="hit", + previous_response_id=response_id, + session_id=session_id_value, + ) return cached_account_id fallback_account_id = ( self._websocket_previous_response_account_index.get((response_id, api_key_id, None)) @@ -4730,10 +4967,55 @@ async def _resolve_websocket_previous_response_owner( api_key_id=api_key_id, session_id=session_id_value, ) - except Exception: - logger.warning("Previous response owner lookup failed; continuing without owner pinning", exc_info=True) - return fallback_account_id + except Exception as exc: + if fallback_account_id is not None: + _record_continuity_owner_resolution( + surface=surface, + source="request_cache_fallback", + outcome="hit", + previous_response_id=response_id, + session_id=session_id_value, + ) + logger.warning( + "Previous response owner lookup failed; using cached owner pin", + exc_info=True, + ) + return fallback_account_id + _record_continuity_owner_resolution( + surface=surface, + source="request_logs", + outcome="fail_closed", + previous_response_id=response_id, + session_id=session_id_value, + ) + _record_continuity_fail_closed( + surface=surface, + reason="owner_lookup_failed", + previous_response_id=response_id, + session_id=session_id_value, + ) + logger.warning("Previous response owner lookup failed; failing closed", exc_info=True) + raise ProxyResponseError( + 502, + _previous_response_owner_lookup_failed_error_envelope(), + ) from exc if account_id is None: + if fallback_account_id is not None: + _record_continuity_owner_resolution( + surface=surface, + source="request_cache_fallback", + outcome="hit", + previous_response_id=response_id, + session_id=session_id_value, + ) + else: + _record_continuity_owner_resolution( + surface=surface, + source="request_logs", + outcome="miss", + previous_response_id=response_id, + session_id=session_id_value, + ) return fallback_account_id self._remember_websocket_previous_response_owner( previous_response_id=response_id, @@ -4741,6 +5023,13 @@ async def _resolve_websocket_previous_response_owner( account_id=account_id, session_id=session_id_value, ) + _record_continuity_owner_resolution( + surface=surface, + source="request_logs", + outcome="hit", + previous_response_id=response_id, + session_id=session_id_value, + ) return account_id async def _handle_websocket_connect_error(self, account: Account, exc: ProxyResponseError) -> ClassifiedFailure: @@ -5010,6 +5299,20 @@ async def _process_upstream_websocket_text( payload=payload, has_other_pending_requests=has_other_pending_requests, ) + if ( + retry_error_code in _WEBSOCKET_TRANSPARENT_REPLAY_ERROR_CODES + and request_state.previous_response_id is not None + and request_state.preferred_account_id is not None + ): + await self._handle_stream_error( + account, + {"message": _websocket_event_error_message(event_type, payload) or "Upstream error"}, + retry_error_code, + ) + event, payload, event_type, downstream_text = _rewrite_websocket_previous_response_owner_unavailable_event( + request_state=request_state, + ) + retry_error_code = None if retry_error_code is not None: upstream_control.reconnect_requested = True if retry_is_previous_response_not_found: @@ -5267,6 +5570,13 @@ async def _emit_websocket_connect_failure( error_code: str, error_message: str, ) -> None: + status_code, payload, error_code, error_message = _sanitize_websocket_connect_failure( + request_state=request_state, + status_code=status_code, + payload=payload, + error_code=error_code, + error_message=error_message, + ) await self._release_websocket_reservation(request_state.api_key_reservation) await self._write_websocket_connect_failure( account_id=account_id, @@ -5275,6 +5585,9 @@ async def _emit_websocket_connect_failure( error_code=error_code, error_message=error_message, ) + response_create_gate = request_state.response_create_gate + if response_create_gate is not None: + _release_websocket_response_create_gate(request_state, response_create_gate) async with client_send_lock: await websocket.send_text( _serialize_websocket_error_event(_wrapped_websocket_error_event(status_code, payload)) @@ -5402,6 +5715,9 @@ async def _emit_websocket_terminal_error( response_id=request_state.response_id or request_state.request_id, error_param=error_param, ) + response_create_gate = request_state.response_create_gate + if response_create_gate is not None: + _release_websocket_response_create_gate(request_state, response_create_gate) try: await self._send_downstream_websocket_text( websocket, @@ -5694,6 +6010,16 @@ async def _stream_with_retry( settlement = _StreamSettlement() last_transient_exc: ProxyResponseError | None = None excluded_account_ids: set[str] = set() + preferred_account_id: str | None = None + require_preferred_account = False + if payload.previous_response_id is not None: + preferred_account_id = await self._resolve_websocket_previous_response_owner( + previous_response_id=payload.previous_response_id, + api_key=api_key, + session_id=_owner_lookup_session_id_from_headers(headers), + surface="http_stream", + ) + require_preferred_account = preferred_account_id is not None try: for attempt in range(max_attempts): remaining_budget = _remaining_budget_seconds(deadline) @@ -5731,6 +6057,7 @@ async def _stream_with_retry( routing_strategy=routing_strategy, model=payload.model, exclude_account_ids=excluded_account_ids, + preferred_account_id=preferred_account_id, ) except ProxyResponseError as exc: error = _parse_openai_error(exc.payload) @@ -5762,6 +6089,36 @@ async def _stream_with_retry( return account = selection.account if not account: + if require_preferred_account and preferred_account_id is not None: + message = "Previous response owner account is unavailable; retry later." + _record_continuity_fail_closed( + surface="http_stream", + reason="owner_account_unavailable", + previous_response_id=payload.previous_response_id, + session_id=headers.get("x-codex-turn-state") or headers.get("session_id"), + upstream_error_code="no_accounts", + ) + event = response_failed_event( + "upstream_unavailable", + message, + response_id=request_id, + ) + yield format_sse_event(event) + await self._write_request_log( + account_id=preferred_account_id, + api_key=api_key, + request_id=request_id, + model=payload.model, + latency_ms=int((time.monotonic() - start) * 1000), + status="error", + error_code="upstream_unavailable", + error_message=message, + reasoning_effort=payload.reasoning.effort if payload.reasoning else None, + transport=request_transport, + service_tier=payload.service_tier, + requested_service_tier=payload.service_tier, + ) + return # If a prior attempt stored a transient 500 and the caller # expects HTTP error propagation, re-raise the original error # instead of returning a generic no_accounts event. @@ -5792,6 +6149,40 @@ async def _stream_with_retry( return account_id_value = account.id + if ( + require_preferred_account + and preferred_account_id is not None + and account.id != preferred_account_id + ): + message = "Previous response owner account is unavailable; retry later." + _record_continuity_fail_closed( + surface="http_stream", + reason="owner_account_unavailable", + previous_response_id=payload.previous_response_id, + session_id=headers.get("x-codex-turn-state") or headers.get("session_id"), + upstream_error_code="upstream_unavailable", + ) + event = response_failed_event( + "upstream_unavailable", + message, + response_id=request_id, + ) + yield format_sse_event(event) + await self._write_request_log( + account_id=preferred_account_id, + api_key=api_key, + request_id=request_id, + model=payload.model, + latency_ms=int((time.monotonic() - start) * 1000), + status="error", + error_code="upstream_unavailable", + error_message=message, + reasoning_effort=payload.reasoning.effort if payload.reasoning else None, + transport=request_transport, + service_tier=payload.service_tier, + requested_service_tier=payload.service_tier, + ) + return try: remaining_budget = _remaining_budget_seconds(deadline) if remaining_budget <= 0: @@ -5894,6 +6285,7 @@ async def _stream_with_retry( suppress_text_done_events=suppress_text_done_events, upstream_stream_transport=upstream_stream_transport, request_transport=request_transport, + preferred_account_id=preferred_account_id, ): yield line except (_TransientStreamError, ProxyResponseError) as tex: @@ -6219,6 +6611,7 @@ async def _stream_once( suppress_text_done_events: bool, upstream_stream_transport: str | None, request_transport: str, + preferred_account_id: str | None = None, ) -> AsyncIterator[str]: account_id_value = account.id access_token = self._encryptor.decrypt(account.access_token_encrypted) @@ -6277,22 +6670,52 @@ async def _stream_once( error = response.error if response else None else: error = event.error + response_id = ( + event.response.id + if event.type == "response.failed" and event.response and event.response.id + else request_id + ) code = _normalize_error_code( error.code if error else None, error.type if error else None, ) + rewritten_error = _rewrite_previous_response_stream_error( + previous_response_id=payload.previous_response_id, + preferred_account_id=preferred_account_id, + error_code=code, + error_type=error.type if error else None, + error_message=error.message if error else None, + error_param=error.param if error else None, + ) status = "error" - error_code = code - error_message = error.message if error else None settlement.error = _upstream_error_from_openai(error) settlement.record_success = False - settlement.account_health_error = _should_penalize_stream_error(code) - if allow_retry and _should_retry_stream_error(code): - raise _RetryableStreamError(code, settlement.error) - if allow_transient_retry and code in _TRANSIENT_RETRY_CODES: - raise _TransientStreamError(code, settlement.error) + if rewritten_error is not None: + rewritten_code, rewritten_message, upstream_error_code = rewritten_error + if upstream_error_code is not None: + await self._handle_stream_error( + account, + settlement.error, + upstream_error_code, + ) + first, event, first_payload, event_type = _build_rewritten_stream_response_failed_event( + response_id=response_id, + error_code=rewritten_code, + error_message=rewritten_message, + ) + error_code = rewritten_code + error_message = rewritten_message + settlement.account_health_error = False + else: + error_code = code + error_message = error.message if error else None + settlement.account_health_error = _should_penalize_stream_error(code) + if allow_retry and _should_retry_stream_error(code): + raise _RetryableStreamError(code, settlement.error) + if allow_transient_retry and code in _TRANSIENT_RETRY_CODES: + raise _TransientStreamError(code, settlement.error) terminal_stream_error = _TerminalStreamError( - code, + error_code or code, settlement.error, ) if allow_retry: @@ -6300,7 +6723,7 @@ async def _stream_once( "Not retrying non-recoverable stream failure request_id=%s account_id=%s code=%s", request_id, account_id_value, - code, + error_code or code, ) if event and event.type in ("response.completed", "response.incomplete"): @@ -6347,14 +6770,47 @@ async def _stream_once( error = response.error if response else None else: error = event.error - error_code = _normalize_error_code( + raw_error_code = _normalize_error_code( error.code if error else None, error.type if error else None, ) - error_message = error.message if error else None - settlement.error = _upstream_error_from_openai(error) - settlement.record_success = False - settlement.account_health_error = _should_penalize_stream_error(error_code) + rewritten_error = _rewrite_previous_response_stream_error( + previous_response_id=payload.previous_response_id, + preferred_account_id=preferred_account_id, + error_code=raw_error_code, + error_type=error.type if error else None, + error_message=error.message if error else None, + error_param=error.param if error else None, + ) + if rewritten_error is not None: + response_id = ( + event.response.id + if event_type == "response.failed" and event.response and event.response.id + else request_id + ) + rewritten_code, rewritten_message, upstream_error_code = rewritten_error + if upstream_error_code is not None: + await self._handle_stream_error( + account, + _upstream_error_from_openai(error), + upstream_error_code, + ) + line, event, event_payload, event_type = _build_rewritten_stream_response_failed_event( + response_id=response_id, + error_code=rewritten_code, + error_message=rewritten_message, + ) + error_code = rewritten_code + error_message = rewritten_message + settlement.error = _upstream_error_from_openai(error) + settlement.record_success = False + settlement.account_health_error = False + else: + error_code = raw_error_code + error_message = error.message if error else None + settlement.error = _upstream_error_from_openai(error) + settlement.record_success = False + settlement.account_health_error = _should_penalize_stream_error(error_code) if event_type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None if event_type == "response.incomplete": @@ -6365,6 +6821,36 @@ async def _stream_once( except ProxyResponseError as exc: response_create_lease.release() error = _parse_openai_error(exc.payload) + rewritten_error = _rewrite_previous_response_stream_error( + previous_response_id=payload.previous_response_id, + preferred_account_id=preferred_account_id, + error_code=_normalize_error_code( + error.code if error else None, + error.type if error else None, + ), + error_type=error.type if error else None, + error_message=error.message if error else None, + error_param=error.param if error else None, + ) + if rewritten_error is not None: + rewritten_code, rewritten_message, upstream_error_code = rewritten_error + if upstream_error_code is not None: + await self._handle_stream_error( + account, + _upstream_error_from_openai(error), + upstream_error_code, + ) + status = "error" + error_code = rewritten_code + error_message = rewritten_message + settlement.record_success = False + settlement.account_health_error = False + yield _build_rewritten_stream_response_failed_event( + response_id=request_id, + error_code=rewritten_code, + error_message=rewritten_message, + )[0] + return status = "error" error_code = _normalize_error_code( error.code if error else None, @@ -6916,6 +7402,7 @@ class _WebSocketRequestState: error_param_override: str | None = None error_http_status_override: int | None = None response_create_gate_acquired: bool = False + response_create_gate: asyncio.Semaphore | None = None response_create_admission: AdmissionLease | None = None affinity_policy: _AffinityPolicy = field(default_factory=_AffinityPolicy) @@ -7331,6 +7818,13 @@ def _maybe_rewrite_websocket_previous_response_not_found_event( return event, payload, event_type, original_text upstream_control.reconnect_requested = True + _record_continuity_fail_closed( + surface="websocket_stream", + reason="previous_response_not_found", + previous_response_id=request_state.previous_response_id, + session_id=request_state.session_id, + upstream_error_code=error_code, + ) rewritten_event_payload = response_failed_event( "stream_incomplete", "Upstream websocket closed before response.completed", @@ -7345,6 +7839,136 @@ def _maybe_rewrite_websocket_previous_response_not_found_event( return rewritten_event, rewritten_payload, rewritten_event_type, rewritten_text +def _rewrite_websocket_previous_response_owner_unavailable_event( + *, + request_state: _WebSocketRequestState, +) -> tuple[OpenAIEvent | None, dict[str, JsonValue] | None, str | None, str]: + _record_continuity_fail_closed( + surface="websocket_stream", + reason="owner_account_unavailable", + previous_response_id=request_state.previous_response_id, + session_id=request_state.session_id, + ) + rewritten_event_payload = response_failed_event( + "upstream_unavailable", + "Previous response owner account is unavailable; retry later.", + error_type="server_error", + response_id=request_state.response_id or request_state.request_id, + ) + rewritten_text = json.dumps(rewritten_event_payload, ensure_ascii=True, separators=(",", ":")) + rewritten_event_block = format_sse_event(rewritten_event_payload) + rewritten_payload = parse_sse_data_json(rewritten_event_block) + rewritten_event = parse_sse_event(rewritten_event_block) + rewritten_event_type = _event_type_from_payload(rewritten_event, rewritten_payload) + return rewritten_event, rewritten_payload, rewritten_event_type, rewritten_text + + +def _sanitize_websocket_connect_failure( + *, + request_state: _WebSocketRequestState, + status_code: int, + payload: OpenAIErrorEnvelope, + error_code: str, + error_message: str, +) -> tuple[int, OpenAIErrorEnvelope, str, str]: + if request_state.previous_response_id is None: + return status_code, payload, error_code, error_message + + parsed_error = _parse_openai_error(payload) + normalized_code = _normalize_error_code( + parsed_error.code if parsed_error else error_code, + parsed_error.type if parsed_error else None, + ) + normalized_message = parsed_error.message if parsed_error and parsed_error.message else error_message + if not _is_previous_response_not_found_error( + code=normalized_code, + param=parsed_error.param if parsed_error else None, + message=normalized_message, + ): + return status_code, payload, error_code, error_message + + rewritten_message = "Upstream websocket closed before response.completed" + _record_continuity_fail_closed( + surface="websocket_connect", + reason="previous_response_not_found", + previous_response_id=request_state.previous_response_id, + session_id=request_state.session_id, + upstream_error_code=normalized_code, + ) + return ( + 502, + openai_error( + "stream_incomplete", + rewritten_message, + error_type="server_error", + ), + "stream_incomplete", + rewritten_message, + ) + + +def _rewrite_previous_response_stream_error( + *, + previous_response_id: str | None, + preferred_account_id: str | None, + error_code: str | None, + error_type: str | None, + error_message: str | None, + error_param: str | None, +) -> tuple[str, str, str | None] | None: + if previous_response_id is None: + return None + if _is_previous_response_not_found_error( + code=error_code, + param=error_param, + message=error_message, + ): + _record_continuity_fail_closed( + surface="http_stream", + reason="previous_response_not_found", + previous_response_id=previous_response_id, + upstream_error_code=error_code, + ) + return ( + "stream_incomplete", + "Upstream websocket closed before response.completed", + None, + ) + normalized_code = _normalize_error_code(error_code, error_type) + if preferred_account_id is not None and normalized_code in _ACCOUNT_RECOVERY_RETRY_CODES: + _record_continuity_fail_closed( + surface="http_stream", + reason="owner_account_unavailable", + previous_response_id=previous_response_id, + upstream_error_code=normalized_code, + ) + return ( + "upstream_unavailable", + "Previous response owner account is unavailable; retry later.", + normalized_code, + ) + return None + + +def _build_rewritten_stream_response_failed_event( + *, + response_id: str, + error_code: str, + error_message: str, +) -> tuple[str, OpenAIEvent | None, dict[str, JsonValue] | None, str | None]: + rewritten_event_payload = response_failed_event( + error_code, + error_message, + error_type="server_error", + response_id=response_id, + ) + rewritten_event_block = format_sse_event(rewritten_event_payload) + rewritten_payload = parse_sse_data_json(rewritten_event_block) + rewritten_event = parse_sse_event(rewritten_event_block) + rewritten_event_type = _event_type_from_payload(rewritten_event, rewritten_payload) + return rewritten_event_block, rewritten_event, rewritten_payload, rewritten_event_type + + def _find_websocket_request_state_by_response_id( pending_requests: deque[_WebSocketRequestState], response_id: str, @@ -7379,6 +8003,7 @@ def _release_websocket_response_create_gate( request_state.response_create_admission.release() request_state.response_create_admission = None request_state.awaiting_response_created = False + request_state.response_create_gate = None if not request_state.response_create_gate_acquired: return request_state.response_create_gate_acquired = False @@ -8081,6 +8706,65 @@ def _maybe_log_proxy_service_tier_trace( ) +def _hash_identifier_or_none(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + if not stripped: + return None + return _hash_identifier(stripped) + + +def _record_continuity_owner_resolution( + *, + surface: str, + source: str, + outcome: str, + previous_response_id: str | None, + session_id: str | None, +) -> None: + if PROMETHEUS_AVAILABLE and continuity_owner_resolution_total is not None: + continuity_owner_resolution_total.labels( + surface=surface, + source=source, + outcome=outcome, + ).inc() + if outcome == "miss" or (outcome == "hit" and source == "request_cache"): + return + logger.log( + logging.WARNING if outcome == "fail_closed" else logging.INFO, + "continuity_owner_resolution surface=%s source=%s outcome=%s previous_response_id=%s session_id=%s", + surface, + source, + outcome, + _hash_identifier_or_none(previous_response_id), + _hash_identifier_or_none(session_id), + ) + + +def _record_continuity_fail_closed( + *, + surface: str, + reason: str, + previous_response_id: str | None, + session_id: str | None = None, + upstream_error_code: str | None = None, +) -> None: + if PROMETHEUS_AVAILABLE and continuity_fail_closed_total is not None: + continuity_fail_closed_total.labels( + surface=surface, + reason=reason, + ).inc() + logger.warning( + "continuity_fail_closed surface=%s reason=%s previous_response_id=%s session_id=%s upstream_error_code=%s", + surface, + reason, + _hash_identifier_or_none(previous_response_id), + _hash_identifier_or_none(session_id), + upstream_error_code, + ) + + def _hash_identifier(value: str) -> str: digest = sha256(value.encode("utf-8")).hexdigest() return f"sha256:{digest[:12]}" @@ -8379,6 +9063,17 @@ def _http_bridge_session_reusable_for_request( return not session.codex_session +def _http_bridge_session_matches_preferred_account( + *, + session: "_HTTPBridgeSession", + previous_response_id: str | None, + preferred_account_id: str | None, +) -> bool: + if previous_response_id is None or preferred_account_id is None: + return True + return session.account.id == preferred_account_id + + def _resolve_prompt_cache_key( payload: ResponsesRequest | ResponsesCompactRequest, *, @@ -8780,6 +9475,30 @@ def _http_bridge_previous_response_error_envelope( return payload +def _http_bridge_continuity_lost_error_envelope() -> OpenAIErrorEnvelope: + return openai_error( + "stream_incomplete", + "Upstream websocket closed before response.completed", + error_type="server_error", + ) + + +def _http_bridge_owner_lookup_unavailable_error_envelope() -> OpenAIErrorEnvelope: + return openai_error( + "upstream_unavailable", + "HTTP bridge owner metadata unavailable; retry later.", + error_type="server_error", + ) + + +def _previous_response_owner_lookup_failed_error_envelope() -> OpenAIErrorEnvelope: + return openai_error( + "upstream_unavailable", + "Previous response owner lookup failed; retry later.", + error_type="server_error", + ) + + def _mark_request_state_previous_response_not_found( request_state: _WebSocketRequestState, detail: str, diff --git a/openspec/changes/harden-continuity-fail-closed-edges/.openspec.yaml b/openspec/changes/harden-continuity-fail-closed-edges/.openspec.yaml new file mode 100644 index 00000000..3a54a172 --- /dev/null +++ b/openspec/changes/harden-continuity-fail-closed-edges/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-16 diff --git a/openspec/changes/harden-continuity-fail-closed-edges/design.md b/openspec/changes/harden-continuity-fail-closed-edges/design.md new file mode 100644 index 00000000..5fa7afad --- /dev/null +++ b/openspec/changes/harden-continuity-fail-closed-edges/design.md @@ -0,0 +1,41 @@ +## Context + +The proxy already rewrites most `previous_response_id` continuity failures into retryable contracts, but the remaining gaps sit in two places: bridge-local continuity loss still emits `400 previous_response_not_found`, and owner lookup failures can continue without hard pinning. Both behaviors are inconsistent with the practical goal of preserving run continuity whenever the proxy cannot prove a safe owner. + +## Goals / Non-Goals + +**Goals:** +- Make continuity-loss edge cases retryable across HTTP bridge, HTTP fallback, and websocket follow-up flows. +- Ensure lookup failures for hard continuity requests fail closed instead of degrading into unpinned recovery. +- Cover the remaining edge cases with regression tests. + +**Non-Goals:** +- Introduce durable continuity guarantees beyond the existing owner/alias model. +- Change prompt-cache locality behavior for soft-affinity requests that do not depend on hard continuity anchors. + +## Decisions + +### Use retryable fail-closed errors for continuity loss +Bridge-local continuity loss should surface as a retryable continuity failure, not as `400 previous_response_not_found`. The proxy already uses `stream_incomplete` for equivalent upstream continuity loss, so the same contract should apply when the bridge itself loses continuity metadata. + +Alternative considered: keep raw `400` for “definitive” local misses. Rejected because it leaves clients with two incompatible contracts for the same continuity failure class. + +### Fail closed on owner/ring lookup errors for hard continuity +When a request depends on `previous_response_id` or hard bridge continuity keys, lookup failures must not fall back to local recovery without pinning. The proxy should return a retryable `upstream_unavailable` error instead. + +Alternative considered: continue current degrade-open behavior. Rejected because it allows continuity fragmentation precisely when the proxy has lost the data needed to enforce owner correctness. + +## Risks / Trade-offs + +- [Risk] More requests can fail fast during transient owner/ring metadata outages. → Mitigation: failures become retryable and avoid silent continuity forks. +- [Risk] Existing tests and assumptions around raw `previous_response_not_found` need updates. → Mitigation: add targeted regressions before changing runtime behavior. + +## Migration Plan + +1. Add regression tests for bridge continuity-loss and owner lookup failure paths. +2. Update runtime behavior to emit retryable fail-closed errors for those paths. +3. Run targeted continuity suites and full pytest before merging. + +## Open Questions + +- None for this change; the desired contract is to eliminate remaining raw continuity leaks and unpinned lookup fallbacks. diff --git a/openspec/changes/harden-continuity-fail-closed-edges/proposal.md b/openspec/changes/harden-continuity-fail-closed-edges/proposal.md new file mode 100644 index 00000000..49131ebf --- /dev/null +++ b/openspec/changes/harden-continuity-fail-closed-edges/proposal.md @@ -0,0 +1,25 @@ +## Why + +Recent continuity fixes closed the main websocket and HTTP fallback leaks, but two edge cases still violate the intended contract. Bridge-enabled HTTP can still return raw `400 previous_response_not_found` when continuity metadata is lost, and owner/ring lookup failures can still degrade into local recovery without guaranteed owner pinning. + +## What Changes + +- Replace bridge-local continuity-loss `previous_response_not_found` responses with retryable continuity errors. +- Require hard continuity requests to fail closed when owner or ring lookup errors prevent safe pinning. +- Add regression coverage for bridge continuity-loss and lookup-failure edge cases. + +## Capabilities + +### New Capabilities + +- None. + +### Modified Capabilities + +- `responses-api-compat`: continuity-dependent follow-up requests now use retryable fail-closed errors instead of raw `previous_response_not_found`, and owner lookup failures no longer degrade into unpinned recovery. + +## Impact + +- Affected code: `app/modules/proxy/service.py`, `app/modules/proxy/api.py`, `app/core/clients/proxy.py`, and continuity-focused tests. +- Affected APIs: HTTP `/v1/responses`, HTTP `/backend-api/codex/responses`, and websocket Responses continuity handling. +- Operational impact: continuity failures become consistently retryable; operators should watch `stream_incomplete` and `upstream_unavailable` instead of raw `previous_response_not_found` for these paths. diff --git a/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md new file mode 100644 index 00000000..d560e9a2 --- /dev/null +++ b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md @@ -0,0 +1,31 @@ +## ADDED Requirements + +### Requirement: Continuity-dependent Responses follow-ups fail closed with retryable errors +When a Responses follow-up depends on previously established continuity state, the service MUST return a retryable continuity error if that continuity cannot be reconstructed safely. The service MUST NOT expose raw `previous_response_not_found` for bridge-local metadata loss or similar internal continuity gaps. + +#### Scenario: HTTP bridge loses local continuity metadata for a follow-up request +- **WHEN** an HTTP `/v1/responses` or `/backend-api/codex/responses` follow-up request depends on `previous_response_id` or a hard continuity turn-state +- **AND** the bridge cannot reconstruct the matching live continuity state from local or durable metadata +- **THEN** the service returns a retryable OpenAI-format error +- **AND** the error code is not `previous_response_not_found` + +#### Scenario: in-flight bridge follower loses continuity while waiting on the same canonical session +- **WHEN** a follow-up request waits on an in-flight HTTP bridge session for the same hard continuity key +- **AND** the bridge still cannot reconstruct safe continuity state once the leader finishes +- **THEN** the service returns a retryable OpenAI-format error +- **AND** the error code is not `previous_response_not_found` + +### Requirement: Hard continuity owner lookup fails closed +When a request depends on hard continuity ownership, the service MUST fail closed if owner or ring lookup errors prevent safe pinning. The service MUST NOT continue with local recovery or account selection that bypasses hard owner enforcement. + +#### Scenario: websocket previous-response owner lookup errors +- **WHEN** a websocket or HTTP fallback follow-up request includes `previous_response_id` +- **AND** owner lookup errors prevent the proxy from determining the required owner account +- **THEN** the service returns a retryable OpenAI-format error +- **AND** it does not continue the request on an unpinned account + +#### Scenario: bridge owner or ring lookup errors for hard continuity keys +- **WHEN** an HTTP bridge request uses a hard continuity key such as turn-state, explicit session affinity, or `previous_response_id` +- **AND** owner or ring lookup errors prevent the proxy from proving the correct bridge owner +- **THEN** the service returns a retryable OpenAI-format error +- **AND** it does not create or recover a local bridge session on the current replica diff --git a/openspec/changes/harden-continuity-fail-closed-edges/tasks.md b/openspec/changes/harden-continuity-fail-closed-edges/tasks.md new file mode 100644 index 00000000..b77dfe0a --- /dev/null +++ b/openspec/changes/harden-continuity-fail-closed-edges/tasks.md @@ -0,0 +1,14 @@ +## 1. Continuity Contract + +- [ ] 1.1 Update bridge-local continuity-loss paths to return retryable errors instead of raw `previous_response_not_found`. +- [ ] 1.2 Fail closed on hard-continuity owner/ring lookup errors instead of degrading into unpinned or local recovery. + +## 2. Regression Coverage + +- [ ] 2.1 Add bridge regression tests for missing turn-state alias and inflight-follower continuity loss. +- [ ] 2.2 Add lookup-failure regression tests for websocket or HTTP fallback `previous_response_id` flows and hard bridge owner lookup failures. + +## 3. Verification + +- [ ] 3.1 Run targeted continuity test suites covering bridge, websocket, and HTTP fallback paths. +- [ ] 3.2 Run full pytest and confirm no broader regressions. diff --git a/openspec/changes/observe-continuity-decisions/proposal.md b/openspec/changes/observe-continuity-decisions/proposal.md new file mode 100644 index 00000000..a5b839ee --- /dev/null +++ b/openspec/changes/observe-continuity-decisions/proposal.md @@ -0,0 +1,20 @@ +## Why + +Continuity fail-closed fixes now keep clients on retryable contracts, but operators still have to infer the root cause from scattered warnings and endpoint-level error codes. That slows incident analysis for `previous_response_id` follow-ups because it is not immediately obvious whether the proxy resolved continuity from a local bridge session, request-log lookup, cache, or failed closed for a specific reason. + +## What Changes + +- Add structured continuity decision logs for owner resolution and fail-closed/rewrite outcomes. +- Add low-cardinality Prometheus counters for continuity owner-resolution sources and continuity fail-closed reasons. +- Cover the new observability signals with unit tests. + +## Capabilities + +### Modified Capabilities + +- `proxy-runtime-observability`: continuity-sensitive responses flows now emit explicit operator-facing diagnostics for owner resolution and fail-closed decisions. + +## Impact + +- Affected code: `app/modules/proxy/service.py`, `app/core/metrics/prometheus.py`, and observability-focused tests. +- Operational impact: oncall can distinguish local bridge reuse, cache/request-log owner resolution, and fail-closed continuity masking without inspecting raw upstream payloads. diff --git a/openspec/changes/observe-continuity-decisions/specs/proxy-runtime-observability/spec.md b/openspec/changes/observe-continuity-decisions/specs/proxy-runtime-observability/spec.md new file mode 100644 index 00000000..397b7b5c --- /dev/null +++ b/openspec/changes/observe-continuity-decisions/specs/proxy-runtime-observability/spec.md @@ -0,0 +1,15 @@ +## ADDED Requirements + +### Requirement: Continuity-sensitive responses flows emit explicit operator diagnostics +When the proxy resolves or fails closed a continuity-sensitive follow-up request, the system MUST emit structured diagnostics that let operators determine how continuity ownership was resolved or why the proxy returned a retryable masked error. + +#### Scenario: owner resolution source is recorded for a previous-response follow-up +- **WHEN** a websocket, HTTP fallback, or HTTP bridge follow-up request includes `previous_response_id` +- **AND** the proxy resolves the required owner account from a continuity source such as a local bridge session, owner cache, or request-log lookup +- **THEN** the system emits a structured diagnostic describing the continuity surface, source, and outcome +- **AND** the diagnostic does not expose the raw `previous_response_id` + +#### Scenario: fail-closed continuity masking is recorded +- **WHEN** the proxy rewrites or returns a retryable continuity error because owner metadata is unavailable, continuity state is lost, or the pinned owner account is unavailable +- **THEN** the system emits a structured diagnostic describing the continuity surface and fail-closed reason +- **AND** Prometheus counters record the low-cardinality source or reason labels for that decision diff --git a/openspec/changes/observe-continuity-decisions/tasks.md b/openspec/changes/observe-continuity-decisions/tasks.md new file mode 100644 index 00000000..fde1adaa --- /dev/null +++ b/openspec/changes/observe-continuity-decisions/tasks.md @@ -0,0 +1,13 @@ +## 1. Continuity Decision Signals + +- [x] 1.1 Add structured continuity decision logging for owner resolution and fail-closed/rewrite paths. +- [x] 1.2 Add Prometheus counters for continuity owner-resolution sources and fail-closed reasons. + +## 2. Regression Coverage + +- [x] 2.1 Add unit tests for continuity decision logs and metrics on representative websocket and HTTP bridge paths. + +## 3. Verification + +- [x] 3.1 Run targeted observability and continuity test suites. +- [x] 3.2 Run `ruff`, `openspec validate --specs`, and full `pytest`. diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index 77762a23..d8bbb008 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -1248,15 +1248,11 @@ async def test_v1_responses_http_bridge_missing_turn_state_alias_with_previous_r ) exc = exc_info.value - assert exc.status_code == 400 + assert exc.status_code == 502 assert exc.payload["error"] == { - "message": ( - "Previous response with id 'resp_missing_alias' not found. " - "HTTP bridge continuity was lost. Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", + "message": "Upstream websocket closed before response.completed", + "type": "server_error", + "code": "stream_incomplete", } @@ -5834,15 +5830,11 @@ async def fake_create_http_bridge_session( assert service._http_bridge_sessions[key] is created_session exc = exc_info.value - assert exc.status_code == 400 + assert exc.status_code == 502 assert exc.payload["error"] == { - "message": ( - "Previous response with id 'resp_inflight' not found. " - "HTTP bridge continuity was lost. Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", + "message": "Upstream websocket closed before response.completed", + "type": "server_error", + "code": "stream_incomplete", } diff --git a/tests/integration/test_proxy_responses.py b/tests/integration/test_proxy_responses.py index 24629172..70af704b 100644 --- a/tests/integration/test_proxy_responses.py +++ b/tests/integration/test_proxy_responses.py @@ -2,16 +2,19 @@ import base64 import json +from types import SimpleNamespace import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import select +import app.core.clients.proxy as proxy_client_module import app.modules.proxy.service as proxy_module from app.core.auth import generate_unique_account_id from app.core.config.settings import Settings from app.db.models import DashboardSettings, RequestLog from app.db.session import SessionLocal +from app.modules.request_logs.repository import RequestLogsRepository pytestmark = pytest.mark.integration @@ -45,6 +48,34 @@ def _extract_first_event(lines: list[str]) -> dict: raise AssertionError("No SSE data event found") +class _FakeUpstreamWebSocket: + def __init__(self, messages: list[object]) -> None: + self._messages = list(messages) + self.sent_json: list[dict[str, object]] = [] + self.closed = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + self.closed = True + return False + + async def send_json(self, payload: dict[str, object]) -> None: + self.sent_json.append(payload) + + async def receive(self): + if self._messages: + return self._messages.pop(0) + return SimpleNamespace(type=proxy_client_module.aiohttp.WSMsgType.CLOSE, data=None, extra=None) + + async def close(self) -> None: + self.closed = True + + def exception(self): + return None + + @pytest.fixture(autouse=True) def _disable_http_bridge(monkeypatch: pytest.MonkeyPatch) -> None: app_settings = Settings( @@ -180,6 +211,361 @@ async def test_v1_responses_routes_under_root_path(app_instance): assert event["response"]["error"]["code"] == "no_accounts" +@pytest.mark.asyncio +async def test_v1_responses_previous_response_not_found_without_http_bridge_returns_stream_incomplete( + async_client, + monkeypatch, +): + email = "prev-http-fallback@example.com" + raw_account_id = "acc_prev_http_fallback" + auth_json = _make_auth_json(raw_account_id, email) + files = {"auth_json": ("auth.json", json.dumps(auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=files) + assert response.status_code == 200 + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **kwargs): + del payload, headers, access_token, account_id, base_url, raise_for_status, kwargs + error_payload = proxy_module.openai_error( + "previous_response_not_found", + "Previous response with id 'resp_prev_http_fallback' not found.", + error_type="invalid_request_error", + ) + error_payload["error"]["param"] = "previous_response_id" + raise proxy_module.ProxyResponseError(400, error_payload) + if False: + yield "" + + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "input": "continue", + "previous_response_id": "resp_prev_http_fallback", + }, + headers={"session_id": "sid_prev_http_fallback"}, + ) + + assert response.status_code == 502 + assert response.json()["error"]["code"] == "stream_incomplete" + assert response.json()["error"]["message"] == "Upstream websocket closed before response.completed" + + +@pytest.mark.asyncio +async def test_v1_responses_previous_response_not_found_without_http_bridge_and_missing_owner_returns_stream_incomplete( + async_client, + monkeypatch, +): + email = "prev-http-missing-owner@example.com" + raw_account_id = "acc_prev_http_missing_owner" + auth_json = _make_auth_json(raw_account_id, email) + files = {"auth_json": ("auth.json", json.dumps(auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=files) + assert response.status_code == 200 + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **kwargs): + del payload, headers, access_token, account_id, base_url, raise_for_status, kwargs + error_payload = proxy_module.openai_error( + "previous_response_not_found", + "Previous response with id 'resp_prev_http_missing_owner' not found.", + error_type="invalid_request_error", + ) + error_payload["error"]["param"] = "previous_response_id" + raise proxy_module.ProxyResponseError(400, error_payload) + if False: + yield "" + + async def fake_resolve_owner(self, *, previous_response_id, api_key, session_id, surface): + del self, previous_response_id, api_key, session_id, surface + return None + + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + monkeypatch.setattr(proxy_module.ProxyService, "_resolve_websocket_previous_response_owner", fake_resolve_owner) + + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "input": "continue", + "previous_response_id": "resp_prev_http_missing_owner", + }, + headers={"session_id": "sid_prev_http_missing_owner"}, + ) + + assert response.status_code == 502 + assert response.json()["error"]["code"] == "stream_incomplete" + assert response.json()["error"]["message"] == "Upstream websocket closed before response.completed" + + +@pytest.mark.asyncio +async def test_v1_responses_previous_response_owner_lookup_failure_without_http_bridge_returns_upstream_unavailable( + async_client, + monkeypatch, +): + email = "prev-http-owner-lookup-failure@example.com" + raw_account_id = "acc_prev_http_owner_lookup_failure" + auth_json = _make_auth_json(raw_account_id, email) + files = {"auth_json": ("auth.json", json.dumps(auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=files) + assert response.status_code == 200 + + async def fail_owner_lookup(self, *, response_id, api_key_id, session_id=None): + del self, response_id, api_key_id, session_id + raise RuntimeError("lookup unavailable") + + async def fail_stream(*args, **kwargs): + del args, kwargs + raise AssertionError("owner lookup failure must fail before upstream stream attempt") + if False: + yield "" + + monkeypatch.setattr(RequestLogsRepository, "find_latest_account_id_for_response_id", fail_owner_lookup) + monkeypatch.setattr(proxy_module, "core_stream_responses", fail_stream) + + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "input": "continue", + "previous_response_id": "resp_prev_owner_lookup_failure", + }, + headers={"session_id": "sid_prev_owner_lookup_failure"}, + ) + + assert response.status_code == 502 + assert response.json()["error"]["code"] == "upstream_unavailable" + assert response.json()["error"]["message"] == "Previous response owner lookup failed; retry later." + + +@pytest.mark.asyncio +async def test_v1_responses_without_http_bridge_websocket_upstream_rejects_oversized_response_create_before_connect( + async_client, + monkeypatch, +): + email = "stream-ws-oversized@example.com" + raw_account_id = "acc_stream_ws_oversized" + auth_json = _make_auth_json(raw_account_id, email) + files = {"auth_json": ("auth.json", json.dumps(auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=files) + assert response.status_code == 200 + + app_settings = Settings( + http_responses_session_bridge_enabled=False, + proxy_request_budget_seconds=75.0, + compact_request_budget_seconds=75.0, + transcription_request_budget_seconds=120.0, + upstream_compact_timeout_seconds=None, + upstream_stream_transport="auto", + log_proxy_request_payload=False, + log_proxy_request_shape=False, + log_proxy_request_shape_raw_cache_key=False, + log_proxy_service_tier_trace=False, + stream_idle_timeout_seconds=300.0, + proxy_token_refresh_limit=32, + proxy_upstream_websocket_connect_limit=64, + proxy_response_create_limit=64, + proxy_compact_response_create_limit=16, + ) + dashboard_settings = DashboardSettings( + id=1, + sticky_threads_enabled=False, + upstream_stream_transport="websocket", + prefer_earlier_reset_accounts=False, + routing_strategy="usage_weighted", + openai_cache_affinity_max_age_seconds=300, + import_without_overwrite=False, + totp_required_on_login=False, + api_key_auth_enabled=False, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + sticky_reallocation_budget_threshold_pct=95.0, + ) + + class _SettingsCache: + async def get(self) -> DashboardSettings: + return dashboard_settings + + class _CoreProxySettings: + upstream_base_url = "https://chatgpt.com/backend-api" + upstream_stream_transport = "default" + upstream_connect_timeout_seconds = 8.0 + stream_idle_timeout_seconds = 45.0 + max_sse_event_bytes = 1024 + image_inline_fetch_enabled = False + log_upstream_request_payload = False + proxy_request_budget_seconds = 75.0 + log_upstream_request_summary = False + + async def fail_open_upstream_websocket(**kwargs): + del kwargs + raise AssertionError("oversized response.create must fail before upstream websocket connect") + + monkeypatch.setattr(proxy_module, "get_settings", lambda: app_settings) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _SettingsCache()) + monkeypatch.setattr(proxy_client_module, "get_settings", lambda: _CoreProxySettings()) + monkeypatch.setattr(proxy_client_module, "_UPSTREAM_RESPONSE_CREATE_WARN_BYTES", 64, raising=False) + monkeypatch.setattr(proxy_client_module, "_UPSTREAM_RESPONSE_CREATE_MAX_BYTES", 128, raising=False) + monkeypatch.setattr(proxy_client_module, "_open_upstream_websocket", fail_open_upstream_websocket) + + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "x" * 256}]}], + }, + ) + + assert response.status_code == 413 + payload = response.json() + assert payload["error"]["code"] == "payload_too_large" + assert payload["error"]["type"] == "invalid_request_error" + assert payload["error"]["param"] == "input" + assert "response.create is too large for upstream websocket" in payload["error"]["message"] + + +@pytest.mark.asyncio +async def test_v1_responses_without_http_bridge_websocket_upstream_slims_historical_inline_artifacts_and_succeeds( + async_client, + monkeypatch, +): + email = "stream-ws-slim@example.com" + raw_account_id = "acc_stream_ws_slim" + auth_json = _make_auth_json(raw_account_id, email) + files = {"auth_json": ("auth.json", json.dumps(auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=files) + assert response.status_code == 200 + + app_settings = Settings( + http_responses_session_bridge_enabled=False, + proxy_request_budget_seconds=75.0, + compact_request_budget_seconds=75.0, + transcription_request_budget_seconds=120.0, + upstream_compact_timeout_seconds=None, + upstream_stream_transport="auto", + log_proxy_request_payload=False, + log_proxy_request_shape=False, + log_proxy_request_shape_raw_cache_key=False, + log_proxy_service_tier_trace=False, + stream_idle_timeout_seconds=300.0, + proxy_token_refresh_limit=32, + proxy_upstream_websocket_connect_limit=64, + proxy_response_create_limit=64, + proxy_compact_response_create_limit=16, + ) + dashboard_settings = DashboardSettings( + id=1, + sticky_threads_enabled=False, + upstream_stream_transport="websocket", + prefer_earlier_reset_accounts=False, + routing_strategy="usage_weighted", + openai_cache_affinity_max_age_seconds=300, + import_without_overwrite=False, + totp_required_on_login=False, + api_key_auth_enabled=False, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + sticky_reallocation_budget_threshold_pct=95.0, + ) + + class _SettingsCache: + async def get(self) -> DashboardSettings: + return dashboard_settings + + class _CoreProxySettings: + upstream_base_url = "https://chatgpt.com/backend-api" + upstream_stream_transport = "default" + upstream_connect_timeout_seconds = 8.0 + stream_idle_timeout_seconds = 45.0 + max_sse_event_bytes = 1024 + image_inline_fetch_enabled = False + log_upstream_request_payload = False + proxy_request_budget_seconds = 75.0 + log_upstream_request_summary = False + + fake_upstream = _FakeUpstreamWebSocket( + [ + SimpleNamespace( + type=proxy_client_module.aiohttp.WSMsgType.TEXT, + data=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_http_stream_slim", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + extra=None, + ), + SimpleNamespace( + type=proxy_client_module.aiohttp.WSMsgType.TEXT, + data=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_http_stream_slim", + "object": "response", + "status": "completed", + "usage": {"input_tokens": 3, "output_tokens": 1, "total_tokens": 4}, + }, + }, + separators=(",", ":"), + ), + extra=None, + ), + ] + ) + + async def fake_open_upstream_websocket(**kwargs): + del kwargs + return fake_upstream, fake_upstream + + monkeypatch.setattr(proxy_module, "get_settings", lambda: app_settings) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _SettingsCache()) + monkeypatch.setattr(proxy_client_module, "get_settings", lambda: _CoreProxySettings()) + monkeypatch.setattr(proxy_client_module, "_UPSTREAM_RESPONSE_CREATE_WARN_BYTES", 64, raising=False) + monkeypatch.setattr(proxy_client_module, "_UPSTREAM_RESPONSE_CREATE_MAX_BYTES", 640, raising=False) + monkeypatch.setattr(proxy_client_module, "_open_upstream_websocket", fake_open_upstream_websocket) + + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": [ + {"role": "user", "content": [{"type": "input_text", "text": "old turn"}]}, + { + "type": "function_call_output", + "call_id": "call_1", + "output": "data:image/png;base64," + ("A" * 1200), + }, + { + "role": "assistant", + "content": [ + { + "type": "input_image", + "image_url": "data:image/png;base64," + ("B" * 1200), + } + ], + }, + {"role": "user", "content": [{"type": "input_text", "text": "latest turn"}]}, + ], + }, + ) + + assert response.status_code == 200 + assert response.json()["id"] == "resp_http_stream_slim" + assert fake_upstream.sent_json + request_input = fake_upstream.sent_json[0]["input"] + assert isinstance(request_input, list) + assert request_input[1]["output"] == proxy_module._RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE.format( + bytes=len(("data:image/png;base64," + ("A" * 1200)).encode("utf-8")) + ) + assert request_input[2]["content"] == [ + {"type": "input_text", "text": proxy_module._RESPONSE_CREATE_IMAGE_OMISSION_NOTICE} + ] + + @pytest.mark.asyncio async def test_v1_responses_accepts_messages(async_client): payload = { diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index ccdde260..d8872048 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1397,6 +1397,123 @@ async def fake_connect_proxy_websocket( assert first_upstream.closed is True +def test_backend_responses_websocket_connect_failure_masks_previous_response_not_found( + app_instance, + monkeypatch, +): + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_select_websocket_connect_account( + self, + deadline, + *, + sticky_key, + sticky_kind, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + reallocate_sticky, + sticky_max_age_seconds, + exclude_account_ids, + preferred_account_id, + require_preferred_account, + ): + del ( + self, + deadline, + sticky_key, + sticky_kind, + prefer_earlier_reset, + routing_strategy, + model, + api_key, + client_send_lock, + websocket, + reallocate_sticky, + sticky_max_age_seconds, + exclude_account_ids, + preferred_account_id, + require_preferred_account, + ) + assert request_state.previous_response_id == "resp_ws_prev_anchor" + return SimpleNamespace(id="acct_ws_prev_connect_failure") + + async def fake_try_open_websocket_connect_attempt( + self, + account, + headers, + *, + deadline, + api_key, + request_state, + client_send_lock, + websocket, + ): + del self, account, headers, deadline, api_key, request_state, client_send_lock, websocket + payload = proxy_module.openai_error( + "previous_response_not_found", + "Previous response with id 'resp_ws_prev_anchor' not found.", + error_type="invalid_request_error", + ) + payload["error"]["param"] = "previous_response_id" + raise proxy_module.ProxyResponseError(400, payload) + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr( + proxy_module.ProxyService, + "_select_websocket_connect_account", + fake_select_websocket_connect_account, + ) + monkeypatch.setattr( + proxy_module.ProxyService, + "_try_open_websocket_connect_attempt", + fake_try_open_websocket_connect_attempt, + ) + monkeypatch.setattr( + proxy_module.ProxyService, + "_decide_websocket_failover_action", + lambda *args, **kwargs: asyncio.sleep(0, result="surface"), + ) + monkeypatch.setattr( + proxy_module.ProxyService, + "_release_websocket_reservation", + lambda *args, **kwargs: asyncio.sleep(0), + ) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "previous_response_id": "resp_ws_prev_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + event = json.loads(websocket.receive_text()) + + assert event["type"] == "error" + assert event["status"] == 502 + assert event["error"]["code"] == "stream_incomplete" + assert event["error"]["message"] == "Upstream websocket closed before response.completed" + + @pytest.mark.parametrize("frame", ['{"type":"response.create"', "[]"]) def test_backend_responses_websocket_rejects_malformed_first_frame_as_invalid_payload(app_instance, monkeypatch, frame): called = {"connect": False} @@ -2599,6 +2716,128 @@ async def fake_write_request_log(self, **kwargs): assert json.loads(first_upstream.sent_text[0]) == json.loads(second_upstream.sent_text[0]) +def test_backend_responses_websocket_previous_response_usage_limit_returns_upstream_unavailable( + app_instance, + monkeypatch, +): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 429, + "error": { + "type": "invalid_request_error", + "code": "usage_limit_reached", + "message": "The usage limit has been reached", + }, + }, + separators=(",", ":"), + ), + ) + ] + ) + connect_models: list[str | None] = [] + captured_preferred_accounts: list[str | None] = [] + handled_error_codes: list[str] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, previous_response_id, api_key, session_id, surface + return "acct_ws_proxy_owner" + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + api_key, + client_send_lock, + websocket, + ) + connect_models.append(model) + captured_preferred_accounts.append(request_state.preferred_account_id) + return SimpleNamespace(id="acct_ws_proxy_owner"), first_upstream + + async def fake_handle_stream_error(self, account, error, code): + del self, account, error + handled_error_codes.append(code) + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "previous_response_id": "resp_ws_prev_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + event = json.loads(websocket.receive_text()) + + assert event["type"] == "response.failed" + assert event["response"]["error"]["code"] == "upstream_unavailable" + assert event["response"]["error"]["message"] == "Previous response owner account is unavailable; retry later." + assert connect_models == ["gpt-5.1"] + assert captured_preferred_accounts == ["acct_ws_proxy_owner"] + assert handled_error_codes == ["usage_limit_reached"] + + def test_backend_responses_websocket_transparent_replay_emits_no_accounts_when_reconnect_fails( app_instance, monkeypatch, diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 70da1d61..2882f5cb 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -125,6 +125,10 @@ def test_prometheus_metrics_defined_when_dependency_available(monkeypatch: pytes assert prometheus_module.active_connections.name == "codex_lb_active_connections" assert prometheus_module.bridge_instance_mismatch_total.name == "codex_lb_bridge_instance_mismatch_total" assert prometheus_module.bridge_instance_mismatch_total.labelnames == ("outcome",) + assert prometheus_module.continuity_owner_resolution_total.name == "codex_lb_continuity_owner_resolution_total" + assert prometheus_module.continuity_owner_resolution_total.labelnames == ("surface", "source", "outcome") + assert prometheus_module.continuity_fail_closed_total.name == "codex_lb_continuity_fail_closed_total" + assert prometheus_module.continuity_fail_closed_total.labelnames == ("surface", "reason") @pytest.mark.asyncio diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index 41567bfd..b108e993 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import time from collections import deque from contextlib import nullcontext @@ -9,6 +10,7 @@ from typing import Any, cast from unittest.mock import AsyncMock +import aiohttp import anyio import pytest from fastapi import WebSocket @@ -1919,6 +1921,7 @@ async def fake_get_or_create_http_bridge_session(*args: object, **kwargs: object previous_response_id="resp_prev_owner_lookup", api_key=None, session_id="turn_http_owner", + surface="http_bridge", ) assert captured_preferred["value"] == "acc-owner-from-logs" @@ -2040,6 +2043,7 @@ async def fake_stream_http_bridge_session_events( previous_response_id="resp_prev_owner_lookup", api_key=None, session_id="http_turn_generated", + surface="http_bridge", ) assert request_state.session_id == "http_turn_generated" assert request_state.preferred_account_id == "acc-owner-from-turn-state" @@ -2494,6 +2498,233 @@ async def fake_stream_http_bridge_session_events( reserve_retry.assert_awaited_once() +@pytest.mark.asyncio +async def test_http_bridge_local_owner_account_id_records_resolution_source( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + + class _ObservedCounter: + def __init__(self) -> None: + self.samples: list[dict[str, object]] = [] + + def labels(self, **labels: str): + sample: dict[str, object] = {"labels": dict(labels), "value": 0.0} + self.samples.append(sample) + + def inc(amount: float = 1.0) -> None: + sample["value"] = cast(float, sample["value"]) + amount + + return SimpleNamespace(inc=inc) + + counter = _ObservedCounter() + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_owner_resolution_total", counter, raising=False) + caplog.set_level(logging.INFO, logger="app.modules.proxy.service") + + key = proxy_service._HTTPBridgeSessionKey("prompt_cache", "bridge-prev-rebind", None) + session = proxy_service._HTTPBridgeSession( + key=key, + headers={}, + affinity=proxy_service._AffinityPolicy( + key="bridge-prev-rebind", kind=proxy_service.StickySessionKind.PROMPT_CACHE + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=1.0, + idle_ttl_seconds=120.0, + ) + service._http_bridge_sessions[key] = session + + owner = await service._http_bridge_local_owner_account_id( + key=key, + incoming_turn_state=None, + previous_response_id="resp_prev_local_owner_metric", + api_key=None, + ) + + assert owner == "acc-1" + assert "continuity_owner_resolution surface=http_bridge source=local_bridge_session outcome=hit" in caplog.text + assert "resp_prev_local_owner_metric" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "http_bridge", "source": "local_bridge_session", "outcome": "hit"}, + "value": 1.0, + } + ] + + +@pytest.mark.asyncio +async def test_stream_via_http_bridge_reacquires_api_key_reservation_after_owner_forward_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + api_key = _make_api_key(key_id="key-1", assigned_account_ids=[]) + initial_reservation = proxy_service.ApiKeyUsageReservationData( + reservation_id="resv-initial", + key_id=api_key.id, + model="gpt-5.4", + ) + retried_reservation = proxy_service.ApiKeyUsageReservationData( + reservation_id="resv-retry", + key_id=api_key.id, + model="gpt-5.4", + ) + payload = proxy_service.ResponsesRequest.model_validate( + { + "model": "gpt-5.4", + "instructions": "hi", + "input": "hello", + "previous_response_id": "resp_prev_1", + } + ) + + request_state_initial = proxy_service._WebSocketRequestState( + request_id="req-initial", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=initial_reservation, + started_at=1.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + request_state_initial.request_stage = "follow_up" + request_state_initial.preferred_account_id = "acc-1" + request_state_retry = proxy_service._WebSocketRequestState( + request_id="req-retry", + model="gpt-5.4", + service_tier=None, + reasoning_effort=None, + api_key_reservation=retried_reservation, + started_at=2.0, + event_queue=asyncio.Queue(), + transport="http", + previous_response_id="resp_prev_1", + ) + + prepare_reservations: list[proxy_service.ApiKeyUsageReservationData | None] = [] + + def fake_prepare( + prepared_payload: proxy_service.ResponsesRequest, + _headers: dict[str, str] | Any, + *, + api_key: proxy_service.ApiKeyData | None, + api_key_reservation: proxy_service.ApiKeyUsageReservationData | None, + request_id: str, + ) -> tuple[proxy_service._WebSocketRequestState, str]: + del prepared_payload, api_key, request_id + prepare_reservations.append(api_key_reservation) + if len(prepare_reservations) == 1: + return request_state_initial, '{"type":"response.create","request":"initial"}' + return request_state_retry, '{"type":"response.create","request":"retry"}' + + owner_forward = proxy_service._HTTPBridgeOwnerForward( + owner_instance="instance-b", + owner_endpoint="http://instance-b", + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", api_key.id), + ) + session_retry = proxy_service._HTTPBridgeSession( + key=proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", api_key.id), + headers={"x-codex-session-id": "sid-123"}, + affinity=proxy_service._AffinityPolicy( + key="sid-123", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + ) + + submitted_reservations: list[proxy_service.ApiKeyUsageReservationData | None] = [] + + async def fake_forward_http_bridge_request_to_owner(**kwargs: object): + del kwargs + raise ProxyResponseError(400, proxy_service.openai_error("previous_response_not_found", "missing")) + yield "" + + async def fake_submit_http_bridge_request( + _session: proxy_service._HTTPBridgeSession, + *, + request_state: proxy_service._WebSocketRequestState, + text_data: str, + queue_limit: int, + ) -> None: + del _session, text_data, queue_limit + submitted_reservations.append(request_state.api_key_reservation) + event_queue = request_state.event_queue + assert event_queue is not None + await event_queue.put('data: {"type":"response.completed"}\n\n') + await event_queue.put(None) + + reserve_retry = AsyncMock(return_value=retried_reservation) + get_or_create = AsyncMock(side_effect=[owner_forward, session_retry]) + + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=1800, + http_responses_session_bridge_prompt_cache_idle_ttl_seconds=3600, + http_responses_session_bridge_gateway_safe_mode=False, + ) + ) + ), + ), + ) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(service._durable_bridge, "lookup_request_targets", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_resolve_websocket_previous_response_owner", AsyncMock(return_value="acc-1")) + monkeypatch.setattr(service, "_prepare_http_bridge_request", fake_prepare) + monkeypatch.setattr(service, "_get_or_create_http_bridge_session", get_or_create) + monkeypatch.setattr(service, "_forward_http_bridge_request_to_owner", fake_forward_http_bridge_request_to_owner) + monkeypatch.setattr(service, "_submit_http_bridge_request", fake_submit_http_bridge_request) + monkeypatch.setattr(service, "_detach_http_bridge_request", AsyncMock()) + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_retry) + + chunks = [ + chunk + async for chunk in service._stream_via_http_bridge( + payload, + headers={"x-codex-session-id": "sid-123"}, + codex_session_affinity=True, + propagate_http_errors=False, + openai_cache_affinity=False, + api_key=api_key, + api_key_reservation=initial_reservation, + suppress_text_done_events=False, + idle_ttl_seconds=120.0, + codex_idle_ttl_seconds=900.0, + max_sessions=8, + queue_limit=4, + ) + ] + + assert chunks == ['data: {"type":"response.completed"}\n\n'] + assert prepare_reservations == [initial_reservation, retried_reservation] + assert submitted_reservations == [retried_reservation] + reserve_retry.assert_awaited_once() + + @pytest.mark.asyncio async def test_stream_via_http_bridge_local_previous_response_rebind_fails_existing_pending_requests( monkeypatch: pytest.MonkeyPatch, @@ -3407,6 +3638,87 @@ async def test_get_or_create_http_bridge_session_recovers_from_previous_response assert "http_turn_missing_alias" in recovered_session.downstream_turn_state_aliases +@pytest.mark.asyncio +async def test_get_or_create_http_bridge_session_drops_stale_previous_response_mapping( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("turn_state_header", "http_turn_missing_alias", None) + stale_key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-stale", None) + created_key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-new", None) + stale_session = proxy_service._HTTPBridgeSession( + key=stale_key, + headers={"x-codex-session-id": "sid-stale"}, + affinity=proxy_service._AffinityPolicy( + key="sid-stale", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-1", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=2.0, + idle_ttl_seconds=120.0, + closed=True, + previous_response_ids={"resp_prev_1"}, + ) + created_session = proxy_service._HTTPBridgeSession( + key=created_key, + headers={"x-codex-session-id": "sid-new"}, + affinity=proxy_service._AffinityPolicy( + key="sid-new", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + request_model="gpt-5.4", + account=cast(Any, SimpleNamespace(id="acc-2", status=AccountStatus.ACTIVE)), + upstream=cast(UpstreamResponsesWebSocket, SimpleNamespace(close=AsyncMock())), + upstream_control=proxy_service._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=0, + last_used_at=3.0, + idle_ttl_seconds=120.0, + ) + alias_key = proxy_service._http_bridge_previous_response_alias_key("resp_prev_1", None) + service._http_bridge_sessions[stale_key] = stale_session + service._http_bridge_previous_response_index[alias_key] = stale_key + monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) + monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) + monkeypatch.setattr( + proxy_service, + "_active_http_bridge_instance_ring", + AsyncMock(return_value=("instance-a", ["instance-a", "instance-b"])), + ) + + resolved = await service._get_or_create_http_bridge_session( + key, + headers={ + "x-codex-turn-state": "http_turn_missing_alias", + "x-codex-session-id": "sid-new", + }, + affinity=proxy_service._AffinityPolicy( + key="sid-new", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + previous_response_id="resp_prev_1", + ) + + assert resolved is created_session + assert alias_key not in service._http_bridge_previous_response_index + + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_allows_local_rebind_for_previous_response_recovery( monkeypatch: pytest.MonkeyPatch, @@ -4209,7 +4521,7 @@ async def test_claim_durable_http_bridge_session_rejects_remote_owner_without_ta @pytest.mark.asyncio -async def test_get_or_create_http_bridge_session_allows_local_bootstrap_when_ring_lookup_fails( +async def test_get_or_create_http_bridge_session_hard_continuity_lookup_failure_fails_closed( monkeypatch: pytest.MonkeyPatch, ) -> None: service = proxy_service.ProxyService(cast(Any, nullcontext())) @@ -4247,20 +4559,25 @@ async def test_get_or_create_http_bridge_session_allows_local_bootstrap_when_rin AsyncMock(side_effect=ConnectionRefusedError("db unavailable")), ) - resolved = await service._get_or_create_http_bridge_session( - key, - headers={"x-codex-session-id": "sid-123"}, - affinity=proxy_service._AffinityPolicy( - key="sid-123", - kind=proxy_service.StickySessionKind.CODEX_SESSION, - ), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) + with pytest.raises(ProxyResponseError) as exc_info: + await service._get_or_create_http_bridge_session( + key, + headers={"x-codex-session-id": "sid-123"}, + affinity=proxy_service._AffinityPolicy( + key="sid-123", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) - assert resolved is created_session + service._create_http_bridge_session.assert_not_awaited() + exc = exc_info.value + assert exc.status_code == 502 + assert exc.payload["error"]["code"] == "upstream_unavailable" + assert exc.payload["error"]["message"] == "HTTP bridge owner metadata unavailable; retry later." @pytest.mark.asyncio @@ -4422,6 +4739,70 @@ async def test_get_or_create_http_bridge_session_soft_mismatch_rebinds_locally( assert resolved is created_session +@pytest.mark.asyncio +async def test_create_http_bridge_session_fails_closed_when_previous_response_owner_is_unavailable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = proxy_service.ProxyService(cast(Any, nullcontext())) + key = proxy_service._HTTPBridgeSessionKey("session_header", "sid-123", None) + preferred_account = cast(Any, SimpleNamespace(id="acc-owner", status=AccountStatus.ACTIVE)) + fallback_account = cast(Any, SimpleNamespace(id="acc-fallback", status=AccountStatus.ACTIVE)) + select_account = AsyncMock( + side_effect=[ + proxy_service.AccountSelection(account=preferred_account, error_message=None, error_code=None), + proxy_service.AccountSelection(account=fallback_account, error_message=None, error_code=None), + ] + ) + ensure_fresh = AsyncMock(side_effect=[aiohttp.ClientError("preferred connect failed"), fallback_account]) + open_upstream = AsyncMock( + return_value=cast(Any, SimpleNamespace(response_header=lambda _name: None, close=AsyncMock())) + ) + + async def fake_relay(_session: proxy_service._HTTPBridgeSession) -> None: + return None + + monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) + monkeypatch.setattr( + proxy_service, + "get_settings_cache", + lambda: cast( + Any, + SimpleNamespace( + get=AsyncMock( + return_value=SimpleNamespace( + prefer_earlier_reset_accounts=False, + routing_strategy=None, + ) + ) + ), + ), + ) + monkeypatch.setattr(service, "_select_account_with_budget_compatible", select_account) + monkeypatch.setattr(service, "_ensure_fresh_with_budget", ensure_fresh) + monkeypatch.setattr(service, "_open_upstream_websocket_with_budget", open_upstream) + monkeypatch.setattr(service, "_relay_http_bridge_upstream_messages", fake_relay) + + with pytest.raises(ProxyResponseError) as exc_info: + await service._create_http_bridge_session( + key, + headers={"x-codex-session-id": "sid-123"}, + affinity=proxy_service._AffinityPolicy( + key="sid-123", + kind=proxy_service.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + request_stage="reattach", + preferred_account_id="acc-owner", + require_preferred_account=True, + ) + + assert exc_info.value.status_code == 502 + assert exc_info.value.payload["error"]["code"] == "upstream_unavailable" + open_upstream.assert_not_awaited() + + @pytest.mark.asyncio async def test_get_or_create_http_bridge_session_prompt_cache_mismatch_stays_local_when_gateway_safe_mode_disabled( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index a242f3c8..35133cc7 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -152,6 +152,20 @@ async def list_active(self, stale_threshold_seconds: int = 120, *, require_endpo return list(self.members) +class _ObservedCounter: + def __init__(self) -> None: + self.samples: list[dict[str, object]] = [] + + def labels(self, **labels: str): + sample: dict[str, object] = {"labels": dict(labels), "value": 0.0} + self.samples.append(sample) + + def inc(amount: float = 1.0) -> None: + sample["value"] = cast(float, sample["value"]) + amount + + return SimpleNamespace(inc=inc) + + @pytest.mark.anyio async def test_owner_instance_uses_rendezvous_hash() -> None: settings = Settings( @@ -202,6 +216,79 @@ async def test_ring_raises_on_db_error() -> None: ) +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_records_request_log_source(monkeypatch, caplog): + request_logs = _RequestLogsRecorder() + request_logs.response_owner_by_id[("resp_prev_owner_metric", None, "turn_scope_owner_metric")] = "acc_owner_prev" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_owner_resolution_total", counter, raising=False) + caplog.set_level(logging.INFO, logger="app.modules.proxy.service") + + owner = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_owner_metric", + api_key=None, + session_id="turn_scope_owner_metric", + surface="websocket", + ) + + assert owner == "acc_owner_prev" + assert "continuity_owner_resolution surface=websocket source=request_logs outcome=hit" in caplog.text + assert "previous_response_id=sha256:" in caplog.text + assert "session_id=sha256:" in caplog.text + assert "resp_prev_owner_metric" not in caplog.text + assert "turn_scope_owner_metric" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "websocket", "source": "request_logs", "outcome": "hit"}, + "value": 1.0, + } + ] + + +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_fail_closed_records_metric_and_log(monkeypatch, caplog): + request_logs = _RequestLogsRecorder() + request_logs.lookup_error = RuntimeError("lookup unavailable") + service = proxy_service.ProxyService(_repo_factory(request_logs)) + resolution_counter = _ObservedCounter() + fail_closed_counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_owner_resolution_total", resolution_counter, raising=False) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", fail_closed_counter, raising=False) + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") + + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_owner_metric_fail", + api_key=None, + session_id="turn_scope_owner_metric_fail", + surface="websocket", + ) + + assert exc_info.value.status_code == 502 + assert exc_info.value.payload["error"]["code"] == "upstream_unavailable" + assert "continuity_owner_resolution surface=websocket source=request_logs outcome=fail_closed" in caplog.text + assert "continuity_fail_closed surface=websocket reason=owner_lookup_failed" in caplog.text + assert "resp_prev_owner_metric_fail" not in caplog.text + assert "turn_scope_owner_metric_fail" not in caplog.text + assert resolution_counter.samples == [ + { + "labels": {"surface": "websocket", "source": "request_logs", "outcome": "fail_closed"}, + "value": 1.0, + } + ] + assert fail_closed_counter.samples == [ + { + "labels": {"surface": "websocket", "reason": "owner_lookup_failed"}, + "value": 1.0, + } + ] + + def test_build_upstream_websocket_headers_strip_accept_and_content_type_case_insensitively(): headers = proxy_module._build_upstream_websocket_headers( { @@ -492,6 +579,7 @@ def __init__(self) -> None: self.latest_response_by_session: dict[tuple[str, str | None], str] = {} self.lookup_calls: list[tuple[str, str | None, str | None]] = [] self.session_lookup_calls: list[tuple[str, str | None]] = [] + self.lookup_error: Exception | None = None async def add_log(self, **kwargs: object) -> None: self.calls.append(dict(kwargs)) @@ -505,6 +593,8 @@ async def find_latest_account_id_for_response_id( ) -> str | None: key = (response_id, api_key_id, session_id) self.lookup_calls.append(key) + if self.lookup_error is not None: + raise self.lookup_error owner = self.response_owner_by_id.get(key) if owner is not None: return owner @@ -1562,6 +1652,133 @@ class Settings: ] +@pytest.mark.asyncio +async def test_stream_responses_websocket_rejects_oversized_response_create_before_connect(monkeypatch): + class Settings: + upstream_base_url = "https://chatgpt.com/backend-api" + upstream_stream_transport = "websocket" + upstream_connect_timeout_seconds = 8.0 + stream_idle_timeout_seconds = 45.0 + max_sse_event_bytes = 1024 + image_inline_fetch_enabled = False + log_upstream_request_payload = False + proxy_request_budget_seconds = 75.0 + log_upstream_request_summary = False + + monkeypatch.setattr(proxy_module, "get_settings", lambda: Settings()) + monkeypatch.setattr(proxy_module, "_maybe_log_upstream_request_start", lambda **kwargs: None) + monkeypatch.setattr(proxy_module, "_maybe_log_upstream_request_complete", lambda **kwargs: None) + monkeypatch.setattr(proxy_module, "_UPSTREAM_RESPONSE_CREATE_WARN_BYTES", 64, raising=False) + monkeypatch.setattr(proxy_module, "_UPSTREAM_RESPONSE_CREATE_MAX_BYTES", 128, raising=False) + + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.1", + "instructions": "hi", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "x" * 256}]}], + } + ) + session = _WsSession(_WsResponse([])) + + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + _ = [ + event + async for event in proxy_module.stream_responses( + payload, + headers={}, + access_token="token", + account_id="acc_1", + session=cast(proxy_module.aiohttp.ClientSession, session), + raise_for_status=True, + ) + ] + + assert exc_info.value.status_code == 413 + assert exc_info.value.payload["error"]["code"] == "payload_too_large" + assert session.ws_calls == [] + + +@pytest.mark.asyncio +async def test_stream_responses_websocket_slims_historical_inline_artifacts_and_succeeds(monkeypatch): + class Settings: + upstream_base_url = "https://chatgpt.com/backend-api" + upstream_stream_transport = "websocket" + upstream_connect_timeout_seconds = 8.0 + stream_idle_timeout_seconds = 45.0 + max_sse_event_bytes = 1024 + image_inline_fetch_enabled = False + log_upstream_request_payload = False + proxy_request_budget_seconds = 75.0 + log_upstream_request_summary = False + + monkeypatch.setattr(proxy_module, "get_settings", lambda: Settings()) + monkeypatch.setattr(proxy_module, "_maybe_log_upstream_request_start", lambda **kwargs: None) + monkeypatch.setattr(proxy_module, "_maybe_log_upstream_request_complete", lambda **kwargs: None) + monkeypatch.setattr(proxy_module, "_UPSTREAM_RESPONSE_CREATE_WARN_BYTES", 64, raising=False) + monkeypatch.setattr(proxy_module, "_UPSTREAM_RESPONSE_CREATE_MAX_BYTES", 640, raising=False) + + messages = [ + SimpleNamespace( + type=proxy_module.aiohttp.WSMsgType.TEXT, + data='{"type":"response.created","response":{"id":"resp_ws_slim","service_tier":"auto"}}', + ), + SimpleNamespace( + type=proxy_module.aiohttp.WSMsgType.TEXT, + data='{"type":"response.completed","response":{"id":"resp_ws_slim","service_tier":"default"}}', + ), + ] + websocket = _WsResponse(messages) + session = _WsSession(websocket) + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": [ + {"role": "user", "content": [{"type": "input_text", "text": "old turn"}]}, + { + "type": "function_call_output", + "call_id": "call_1", + "output": "data:image/png;base64," + ("A" * 1200), + }, + { + "role": "assistant", + "content": [ + { + "type": "input_image", + "image_url": "data:image/png;base64," + ("B" * 1200), + } + ], + }, + {"role": "user", "content": [{"type": "input_text", "text": "latest turn"}]}, + ], + } + ) + + events = [ + event + async for event in proxy_module.stream_responses( + payload, + headers={}, + access_token="token", + account_id="acc_1", + session=cast(proxy_module.aiohttp.ClientSession, session), + ) + ] + + assert len(events) == 2 + assert len(session.ws_calls) == 1 + request_payload = websocket.sent_json[0] + request_input = cast(list[dict[str, object]], request_payload["input"]) + assert request_input[1]["output"] == proxy_service._RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE.format( + bytes=len(("data:image/png;base64," + ("A" * 1200)).encode("utf-8")) + ) + assistant_item = request_input[2] + assert assistant_item["content"] == [ + {"type": "input_text", "text": proxy_service._RESPONSE_CREATE_IMAGE_OMISSION_NOTICE} + ] + assert request_input[-1] == {"role": "user", "content": [{"type": "input_text", "text": "latest turn"}]} + + @pytest.mark.asyncio async def test_stream_responses_websocket_forces_response_create_event_type(monkeypatch): class Settings: @@ -3879,6 +4096,79 @@ async def test_connect_proxy_websocket_fails_over_on_handshake_usage_limit_reach assert request_logs.calls == [] +@pytest.mark.asyncio +async def test_connect_proxy_websocket_previous_response_owner_usage_limit_fails_closed(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account_owner = _make_account("acc_ws_prev_owner") + account_other = _make_account("acc_ws_other") + seen_excluded_account_ids: list[set[str]] = [] + + async def select_account(deadline: float, **kwargs: object) -> AccountSelection: + del deadline + excluded_account_ids = kwargs.get("exclude_account_ids") + seen_excluded_account_ids.append(set(cast(set[str], excluded_account_ids))) + if len(seen_excluded_account_ids) == 1: + return AccountSelection(account=account_owner, error_message=None) + return AccountSelection(account=account_other, error_message=None) + + mark_rate_limit = AsyncMock() + first_handshake_error = proxy_module.ProxyResponseError( + 429, + openai_error("usage_limit_reached", "usage limit reached"), + ) + + monkeypatch.setattr(service, "_select_account_with_budget_compatible", select_account) + monkeypatch.setattr(service._load_balancer, "mark_rate_limit", mark_rate_limit) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account_owner)) + monkeypatch.setattr(service, "_open_upstream_websocket", AsyncMock(side_effect=[first_handshake_error])) + monkeypatch.setattr(service, "_release_websocket_reservation", AsyncMock()) + + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_owner_handshake_429", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_prev_anchor", + preferred_account_id=account_owner.id, + ) + + websocket_send = AsyncMock() + websocket = cast(WebSocket, SimpleNamespace(send_text=websocket_send)) + selected_account, selected_upstream = await service._connect_proxy_websocket( + {}, + sticky_key=None, + sticky_kind=None, + prefer_earlier_reset=False, + routing_strategy="usage_weighted", + model="gpt-5.1", + request_state=request_state, + api_key=None, + client_send_lock=anyio.Lock(), + websocket=websocket, + ) + + assert selected_account is None + assert selected_upstream is None + assert seen_excluded_account_ids == [set(), {account_owner.id}] + mark_rate_limit.assert_awaited_once() + mark_call = mark_rate_limit.await_args + assert mark_call is not None + assert mark_call.args[0] == account_owner + assert mark_call.args[1]["message"] == "usage limit reached" + await_args = websocket_send.await_args + assert await_args is not None + sent_payload = json.loads(await_args.args[0]) + assert sent_payload["status"] == 502 + assert sent_payload["error"]["code"] == "upstream_unavailable" + assert sent_payload["error"]["message"] == "Previous response owner account is unavailable; retry later." + assert request_logs.calls[0]["request_id"] == "ws_req_prev_owner_handshake_429" + assert request_logs.calls[0]["error_code"] == "upstream_unavailable" + assert request_logs.calls[0]["account_id"] == account_owner.id + + @pytest.mark.asyncio async def test_connect_proxy_websocket_surfaces_local_connect_overload_without_penalizing_account(monkeypatch): settings = _make_proxy_settings(log_proxy_service_tier_trace=False) @@ -4045,35 +4335,39 @@ async def test_select_websocket_connect_account_requires_preferred_account_for_p @pytest.mark.asyncio -async def test_connect_proxy_websocket_surfaces_forced_refresh_transport_error(monkeypatch): +async def test_select_websocket_connect_account_records_fail_closed_for_preferred_account_mismatch( + monkeypatch, + caplog, +): request_logs = _RequestLogsRecorder() service = proxy_service.ProxyService(_repo_factory(request_logs)) - account = _make_account("acc_ws_forced_refresh_timeout") - initial_error = proxy_module.ProxyResponseError(401, openai_error("invalid_api_key", "expired")) - release_reservation = AsyncMock() - - monkeypatch.setattr( - service._load_balancer, - "select_account", - AsyncMock(return_value=AccountSelection(account=account, error_message=None)), - ) - monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(side_effect=[account, asyncio.TimeoutError()])) - monkeypatch.setattr(service, "_open_upstream_websocket", AsyncMock(side_effect=initial_error)) - monkeypatch.setattr(service, "_release_websocket_reservation", release_reservation) - request_state = proxy_service._WebSocketRequestState( - request_id="ws_req_forced_refresh_timeout", + request_id="ws_req_prev_owner_mismatch_metric", model="gpt-5.1", service_tier=None, reasoning_effort=None, api_key_reservation=None, started_at=0.0, + previous_response_id="resp_prev_owner", + preferred_account_id="acc_owner", + session_id="turn_ws_owner_mismatch", + ) + selected_account = _make_account("acc_other") + counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", counter, raising=False) + monkeypatch.setattr( + service, + "_select_account_with_budget_compatible", + AsyncMock(return_value=AccountSelection(account=selected_account, error_message=None)), ) + monkeypatch.setattr(service, "_release_websocket_reservation", AsyncMock()) + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") websocket_send = AsyncMock() - websocket = cast(WebSocket, SimpleNamespace(send_text=websocket_send)) - selected_account, selected_upstream = await service._connect_proxy_websocket( - {}, + result = await service._select_websocket_connect_account( + 10_000.0, sticky_key=None, sticky_kind=None, prefer_earlier_reset=False, @@ -4082,42 +4376,174 @@ async def test_connect_proxy_websocket_surfaces_forced_refresh_transport_error(m request_state=request_state, api_key=None, client_send_lock=anyio.Lock(), - websocket=websocket, + websocket=cast(WebSocket, SimpleNamespace(send_text=websocket_send)), + reallocate_sticky=False, + sticky_max_age_seconds=None, + exclude_account_ids=set(), + preferred_account_id="acc_owner", + require_preferred_account=True, ) - assert selected_account is None - assert selected_upstream is None - release_reservation.assert_awaited_once_with(None) - await_args = websocket_send.await_args - assert await_args is not None - sent_payload = json.loads(await_args.args[0]) + assert result is None + sent_payload = json.loads(websocket_send.await_args.args[0]) assert sent_payload["status"] == 502 assert sent_payload["error"]["code"] == "upstream_unavailable" - assert sent_payload["error"]["message"] == "Request to upstream timed out" - assert request_logs.calls[0]["request_id"] == "ws_req_forced_refresh_timeout" - assert request_logs.calls[0]["error_code"] == "upstream_unavailable" - assert request_logs.calls[0]["transport"] == "websocket" + assert "continuity_fail_closed surface=websocket_connect reason=owner_account_unavailable" in caplog.text + assert "resp_prev_owner" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "websocket_connect", "reason": "owner_account_unavailable"}, + "value": 1.0, + } + ] @pytest.mark.asyncio -async def test_connect_proxy_websocket_maps_handshake_budget_exhaustion_to_timeout_error(monkeypatch): +async def test_select_websocket_connect_account_preferred_owner_missing_fails_closed( + monkeypatch, + caplog, +): request_logs = _RequestLogsRecorder() service = proxy_service.ProxyService(_repo_factory(request_logs)) - account = _make_account("acc_ws_handshake_budget") - handle_connect_error = AsyncMock() - - monkeypatch.setattr( - service._load_balancer, - "select_account", - AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_owner_missing", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_prev_owner", + preferred_account_id="acc_owner", + session_id="turn_ws_owner_missing", ) - monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account)) + counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", counter, raising=False) monkeypatch.setattr( service, - "_open_upstream_websocket", + "_select_account_with_budget_compatible", AsyncMock( - side_effect=proxy_module.ProxyResponseError( - 502, + return_value=AccountSelection( + account=None, + error_message="No active accounts available", + error_code="no_accounts", + ) + ), + ) + monkeypatch.setattr(service, "_release_websocket_reservation", AsyncMock()) + + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") + websocket_send = AsyncMock() + result = await service._select_websocket_connect_account( + 10_000.0, + sticky_key=None, + sticky_kind=None, + prefer_earlier_reset=False, + routing_strategy="usage_weighted", + model="gpt-5.1", + request_state=request_state, + api_key=None, + client_send_lock=anyio.Lock(), + websocket=cast(WebSocket, SimpleNamespace(send_text=websocket_send)), + reallocate_sticky=False, + sticky_max_age_seconds=None, + exclude_account_ids=set(), + preferred_account_id="acc_owner", + require_preferred_account=True, + ) + + assert result is None + sent_payload = json.loads(websocket_send.await_args.args[0]) + assert sent_payload["status"] == 502 + assert sent_payload["error"]["code"] == "upstream_unavailable" + assert sent_payload["error"]["message"] == "Previous response owner account is unavailable; retry later." + assert request_logs.calls[0]["account_id"] == "acc_owner" + assert request_logs.calls[0]["error_code"] == "upstream_unavailable" + assert "continuity_fail_closed surface=websocket_connect reason=owner_account_unavailable" in caplog.text + assert "resp_prev_owner" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "websocket_connect", "reason": "owner_account_unavailable"}, + "value": 1.0, + } + ] + + +@pytest.mark.asyncio +async def test_connect_proxy_websocket_surfaces_forced_refresh_transport_error(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account = _make_account("acc_ws_forced_refresh_timeout") + initial_error = proxy_module.ProxyResponseError(401, openai_error("invalid_api_key", "expired")) + release_reservation = AsyncMock() + + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + ) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(side_effect=[account, asyncio.TimeoutError()])) + monkeypatch.setattr(service, "_open_upstream_websocket", AsyncMock(side_effect=initial_error)) + monkeypatch.setattr(service, "_release_websocket_reservation", release_reservation) + + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_forced_refresh_timeout", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + ) + + websocket_send = AsyncMock() + websocket = cast(WebSocket, SimpleNamespace(send_text=websocket_send)) + selected_account, selected_upstream = await service._connect_proxy_websocket( + {}, + sticky_key=None, + sticky_kind=None, + prefer_earlier_reset=False, + routing_strategy="usage_weighted", + model="gpt-5.1", + request_state=request_state, + api_key=None, + client_send_lock=anyio.Lock(), + websocket=websocket, + ) + + assert selected_account is None + assert selected_upstream is None + release_reservation.assert_awaited_once_with(None) + await_args = websocket_send.await_args + assert await_args is not None + sent_payload = json.loads(await_args.args[0]) + assert sent_payload["status"] == 502 + assert sent_payload["error"]["code"] == "upstream_unavailable" + assert sent_payload["error"]["message"] == "Request to upstream timed out" + assert request_logs.calls[0]["request_id"] == "ws_req_forced_refresh_timeout" + assert request_logs.calls[0]["error_code"] == "upstream_unavailable" + assert request_logs.calls[0]["transport"] == "websocket" + + +@pytest.mark.asyncio +async def test_connect_proxy_websocket_maps_handshake_budget_exhaustion_to_timeout_error(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account = _make_account("acc_ws_handshake_budget") + handle_connect_error = AsyncMock() + + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + ) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account)) + monkeypatch.setattr( + service, + "_open_upstream_websocket", + AsyncMock( + side_effect=proxy_module.ProxyResponseError( + 502, openai_error("upstream_unavailable", "Proxy request budget exhausted"), ) ), @@ -4858,6 +5284,85 @@ async def test_process_upstream_websocket_text_transparently_retries_precreated_ assert list(pending_requests) == [] +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_maps_previous_response_usage_limit_to_upstream_unavailable( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_prev_quota_owner") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "previous_response_id": "resp_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + } + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_quota_unavailable", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps(request_payload, separators=(",", ":")), + previous_response_id="resp_anchor", + preferred_account_id=account.id, + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + upstream_payload = { + "type": "error", + "status": 429, + "error": { + "type": "invalid_request_error", + "code": "usage_limit_reached", + "message": "The usage limit has been reached", + }, + } + upstream_text = json.dumps(upstream_payload, separators=(",", ":")) + + downstream_text = await service._process_upstream_websocket_text( + upstream_text, + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"code":"upstream_unavailable"' in downstream_text + handle_stream_error.assert_awaited_once() + handle_call = handle_stream_error.await_args + assert handle_call is not None + assert handle_call.args[0] == account + assert handle_call.args[2] == "usage_limit_reached" + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.kwargs["event_type"] == "response.failed" + payload = finalize_call.kwargs["payload"] + assert isinstance(payload, dict) + response_payload = cast(dict[str, JsonValue], payload["response"]) + error_payload = cast(dict[str, JsonValue], response_payload["error"]) + assert error_payload["code"] == "upstream_unavailable" + assert error_payload["message"] == "Previous response owner account is unavailable; retry later." + assert upstream_control.reconnect_requested is False + assert upstream_control.suppress_downstream_event is False + assert upstream_control.replay_request_state is None + assert pending_request.replay_count == 0 + assert list(pending_requests) == [] + + @pytest.mark.asyncio async def test_proxy_responses_websocket_transparent_replay_preserves_sticky_thread_affinity( monkeypatch, @@ -5751,6 +6256,83 @@ async def fake_connect_proxy_websocket( assert [event["type"] for event in emitted_events] == ["response.created", "response.completed"] +@pytest.mark.asyncio +async def test_proxy_responses_websocket_previous_response_owner_lookup_failure_returns_upstream_unavailable( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + request_logs.lookup_error = RuntimeError("lookup unavailable") + service = proxy_service.ProxyService(_repo_factory(request_logs)) + + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + settings.stream_idle_timeout_seconds = 300.0 + settings.proxy_downstream_websocket_idle_timeout_seconds = 120.0 + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + + class _FakeDownstreamWebSocket: + def __init__(self, request_text: str) -> None: + self._request_text = request_text + self._request_sent = False + self._disconnect_sent = False + self._done = asyncio.Event() + self.sent_text: list[str] = [] + + async def receive(self) -> dict[str, object]: + if not self._request_sent: + self._request_sent = True + return {"type": "websocket.receive", "text": self._request_text} + if not self._disconnect_sent: + await self._done.wait() + self._disconnect_sent = True + return {"type": "websocket.disconnect"} + await asyncio.sleep(0) + return {"type": "websocket.disconnect"} + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + self._done.set() + + async def send_bytes(self, _data: bytes) -> None: + return None + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + del code, reason + self._done.set() + + async def fail_connect_proxy_websocket(*args, **kwargs): + del args, kwargs + raise AssertionError("owner lookup failure must fail before websocket connect") + + monkeypatch.setattr(proxy_service.ProxyService, "_connect_proxy_websocket", fail_connect_proxy_websocket) + + request_payload = { + "type": "response.create", + "model": "gpt-5.1", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "previous_response_id": "resp_prev_lookup_failure", + "stream": True, + } + downstream = _FakeDownstreamWebSocket(json.dumps(request_payload, separators=(",", ":"))) + + await service.proxy_responses_websocket( + cast(WebSocket, downstream), + {"session_id": "sid_owner_lookup_failure"}, + codex_session_affinity=False, + openai_cache_affinity=False, + api_key=None, + ) + + assert request_logs.lookup_calls == [("resp_prev_lookup_failure", None, "sid_owner_lookup_failure")] + assert len(downstream.sent_text) == 1 + payload = json.loads(downstream.sent_text[0]) + assert payload["type"] == "response.failed" + assert payload["response"]["status"] == "failed" + assert payload["response"]["error"]["code"] == "upstream_unavailable" + assert payload["response"]["error"]["message"] == "Previous response owner lookup failed; retry later." + + @pytest.mark.asyncio async def test_resolve_websocket_previous_response_owner_rechecks_same_scope_after_initial_miss(monkeypatch): request_logs = _RequestLogsRecorder() @@ -5762,12 +6344,14 @@ async def test_resolve_websocket_previous_response_owner_rechecks_same_scope_aft previous_response_id="resp_prev_missing", api_key=None, session_id="req_scope_1", + surface="websocket", ) clock["value"] = 102.0 owner_2 = await service._resolve_websocket_previous_response_owner( previous_response_id="resp_prev_missing", api_key=None, session_id="req_scope_1", + surface="websocket", ) request_logs.response_owner_by_id[("resp_prev_missing", None, None)] = "acc_owner_after_commit" clock["value"] = 103.0 @@ -5775,6 +6359,7 @@ async def test_resolve_websocket_previous_response_owner_rechecks_same_scope_aft previous_response_id="resp_prev_missing", api_key=None, session_id="req_scope_1", + surface="websocket", ) assert owner_1 is None @@ -5822,6 +6407,7 @@ async def test_resolve_websocket_previous_response_owner_miss_does_not_evict_kno previous_response_id="resp_prev_shared", api_key=api_key, session_id="req_terminal_a", + surface="websocket", ) assert owner == "acc_owner" @@ -5856,6 +6442,7 @@ async def test_resolve_websocket_previous_response_owner_prefers_scoped_lookup_o previous_response_id="resp_prev_shared", api_key=api_key, session_id="turn_scope_a", + surface="websocket", ) assert owner == "acc_owner_scoped" @@ -6239,6 +6826,212 @@ def test_maybe_rewrite_websocket_previous_response_not_found_leaves_non_previous assert rewritten_text == original_text +def test_sanitize_websocket_connect_failure_rewrites_previous_response_not_found(monkeypatch, caplog): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_connect_failure", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_prev_anchor", + ) + counter = _ObservedCounter() + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", counter, raising=False) + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") + original_payload = proxy_module.openai_error( + "previous_response_not_found", + "Previous response with id 'resp_prev_anchor' not found.", + error_type="invalid_request_error", + ) + original_payload["error"]["param"] = "previous_response_id" + + ( + rewritten_status, + rewritten_payload, + rewritten_error_code, + rewritten_error_message, + ) = proxy_service._sanitize_websocket_connect_failure( + request_state=request_state, + status_code=400, + payload=original_payload, + error_code="previous_response_not_found", + error_message="Previous response with id 'resp_prev_anchor' not found.", + ) + + assert rewritten_status == 502 + assert rewritten_payload["error"]["code"] == "stream_incomplete" + assert rewritten_payload["error"]["message"] == "Upstream websocket closed before response.completed" + assert rewritten_payload["error"]["type"] == "server_error" + assert rewritten_error_code == "stream_incomplete" + assert rewritten_error_message == "Upstream websocket closed before response.completed" + assert "continuity_fail_closed surface=websocket_connect reason=previous_response_not_found" in caplog.text + assert "resp_prev_anchor" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "websocket_connect", "reason": "previous_response_not_found"}, + "value": 1.0, + } + ] + + +def test_sanitize_websocket_connect_failure_rewrites_invalid_request_previous_response_not_found(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_connect_failure_invalid_request", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_prev_anchor", + ) + original_payload = proxy_module.openai_error( + "invalid_request_error", + "Previous response with id 'resp_prev_anchor' not found.", + error_type="invalid_request_error", + ) + original_payload["error"]["param"] = "previous_response_id" + + ( + rewritten_status, + rewritten_payload, + rewritten_error_code, + rewritten_error_message, + ) = proxy_service._sanitize_websocket_connect_failure( + request_state=request_state, + status_code=400, + payload=original_payload, + error_code="invalid_request_error", + error_message="Previous response with id 'resp_prev_anchor' not found.", + ) + + assert rewritten_status == 502 + assert rewritten_payload["error"]["code"] == "stream_incomplete" + assert rewritten_payload["error"]["message"] == "Upstream websocket closed before response.completed" + assert rewritten_payload["error"]["type"] == "server_error" + assert rewritten_error_code == "stream_incomplete" + assert rewritten_error_message == "Upstream websocket closed before response.completed" + + +@pytest.mark.asyncio +async def test_emit_websocket_connect_failure_releases_response_create_gate(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_connect_failure_gate", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + ) + response_create_gate = asyncio.Semaphore(1) + await response_create_gate.acquire() + request_state.response_create_gate_acquired = True + request_state.response_create_gate = response_create_gate + + release_reservation = AsyncMock() + monkeypatch.setattr(service, "_release_websocket_reservation", release_reservation) + + websocket_send = AsyncMock() + websocket = cast(WebSocket, SimpleNamespace(send_text=websocket_send)) + + await service._emit_websocket_connect_failure( + websocket, + client_send_lock=anyio.Lock(), + account_id="acc_connect_failure", + api_key=None, + request_state=request_state, + status_code=502, + payload=openai_error( + "upstream_unavailable", + "Previous response owner account is unavailable; retry later.", + error_type="server_error", + ), + error_code="upstream_unavailable", + error_message="Previous response owner account is unavailable; retry later.", + ) + + release_reservation.assert_awaited_once_with(None) + assert response_create_gate.locked() is False + assert request_state.awaiting_response_created is False + assert request_state.response_create_gate_acquired is False + assert request_state.response_create_gate is None + + +@pytest.mark.asyncio +async def test_emit_websocket_terminal_error_releases_response_create_gate(): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_terminal_failure_gate", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + ) + response_create_gate = asyncio.Semaphore(1) + await response_create_gate.acquire() + request_state.response_create_gate_acquired = True + request_state.response_create_gate = response_create_gate + + websocket_send = AsyncMock() + websocket = cast(WebSocket, SimpleNamespace(send_text=websocket_send)) + + await service._emit_websocket_terminal_error( + websocket, + client_send_lock=anyio.Lock(), + request_state=request_state, + error_code="upstream_unavailable", + error_message="Previous response owner lookup failed; retry later.", + ) + + assert response_create_gate.locked() is False + assert request_state.awaiting_response_created is False + assert request_state.response_create_gate_acquired is False + assert request_state.response_create_gate is None + + +def test_sanitize_websocket_connect_failure_leaves_unrelated_previous_response_error_unchanged(): + request_state = proxy_service._WebSocketRequestState( + request_id="ws_req_prev_connect_failure_unrelated", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_prev_anchor", + ) + original_payload = proxy_module.openai_error( + "invalid_request_error", + "Invalid request payload", + error_type="invalid_request_error", + ) + original_payload["error"]["param"] = "previous_response_id" + + ( + rewritten_status, + rewritten_payload, + rewritten_error_code, + rewritten_error_message, + ) = proxy_service._sanitize_websocket_connect_failure( + request_state=request_state, + status_code=400, + payload=original_payload, + error_code="invalid_request_error", + error_message="Invalid request payload", + ) + + assert rewritten_status == 400 + assert rewritten_payload == original_payload + assert rewritten_error_code == "invalid_request_error" + assert rewritten_error_message == "Invalid request payload" + + @pytest.mark.asyncio async def test_stream_responses_budget_exhaustion_emits_timeout_event(monkeypatch): settings = _make_proxy_settings(log_proxy_service_tier_trace=False) @@ -6626,6 +7419,189 @@ async def fake_stream(payload, headers, access_token, account_id, base_url=None, assert request_logs.calls[0]["error_code"] is None +@pytest.mark.asyncio +async def test_stream_previous_response_not_found_proxy_error_is_masked_to_stream_incomplete(monkeypatch, caplog): + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account = _make_account("acc_prev_missing_stream") + record_error = AsyncMock() + record_success = AsyncMock() + counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", counter, raising=False) + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=account, error_message=None)), + ) + monkeypatch.setattr(service._load_balancer, "record_error", record_error) + monkeypatch.setattr(service._load_balancer, "record_success", record_success) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account)) + monkeypatch.setattr(service, "_settle_stream_api_key_usage", AsyncMock(return_value=True)) + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **kwargs): + del payload, headers, access_token, account_id, base_url, raise_for_status, kwargs + error_payload = openai_error( + "previous_response_not_found", + "Previous response with id 'resp_prev_anchor' not found.", + error_type="invalid_request_error", + ) + error_payload["error"]["param"] = "previous_response_id" + raise proxy_module.ProxyResponseError(400, error_payload) + if False: + yield "" + + monkeypatch.setattr(proxy_service, "core_stream_responses", fake_stream) + + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.1", + "instructions": "hi", + "input": [], + "stream": True, + "previous_response_id": "resp_prev_anchor", + } + ) + + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") + chunks = [chunk async for chunk in service.stream_responses(payload, {"session_id": "sid-stream"})] + + event = json.loads(chunks[0].split("data: ", 1)[1]) + assert event["type"] == "response.failed" + assert event["response"]["error"]["code"] == "stream_incomplete" + assert event["response"]["error"]["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in chunks[0] + assert request_logs.lookup_calls == [("resp_prev_anchor", None, "sid-stream")] + assert request_logs.calls[0]["error_code"] == "stream_incomplete" + assert "continuity_fail_closed surface=http_stream reason=previous_response_not_found" in caplog.text + assert "resp_prev_anchor" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "http_stream", "reason": "previous_response_not_found"}, + "value": 1.0, + } + ] + record_error.assert_not_awaited() + record_success.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_stream_previous_response_owner_usage_limit_fails_closed(monkeypatch): + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + account_owner = _make_account("acc_prev_owner_stream") + account_other = _make_account("acc_other_stream") + request_logs.response_owner_by_id[("resp_prev_anchor", None, "sid-stream")] = account_owner.id + select_account_calls: list[dict[str, object]] = [] + handle_stream_error = AsyncMock(return_value={"failure_class": "rate_limit"}) + record_success = AsyncMock() + + async def fake_select_account(**kwargs): + select_account_calls.append(dict(kwargs)) + account_ids = kwargs.get("account_ids") + exclude_account_ids = set(cast(set[str], kwargs.get("exclude_account_ids", set()))) + if account_ids == {account_owner.id}: + return AccountSelection(account=account_owner, error_message=None) + if account_owner.id in exclude_account_ids: + return AccountSelection(account=account_other, error_message=None) + return AccountSelection(account=account_owner, error_message=None) + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + monkeypatch.setattr(service._load_balancer, "select_account", AsyncMock(side_effect=fake_select_account)) + monkeypatch.setattr(service._load_balancer, "record_success", record_success) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + monkeypatch.setattr(service, "_ensure_fresh", AsyncMock(return_value=account_owner)) + monkeypatch.setattr(service, "_settle_stream_api_key_usage", AsyncMock(return_value=True)) + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **kwargs): + del payload, headers, access_token, account_id, base_url, raise_for_status, kwargs + yield ( + 'data: {"type":"response.failed","response":{"id":"resp_owner_limit","status":"failed",' + '"error":{"code":"usage_limit_reached","message":"usage limit reached"},' + '"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}\n\n' + ) + + monkeypatch.setattr(proxy_service, "core_stream_responses", fake_stream) + + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.1", + "instructions": "hi", + "input": [], + "stream": True, + "previous_response_id": "resp_prev_anchor", + } + ) + + chunks = [chunk async for chunk in service.stream_responses(payload, {"session_id": "sid-stream"})] + + event = json.loads(chunks[0].split("data: ", 1)[1]) + assert event["type"] == "response.failed" + assert event["response"]["error"]["code"] == "upstream_unavailable" + assert event["response"]["error"]["message"] == "Previous response owner account is unavailable; retry later." + assert request_logs.lookup_calls == [("resp_prev_anchor", None, "sid-stream")] + assert request_logs.calls[0]["error_code"] == "upstream_unavailable" + assert request_logs.calls[0]["account_id"] == account_owner.id + assert len(select_account_calls) == 1 + assert select_account_calls[0]["account_ids"] == {account_owner.id} + handle_stream_error.assert_awaited_once() + assert handle_stream_error.await_args.args[0] == account_owner + assert handle_stream_error.await_args.args[2] == "usage_limit_reached" + record_success.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_stream_selection_fail_closed_records_owner_unavailable_metric(monkeypatch, caplog): + settings = _make_proxy_settings(log_proxy_service_tier_trace=False) + request_logs = _RequestLogsRecorder() + request_logs.response_owner_by_id[("resp_prev_anchor", None, "sid-stream")] = "acc_prev_owner_stream" + service = proxy_service.ProxyService(_repo_factory(request_logs)) + counter = _ObservedCounter() + + monkeypatch.setattr(proxy_service, "get_settings_cache", lambda: _SettingsCache(settings)) + monkeypatch.setattr(proxy_service, "get_settings", lambda: settings) + monkeypatch.setattr(proxy_service, "PROMETHEUS_AVAILABLE", True) + monkeypatch.setattr(proxy_service, "continuity_fail_closed_total", counter, raising=False) + monkeypatch.setattr( + service._load_balancer, + "select_account", + AsyncMock(return_value=AccountSelection(account=None, error_message="No active accounts available")), + ) + + payload = ResponsesRequest.model_validate( + { + "model": "gpt-5.1", + "instructions": "hi", + "input": [], + "stream": True, + "previous_response_id": "resp_prev_anchor", + } + ) + + caplog.set_level(logging.WARNING, logger="app.modules.proxy.service") + chunks = [chunk async for chunk in service.stream_responses(payload, {"session_id": "sid-stream"})] + + event = json.loads(chunks[0].split("data: ", 1)[1]) + assert event["type"] == "response.failed" + assert event["response"]["error"]["code"] == "upstream_unavailable" + assert event["response"]["error"]["message"] == "Previous response owner account is unavailable; retry later." + assert request_logs.calls[0]["account_id"] == "acc_prev_owner_stream" + assert "continuity_fail_closed surface=http_stream reason=owner_account_unavailable" in caplog.text + assert "resp_prev_anchor" not in caplog.text + assert counter.samples == [ + { + "labels": {"surface": "http_stream", "reason": "owner_account_unavailable"}, + "value": 1.0, + } + ] + + @pytest.mark.asyncio async def test_compact_responses_budget_exhaustion_returns_upstream_unavailable(monkeypatch): settings = _make_proxy_settings(log_proxy_service_tier_trace=False) From 7ba531bf0127ff30b459ef5ee80696cc70e6f56a Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 19:48:43 +0200 Subject: [PATCH 09/18] fix(proxy): resolve ty diagnostics in continuity tests --- app/modules/proxy/service.py | 2 +- tests/integration/test_proxy_responses.py | 7 +++++-- tests/unit/test_proxy_http_bridge.py | 17 +++++++++++------ tests/unit/test_proxy_utils.py | 14 ++++++++++---- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 86f3c0ba..42ef04d3 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -6812,7 +6812,7 @@ async def _stream_once( settlement.record_success = False settlement.account_health_error = _should_penalize_stream_error(error_code) if event_type in ("response.completed", "response.incomplete"): - usage = event.response.usage if event.response else None + usage = event.response.usage if event is not None and event.response else None if event_type == "response.incomplete": status = "error" if latency_first_token_ms is None and event_type in _TEXT_DELTA_EVENT_TYPES: diff --git a/tests/integration/test_proxy_responses.py b/tests/integration/test_proxy_responses.py index 70af704b..5087ced6 100644 --- a/tests/integration/test_proxy_responses.py +++ b/tests/integration/test_proxy_responses.py @@ -3,6 +3,7 @@ import base64 import json from types import SimpleNamespace +from typing import cast import pytest from httpx import ASGITransport, AsyncClient @@ -558,10 +559,12 @@ async def fake_open_upstream_websocket(**kwargs): assert fake_upstream.sent_json request_input = fake_upstream.sent_json[0]["input"] assert isinstance(request_input, list) - assert request_input[1]["output"] == proxy_module._RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE.format( + tool_input = cast(dict[str, object], request_input[1]) + assistant_input = cast(dict[str, object], request_input[2]) + assert tool_input["output"] == proxy_module._RESPONSE_CREATE_TOOL_OUTPUT_OMISSION_NOTICE.format( bytes=len(("data:image/png;base64," + ("A" * 1200)).encode("utf-8")) ) - assert request_input[2]["content"] == [ + assert assistant_input["content"] == [ {"type": "input_text", "text": proxy_module._RESPONSE_CREATE_IMAGE_OMISSION_NOTICE} ] diff --git a/tests/unit/test_proxy_http_bridge.py b/tests/unit/test_proxy_http_bridge.py index b108e993..9087e16b 100644 --- a/tests/unit/test_proxy_http_bridge.py +++ b/tests/unit/test_proxy_http_bridge.py @@ -3688,7 +3688,8 @@ async def test_get_or_create_http_bridge_session_drops_stale_previous_response_m service._http_bridge_sessions[stale_key] = stale_session service._http_bridge_previous_response_index[alias_key] = stale_key monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) - monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + create_http_bridge_session = AsyncMock(return_value=created_session) + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-a")) @@ -3744,7 +3745,8 @@ async def test_get_or_create_http_bridge_session_allows_local_rebind_for_previou idle_ttl_seconds=120.0, ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) - monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + create_http_bridge_session = AsyncMock(return_value=created_session) + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) @@ -3798,7 +3800,8 @@ async def test_get_or_create_http_bridge_session_allows_local_rebind_for_bootstr idle_ttl_seconds=120.0, ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) - monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + create_http_bridge_session = AsyncMock(return_value=created_session) + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) monkeypatch.setattr(proxy_service, "_http_bridge_owner_instance", AsyncMock(return_value="instance-b")) @@ -3875,7 +3878,8 @@ async def test_get_or_create_http_bridge_session_recovers_locally_when_owner_end idle_ttl_seconds=120.0, ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) - monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + create_http_bridge_session = AsyncMock(return_value=created_session) + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) claim_durable = AsyncMock() monkeypatch.setattr(service, "_claim_durable_http_bridge_session", claim_durable) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) @@ -4545,7 +4549,8 @@ async def test_get_or_create_http_bridge_session_hard_continuity_lookup_failure_ idle_ttl_seconds=120.0, ) monkeypatch.setattr(service, "_prune_http_bridge_sessions_locked", AsyncMock()) - monkeypatch.setattr(service, "_create_http_bridge_session", AsyncMock(return_value=created_session)) + create_http_bridge_session = AsyncMock(return_value=created_session) + monkeypatch.setattr(service, "_create_http_bridge_session", create_http_bridge_session) monkeypatch.setattr(service, "_claim_durable_http_bridge_session", AsyncMock()) monkeypatch.setattr(proxy_service, "get_settings", lambda: _make_app_settings()) monkeypatch.setattr( @@ -4573,7 +4578,7 @@ async def test_get_or_create_http_bridge_session_hard_continuity_lookup_failure_ max_sessions=8, ) - service._create_http_bridge_session.assert_not_awaited() + create_http_bridge_session.assert_not_awaited() exc = exc_info.value assert exc.status_code == 502 assert exc.payload["error"]["code"] == "upstream_unavailable" diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index 35133cc7..343eb995 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -4385,7 +4385,9 @@ async def test_select_websocket_connect_account_records_fail_closed_for_preferre ) assert result is None - sent_payload = json.loads(websocket_send.await_args.args[0]) + await_args = websocket_send.await_args + assert await_args is not None + sent_payload = json.loads(await_args.args[0]) assert sent_payload["status"] == 502 assert sent_payload["error"]["code"] == "upstream_unavailable" assert "continuity_fail_closed surface=websocket_connect reason=owner_account_unavailable" in caplog.text @@ -4454,7 +4456,9 @@ async def test_select_websocket_connect_account_preferred_owner_missing_fails_cl ) assert result is None - sent_payload = json.loads(websocket_send.await_args.args[0]) + await_args = websocket_send.await_args + assert await_args is not None + sent_payload = json.loads(await_args.args[0]) assert sent_payload["status"] == 502 assert sent_payload["error"]["code"] == "upstream_unavailable" assert sent_payload["error"]["message"] == "Previous response owner account is unavailable; retry later." @@ -7551,8 +7555,10 @@ async def fake_stream(payload, headers, access_token, account_id, base_url=None, assert len(select_account_calls) == 1 assert select_account_calls[0]["account_ids"] == {account_owner.id} handle_stream_error.assert_awaited_once() - assert handle_stream_error.await_args.args[0] == account_owner - assert handle_stream_error.await_args.args[2] == "usage_limit_reached" + handle_await_args = handle_stream_error.await_args + assert handle_await_args is not None + assert handle_await_args.args[0] == account_owner + assert handle_await_args.args[2] == "usage_limit_reached" record_success.assert_not_awaited() From e4b55e1b7d75df6104dfc179bb9237a027e9d218 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Thu, 16 Apr 2026 20:17:20 +0200 Subject: [PATCH 10/18] fix(proxy): persist non-bridge continuity anchors --- app/modules/proxy/service.py | 14 ++- tests/integration/test_proxy_api_extended.py | 2 +- tests/integration/test_proxy_responses.py | 107 ++++++++++++++++++- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 42ef04d3..55808cf8 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -6621,10 +6621,12 @@ async def _stream_once( service_tier = requested_service_tier actual_service_tier: str | None = None reasoning_effort = payload.reasoning.effort if payload.reasoning else None + session_id = _owner_lookup_session_id_from_headers(headers) start = time.monotonic() status = "success" error_code = None error_message = None + response_id = request_id usage = None saw_text_delta = False latency_first_token_ms: int | None = None @@ -6728,6 +6730,8 @@ async def _stream_once( if event and event.type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None + if event.response and event.response.id: + response_id = event.response.id if event.type == "response.incomplete": status = "error" @@ -6768,6 +6772,8 @@ async def _stream_once( if event_type == "response.failed": response = event.response error = response.error if response else None + if response and response.id: + response_id = response.id else: error = event.error raw_error_code = _normalize_error_code( @@ -6812,7 +6818,10 @@ async def _stream_once( settlement.record_success = False settlement.account_health_error = _should_penalize_stream_error(error_code) if event_type in ("response.completed", "response.incomplete"): - usage = event.response.usage if event is not None and event.response else None + response = event.response if event is not None else None + usage = response.usage if response else None + if response and response.id: + response_id = response.id if event_type == "response.incomplete": status = "error" if latency_first_token_ms is None and event_type in _TEXT_DELTA_EVENT_TYPES: @@ -6881,7 +6890,7 @@ async def _stream_once( await self._write_request_log( account_id=account_id_value, api_key=api_key, - request_id=request_id, + request_id=response_id, model=model, latency_ms=int((time.monotonic() - start) * 1000), status=status, @@ -6897,6 +6906,7 @@ async def _stream_once( requested_service_tier=requested_service_tier, actual_service_tier=actual_service_tier, latency_first_token_ms=latency_first_token_ms, + session_id=session_id, ) _maybe_log_proxy_service_tier_trace( "stream", diff --git a/tests/integration/test_proxy_api_extended.py b/tests/integration/test_proxy_api_extended.py index 040f3435..77611db4 100644 --- a/tests/integration/test_proxy_api_extended.py +++ b/tests/integration/test_proxy_api_extended.py @@ -134,7 +134,7 @@ async def fake_stream(payload, headers, access_token, account_id, base_url=None, ) log = result.scalars().first() assert log is not None - assert log.request_id == request_id + assert log.request_id == "resp_1" assert log.input_tokens == 10 assert log.output_tokens == 5 assert log.cached_input_tokens == 3 diff --git a/tests/integration/test_proxy_responses.py b/tests/integration/test_proxy_responses.py index 5087ced6..e5de5440 100644 --- a/tests/integration/test_proxy_responses.py +++ b/tests/integration/test_proxy_responses.py @@ -13,7 +13,7 @@ import app.modules.proxy.service as proxy_module from app.core.auth import generate_unique_account_id from app.core.config.settings import Settings -from app.db.models import DashboardSettings, RequestLog +from app.db.models import Account, DashboardSettings, RequestLog from app.db.session import SessionLocal from app.modules.request_logs.repository import RequestLogsRepository @@ -339,6 +339,109 @@ async def fail_stream(*args, **kwargs): assert response.json()["error"]["message"] == "Previous response owner lookup failed; retry later." +@pytest.mark.asyncio +async def test_v1_responses_previous_response_followup_without_http_bridge_recovers_owner_from_request_logs( + async_client, + monkeypatch, +): + owner_email = "prev-http-owner-anchor@example.com" + owner_raw_account_id = "acc_prev_http_owner_anchor" + owner_auth_json = _make_auth_json(owner_raw_account_id, owner_email) + owner_files = {"auth_json": ("auth.json", json.dumps(owner_auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=owner_files) + assert response.status_code == 200 + + other_email = "prev-http-other-anchor@example.com" + other_raw_account_id = "acc_prev_http_other_anchor" + other_auth_json = _make_auth_json(other_raw_account_id, other_email) + other_files = {"auth_json": ("auth.json", json.dumps(other_auth_json), "application/json")} + response = await async_client.post("/api/accounts/import", files=other_files) + assert response.status_code == 200 + + async with SessionLocal() as session: + accounts = { + account.chatgpt_account_id: account + for account in (await session.execute(select(Account))).scalars().all() + if account.chatgpt_account_id in {owner_raw_account_id, other_raw_account_id} + } + + owner_account = accounts[owner_raw_account_id] + other_account = accounts[other_raw_account_id] + selection_preferred_ids: list[str | None] = [] + + async def fake_select_account(self, deadline: float, **kwargs): + del self, deadline + preferred_account_id = cast(str | None, kwargs.get("preferred_account_id")) + selection_preferred_ids.append(preferred_account_id) + if not selection_preferred_ids[:-1]: + return proxy_module.AccountSelection(account=owner_account, error_message=None, error_code=None) + if preferred_account_id == owner_account.id: + return proxy_module.AccountSelection(account=owner_account, error_message=None, error_code=None) + return proxy_module.AccountSelection(account=other_account, error_message=None, error_code=None) + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **kwargs): + del headers, access_token, base_url, raise_for_status, kwargs + if payload.previous_response_id is None: + assert account_id == owner_raw_account_id + yield ( + 'data: {"type":"response.completed","response":{"id":"resp_prev_http_anchor",' + '"object":"response","status":"completed","usage":{"input_tokens":3,"output_tokens":1,"total_tokens":4}}}\n\n' + ) + return + if payload.previous_response_id == "resp_prev_http_anchor" and account_id == owner_raw_account_id: + yield ( + 'data: {"type":"response.completed","response":{"id":"resp_prev_http_followup",' + '"object":"response","status":"completed","usage":{"input_tokens":2,"output_tokens":1,"total_tokens":3}}}\n\n' + ) + return + error_payload = proxy_module.openai_error( + "previous_response_not_found", + "Previous response with id 'resp_prev_http_anchor' not found.", + error_type="invalid_request_error", + ) + error_payload["error"]["param"] = "previous_response_id" + raise proxy_module.ProxyResponseError(400, error_payload) + if False: + yield "" + + async def fake_ensure_fresh(self, account, **kwargs): + del self, kwargs + return account + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget_compatible", fake_select_account) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh) + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + + first_response = await async_client.post( + "/v1/responses", + json={"model": "gpt-5.1", "input": "start"}, + headers={"session_id": "sid_prev_http_anchor"}, + ) + + assert first_response.status_code == 200 + assert first_response.json()["id"] == "resp_prev_http_anchor" + async with SessionLocal() as session: + persisted_log = ( + await session.execute(select(RequestLog).where(RequestLog.request_id == "resp_prev_http_anchor").limit(1)) + ).scalar_one_or_none() + assert persisted_log is not None + assert persisted_log.session_id == "sid_prev_http_anchor" + + second_response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "input": "continue", + "previous_response_id": "resp_prev_http_anchor", + }, + headers={"session_id": "sid_prev_http_anchor"}, + ) + + assert second_response.status_code == 200 + assert second_response.json()["id"] == "resp_prev_http_followup" + assert selection_preferred_ids == [None, owner_account.id] + + @pytest.mark.asyncio async def test_v1_responses_without_http_bridge_websocket_upstream_rejects_oversized_response_create_before_connect( async_client, @@ -671,7 +774,7 @@ async def fake_stream(payload, headers, access_token, account_id, base_url=None, ) log = result.scalars().first() assert log is not None - assert log.request_id == request_id + assert log.request_id == "resp_1" assert log.transport == "http" From 86d9afed2978cf183c5509082a97358b1d912907 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Mon, 20 Apr 2026 19:49:38 +0200 Subject: [PATCH 11/18] fix(proxy): mask previous_response_not_found without breaking inflight response routing --- app/modules/proxy/service.py | 125 +++++++++-- .../integration/test_http_responses_bridge.py | 198 +++++++++++++++++ .../test_proxy_websocket_responses.py | 200 ++++++++++++++++-- tests/unit/test_proxy_utils.py | 141 ++++++++++++ 4 files changed, 626 insertions(+), 38 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 55808cf8..d89e9d53 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4782,10 +4782,19 @@ async def _process_http_bridge_upstream_text( event = parse_sse_event(event_block) event_type = _event_type_from_payload(event, payload) response_id = _websocket_response_id(event, payload) + is_previous_response_not_found_event = _is_previous_response_not_found_error( + code=_normalize_error_code( + _websocket_event_error_code(event_type, payload), + _websocket_event_error_type(event_type, payload), + ), + param=_websocket_event_error_param(event_type, payload), + message=_websocket_event_error_message(event_type, payload), + ) async with session.pending_lock: matched_request_state = None created_request_state = None + has_other_pending_requests = False if event_type == "response.created": matched_request_state = _assign_websocket_response_id(session.pending_requests, response_id) created_request_state = matched_request_state @@ -4796,8 +4805,11 @@ async def _process_http_bridge_upstream_text( response_id, ) release_create_gate = False - elif response_id is None and len(session.pending_requests) == 1: - matched_request_state = session.pending_requests[0] + elif response_id is None: + matched_request_state = _match_websocket_request_state_for_anonymous_event( + session.pending_requests, + prefer_previous_response_not_found=is_previous_response_not_found_event, + ) release_create_gate = False else: release_create_gate = False @@ -4814,13 +4826,34 @@ async def _process_http_bridge_upstream_text( session.pending_requests, response_id=response_id, fallback_request_state=matched_request_state, + prefer_previous_response_not_found=is_previous_response_not_found_event, + allow_precreated_terminal_fallback=event_type in {"response.failed", "response.incomplete", "error"}, ) if terminal_request_state is not None: session.queued_request_count = max(0, session.queued_request_count - 1) + has_other_pending_requests = bool(session.pending_requests) + + status_request_state = terminal_request_state or matched_request_state + if ( + event_type == "error" + and status_request_state is not None + and status_request_state.previous_response_id is not None + and is_previous_response_not_found_event + and has_other_pending_requests + ): + status_request_state.error_http_status_override = 502 + event, payload, event_type, rewritten_text = _maybe_rewrite_websocket_previous_response_not_found_event( + request_state=status_request_state, + event=event, + payload=payload, + event_type=event_type, + upstream_control=session.upstream_control, + original_text=text, + ) + event_block = f"data: {rewritten_text}\n\n" if event_type == "error": http_status = _http_error_status_from_payload(payload) - status_request_state = terminal_request_state or matched_request_state if status_request_state is not None: status_request_state.error_http_status_override = http_status ( @@ -5229,6 +5262,14 @@ async def _process_upstream_websocket_text( event = parse_sse_event(event_block) event_type = _event_type_from_payload(event, payload) response_id = _websocket_response_id(event, payload) + is_previous_response_not_found_event = _is_previous_response_not_found_error( + code=_normalize_error_code( + _websocket_event_error_code(event_type, payload), + _websocket_event_error_type(event_type, payload), + ), + param=_websocket_event_error_param(event_type, payload), + message=_websocket_event_error_message(event_type, payload), + ) async with pending_lock: request_state = None @@ -5241,8 +5282,11 @@ async def _process_upstream_websocket_text( elif response_id is not None: request_state = _find_websocket_request_state_by_response_id(pending_requests, response_id) release_create_gate = False - elif response_id is None and len(pending_requests) == 1: - request_state = pending_requests[0] + elif response_id is None: + request_state = _match_websocket_request_state_for_anonymous_event( + pending_requests, + prefer_previous_response_not_found=is_previous_response_not_found_event, + ) release_create_gate = False else: release_create_gate = False @@ -5259,6 +5303,8 @@ async def _process_upstream_websocket_text( pending_requests, response_id=response_id, fallback_request_state=request_state, + prefer_previous_response_not_found=is_previous_response_not_found_event, + allow_precreated_terminal_fallback=event_type in {"response.failed", "response.incomplete", "error"}, ) has_other_pending_requests = bool(pending_requests) else: @@ -5270,14 +5316,7 @@ async def _process_upstream_websocket_text( if request_state is None: return text - retry_is_previous_response_not_found = _is_previous_response_not_found_error( - code=_normalize_error_code( - _websocket_event_error_code(event_type, payload), - _websocket_event_error_type(event_type, payload), - ), - param=_websocket_event_error_param(event_type, payload), - message=_websocket_event_error_message(event_type, payload), - ) + retry_is_previous_response_not_found = is_previous_response_not_found_event retry_error_code = _websocket_precreated_retry_error_code( request_state, event_type=event_type, @@ -8005,6 +8044,44 @@ def _assign_websocket_response_id( return None +def _match_websocket_request_state_for_anonymous_event( + pending_requests: deque[_WebSocketRequestState], + *, + prefer_previous_response_not_found: bool, +) -> _WebSocketRequestState | None: + if len(pending_requests) == 1: + return pending_requests[0] + + unresolved_requests = [request_state for request_state in pending_requests if request_state.response_id is None] + if len(unresolved_requests) == 1: + return unresolved_requests[0] + + if not prefer_previous_response_not_found: + return None + + followup_requests = [ + request_state + for request_state in unresolved_requests + if request_state.previous_response_id is not None and request_state.awaiting_response_created + ] + if len(followup_requests) == 1: + return followup_requests[0] + return None + + +def _match_websocket_request_state_for_precreated_terminal_event( + pending_requests: deque[_WebSocketRequestState], +) -> _WebSocketRequestState | None: + unresolved_requests = [ + request_state + for request_state in pending_requests + if request_state.response_id is None and request_state.awaiting_response_created + ] + if len(unresolved_requests) == 1: + return unresolved_requests[0] + return None + + def _release_websocket_response_create_gate( request_state: _WebSocketRequestState, response_create_gate: asyncio.Semaphore, @@ -8428,6 +8505,8 @@ def _pop_terminal_websocket_request_state( *, response_id: str | None, fallback_request_state: _WebSocketRequestState | None, + prefer_previous_response_not_found: bool = False, + allow_precreated_terminal_fallback: bool = False, ) -> _WebSocketRequestState | None: if response_id is not None: request_state = _find_websocket_request_state_by_response_id(pending_requests, response_id) @@ -8437,13 +8516,19 @@ def _pop_terminal_websocket_request_state( if fallback_request_state is not None and fallback_request_state in pending_requests: pending_requests.remove(fallback_request_state) return fallback_request_state - unresolved_requests = [request_state for request_state in pending_requests if request_state.response_id is None] - if len(unresolved_requests) == 1: - request_state = unresolved_requests[0] - pending_requests.remove(request_state) - return request_state - if response_id is None and len(pending_requests) == 1: - return pending_requests.popleft() + if response_id is not None and allow_precreated_terminal_fallback: + request_state = _match_websocket_request_state_for_precreated_terminal_event(pending_requests) + if request_state is not None and request_state in pending_requests: + pending_requests.remove(request_state) + return request_state + if response_id is None: + request_state = _match_websocket_request_state_for_anonymous_event( + pending_requests, + prefer_previous_response_not_found=prefer_previous_response_not_found, + ) + if request_state is not None and request_state in pending_requests: + pending_requests.remove(request_state) + return request_state return None diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index d8bbb008..6dcc1ec8 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -15,6 +15,7 @@ import anyio import pytest import pytest_asyncio +from httpx import ASGITransport, AsyncClient from sqlalchemy import select import app.modules.proxy.service as proxy_module @@ -461,6 +462,85 @@ async def send_text(self, text: str) -> None: ) +class _AnonymousPreviousResponseNotFoundWithInflightUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + def __init__(self) -> None: + super().__init__() + self.first_request_created = asyncio.Event() + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + if len(self.sent_text) == 1: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_inflight", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + self.first_request_created.set() + return + + payload = json.loads(text) + previous_response_id = payload.get("previous_response_id") + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": f"Previous response with id '{previous_response_id}' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_inflight", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + + class _InvalidRequestPreviousResponseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): async def send_text(self, text: str) -> None: self.sent_text.append(text) @@ -6928,6 +7008,124 @@ async def fake_connect_responses_websocket( assert connect_count == 2 +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_masks_anonymous_previous_response_not_found_with_inflight_request( + app_instance, + monkeypatch, +): + _install_bridge_settings(monkeypatch, enabled=True) + upstream = _AnonymousPreviousResponseNotFoundWithInflightUpstreamWebSocket() + connect_count = 0 + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + api_key, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + nonlocal connect_count + connect_count += 1 + return upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + async with app_instance.router.lifespan_context(app_instance): + async with ( + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as admin_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as first_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as second_client, + ): + account_id = await _import_account( + admin_client, + "acc_http_bridge_prev_nf_inflight", + "http-bridge-prev-nf-inflight@example.com", + ) + account = await _get_account(account_id) + + first = asyncio.create_task( + first_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "previous-response-inflight-mixed", + }, + ) + ) + await _wait_for_event(upstream.first_request_created) + + second = asyncio.create_task( + second_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "previous-response-inflight-mixed", + "previous_response_id": "resp_bridge_prev_anchor", + }, + ) + ) + + first_response, second_response = await asyncio.wait_for( + asyncio.gather(first, second), + timeout=_TEST_SYNC_TIMEOUT_SECONDS, + ) + + assert first_response.status_code == 200 + assert first_response.json()["output"][0]["content"][0]["text"] == "OK" + assert second_response.status_code >= 400 + assert second_response.json()["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in second_response.json()["error"].get("code", "") + assert "previous_response_not_found" not in second_response.json()["error"].get("message", "") + assert connect_count == 1 + + @pytest.mark.asyncio async def test_v1_responses_http_bridge_send_retry_keeps_session_open_for_followup_request( async_client, diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index d8872048..1fe7d0ec 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1514,6 +1514,171 @@ async def fake_try_open_websocket_connect_attempt( assert event["error"]["message"] == "Upstream websocket closed before response.completed" +def test_backend_responses_websocket_masks_anonymous_previous_response_not_found_with_inflight_request( + app_instance, + monkeypatch, +): + fake_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_inflight", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_ws_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_inflight", + "status": "completed", + "usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}, + }, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + log_calls: list[dict[str, object]] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + return SimpleNamespace(id="acct_ws_prev_followup"), fake_upstream + + async def fake_write_request_log(self, **kwargs): + del self + log_calls.append(kwargs) + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, previous_response_id, api_key, session_id, surface + return None + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + first_request = { + "type": "response.create", + "model": "gpt-5.4", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "first"}]}], + "stream": True, + } + followup_request = { + "type": "response.create", + "model": "gpt-5.4", + "instructions": "", + "previous_response_id": "resp_ws_prev_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect( + "/backend-api/codex/responses", + headers={"Authorization": "Bearer external-token"}, + ) as websocket: + websocket.send_text(json.dumps(first_request)) + created_event = json.loads(websocket.receive_text()) + + websocket.send_text(json.dumps(followup_request)) + failed_event = json.loads(websocket.receive_text()) + completed_event = json.loads(websocket.receive_text()) + + assert created_event["type"] == "response.created" + assert created_event["response"]["id"] == "resp_ws_inflight" + assert failed_event["type"] == "response.failed" + assert failed_event["response"]["error"]["code"] == "stream_incomplete" + assert failed_event["response"]["error"]["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in json.dumps(failed_event) + assert completed_event["type"] == "response.completed" + assert completed_event["response"]["id"] == "resp_ws_inflight" + assert any(call["status"] == "error" and call["error_code"] == "stream_incomplete" for call in log_calls) + assert any(call["status"] == "success" and call["request_id"] == "resp_ws_inflight" for call in log_calls) + assert fake_upstream.closed is True + + @pytest.mark.parametrize("frame", ['{"type":"response.create"', "[]"]) def test_backend_responses_websocket_rejects_malformed_first_frame_as_invalid_payload(app_instance, monkeypatch, frame): called = {"connect": False} @@ -3018,27 +3183,26 @@ async def fake_connect_proxy_websocket( def test_backend_responses_websocket_matches_terminal_events_by_response_id(app_instance, monkeypatch): - upstream_messages = [ - _FakeUpstreamMessage( - "text", - text=json.dumps( - {"type": "response.created", "response": {"id": "resp_ws_a", "status": "in_progress"}}, - separators=(",", ":"), - ), - ), - _FakeUpstreamMessage( - "text", - text=json.dumps( - {"type": "response.created", "response": {"id": "resp_ws_b", "status": "in_progress"}}, - separators=(",", ":"), - ), - ), - ] fake_upstream = _SequencedUpstreamWebSocket( - upstream_messages, + [], deferred_message_batches=[ - [], [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_a", "status": "in_progress"}}, + separators=(",", ":"), + ), + ) + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_b", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), _FakeUpstreamMessage( "text", text=json.dumps( diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index 343eb995..f1c5531b 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -5154,6 +5154,59 @@ async def test_process_upstream_websocket_text_does_not_match_foreign_response_i assert list(pending_requests) == [pending_request] +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_does_not_match_foreign_completed_event_to_only_unresolved_request( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + account = _make_account("acc_ws_pending_precreated") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_pending_precreated", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps( + { + "type": "response.create", + "model": "gpt-5.1", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "first"}]}], + }, + separators=(",", ":"), + ), + ) + pending_requests = deque([pending_request]) + payload = { + "type": "response.completed", + "response": { + "id": "resp_ws_foreign_completed", + "usage": {"input_tokens": 7, "output_tokens": 11, "total_tokens": 18}, + }, + } + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=proxy_service._WebSocketUpstreamControl(), + response_create_gate=asyncio.Semaphore(1), + ) + + assert downstream_text == json.dumps(payload, separators=(",", ":")) + finalize_request_state.assert_not_awaited() + assert list(pending_requests) == [pending_request] + + @pytest.mark.asyncio async def test_process_upstream_websocket_text_transparently_retries_precreated_usage_limit_failure( monkeypatch, @@ -6538,6 +6591,94 @@ async def test_process_upstream_websocket_text_retries_precreated_previous_respo assert list(pending_requests) == [] +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_masks_previous_response_not_found_for_unique_followup_request( + monkeypatch, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_prev_not_found_followup_match") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + inflight_request = proxy_service._WebSocketRequestState( + request_id="ws_req_inflight", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps( + { + "type": "response.create", + "model": "gpt-5.1", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "first"}]}], + }, + separators=(",", ":"), + ), + ) + followup_request = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_prev_not_found", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + awaiting_response_created=True, + request_text=json.dumps( + { + "type": "response.create", + "model": "gpt-5.1", + "previous_response_id": "resp_anchor", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "continue"}]}], + }, + separators=(",", ":"), + ), + previous_response_id="resp_anchor", + ) + pending_requests = deque([inflight_request, followup_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + upstream_payload = { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + } + upstream_text = json.dumps(upstream_payload, separators=(",", ":")) + + downstream_text = await service._process_upstream_websocket_text( + upstream_text, + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert "previous_response_not_found" not in downstream_text + handle_stream_error.assert_not_awaited() + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request + assert finalize_call.kwargs["event_type"] == "response.failed" + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is False + assert list(pending_requests) == [inflight_request] + + def test_maybe_rewrite_websocket_previous_response_not_found_rewrites_response_failed_event(): request_state = proxy_service._WebSocketRequestState( request_id="ws_req_prev_nf", From c2bec8538769b83732990ba250b8333d0e378e6a Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Mon, 20 Apr 2026 19:57:01 +0200 Subject: [PATCH 12/18] style(proxy): fix ruff line length --- app/modules/proxy/service.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index d89e9d53..9b0ce9ff 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4827,7 +4827,11 @@ async def _process_http_bridge_upstream_text( response_id=response_id, fallback_request_state=matched_request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, - allow_precreated_terminal_fallback=event_type in {"response.failed", "response.incomplete", "error"}, + allow_precreated_terminal_fallback=event_type in { + "response.failed", + "response.incomplete", + "error", + }, ) if terminal_request_state is not None: session.queued_request_count = max(0, session.queued_request_count - 1) @@ -5304,7 +5308,11 @@ async def _process_upstream_websocket_text( response_id=response_id, fallback_request_state=request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, - allow_precreated_terminal_fallback=event_type in {"response.failed", "response.incomplete", "error"}, + allow_precreated_terminal_fallback=event_type in { + "response.failed", + "response.incomplete", + "error", + }, ) has_other_pending_requests = bool(pending_requests) else: From 432005619a436a8c58ce85e31bedebfe2cbdf81c Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Mon, 20 Apr 2026 20:00:43 +0200 Subject: [PATCH 13/18] style(proxy): format service.py with ruff --- app/modules/proxy/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 9b0ce9ff..3d71f76b 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4827,7 +4827,8 @@ async def _process_http_bridge_upstream_text( response_id=response_id, fallback_request_state=matched_request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, - allow_precreated_terminal_fallback=event_type in { + allow_precreated_terminal_fallback=event_type + in { "response.failed", "response.incomplete", "error", @@ -5308,7 +5309,8 @@ async def _process_upstream_websocket_text( response_id=response_id, fallback_request_state=request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, - allow_precreated_terminal_fallback=event_type in { + allow_precreated_terminal_fallback=event_type + in { "response.failed", "response.incomplete", "error", From e25922d582e148009629ae77d79e6c99d3cd5649 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Tue, 21 Apr 2026 00:47:57 +0200 Subject: [PATCH 14/18] fix(proxy): harden previous_response anchor matching for multiplexed follow-ups --- app/modules/proxy/service.py | 118 +- .../design.md | 5 + .../specs/responses-api-compat/spec.md | 7 + .../tasks.md | 12 +- .../integration/test_http_responses_bridge.py | 3481 +++++++++++------ .../test_proxy_websocket_responses.py | 922 +++++ tests/unit/test_proxy_utils.py | 672 ++++ 7 files changed, 3982 insertions(+), 1235 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 3d71f76b..dbea2752 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4782,14 +4782,16 @@ async def _process_http_bridge_upstream_text( event = parse_sse_event(event_block) event_type = _event_type_from_payload(event, payload) response_id = _websocket_response_id(event, payload) + error_message = _websocket_event_error_message(event_type, payload) is_previous_response_not_found_event = _is_previous_response_not_found_error( code=_normalize_error_code( _websocket_event_error_code(event_type, payload), _websocket_event_error_type(event_type, payload), ), param=_websocket_event_error_param(event_type, payload), - message=_websocket_event_error_message(event_type, payload), + message=error_message, ) + previous_response_id_hint = _previous_response_id_from_not_found_message(error_message) async with session.pending_lock: matched_request_state = None @@ -4809,6 +4811,8 @@ async def _process_http_bridge_upstream_text( matched_request_state = _match_websocket_request_state_for_anonymous_event( session.pending_requests, prefer_previous_response_not_found=is_previous_response_not_found_event, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, ) release_create_gate = False else: @@ -4827,6 +4831,8 @@ async def _process_http_bridge_upstream_text( response_id=response_id, fallback_request_state=matched_request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, allow_precreated_terminal_fallback=event_type in { "response.failed", @@ -4840,11 +4846,10 @@ async def _process_http_bridge_upstream_text( status_request_state = terminal_request_state or matched_request_state if ( - event_type == "error" - and status_request_state is not None + status_request_state is not None and status_request_state.previous_response_id is not None and is_previous_response_not_found_event - and has_other_pending_requests + and (response_id is not None or has_other_pending_requests) ): status_request_state.error_http_status_override = 502 event, payload, event_type, rewritten_text = _maybe_rewrite_websocket_previous_response_not_found_event( @@ -5267,14 +5272,16 @@ async def _process_upstream_websocket_text( event = parse_sse_event(event_block) event_type = _event_type_from_payload(event, payload) response_id = _websocket_response_id(event, payload) + error_message = _websocket_event_error_message(event_type, payload) is_previous_response_not_found_event = _is_previous_response_not_found_error( code=_normalize_error_code( _websocket_event_error_code(event_type, payload), _websocket_event_error_type(event_type, payload), ), param=_websocket_event_error_param(event_type, payload), - message=_websocket_event_error_message(event_type, payload), + message=error_message, ) + previous_response_id_hint = _previous_response_id_from_not_found_message(error_message) async with pending_lock: request_state = None @@ -5291,6 +5298,8 @@ async def _process_upstream_websocket_text( request_state = _match_websocket_request_state_for_anonymous_event( pending_requests, prefer_previous_response_not_found=is_previous_response_not_found_event, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, ) release_create_gate = False else: @@ -5309,6 +5318,8 @@ async def _process_upstream_websocket_text( response_id=response_id, fallback_request_state=request_state, prefer_previous_response_not_found=is_previous_response_not_found_event, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, allow_precreated_terminal_fallback=event_type in { "response.failed", @@ -7754,6 +7765,38 @@ def _is_previous_response_not_found_message(message: str | None) -> bool: return "previous response" in normalized and "not found" in normalized +def _previous_response_id_from_not_found_message(message: str | None) -> str | None: + if message is None: + return None + normalized = " ".join(message.split()) + match = re.search( + r"""previous\s+response\s+with\s+id\s+['"](?P[^'"]+)['"]\s+not\s+found""", + normalized, + re.IGNORECASE, + ) + if match is None: + return None + response_id = match.group("response_id").strip() + return response_id or None + + +def _message_mentions_previous_response_id(message: str | None, previous_response_id: str | None) -> bool: + if message is None or previous_response_id is None: + return False + normalized_message = " ".join(message.split()) + normalized_previous_response_id = previous_response_id.strip() + if not normalized_previous_response_id: + return False + identifier_pattern = re.escape(normalized_previous_response_id) + return ( + re.search( + rf"(? str | None: if not isinstance(session_id, str): return None @@ -8058,24 +8101,22 @@ def _match_websocket_request_state_for_anonymous_event( pending_requests: deque[_WebSocketRequestState], *, prefer_previous_response_not_found: bool, + previous_response_id_hint: str | None = None, + error_message: str | None = None, ) -> _WebSocketRequestState | None: + if prefer_previous_response_not_found: + return _match_websocket_request_state_for_previous_response_error( + pending_requests, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, + ) + if len(pending_requests) == 1: return pending_requests[0] unresolved_requests = [request_state for request_state in pending_requests if request_state.response_id is None] if len(unresolved_requests) == 1: return unresolved_requests[0] - - if not prefer_previous_response_not_found: - return None - - followup_requests = [ - request_state - for request_state in unresolved_requests - if request_state.previous_response_id is not None and request_state.awaiting_response_created - ] - if len(followup_requests) == 1: - return followup_requests[0] return None @@ -8092,6 +8133,38 @@ def _match_websocket_request_state_for_precreated_terminal_event( return None +def _match_websocket_request_state_for_previous_response_error( + pending_requests: deque[_WebSocketRequestState], + *, + previous_response_id_hint: str | None = None, + error_message: str | None = None, +) -> _WebSocketRequestState | None: + followup_requests = [ + request_state for request_state in pending_requests if request_state.previous_response_id is not None + ] + if previous_response_id_hint is not None: + matching_requests = [ + request_state + for request_state in followup_requests + if request_state.previous_response_id == previous_response_id_hint + ] + if len(matching_requests) == 1: + return matching_requests[0] + return None + if error_message is not None: + matching_requests = [ + request_state + for request_state in followup_requests + if _message_mentions_previous_response_id(error_message, request_state.previous_response_id) + ] + if len(matching_requests) == 1: + return matching_requests[0] + return None + if len(followup_requests) == 1: + return followup_requests[0] + return None + + def _release_websocket_response_create_gate( request_state: _WebSocketRequestState, response_create_gate: asyncio.Semaphore, @@ -8516,6 +8589,8 @@ def _pop_terminal_websocket_request_state( response_id: str | None, fallback_request_state: _WebSocketRequestState | None, prefer_previous_response_not_found: bool = False, + previous_response_id_hint: str | None = None, + error_message: str | None = None, allow_precreated_terminal_fallback: bool = False, ) -> _WebSocketRequestState | None: if response_id is not None: @@ -8531,10 +8606,21 @@ def _pop_terminal_websocket_request_state( if request_state is not None and request_state in pending_requests: pending_requests.remove(request_state) return request_state + if response_id is not None and prefer_previous_response_not_found: + request_state = _match_websocket_request_state_for_previous_response_error( + pending_requests, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, + ) + if request_state is not None and request_state in pending_requests: + pending_requests.remove(request_state) + return request_state if response_id is None: request_state = _match_websocket_request_state_for_anonymous_event( pending_requests, prefer_previous_response_not_found=prefer_previous_response_not_found, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, ) if request_state is not None and request_state in pending_requests: pending_requests.remove(request_state) diff --git a/openspec/changes/harden-continuity-fail-closed-edges/design.md b/openspec/changes/harden-continuity-fail-closed-edges/design.md index 5fa7afad..30bb0421 100644 --- a/openspec/changes/harden-continuity-fail-closed-edges/design.md +++ b/openspec/changes/harden-continuity-fail-closed-edges/design.md @@ -25,6 +25,11 @@ When a request depends on `previous_response_id` or hard bridge continuity keys, Alternative considered: continue current degrade-open behavior. Rejected because it allows continuity fragmentation precisely when the proxy has lost the data needed to enforce owner correctness. +### Match multiplexed continuity failures to the referenced anchor +When one upstream websocket carries multiple pending follow-up requests, fail-closed continuity handling must target the follow-up whose `previous_response_id` anchor is actually referenced by the upstream failure. Matching should prefer structured identifiers when present and otherwise use the referenced anchor from the upstream error payload/message, with conservative fallback only when the target remains unique. + +Alternative considered: keep count-based heuristics and treat any single follow-up as the failing request. Rejected because it can rewrite the wrong request, leak raw `previous_response_not_found`, or interrupt unrelated in-flight work when multiple anchors share one upstream session. + ## Risks / Trade-offs - [Risk] More requests can fail fast during transient owner/ring metadata outages. → Mitigation: failures become retryable and avoid silent continuity forks. diff --git a/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md index d560e9a2..c94a2b24 100644 --- a/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md +++ b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md @@ -15,6 +15,13 @@ When a Responses follow-up depends on previously established continuity state, t - **THEN** the service returns a retryable OpenAI-format error - **AND** the error code is not `previous_response_not_found` +#### Scenario: multiplexed follow-ups fail closed only for the matching continuity anchor +- **WHEN** a websocket or HTTP bridge session has multiple pending follow-up requests with different `previous_response_id` anchors +- **AND** continuity loss is detected for exactly one of those anchors +- **THEN** the service applies the retryable fail-closed continuity error only to the matching follow-up request +- **AND** it does not expose raw `previous_response_not_found` +- **AND** unrelated pending requests continue on their own response lifecycle + ### Requirement: Hard continuity owner lookup fails closed When a request depends on hard continuity ownership, the service MUST fail closed if owner or ring lookup errors prevent safe pinning. The service MUST NOT continue with local recovery or account selection that bypasses hard owner enforcement. diff --git a/openspec/changes/harden-continuity-fail-closed-edges/tasks.md b/openspec/changes/harden-continuity-fail-closed-edges/tasks.md index b77dfe0a..e1d2867f 100644 --- a/openspec/changes/harden-continuity-fail-closed-edges/tasks.md +++ b/openspec/changes/harden-continuity-fail-closed-edges/tasks.md @@ -1,14 +1,14 @@ ## 1. Continuity Contract -- [ ] 1.1 Update bridge-local continuity-loss paths to return retryable errors instead of raw `previous_response_not_found`. -- [ ] 1.2 Fail closed on hard-continuity owner/ring lookup errors instead of degrading into unpinned or local recovery. +- [x] 1.1 Update bridge-local continuity-loss paths to return retryable errors instead of raw `previous_response_not_found`. +- [x] 1.2 Fail closed on hard-continuity owner/ring lookup errors instead of degrading into unpinned or local recovery. ## 2. Regression Coverage -- [ ] 2.1 Add bridge regression tests for missing turn-state alias and inflight-follower continuity loss. -- [ ] 2.2 Add lookup-failure regression tests for websocket or HTTP fallback `previous_response_id` flows and hard bridge owner lookup failures. +- [x] 2.1 Add bridge regression tests for missing turn-state alias and inflight-follower continuity loss. +- [x] 2.2 Add lookup-failure regression tests for websocket or HTTP fallback `previous_response_id` flows and hard bridge owner lookup failures. ## 3. Verification -- [ ] 3.1 Run targeted continuity test suites covering bridge, websocket, and HTTP fallback paths. -- [ ] 3.2 Run full pytest and confirm no broader regressions. +- [x] 3.1 Run targeted continuity test suites covering bridge, websocket, and HTTP fallback paths. +- [x] 3.2 Run full pytest and confirm no broader regressions. diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index 6dcc1ec8..9cbbd30f 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -566,69 +566,243 @@ async def send_text(self, text: str) -> None: ) -class _FailingSendThenCloseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): +class _ForeignPreviousResponseNotFoundAfterCreatedUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): async def send_text(self, text: str) -> None: self.sent_text.append(text) - await self._messages.put(_FakeUpstreamMessage("close", close_code=1011)) - raise RuntimeError("socket closed during send") + if len(self.sent_text) == 1: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_prev_anchor", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_prev_anchor", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + if len(self.sent_text) == 2: + payload = json.loads(text) + previous_response_id = payload.get("previous_response_id") + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_created", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.failed", + "response": { + "id": "resp_bridge_foreign_prev_nf", + "object": "response", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": f"Previous response with id '{previous_response_id}' not found.", + "param": "previous_response_id", + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return -def _make_dummy_bridge_session(session_key: proxy_module._HTTPBridgeSessionKey) -> SimpleNamespace: - async def _close() -> None: - return None + response_id = "resp_bridge_after_error" + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": response_id, "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": response_id, + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) - return SimpleNamespace( - key=session_key, - headers={}, - closed=False, - account=SimpleNamespace(id=None, status=AccountStatus.ACTIVE), - request_model="gpt-5.4", - pending_lock=anyio.Lock(), - pending_requests=deque(), - queued_request_count=0, - last_used_at=time.monotonic(), - idle_ttl_seconds=120.0, - codex_session=False, - downstream_turn_state=None, - downstream_turn_state_aliases=set(), - previous_response_ids=set(), - durable_session_id=None, - durable_owner_epoch=None, - upstream_reader=None, - upstream_control=proxy_module._WebSocketUpstreamControl(), - upstream=SimpleNamespace(close=_close), - ) +class _AnonymousPreviousResponseNotFoundAfterCreatedUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + def __init__(self) -> None: + super().__init__() + self.first_request_created = asyncio.Event() -class _PrewarmingBridgeUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): async def send_text(self, text: str) -> None: self.sent_text.append(text) - payload = json.loads(text) - response_id = f"resp_prewarm_{len(self.sent_text)}" - output = [] - usage = { - "input_tokens": 12, - "output_tokens": 0, - "total_tokens": 12, - "input_tokens_details": {"cached_tokens": 0}, - "output_tokens_details": {"reasoning_tokens": 0}, - } - if payload.get("generate") is not False: - response_id = f"resp_actual_{len(self.sent_text)}" - output = [ - { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "OK"}], - } - ] - usage = { - "input_tokens": 24, - "output_tokens": 2, - "total_tokens": 26, - "input_tokens_details": {"cached_tokens": 20}, - "output_tokens_details": {"reasoning_tokens": 0}, - } + if len(self.sent_text) == 1: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_inflight", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + self.first_request_created.set() + return + + if len(self.sent_text) == 2: + payload = json.loads(text) + previous_response_id = payload.get("previous_response_id") + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_created", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": f"Previous response with id '{previous_response_id}' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_inflight", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + response_id = "resp_bridge_after_error" await self._messages.put( _FakeUpstreamMessage( "text", @@ -651,8 +825,20 @@ async def send_text(self, text: str) -> None: "id": response_id, "object": "response", "status": "completed", - "output": output, - "usage": usage, + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, }, }, separators=(",", ":"), @@ -661,7 +847,341 @@ async def send_text(self, text: str) -> None: ) -class _TurnStateBridgeUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): +class _TwoFollowupsPreviousResponseNotFoundUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + def __init__(self) -> None: + super().__init__() + self.first_followup_created = asyncio.Event() + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + if len(self.sent_text) == 1: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_prev_anchor_a", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_prev_anchor_a", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + if len(self.sent_text) == 2: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_prev_anchor_b", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_prev_anchor_b", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + if len(self.sent_text) == 3: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_a", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + self.first_followup_created.set() + return + + if len(self.sent_text) == 4: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_b", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": ( + "Cannot continue conversation because upstream lost resp_bridge_prev_anchor_a." + ), + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_followup_b", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + response_id = "resp_bridge_after_error" + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": response_id, "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": response_id, + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + + +class _FailingSendThenCloseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + await self._messages.put(_FakeUpstreamMessage("close", close_code=1011)) + raise RuntimeError("socket closed during send") + + +def _make_dummy_bridge_session(session_key: proxy_module._HTTPBridgeSessionKey) -> SimpleNamespace: + async def _close() -> None: + return None + + return SimpleNamespace( + key=session_key, + headers={}, + closed=False, + account=SimpleNamespace(id=None, status=AccountStatus.ACTIVE), + request_model="gpt-5.4", + pending_lock=anyio.Lock(), + pending_requests=deque(), + queued_request_count=0, + last_used_at=time.monotonic(), + idle_ttl_seconds=120.0, + codex_session=False, + downstream_turn_state=None, + downstream_turn_state_aliases=set(), + previous_response_ids=set(), + durable_session_id=None, + durable_owner_epoch=None, + upstream_reader=None, + upstream_control=proxy_module._WebSocketUpstreamControl(), + upstream=SimpleNamespace(close=_close), + ) + + +class _PrewarmingBridgeUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + payload = json.loads(text) + response_id = f"resp_prewarm_{len(self.sent_text)}" + output = [] + usage = { + "input_tokens": 12, + "output_tokens": 0, + "total_tokens": 12, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}, + } + if payload.get("generate") is not False: + response_id = f"resp_actual_{len(self.sent_text)}" + output = [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ] + usage = { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + } + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": response_id, "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": response_id, + "object": "response", + "status": "completed", + "output": output, + "usage": usage, + }, + }, + separators=(",", ":"), + ), + ) + ) + + +class _TurnStateBridgeUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): def __init__(self, turn_state: str) -> None: super().__init__() self._turn_state = turn_state @@ -4942,14 +5462,178 @@ async def fake_connect_responses_websocket( monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - first_payload = proxy_module.ResponsesRequest( + first_payload = proxy_module.ResponsesRequest( + model="gpt-5.1", + instructions="Return exactly OK.", + input="queued-session", + prompt_cache_key="queued-session-a", + ) + first_affinity = proxy_module._sticky_key_for_responses_request( + first_payload, + {}, + codex_session_affinity=False, + openai_cache_affinity=True, + openai_cache_affinity_max_age_seconds=300, + sticky_threads_enabled=False, + api_key=None, + ) + first_key = proxy_module._make_http_bridge_session_key( + first_payload, + headers={}, + affinity=first_affinity, + api_key=None, + request_id="req_queue_a", + ) + first_session = await service._get_or_create_http_bridge_session( + first_key, + headers={}, + affinity=first_affinity, + api_key=None, + request_model="gpt-5.1", + idle_ttl_seconds=120.0, + max_sessions=1, + ) + + await first_session.response_create_gate.acquire() + request_state, text_data = service._prepare_http_bridge_request( + first_payload, + {}, + api_key=None, + api_key_reservation=None, + ) + request_state.transport = "http" + submit_task = asyncio.create_task( + service._submit_http_bridge_request( + first_session, + request_state=request_state, + text_data=text_data, + queue_limit=8, + ) + ) + await asyncio.sleep(0) + + assert await service._http_bridge_pending_count(first_session) == 1 + async with first_session.pending_lock: + assert list(first_session.pending_requests) == [] + assert first_session.queued_request_count == 1 + + second_payload = proxy_module.ResponsesRequest( + model="gpt-5.1", + instructions="Return exactly OK.", + input="new-session", + prompt_cache_key="queued-session-b", + ) + second_affinity = proxy_module._sticky_key_for_responses_request( + second_payload, + {}, + codex_session_affinity=False, + openai_cache_affinity=True, + openai_cache_affinity_max_age_seconds=300, + sticky_threads_enabled=False, + api_key=None, + ) + second_key = proxy_module._make_http_bridge_session_key( + second_payload, + headers={}, + affinity=second_affinity, + api_key=None, + request_id="req_queue_b", + ) + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + await service._get_or_create_http_bridge_session( + second_key, + headers={}, + affinity=second_affinity, + api_key=None, + request_model="gpt-5.1", + idle_ttl_seconds=120.0, + max_sessions=1, + ) + + exc = exc_info.value + assert exc.status_code == 429 + assert hanging_upstream.closed is False + + submit_task.cancel() + with pytest.raises(asyncio.CancelledError): + await submit_task + first_session.response_create_gate.release() + await service._close_http_bridge_session(first_session) + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_enforces_queue_limit_atomically_for_same_session( + async_client, + app_instance, + monkeypatch, +): + _install_bridge_settings_with_limits(monkeypatch, enabled=True, queue_limit=1) + account_id = await _import_account(async_client, "acc_http_bridge_queue", "http-bridge-queue@example.com") + service = get_proxy_service_for_app(app_instance) + account = await _get_account(account_id) + hanging_upstream = _SilentUpstreamWebSocket() + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + return hanging_upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + payload = proxy_module.ResponsesRequest( model="gpt-5.1", instructions="Return exactly OK.", - input="queued-session", - prompt_cache_key="queued-session-a", + input="same-session", + prompt_cache_key="same-session-key", ) - first_affinity = proxy_module._sticky_key_for_responses_request( - first_payload, + affinity = proxy_module._sticky_key_for_responses_request( + payload, {}, codex_session_affinity=False, openai_cache_affinity=True, @@ -4957,211 +5641,377 @@ async def fake_connect_responses_websocket( sticky_threads_enabled=False, api_key=None, ) - first_key = proxy_module._make_http_bridge_session_key( - first_payload, + key = proxy_module._make_http_bridge_session_key( + payload, headers={}, - affinity=first_affinity, + affinity=affinity, api_key=None, - request_id="req_queue_a", + request_id="req_queue", ) - first_session = await service._get_or_create_http_bridge_session( - first_key, + session = await service._get_or_create_http_bridge_session( + key, headers={}, - affinity=first_affinity, + affinity=affinity, api_key=None, request_model="gpt-5.1", idle_ttl_seconds=120.0, - max_sessions=1, + max_sessions=128, ) - await first_session.response_create_gate.acquire() - request_state, text_data = service._prepare_http_bridge_request( - first_payload, - {}, - api_key=None, - api_key_reservation=None, + first_state, first_text = service._prepare_http_bridge_request(payload, {}, api_key=None, api_key_reservation=None) + first_state.transport = "http" + await service._submit_http_bridge_request(session, request_state=first_state, text_data=first_text, queue_limit=1) + + second_state, second_text = service._prepare_http_bridge_request( + payload, {}, api_key=None, api_key_reservation=None ) - request_state.transport = "http" - submit_task = asyncio.create_task( - service._submit_http_bridge_request( - first_session, - request_state=request_state, - text_data=text_data, - queue_limit=8, + second_state.transport = "http" + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + await service._submit_http_bridge_request( + session, + request_state=second_state, + text_data=second_text, + queue_limit=1, ) + + exc = exc_info.value + assert exc.status_code == 429 + assert session.queued_request_count == 1 + await service._close_http_bridge_session(session) + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_creates_different_session_keys_in_parallel(app_instance, monkeypatch): + service = get_proxy_service_for_app(app_instance) + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=8, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), ) - await asyncio.sleep(0) - assert await service._http_bridge_pending_count(first_session) == 1 - async with first_session.pending_lock: - assert list(first_session.pending_requests) == [] - assert first_session.queued_request_count == 1 + create_started: list[str] = [] + + async def fake_create_http_bridge_session( + self, + key, + *, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, + ): + del self, headers, affinity, request_model, idle_ttl_seconds + create_started.append(key.affinity_key) + await asyncio.sleep(0.2) + return _make_dummy_bridge_session(key) + + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + + key_one = proxy_module._HTTPBridgeSessionKey("request", "bridge-a", None) + key_two = proxy_module._HTTPBridgeSessionKey("request", "bridge-b", None) + t0 = time.monotonic() + + try: + first = asyncio.create_task( + service._get_or_create_http_bridge_session( + key_one, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + second = asyncio.create_task( + service._get_or_create_http_bridge_session( + key_two, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + session_one, session_two = await asyncio.gather(first, second) + elapsed = time.monotonic() - t0 + + assert elapsed < 0.35 + assert sorted(create_started) == ["bridge-a", "bridge-b"] + assert session_one.key == key_one + assert session_two.key == key_two + assert service._http_bridge_sessions[key_one] is session_one + assert service._http_bridge_sessions[key_two] is session_two + finally: + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_singleflights_same_session_key_during_creation(app_instance, monkeypatch): + service = get_proxy_service_for_app(app_instance) + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=8, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), + ) + + create_started: list[str] = [] + + async def fake_create_http_bridge_session( + self, + key, + *, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, + ): + del self, headers, affinity, request_model, idle_ttl_seconds + create_started.append(key.affinity_key) + await asyncio.sleep(0.2) + return _make_dummy_bridge_session(key) + + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-singleflight", None) + t0 = time.monotonic() + + try: + first = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + second = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + session_one, session_two = await asyncio.gather(first, second) + elapsed = time.monotonic() - t0 + + assert elapsed < 0.35 + assert create_started == ["bridge-singleflight"] + assert session_one is session_two + assert service._http_bridge_sessions[key] is session_one + finally: + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_waits_for_inflight_capacity_before_rate_limiting_other_keys( + app_instance, monkeypatch +): + service = get_proxy_service_for_app(app_instance) + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=1, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), + ) + + first_create_started = asyncio.Event() + release_first_create = asyncio.Event() + create_attempts: list[str] = [] + + async def fake_create_http_bridge_session( + self, + key, + *, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, + ): + del self, headers, affinity, request_model, idle_ttl_seconds + create_attempts.append(key.affinity_key) + if key.affinity_key == "bridge-capacity-a": + first_create_started.set() + await _wait_for_event(release_first_create) + raise RuntimeError("first create failed") + return _make_dummy_bridge_session(key) + + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - second_payload = proxy_module.ResponsesRequest( - model="gpt-5.1", - instructions="Return exactly OK.", - input="new-session", - prompt_cache_key="queued-session-b", - ) - second_affinity = proxy_module._sticky_key_for_responses_request( - second_payload, - {}, - codex_session_affinity=False, - openai_cache_affinity=True, - openai_cache_affinity_max_age_seconds=300, - sticky_threads_enabled=False, - api_key=None, - ) - second_key = proxy_module._make_http_bridge_session_key( - second_payload, - headers={}, - affinity=second_affinity, - api_key=None, - request_id="req_queue_b", + key_one = proxy_module._HTTPBridgeSessionKey("request", "bridge-capacity-a", None) + key_two = proxy_module._HTTPBridgeSessionKey("request", "bridge-capacity-b", None) + + first = asyncio.create_task( + service._get_or_create_http_bridge_session( + key_one, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=1, + ) ) - with pytest.raises(proxy_module.ProxyResponseError) as exc_info: - await service._get_or_create_http_bridge_session( - second_key, + await _wait_for_event(first_create_started) + + second = asyncio.create_task( + service._get_or_create_http_bridge_session( + key_two, headers={}, - affinity=second_affinity, + affinity=proxy_module._AffinityPolicy(), api_key=None, - request_model="gpt-5.1", + request_model="gpt-5.4", idle_ttl_seconds=120.0, max_sessions=1, ) + ) + await asyncio.sleep(0.01) + assert not second.done() - exc = exc_info.value - assert exc.status_code == 429 - assert hanging_upstream.closed is False + release_first_create.set() - submit_task.cancel() - with pytest.raises(asyncio.CancelledError): - await submit_task - first_session.response_create_gate.release() - await service._close_http_bridge_session(first_session) + with pytest.raises(RuntimeError, match="first create failed"): + await first + created_session = await asyncio.wait_for(second, timeout=1.0) + + assert create_attempts == ["bridge-capacity-a", "bridge-capacity-b"] + assert service._http_bridge_sessions[key_two] is created_session + assert key_one not in service._http_bridge_inflight_sessions + assert key_two not in service._http_bridge_inflight_sessions @pytest.mark.asyncio -async def test_v1_responses_http_bridge_enforces_queue_limit_atomically_for_same_session( - async_client, - app_instance, - monkeypatch, -): - _install_bridge_settings_with_limits(monkeypatch, enabled=True, queue_limit=1) - account_id = await _import_account(async_client, "acc_http_bridge_queue", "http-bridge-queue@example.com") +async def test_v1_responses_http_bridge_singleflight_follower_refreshes_session_model(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) - account = await _get_account(account_id) - hanging_upstream = _SilentUpstreamWebSocket() + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() - async def fake_select_account_with_budget( - self, - deadline, - *, - request_id, - kind, - sticky_key, - sticky_kind, - reallocate_sticky, - sticky_max_age_seconds, - prefer_earlier_reset_accounts, - routing_strategy, - model, - exclude_account_ids=None, - additional_limit_name=None, - api_key=None, - ): - del ( - self, - deadline, - request_id, - kind, - sticky_key, - sticky_kind, - reallocate_sticky, - sticky_max_age_seconds, - prefer_earlier_reset_accounts, - routing_strategy, - model, - exclude_account_ids, - additional_limit_name, - ) - return AccountSelection(account=account, error_message=None, error_code=None) + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=8, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), + ) - async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): - del self, force, timeout_seconds - return target + create_started = asyncio.Event() + release_create = asyncio.Event() - async def fake_connect_responses_websocket( - headers, - access_token, - account_id_header, + async def fake_create_http_bridge_session( + self, + key, *, - base_url=None, - session=None, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, ): - del headers, access_token, account_id_header, base_url, session - return hanging_upstream - - monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) - monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) - monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + del self, headers, affinity, request_model, idle_ttl_seconds + create_started.set() + await _wait_for_event(release_create) + session = _make_dummy_bridge_session(key) + session.request_model = "gpt-5.1" + return session - payload = proxy_module.ResponsesRequest( - model="gpt-5.1", - instructions="Return exactly OK.", - input="same-session", - prompt_cache_key="same-session-key", - ) - affinity = proxy_module._sticky_key_for_responses_request( - payload, - {}, - codex_session_affinity=False, - openai_cache_affinity=True, - openai_cache_affinity_max_age_seconds=300, - sticky_threads_enabled=False, - api_key=None, - ) - key = proxy_module._make_http_bridge_session_key( - payload, - headers={}, - affinity=affinity, - api_key=None, - request_id="req_queue", - ) - session = await service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=affinity, - api_key=None, - request_model="gpt-5.1", - idle_ttl_seconds=120.0, - max_sessions=128, - ) + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - first_state, first_text = service._prepare_http_bridge_request(payload, {}, api_key=None, api_key_reservation=None) - first_state.transport = "http" - await service._submit_http_bridge_request(session, request_state=first_state, text_data=first_text, queue_limit=1) + key = proxy_module._HTTPBridgeSessionKey("session_header", "shared-session", None) - second_state, second_text = service._prepare_http_bridge_request( - payload, {}, api_key=None, api_key_reservation=None - ) - second_state.transport = "http" - with pytest.raises(proxy_module.ProxyResponseError) as exc_info: - await service._submit_http_bridge_request( - session, - request_state=second_state, - text_data=second_text, - queue_limit=1, + try: + creator = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={"session_id": "shared-session"}, + affinity=proxy_module._AffinityPolicy( + key="shared-session", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.1", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + await _wait_for_event(create_started) + follower = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={"session_id": "shared-session"}, + affinity=proxy_module._AffinityPolicy( + key="shared-session", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) ) + release_create.set() + created_session, follower_session = await asyncio.gather(creator, follower) - exc = exc_info.value - assert exc.status_code == 429 - assert session.queued_request_count == 1 - await service._close_http_bridge_session(session) + assert created_session is follower_session + assert follower_session.request_model == "gpt-5.4" + finally: + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() @pytest.mark.asyncio -async def test_v1_responses_http_bridge_creates_different_session_keys_in_parallel(app_instance, monkeypatch): +async def test_v1_responses_http_bridge_singleflight_follower_replaces_session_when_account_is_no_longer_assigned( + async_client, app_instance, monkeypatch +): service = get_proxy_service_for_app(app_instance) service._http_bridge_sessions.clear() service._http_bridge_inflight_sessions.clear() @@ -5179,7 +6029,19 @@ async def test_v1_responses_http_bridge_creates_different_session_keys_in_parall dashboard_settings=_make_dashboard_settings(), ) - create_started: list[str] = [] + create_started = asyncio.Event() + release_create = asyncio.Event() + create_calls: list[list[str]] = [] + stale_account_id = await _import_account( + async_client, + "acc_http_bridge_stale", + "http-bridge-stale@example.com", + ) + fresh_account_id = await _import_account( + async_client, + "acc_http_bridge_fresh", + "http-bridge-fresh@example.com", + ) async def fake_create_http_bridge_session( self, @@ -5192,48 +6054,61 @@ async def fake_create_http_bridge_session( idle_ttl_seconds, ): del self, headers, affinity, request_model, idle_ttl_seconds - create_started.append(key.affinity_key) - await asyncio.sleep(0.2) - return _make_dummy_bridge_session(key) + create_calls.append(list(api_key.assigned_account_ids if api_key is not None else [])) + if len(create_calls) == 1: + create_started.set() + await _wait_for_event(release_create) + session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) + cast(Any, session).account = SimpleNamespace(id=stale_account_id, status=AccountStatus.ACTIVE) + return session + session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) + cast(Any, session).account = SimpleNamespace(id=fresh_account_id, status=AccountStatus.ACTIVE) + return session monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - key_one = proxy_module._HTTPBridgeSessionKey("request", "bridge-a", None) - key_two = proxy_module._HTTPBridgeSessionKey("request", "bridge-b", None) - t0 = time.monotonic() + key = proxy_module._HTTPBridgeSessionKey("session_header", "shared-session", "key-assignments") + stale_api_key = _make_api_key_data(key_id="key-assignments", assigned_account_ids=[stale_account_id]) + refreshed_api_key = _make_api_key_data(key_id="key-assignments", assigned_account_ids=[fresh_account_id]) try: - first = asyncio.create_task( + creator = asyncio.create_task( service._get_or_create_http_bridge_session( - key_one, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", + key, + headers={"session_id": "shared-session"}, + affinity=proxy_module._AffinityPolicy( + key="shared-session", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + api_key=stale_api_key, + request_model="gpt-5.1", idle_ttl_seconds=120.0, max_sessions=8, ) ) - second = asyncio.create_task( + await _wait_for_event(create_started) + follower = asyncio.create_task( service._get_or_create_http_bridge_session( - key_two, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, + key, + headers={"session_id": "shared-session"}, + affinity=proxy_module._AffinityPolicy( + key="shared-session", + kind=proxy_module.StickySessionKind.CODEX_SESSION, + ), + api_key=refreshed_api_key, request_model="gpt-5.4", idle_ttl_seconds=120.0, max_sessions=8, ) ) - session_one, session_two = await asyncio.gather(first, second) - elapsed = time.monotonic() - t0 + release_create.set() + created_session, follower_session = await asyncio.gather(creator, follower) - assert elapsed < 0.35 - assert sorted(create_started) == ["bridge-a", "bridge-b"] - assert session_one.key == key_one - assert session_two.key == key_two - assert service._http_bridge_sessions[key_one] is session_one - assert service._http_bridge_sessions[key_two] is session_two + assert created_session is not follower_session + assert created_session.account.id == stale_account_id + assert follower_session.account.id == fresh_account_id + assert service._http_bridge_sessions[key] is follower_session + assert create_calls == [[stale_account_id], [fresh_account_id]] finally: service._http_bridge_sessions.clear() service._http_bridge_inflight_sessions.clear() @@ -5241,7 +6116,7 @@ async def fake_create_http_bridge_session( @pytest.mark.asyncio -async def test_v1_responses_http_bridge_singleflights_same_session_key_during_creation(app_instance, monkeypatch): +async def test_v1_responses_http_bridge_singleflights_stale_session_replacement(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) service._http_bridge_sessions.clear() service._http_bridge_inflight_sessions.clear() @@ -5278,8 +6153,10 @@ async def fake_create_http_bridge_session( monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-singleflight", None) - t0 = time.monotonic() + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-stale-replace", None) + stale_session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) + stale_session.closed = True + service._http_bridge_sessions[key] = stale_session try: first = asyncio.create_task( @@ -5305,10 +6182,8 @@ async def fake_create_http_bridge_session( ) ) session_one, session_two = await asyncio.gather(first, second) - elapsed = time.monotonic() - t0 - assert elapsed < 0.35 - assert create_started == ["bridge-singleflight"] + assert create_started == ["bridge-stale-replace"] assert session_one is session_two assert service._http_bridge_sessions[key] is session_one finally: @@ -5318,7 +6193,85 @@ async def fake_create_http_bridge_session( @pytest.mark.asyncio -async def test_v1_responses_http_bridge_waits_for_inflight_capacity_before_rate_limiting_other_keys( +async def test_v1_responses_http_bridge_cleans_up_cancelled_singleflight_creator(app_instance, monkeypatch): + service = get_proxy_service_for_app(app_instance) + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=8, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), + ) + + first_create_started = asyncio.Event() + create_attempts = 0 + + async def fake_create_http_bridge_session( + self, + key, + *, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, + ): + del self, headers, affinity, request_model, idle_ttl_seconds + nonlocal create_attempts + create_attempts += 1 + if create_attempts == 1: + first_create_started.set() + await _wait_for_event(asyncio.Event()) + return _make_dummy_bridge_session(key) + + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-cancelled-create", None) + + creator = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + await _wait_for_event(first_create_started) + creator.cancel() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(creator, timeout=_TEST_SYNC_TIMEOUT_SECONDS) + + replacement = await asyncio.wait_for( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ), + timeout=1.0, + ) + + assert create_attempts == 2 + assert service._http_bridge_sessions[key] is replacement + assert key not in service._http_bridge_inflight_sessions + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_cleans_up_cancelled_singleflight_creator_after_create( app_instance, monkeypatch ): service = get_proxy_service_for_app(app_instance) @@ -5330,7 +6283,7 @@ async def test_v1_responses_http_bridge_waits_for_inflight_capacity_before_rate_ monkeypatch, app_settings=_make_app_settings( enabled=True, - max_sessions=1, + max_sessions=8, codex_idle_ttl_seconds=120.0, instance_id="instance-a", instance_ring=[], @@ -5338,9 +6291,90 @@ async def test_v1_responses_http_bridge_waits_for_inflight_capacity_before_rate_ dashboard_settings=_make_dashboard_settings(), ) - first_create_started = asyncio.Event() - release_first_create = asyncio.Event() - create_attempts: list[str] = [] + create_finished = asyncio.Event() + allow_return = asyncio.Event() + create_attempts = 0 + + async def fake_create_http_bridge_session( + self, + key, + *, + headers, + affinity, + api_key, + request_model, + idle_ttl_seconds, + ): + del self, headers, affinity, request_model, idle_ttl_seconds + nonlocal create_attempts + create_attempts += 1 + if create_attempts == 1: + create_finished.set() + await _wait_for_event(allow_return) + return _make_dummy_bridge_session(key) + + monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-cancelled-after-create", None) + creator = asyncio.create_task( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ) + ) + await _wait_for_event(create_finished) + async with service._http_bridge_lock: + allow_return.set() + await asyncio.sleep(0) + creator.cancel() + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(creator, timeout=_TEST_SYNC_TIMEOUT_SECONDS) + + replacement = await asyncio.wait_for( + service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, + ), + timeout=1.0, + ) + + assert create_attempts == 2 + assert service._http_bridge_sessions[key] is replacement + assert key not in service._http_bridge_inflight_sessions + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_waits_for_inflight_session_before_continuity_error(app_instance, monkeypatch): + service = get_proxy_service_for_app(app_instance) + service._http_bridge_sessions.clear() + service._http_bridge_inflight_sessions.clear() + service._http_bridge_turn_state_index.clear() + + _install_proxy_settings( + monkeypatch, + app_settings=_make_app_settings( + enabled=True, + max_sessions=8, + codex_idle_ttl_seconds=120.0, + instance_id="instance-a", + instance_ring=[], + ), + dashboard_settings=_make_dashboard_settings(), + ) + + create_started = asyncio.Event() + release_create = asyncio.Event() async def fake_create_http_bridge_session( self, @@ -5353,59 +6387,59 @@ async def fake_create_http_bridge_session( idle_ttl_seconds, ): del self, headers, affinity, request_model, idle_ttl_seconds - create_attempts.append(key.affinity_key) - if key.affinity_key == "bridge-capacity-a": - first_create_started.set() - await _wait_for_event(release_first_create) - raise RuntimeError("first create failed") + create_started.set() + await _wait_for_event(release_create) return _make_dummy_bridge_session(key) monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - key_one = proxy_module._HTTPBridgeSessionKey("request", "bridge-capacity-a", None) - key_two = proxy_module._HTTPBridgeSessionKey("request", "bridge-capacity-b", None) + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-waits-for-inflight", None) - first = asyncio.create_task( + creator = asyncio.create_task( service._get_or_create_http_bridge_session( - key_one, + key, headers={}, affinity=proxy_module._AffinityPolicy(), api_key=None, request_model="gpt-5.4", idle_ttl_seconds=120.0, - max_sessions=1, + max_sessions=8, ) ) - await _wait_for_event(first_create_started) + await _wait_for_event(create_started) - second = asyncio.create_task( + follower = asyncio.create_task( service._get_or_create_http_bridge_session( - key_two, + key, headers={}, affinity=proxy_module._AffinityPolicy(), api_key=None, request_model="gpt-5.4", idle_ttl_seconds=120.0, - max_sessions=1, + max_sessions=8, + previous_response_id="resp_inflight", ) ) await asyncio.sleep(0.01) - assert not second.done() - - release_first_create.set() + assert follower.done() - with pytest.raises(RuntimeError, match="first create failed"): - await first - created_session = await asyncio.wait_for(second, timeout=1.0) + release_create.set() + created_session = await creator + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + await follower - assert create_attempts == ["bridge-capacity-a", "bridge-capacity-b"] - assert service._http_bridge_sessions[key_two] is created_session - assert key_one not in service._http_bridge_inflight_sessions - assert key_two not in service._http_bridge_inflight_sessions + assert service._http_bridge_sessions[key] is created_session + exc = exc_info.value + assert exc.status_code == 502 + assert exc.payload["error"] == { + "message": "Upstream websocket closed before response.completed", + "type": "server_error", + "code": "stream_incomplete", + } @pytest.mark.asyncio -async def test_v1_responses_http_bridge_singleflight_follower_refreshes_session_model(app_instance, monkeypatch): +async def test_v1_responses_http_bridge_prunes_idle_session_before_reuse(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) service._http_bridge_sessions.clear() service._http_bridge_inflight_sessions.clear() @@ -5423,8 +6457,7 @@ async def test_v1_responses_http_bridge_singleflight_follower_refreshes_session_ dashboard_settings=_make_dashboard_settings(), ) - create_started = asyncio.Event() - release_create = asyncio.Event() + create_started: list[str] = [] async def fake_create_http_bridge_session( self, @@ -5437,51 +6470,31 @@ async def fake_create_http_bridge_session( idle_ttl_seconds, ): del self, headers, affinity, request_model, idle_ttl_seconds - create_started.set() - await _wait_for_event(release_create) - session = _make_dummy_bridge_session(key) - session.request_model = "gpt-5.1" - return session + create_started.append(key.affinity_key) + return _make_dummy_bridge_session(key) monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - key = proxy_module._HTTPBridgeSessionKey("session_header", "shared-session", None) + key = proxy_module._HTTPBridgeSessionKey("request", "bridge-idle-prune", None) + stale_session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) + stale_session.last_used_at = time.monotonic() - 300.0 + stale_session.idle_ttl_seconds = 120.0 + service._http_bridge_sessions[key] = stale_session try: - creator = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={"session_id": "shared-session"}, - affinity=proxy_module._AffinityPolicy( - key="shared-session", - kind=proxy_module.StickySessionKind.CODEX_SESSION, - ), - api_key=None, - request_model="gpt-5.1", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - await _wait_for_event(create_started) - follower = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={"session_id": "shared-session"}, - affinity=proxy_module._AffinityPolicy( - key="shared-session", - kind=proxy_module.StickySessionKind.CODEX_SESSION, - ), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) + replacement = await service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=proxy_module._AffinityPolicy(), + api_key=None, + request_model="gpt-5.4", + idle_ttl_seconds=120.0, + max_sessions=8, ) - release_create.set() - created_session, follower_session = await asyncio.gather(creator, follower) - assert created_session is follower_session - assert follower_session.request_model == "gpt-5.4" + assert create_started == ["bridge-idle-prune"] + assert replacement is not stale_session + assert service._http_bridge_sessions[key] is replacement finally: service._http_bridge_sessions.clear() service._http_bridge_inflight_sessions.clear() @@ -5489,508 +6502,608 @@ async def fake_create_http_bridge_session( @pytest.mark.asyncio -async def test_v1_responses_http_bridge_singleflight_follower_replaces_session_when_account_is_no_longer_assigned( - async_client, app_instance, monkeypatch -): - service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() - - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], - ), - dashboard_settings=_make_dashboard_settings(), - ) - - create_started = asyncio.Event() - release_create = asyncio.Event() - create_calls: list[list[str]] = [] - stale_account_id = await _import_account( +async def test_v1_responses_http_bridge_stream_failure_remains_valid_sse(async_client, monkeypatch): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( async_client, - "acc_http_bridge_stale", - "http-bridge-stale@example.com", + "acc_http_bridge_sse_failure", + "http-bridge-sse-failure@example.com", ) - fresh_account_id = await _import_account( + account = await _get_account(account_id) + upstream = _CreatedThenCloseUpstreamWebSocket() + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + return upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + async with async_client.stream( + "POST", + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "trigger-sse-failure", + "prompt_cache_key": "sse-failure-key", + "stream": True, + }, + ) as response: + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line.startswith("data: ")] + + events = [json.loads(line[6:]) for line in lines] + assert [event["type"] for event in events] == ["response.created", "response.failed"] + assert events[-1]["response"]["error"]["code"] == "stream_incomplete" + + +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_surfaces_upstream_error_event_as_http_400(async_client, monkeypatch): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( async_client, - "acc_http_bridge_fresh", - "http-bridge-fresh@example.com", + "acc_http_bridge_error_norm", + "http-bridge-error-norm@example.com", ) + account = await _get_account(account_id) + fake_upstream = _ErrorOnlyUpstreamWebSocket() + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + ) + return AccountSelection(account=account, error_message=None, error_code=None) - async def fake_create_http_bridge_session( - self, - key, - *, + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, + access_token, + account_id_header, + *, + base_url=None, + session=None, ): - del self, headers, affinity, request_model, idle_ttl_seconds - create_calls.append(list(api_key.assigned_account_ids if api_key is not None else [])) - if len(create_calls) == 1: - create_started.set() - await _wait_for_event(release_create) - session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) - cast(Any, session).account = SimpleNamespace(id=stale_account_id, status=AccountStatus.ACTIVE) - return session - session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) - cast(Any, session).account = SimpleNamespace(id=fresh_account_id, status=AccountStatus.ACTIVE) - return session - - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + del headers, access_token, account_id_header, base_url, session + return fake_upstream - key = proxy_module._HTTPBridgeSessionKey("session_header", "shared-session", "key-assignments") - stale_api_key = _make_api_key_data(key_id="key-assignments", assigned_account_ids=[stale_account_id]) - refreshed_api_key = _make_api_key_data(key_id="key-assignments", assigned_account_ids=[fresh_account_id]) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - try: - creator = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={"session_id": "shared-session"}, - affinity=proxy_module._AffinityPolicy( - key="shared-session", - kind=proxy_module.StickySessionKind.CODEX_SESSION, - ), - api_key=stale_api_key, - request_model="gpt-5.1", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - await _wait_for_event(create_started) - follower = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={"session_id": "shared-session"}, - affinity=proxy_module._AffinityPolicy( - key="shared-session", - kind=proxy_module.StickySessionKind.CODEX_SESSION, - ), - api_key=refreshed_api_key, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - release_create.set() - created_session, follower_session = await asyncio.gather(creator, follower) + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.3-codex-spark", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "http-bridge-error-norm-key", + "stream": True, + }, + ) - assert created_session is not follower_session - assert created_session.account.id == stale_account_id - assert follower_session.account.id == fresh_account_id - assert service._http_bridge_sessions[key] is follower_session - assert create_calls == [[stale_account_id], [fresh_account_id]] - finally: - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() + assert response.status_code == 400 + assert response.json() == { + "error": { + "message": "The 'gpt-5.3-codex-spark' model is not supported when using Codex with a ChatGPT account.", + "type": "invalid_request_error", + "code": "invalid_request_error", + } + } @pytest.mark.asyncio -async def test_v1_responses_http_bridge_singleflights_stale_session_replacement(app_instance, monkeypatch): - service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() - - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], - ), - dashboard_settings=_make_dashboard_settings(), +async def test_v1_responses_http_bridge_preserves_rate_limit_metadata_in_429(async_client, monkeypatch): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( + async_client, + "acc_http_bridge_ratelimit", + "http-bridge-ratelimit@example.com", ) + account = await _get_account(account_id) + fake_upstream = _RateLimitErrorUpstreamWebSocket() - create_started: list[str] = [] - - async def fake_create_http_bridge_session( + async def fake_select_account_with_budget( self, - key, + deadline, *, - headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, ): - del self, headers, affinity, request_model, idle_ttl_seconds - create_started.append(key.affinity_key) - await asyncio.sleep(0.2) - return _make_dummy_bridge_session(key) + return AccountSelection(account=account, error_message=None, error_code=None) - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + return target - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-stale-replace", None) - stale_session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) - stale_session.closed = True - service._http_bridge_sessions[key] = stale_session + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + return fake_upstream - try: - first = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - second = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - session_one, session_two = await asyncio.gather(first, second) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - assert create_started == ["bridge-stale-replace"] - assert session_one is session_two - assert service._http_bridge_sessions[key] is session_one - finally: - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-4o", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "http-bridge-ratelimit-key", + "stream": True, + }, + ) + + assert response.status_code == 429 + body = response.json() + assert body["error"]["code"] == "rate_limit_exceeded" + assert body["error"]["plan_type"] == "team" + assert body["error"]["resets_at"] == 1700000000 + assert body["error"]["resets_in_seconds"] == 3600 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_cleans_up_cancelled_singleflight_creator(app_instance, monkeypatch): +async def test_v1_responses_http_bridge_cancellation_releases_queued_slot(async_client, app_instance, monkeypatch): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account(async_client, "acc_http_bridge_cancel", "http-bridge-cancel@example.com") service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() + account = await _get_account(account_id) + upstream = _SilentUpstreamWebSocket() - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], - ), - dashboard_settings=_make_dashboard_settings(), - ) + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + ) + return AccountSelection(account=account, error_message=None, error_code=None) - first_create_started = asyncio.Event() - create_attempts = 0 + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target - async def fake_create_http_bridge_session( - self, - key, - *, + async def fake_connect_responses_websocket( headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, + access_token, + account_id_header, + *, + base_url=None, + session=None, ): - del self, headers, affinity, request_model, idle_ttl_seconds - nonlocal create_attempts - create_attempts += 1 - if create_attempts == 1: - first_create_started.set() - await _wait_for_event(asyncio.Event()) - return _make_dummy_bridge_session(key) + del headers, access_token, account_id_header, base_url, session + return upstream - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-cancelled-create", None) + payload = proxy_module.ResponsesRequest( + model="gpt-5.1", + instructions="Return exactly OK.", + input="cancel-me", + prompt_cache_key="cancel-key", + ) + affinity = proxy_module._sticky_key_for_responses_request( + payload, + {}, + codex_session_affinity=False, + openai_cache_affinity=True, + openai_cache_affinity_max_age_seconds=300, + sticky_threads_enabled=False, + api_key=None, + ) + key = proxy_module._make_http_bridge_session_key( + payload, + headers={}, + affinity=affinity, + api_key=None, + request_id="req_cancel", + ) + session = await service._get_or_create_http_bridge_session( + key, + headers={}, + affinity=affinity, + api_key=None, + request_model="gpt-5.1", + idle_ttl_seconds=120.0, + max_sessions=128, + ) - creator = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, + await session.response_create_gate.acquire() + request_state, text_data = service._prepare_http_bridge_request(payload, {}, api_key=None, api_key_reservation=None) + request_state.transport = "http" + task = asyncio.create_task( + service._submit_http_bridge_request( + session, + request_state=request_state, + text_data=text_data, + queue_limit=8, ) ) - await _wait_for_event(first_create_started) - creator.cancel() + await asyncio.sleep(0) + task.cancel() with pytest.raises(asyncio.CancelledError): - await asyncio.wait_for(creator, timeout=_TEST_SYNC_TIMEOUT_SECONDS) - - replacement = await asyncio.wait_for( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ), - timeout=1.0, - ) + await task - assert create_attempts == 2 - assert service._http_bridge_sessions[key] is replacement - assert key not in service._http_bridge_inflight_sessions + assert session.queued_request_count == 0 + async with session.pending_lock: + assert list(session.pending_requests) == [] + session.response_create_gate.release() + await service._close_http_bridge_session(session) @pytest.mark.asyncio -async def test_v1_responses_http_bridge_cleans_up_cancelled_singleflight_creator_after_create( - app_instance, monkeypatch -): - service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() - - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], - ), - dashboard_settings=_make_dashboard_settings(), +async def test_v1_responses_http_bridge_send_retry_restarts_reader(async_client, monkeypatch): + _install_bridge_settings(monkeypatch, enabled=True) + account_id = await _import_account( + async_client, + "acc_http_bridge_send_retry", + "http-bridge-send-retry@example.com", ) + account = await _get_account(account_id) + upstreams = [_FailingSendThenCloseUpstreamWebSocket(), _FakeBridgeUpstreamWebSocket()] + connect_count = 0 - create_finished = asyncio.Event() - allow_return = asyncio.Event() - create_attempts = 0 - - async def fake_create_http_bridge_session( + async def fake_select_account_with_budget( self, - key, + deadline, *, - headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, ): - del self, headers, affinity, request_model, idle_ttl_seconds - nonlocal create_attempts - create_attempts += 1 - if create_attempts == 1: - create_finished.set() - await _wait_for_event(allow_return) - return _make_dummy_bridge_session(key) + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + ) + return AccountSelection(account=account, error_message=None, error_code=None) - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-cancelled-after-create", None) - creator = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - await _wait_for_event(create_finished) - async with service._http_bridge_lock: - allow_return.set() - await asyncio.sleep(0) - creator.cancel() + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + nonlocal connect_count + upstream = upstreams[connect_count] + connect_count += 1 + if isinstance(upstream, _FakeBridgeUpstreamWebSocket) and not upstream._messages.qsize(): + await upstream._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_retry_send", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ) + await upstream._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_retry_send", + "object": "response", + "status": "completed", + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return upstream - with pytest.raises(asyncio.CancelledError): - await asyncio.wait_for(creator, timeout=_TEST_SYNC_TIMEOUT_SECONDS) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - replacement = await asyncio.wait_for( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ), - timeout=1.0, + response = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "retry-send", + "prompt_cache_key": "retry-send-key", + }, ) - assert create_attempts == 2 - assert service._http_bridge_sessions[key] is replacement - assert key not in service._http_bridge_inflight_sessions + assert response.status_code == 200 + assert response.json()["id"] == "resp_retry_send" + assert connect_count == 2 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_waits_for_inflight_session_before_continuity_error(app_instance, monkeypatch): +async def test_retry_http_bridge_precreated_request_releases_pending_lock_before_reconnect(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() - - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], + session = proxy_module._HTTPBridgeSession( + key=proxy_module._HTTPBridgeSessionKey("prompt_cache", "retry-lock-key", None), + headers={}, + affinity=proxy_module._AffinityPolicy( + key="retry-lock-key", + kind=proxy_module.StickySessionKind.PROMPT_CACHE, + max_age_seconds=300, ), - dashboard_settings=_make_dashboard_settings(), + request_model="gpt-5.1", + account=cast(Account, SimpleNamespace(id="acct-retry", status=AccountStatus.ACTIVE)), + upstream=cast(proxy_module.UpstreamResponsesWebSocket, _SilentUpstreamWebSocket()), + upstream_control=proxy_module._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=1, + last_used_at=time.monotonic(), + idle_ttl_seconds=120.0, ) + request_state = proxy_module._WebSocketRequestState( + request_id="req-precreated-retry", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + awaiting_response_created=True, + response_create_gate_acquired=True, + request_text=json.dumps({"type": "response.create", "model": "gpt-5.1", "input": []}), + ) + session.pending_requests.append(request_state) + reconnect_started = asyncio.Event() + allow_reconnect_finish = asyncio.Event() + lock_reacquired = asyncio.Event() + replacement_upstream = _RecordingUpstreamWebSocket() - create_started = asyncio.Event() - release_create = asyncio.Event() - - async def fake_create_http_bridge_session( - self, - key, - *, - headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, - ): - del self, headers, affinity, request_model, idle_ttl_seconds - create_started.set() - await _wait_for_event(release_create) - return _make_dummy_bridge_session(key) - - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) + async def fake_reconnect(self, target_session, *, request_state, restart_reader=False): + del self, request_state, restart_reader + reconnect_started.set() + await _wait_for_event(allow_reconnect_finish) + target_session.upstream = replacement_upstream - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-waits-for-inflight", None) + monkeypatch.setattr(proxy_module.ProxyService, "_reconnect_http_bridge_session", fake_reconnect) - creator = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - ) - await _wait_for_event(create_started) + retry_task = asyncio.create_task(service._retry_http_bridge_precreated_request(session)) + await _wait_for_event(reconnect_started) - follower = asyncio.create_task( - service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - previous_response_id="resp_inflight", - ) - ) - await asyncio.sleep(0.01) - assert follower.done() + async def acquire_pending_lock() -> None: + async with session.pending_lock: + lock_reacquired.set() - release_create.set() - created_session = await creator - with pytest.raises(proxy_module.ProxyResponseError) as exc_info: - await follower + lock_task = asyncio.create_task(acquire_pending_lock()) + await asyncio.wait_for(lock_reacquired.wait(), timeout=1.0) + allow_reconnect_finish.set() - assert service._http_bridge_sessions[key] is created_session - exc = exc_info.value - assert exc.status_code == 502 - assert exc.payload["error"] == { - "message": "Upstream websocket closed before response.completed", - "type": "server_error", - "code": "stream_incomplete", - } + assert await retry_task is True + await lock_task + assert replacement_upstream.sent_text == [request_state.request_text] @pytest.mark.asyncio -async def test_v1_responses_http_bridge_prunes_idle_session_before_reuse(app_instance, monkeypatch): +async def test_retry_http_bridge_precreated_request_ignores_existing_response_id_entries(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() - - _install_proxy_settings( - monkeypatch, - app_settings=_make_app_settings( - enabled=True, - max_sessions=8, - codex_idle_ttl_seconds=120.0, - instance_id="instance-a", - instance_ring=[], + session = proxy_module._HTTPBridgeSession( + key=proxy_module._HTTPBridgeSessionKey("prompt_cache", "retry-race-key", None), + headers={}, + affinity=proxy_module._AffinityPolicy( + key="retry-race-key", + kind=proxy_module.StickySessionKind.PROMPT_CACHE, + max_age_seconds=300, ), - dashboard_settings=_make_dashboard_settings(), + request_model="gpt-5.1", + account=cast(Account, SimpleNamespace(id="acct-race", status=AccountStatus.ACTIVE)), + upstream=cast(proxy_module.UpstreamResponsesWebSocket, _SilentUpstreamWebSocket()), + upstream_control=proxy_module._WebSocketUpstreamControl(), + pending_requests=deque(), + pending_lock=anyio.Lock(), + response_create_gate=asyncio.Semaphore(1), + queued_request_count=2, + last_used_at=time.monotonic(), + idle_ttl_seconds=120.0, ) + existing_request = proxy_module._WebSocketRequestState( + request_id="req-existing", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + response_id="resp-existing", + awaiting_response_created=False, + ) + retry_request = proxy_module._WebSocketRequestState( + request_id="req-precreated-race", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=time.monotonic(), + awaiting_response_created=True, + request_text=json.dumps({"type": "response.create", "model": "gpt-5.1", "input": ["retry"]}), + ) + session.pending_requests.extend([existing_request, retry_request]) + replacement_upstream = _RecordingUpstreamWebSocket() - create_started: list[str] = [] - - async def fake_create_http_bridge_session( - self, - key, - *, - headers, - affinity, - api_key, - request_model, - idle_ttl_seconds, - ): - del self, headers, affinity, request_model, idle_ttl_seconds - create_started.append(key.affinity_key) - return _make_dummy_bridge_session(key) - - monkeypatch.setattr(proxy_module.ProxyService, "_create_http_bridge_session", fake_create_http_bridge_session) - - key = proxy_module._HTTPBridgeSessionKey("request", "bridge-idle-prune", None) - stale_session = cast(proxy_module._HTTPBridgeSession, _make_dummy_bridge_session(key)) - stale_session.last_used_at = time.monotonic() - 300.0 - stale_session.idle_ttl_seconds = 120.0 - service._http_bridge_sessions[key] = stale_session - - try: - replacement = await service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=proxy_module._AffinityPolicy(), - api_key=None, - request_model="gpt-5.4", - idle_ttl_seconds=120.0, - max_sessions=8, - ) - - assert create_started == ["bridge-idle-prune"] - assert replacement is not stale_session - assert service._http_bridge_sessions[key] is replacement - finally: - service._http_bridge_sessions.clear() - service._http_bridge_inflight_sessions.clear() - service._http_bridge_turn_state_index.clear() + async def fake_reconnect(self, target_session, *, request_state, restart_reader=False): + del self, request_state, restart_reader + target_session.upstream = replacement_upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_reconnect_http_bridge_session", fake_reconnect) + + assert await service._retry_http_bridge_precreated_request(session) is True + assert replacement_upstream.sent_text == [retry_request.request_text] @pytest.mark.asyncio -async def test_v1_responses_http_bridge_stream_failure_remains_valid_sse(async_client, monkeypatch): +async def test_v1_responses_http_bridge_send_failure_returns_upstream_unavailable( + async_client, + app_instance, + monkeypatch, +): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, - "acc_http_bridge_sse_failure", - "http-bridge-sse-failure@example.com", + "acc_http_bridge_send_failure_previous_response", + "http-bridge-send-failure-previous-response@example.com", ) account = await _get_account(account_id) - upstream = _CreatedThenCloseUpstreamWebSocket() + fake_upstream = _FakeBridgeUpstreamWebSocket() + failing_upstream = _FailingSendThenCloseUpstreamWebSocket() + connect_count = 0 async def fake_select_account_with_budget( self, @@ -6039,41 +7152,64 @@ async def fake_connect_responses_websocket( session=None, ): del headers, access_token, account_id_header, base_url, session - return upstream + nonlocal connect_count + connect_count += 1 + return fake_upstream if connect_count == 1 else failing_upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - async with async_client.stream( - "POST", + first = await async_client.post( "/v1/responses", json={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "trigger-sse-failure", - "prompt_cache_key": "sse-failure-key", - "stream": True, + "input": "hello", + "prompt_cache_key": "send-failure-previous-response", }, - ) as response: - assert response.status_code == 200 - lines = [line async for line in response.aiter_lines() if line.startswith("data: ")] + ) + assert first.status_code == 200 + first_body = first.json() - events = [json.loads(line[6:]) for line in lines] - assert [event["type"] for event in events] == ["response.created", "response.failed"] - assert events[-1]["response"]["error"]["code"] == "stream_incomplete" + service = get_proxy_service_for_app(app_instance) + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + session.upstream = cast(proxy_module.UpstreamResponsesWebSocket, failing_upstream) + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "send-failure-previous-response", + "previous_response_id": first_body["id"], + }, + ) + + assert second.status_code == 502 + assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "bridge_owner_unreachable") + assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert connect_count == 1 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_surfaces_upstream_error_event_as_http_400(async_client, monkeypatch): +async def test_v1_responses_http_bridge_precreated_disconnect_returns_upstream_unavailable( + async_client, + app_instance, + monkeypatch, +): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, - "acc_http_bridge_error_norm", - "http-bridge-error-norm@example.com", + "acc_http_bridge_precreated_previous_response", + "http-bridge-precreated-previous-response@example.com", ) account = await _get_account(account_id) - fake_upstream = _ErrorOnlyUpstreamWebSocket() + fake_upstream = _FakeBridgeUpstreamWebSocket() + precreated_close_upstream = _PrecreatedCloseUpstreamWebSocket() + connect_count = 0 async def fake_select_account_with_budget( self, @@ -6122,106 +7258,68 @@ async def fake_connect_responses_websocket( session=None, ): del headers, access_token, account_id_header, base_url, session - return fake_upstream + nonlocal connect_count + connect_count += 1 + return fake_upstream if connect_count == 1 else precreated_close_upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - response = await async_client.post( + first = await async_client.post( "/v1/responses", json={ - "model": "gpt-5.3-codex-spark", + "model": "gpt-5.1", "instructions": "Return exactly OK.", "input": "hello", - "prompt_cache_key": "http-bridge-error-norm-key", - "stream": True, + "prompt_cache_key": "precreated-previous-response", }, ) + assert first.status_code == 200 + first_body = first.json() - assert response.status_code == 400 - assert response.json() == { - "error": { - "message": "The 'gpt-5.3-codex-spark' model is not supported when using Codex with a ChatGPT account.", - "type": "invalid_request_error", - "code": "invalid_request_error", - } - } - - -@pytest.mark.asyncio -async def test_v1_responses_http_bridge_preserves_rate_limit_metadata_in_429(async_client, monkeypatch): - _install_bridge_settings(monkeypatch, enabled=True) - account_id = await _import_account( - async_client, - "acc_http_bridge_ratelimit", - "http-bridge-ratelimit@example.com", - ) - account = await _get_account(account_id) - fake_upstream = _RateLimitErrorUpstreamWebSocket() - - async def fake_select_account_with_budget( - self, - deadline, - *, - request_id, - kind, - sticky_key, - sticky_kind, - reallocate_sticky, - sticky_max_age_seconds, - prefer_earlier_reset_accounts, - routing_strategy, - model, - exclude_account_ids=None, - additional_limit_name=None, - api_key=None, - ): - return AccountSelection(account=account, error_message=None, error_code=None) - - async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): - return target - - async def fake_connect_responses_websocket( - headers, - access_token, - account_id_header, - *, - base_url=None, - session=None, - ): - return fake_upstream - - monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) - monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) - monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + service = get_proxy_service_for_app(app_instance) + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + await _replace_http_bridge_upstream_reader( + service, + session, + cast(proxy_module.UpstreamResponsesWebSocket, precreated_close_upstream), + ) - response = await async_client.post( + second = await async_client.post( "/v1/responses", json={ - "model": "gpt-4o", + "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "hello", - "prompt_cache_key": "http-bridge-ratelimit-key", - "stream": True, + "input": "hello-again", + "prompt_cache_key": "precreated-previous-response", + "previous_response_id": first_body["id"], }, ) - assert response.status_code == 429 - body = response.json() - assert body["error"]["code"] == "rate_limit_exceeded" - assert body["error"]["plan_type"] == "team" - assert body["error"]["resets_at"] == 1700000000 - assert body["error"]["resets_in_seconds"] == 3600 + assert second.status_code == 502 + assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "upstream_request_timeout") + assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert connect_count == 1 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_cancellation_releases_queued_slot(async_client, app_instance, monkeypatch): +async def test_v1_responses_http_bridge_rebinds_after_upstream_previous_response_not_found( + async_client, + app_instance, + monkeypatch, +): _install_bridge_settings(monkeypatch, enabled=True) - account_id = await _import_account(async_client, "acc_http_bridge_cancel", "http-bridge-cancel@example.com") - service = get_proxy_service_for_app(app_instance) + account_id = await _import_account( + async_client, + "acc_http_bridge_previous_response_rebind", + "http-bridge-previous-response-rebind@example.com", + ) account = await _get_account(account_id) - upstream = _SilentUpstreamWebSocket() + first_upstream = _FakeBridgeUpstreamWebSocket() + recovered_upstream = _FakeBridgeUpstreamWebSocket() + connect_count = 0 async def fake_select_account_with_budget( self, @@ -6254,6 +7352,7 @@ async def fake_select_account_with_budget( model, exclude_account_ids, additional_limit_name, + api_key, ) return AccountSelection(account=account, error_message=None, error_code=None) @@ -6270,77 +7369,68 @@ async def fake_connect_responses_websocket( session=None, ): del headers, access_token, account_id_header, base_url, session - return upstream + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return first_upstream + return recovered_upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - payload = proxy_module.ResponsesRequest( - model="gpt-5.1", - instructions="Return exactly OK.", - input="cancel-me", - prompt_cache_key="cancel-key", - ) - affinity = proxy_module._sticky_key_for_responses_request( - payload, - {}, - codex_session_affinity=False, - openai_cache_affinity=True, - openai_cache_affinity_max_age_seconds=300, - sticky_threads_enabled=False, - api_key=None, - ) - key = proxy_module._make_http_bridge_session_key( - payload, - headers={}, - affinity=affinity, - api_key=None, - request_id="req_cancel", - ) - session = await service._get_or_create_http_bridge_session( - key, - headers={}, - affinity=affinity, - api_key=None, - request_model="gpt-5.1", - idle_ttl_seconds=120.0, - max_sessions=128, + first = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "previous-response-rebind", + }, ) + assert first.status_code == 200 + first_body = first.json() - await session.response_create_gate.acquire() - request_state, text_data = service._prepare_http_bridge_request(payload, {}, api_key=None, api_key_reservation=None) - request_state.transport = "http" - task = asyncio.create_task( - service._submit_http_bridge_request( + service = get_proxy_service_for_app(app_instance) + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + await _replace_http_bridge_upstream_reader( + service, session, - request_state=request_state, - text_data=text_data, - queue_limit=8, + cast(proxy_module.UpstreamResponsesWebSocket, _PreviousResponseNotFoundUpstreamWebSocket()), ) + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "previous-response-rebind", + "previous_response_id": first_body["id"], + }, ) - await asyncio.sleep(0) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - assert session.queued_request_count == 0 - async with session.pending_lock: - assert list(session.pending_requests) == [] - session.response_create_gate.release() - await service._close_http_bridge_session(session) + assert second.status_code == 200 + assert second.json()["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 2 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_send_retry_restarts_reader(async_client, monkeypatch): +async def test_v1_responses_http_bridge_rebinds_after_upstream_invalid_request_previous_response_not_found_param( + async_client, + app_instance, + monkeypatch, +): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, - "acc_http_bridge_send_retry", - "http-bridge-send-retry@example.com", + "acc_http_bridge_invalid_request_rebind", + "http-bridge-invalid-request-rebind@example.com", ) account = await _get_account(account_id) - upstreams = [_FailingSendThenCloseUpstreamWebSocket(), _FakeBridgeUpstreamWebSocket()] + first_upstream = _FakeBridgeUpstreamWebSocket() + recovered_upstream = _FakeBridgeUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -6374,6 +7464,7 @@ async def fake_select_account_with_budget( model, exclude_account_ids, additional_limit_name, + api_key, ) return AccountSelection(account=account, error_message=None, error_code=None) @@ -6391,198 +7482,59 @@ async def fake_connect_responses_websocket( ): del headers, access_token, account_id_header, base_url, session nonlocal connect_count - upstream = upstreams[connect_count] connect_count += 1 - if isinstance(upstream, _FakeBridgeUpstreamWebSocket) and not upstream._messages.qsize(): - await upstream._messages.put( - _FakeUpstreamMessage( - "text", - text=json.dumps( - { - "type": "response.created", - "response": {"id": "resp_retry_send", "object": "response", "status": "in_progress"}, - }, - separators=(",", ":"), - ), - ) - ) - await upstream._messages.put( - _FakeUpstreamMessage( - "text", - text=json.dumps( - { - "type": "response.completed", - "response": { - "id": "resp_retry_send", - "object": "response", - "status": "completed", - "usage": { - "input_tokens": 24, - "output_tokens": 2, - "total_tokens": 26, - "input_tokens_details": {"cached_tokens": 20}, - "output_tokens_details": {"reasoning_tokens": 0}, - }, - }, - }, - separators=(",", ":"), - ), - ) - ) - return upstream + if connect_count == 1: + return first_upstream + return recovered_upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - response = await async_client.post( + first = await async_client.post( "/v1/responses", json={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "retry-send", - "prompt_cache_key": "retry-send-key", + "input": "hello", + "prompt_cache_key": "invalid-request-rebind", }, ) + assert first.status_code == 200 + first_body = first.json() - assert response.status_code == 200 - assert response.json()["id"] == "resp_retry_send" - assert connect_count == 2 - - -@pytest.mark.asyncio -async def test_retry_http_bridge_precreated_request_releases_pending_lock_before_reconnect(app_instance, monkeypatch): service = get_proxy_service_for_app(app_instance) - session = proxy_module._HTTPBridgeSession( - key=proxy_module._HTTPBridgeSessionKey("prompt_cache", "retry-lock-key", None), - headers={}, - affinity=proxy_module._AffinityPolicy( - key="retry-lock-key", - kind=proxy_module.StickySessionKind.PROMPT_CACHE, - max_age_seconds=300, - ), - request_model="gpt-5.1", - account=cast(Account, SimpleNamespace(id="acct-retry", status=AccountStatus.ACTIVE)), - upstream=cast(proxy_module.UpstreamResponsesWebSocket, _SilentUpstreamWebSocket()), - upstream_control=proxy_module._WebSocketUpstreamControl(), - pending_requests=deque(), - pending_lock=anyio.Lock(), - response_create_gate=asyncio.Semaphore(1), - queued_request_count=1, - last_used_at=time.monotonic(), - idle_ttl_seconds=120.0, - ) - request_state = proxy_module._WebSocketRequestState( - request_id="req-precreated-retry", - model="gpt-5.1", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=time.monotonic(), - awaiting_response_created=True, - response_create_gate_acquired=True, - request_text=json.dumps({"type": "response.create", "model": "gpt-5.1", "input": []}), - ) - session.pending_requests.append(request_state) - reconnect_started = asyncio.Event() - allow_reconnect_finish = asyncio.Event() - lock_reacquired = asyncio.Event() - replacement_upstream = _RecordingUpstreamWebSocket() - - async def fake_reconnect(self, target_session, *, request_state, restart_reader=False): - del self, request_state, restart_reader - reconnect_started.set() - await _wait_for_event(allow_reconnect_finish) - target_session.upstream = replacement_upstream - - monkeypatch.setattr(proxy_module.ProxyService, "_reconnect_http_bridge_session", fake_reconnect) - - retry_task = asyncio.create_task(service._retry_http_bridge_precreated_request(session)) - await _wait_for_event(reconnect_started) - - async def acquire_pending_lock() -> None: - async with session.pending_lock: - lock_reacquired.set() - - lock_task = asyncio.create_task(acquire_pending_lock()) - await asyncio.wait_for(lock_reacquired.wait(), timeout=1.0) - allow_reconnect_finish.set() - - assert await retry_task is True - await lock_task - assert replacement_upstream.sent_text == [request_state.request_text] - + async with service._http_bridge_lock: + session = next(iter(service._http_bridge_sessions.values())) + await _replace_http_bridge_upstream_reader( + service, + session, + cast(proxy_module.UpstreamResponsesWebSocket, _InvalidRequestPreviousResponseUpstreamWebSocket()), + ) -@pytest.mark.asyncio -async def test_retry_http_bridge_precreated_request_ignores_existing_response_id_entries(app_instance, monkeypatch): - service = get_proxy_service_for_app(app_instance) - session = proxy_module._HTTPBridgeSession( - key=proxy_module._HTTPBridgeSessionKey("prompt_cache", "retry-race-key", None), - headers={}, - affinity=proxy_module._AffinityPolicy( - key="retry-race-key", - kind=proxy_module.StickySessionKind.PROMPT_CACHE, - max_age_seconds=300, - ), - request_model="gpt-5.1", - account=cast(Account, SimpleNamespace(id="acct-race", status=AccountStatus.ACTIVE)), - upstream=cast(proxy_module.UpstreamResponsesWebSocket, _SilentUpstreamWebSocket()), - upstream_control=proxy_module._WebSocketUpstreamControl(), - pending_requests=deque(), - pending_lock=anyio.Lock(), - response_create_gate=asyncio.Semaphore(1), - queued_request_count=2, - last_used_at=time.monotonic(), - idle_ttl_seconds=120.0, - ) - existing_request = proxy_module._WebSocketRequestState( - request_id="req-existing", - model="gpt-5.1", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=time.monotonic(), - response_id="resp-existing", - awaiting_response_created=False, - ) - retry_request = proxy_module._WebSocketRequestState( - request_id="req-precreated-race", - model="gpt-5.1", - service_tier=None, - reasoning_effort=None, - api_key_reservation=None, - started_at=time.monotonic(), - awaiting_response_created=True, - request_text=json.dumps({"type": "response.create", "model": "gpt-5.1", "input": ["retry"]}), + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "invalid-request-rebind", + "previous_response_id": first_body["id"], + }, ) - session.pending_requests.extend([existing_request, retry_request]) - replacement_upstream = _RecordingUpstreamWebSocket() - - async def fake_reconnect(self, target_session, *, request_state, restart_reader=False): - del self, request_state, restart_reader - target_session.upstream = replacement_upstream - - monkeypatch.setattr(proxy_module.ProxyService, "_reconnect_http_bridge_session", fake_reconnect) - assert await service._retry_http_bridge_precreated_request(session) is True - assert replacement_upstream.sent_text == [retry_request.request_text] + assert second.status_code == 200 + assert second.json()["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 2 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_send_failure_returns_upstream_unavailable( - async_client, +async def test_v1_responses_http_bridge_masks_anonymous_previous_response_not_found_with_inflight_request( app_instance, monkeypatch, ): _install_bridge_settings(monkeypatch, enabled=True) - account_id = await _import_account( - async_client, - "acc_http_bridge_send_failure_previous_response", - "http-bridge-send-failure-previous-response@example.com", - ) - account = await _get_account(account_id) - fake_upstream = _FakeBridgeUpstreamWebSocket() - failing_upstream = _FailingSendThenCloseUpstreamWebSocket() + upstream = _AnonymousPreviousResponseNotFoundWithInflightUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -6616,6 +7568,7 @@ async def fake_select_account_with_budget( model, exclude_account_ids, additional_limit_name, + api_key, ) return AccountSelection(account=account, error_message=None, error_code=None) @@ -6634,61 +7587,78 @@ async def fake_connect_responses_websocket( del headers, access_token, account_id_header, base_url, session nonlocal connect_count connect_count += 1 - return fake_upstream if connect_count == 1 else failing_upstream + return upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - first = await async_client.post( - "/v1/responses", - json={ - "model": "gpt-5.1", - "instructions": "Return exactly OK.", - "input": "hello", - "prompt_cache_key": "send-failure-previous-response", - }, - ) - assert first.status_code == 200 - first_body = first.json() + async with app_instance.router.lifespan_context(app_instance): + async with ( + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as admin_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as first_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as second_client, + ): + account_id = await _import_account( + admin_client, + "acc_http_bridge_prev_nf_inflight", + "http-bridge-prev-nf-inflight@example.com", + ) + account = await _get_account(account_id) + + first = asyncio.create_task( + first_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "previous-response-inflight-mixed", + }, + ) + ) + await _wait_for_event(upstream.first_request_created) - service = get_proxy_service_for_app(app_instance) - async with service._http_bridge_lock: - session = next(iter(service._http_bridge_sessions.values())) - session.upstream = cast(proxy_module.UpstreamResponsesWebSocket, failing_upstream) + second = asyncio.create_task( + second_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello-again", + "prompt_cache_key": "previous-response-inflight-mixed", + "previous_response_id": "resp_bridge_prev_anchor", + }, + ) + ) - second = await async_client.post( - "/v1/responses", - json={ - "model": "gpt-5.1", - "instructions": "Return exactly OK.", - "input": "hello-again", - "prompt_cache_key": "send-failure-previous-response", - "previous_response_id": first_body["id"], - }, - ) + first_response, second_response = await asyncio.wait_for( + asyncio.gather(first, second), + timeout=_TEST_SYNC_TIMEOUT_SECONDS, + ) - assert second.status_code == 502 - assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "bridge_owner_unreachable") - assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert first_response.status_code == 200 + assert first_response.json()["output"][0]["content"][0]["text"] == "OK" + assert second_response.status_code >= 400 + assert second_response.json()["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in second_response.json()["error"].get("code", "") + assert "previous_response_not_found" not in second_response.json()["error"].get("message", "") assert connect_count == 1 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_precreated_disconnect_returns_upstream_unavailable( +async def test_v1_responses_http_bridge_keeps_session_alive_after_foreign_previous_response_not_found( async_client, - app_instance, monkeypatch, ): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, - "acc_http_bridge_precreated_previous_response", - "http-bridge-precreated-previous-response@example.com", + "acc_http_bridge_foreign_prev_nf_created", + "http-bridge-foreign-prev-nf-created@example.com", ) account = await _get_account(account_id) - fake_upstream = _FakeBridgeUpstreamWebSocket() - precreated_close_upstream = _PrecreatedCloseUpstreamWebSocket() + upstream = _ForeignPreviousResponseNotFoundAfterCreatedUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -6722,6 +7692,7 @@ async def fake_select_account_with_budget( model, exclude_account_ids, additional_limit_name, + api_key, ) return AccountSelection(account=account, error_message=None, error_code=None) @@ -6740,7 +7711,7 @@ async def fake_connect_responses_websocket( del headers, access_token, account_id_header, base_url, session nonlocal connect_count connect_count += 1 - return fake_upstream if connect_count == 1 else precreated_close_upstream + return upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) @@ -6752,53 +7723,56 @@ async def fake_connect_responses_websocket( "model": "gpt-5.1", "instructions": "Return exactly OK.", "input": "hello", - "prompt_cache_key": "precreated-previous-response", + "prompt_cache_key": "foreign-previous-response-created", }, ) assert first.status_code == 200 first_body = first.json() - service = get_proxy_service_for_app(app_instance) - async with service._http_bridge_lock: - session = next(iter(service._http_bridge_sessions.values())) - await _replace_http_bridge_upstream_reader( - service, - session, - cast(proxy_module.UpstreamResponsesWebSocket, precreated_close_upstream), - ) - second = await async_client.post( "/v1/responses", json={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "hello-again", - "prompt_cache_key": "precreated-previous-response", + "input": "continue", + "prompt_cache_key": "foreign-previous-response-created", "previous_response_id": first_body["id"], }, ) + third = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "after error", + "prompt_cache_key": "foreign-previous-response-created", + }, + ) + assert second.status_code == 502 - assert second.json()["error"]["code"] in ("upstream_unavailable", "stream_incomplete", "upstream_request_timeout") + assert second.json()["error"]["code"] == "stream_incomplete" assert "previous_response_not_found" not in second.json()["error"].get("code", "") + assert "previous_response_not_found" not in second.json()["error"].get("message", "") + assert third.status_code == 200 + assert third.json()["output"][0]["content"][0]["text"] == "OK" assert connect_count == 1 + assert upstream.closed is False @pytest.mark.asyncio -async def test_v1_responses_http_bridge_rebinds_after_upstream_previous_response_not_found( +async def test_v1_responses_http_bridge_stream_keeps_session_alive_after_foreign_previous_response_not_found( async_client, - app_instance, monkeypatch, ): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, - "acc_http_bridge_previous_response_rebind", - "http-bridge-previous-response-rebind@example.com", + "acc_http_bridge_foreign_prev_nf_created_stream", + "http-bridge-foreign-prev-nf-created-stream@example.com", ) account = await _get_account(account_id) - first_upstream = _FakeBridgeUpstreamWebSocket() - recovered_upstream = _FakeBridgeUpstreamWebSocket() + upstream = _ForeignPreviousResponseNotFoundAfterCreatedUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -6851,66 +7825,67 @@ async def fake_connect_responses_websocket( del headers, access_token, account_id_header, base_url, session nonlocal connect_count connect_count += 1 - if connect_count == 1: - return first_upstream - return recovered_upstream + return upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - first = await async_client.post( + first_events = await _collect_sse_events( + async_client, "/v1/responses", - json={ + json_body={ "model": "gpt-5.1", "instructions": "Return exactly OK.", "input": "hello", - "prompt_cache_key": "previous-response-rebind", + "prompt_cache_key": "foreign-previous-response-created-stream", + "stream": True, }, ) - assert first.status_code == 200 - first_body = first.json() + first_response = first_events[-1]["response"] - service = get_proxy_service_for_app(app_instance) - async with service._http_bridge_lock: - session = next(iter(service._http_bridge_sessions.values())) - await _replace_http_bridge_upstream_reader( - service, - session, - cast(proxy_module.UpstreamResponsesWebSocket, _PreviousResponseNotFoundUpstreamWebSocket()), - ) + second_events = await _collect_sse_events( + async_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "continue", + "prompt_cache_key": "foreign-previous-response-created-stream", + "previous_response_id": first_response["id"], + "stream": True, + }, + ) - second = await async_client.post( + third_events = await _collect_sse_events( + async_client, "/v1/responses", - json={ + json_body={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "hello-again", - "prompt_cache_key": "previous-response-rebind", - "previous_response_id": first_body["id"], + "input": "after error", + "prompt_cache_key": "foreign-previous-response-created-stream", + "stream": True, }, ) - assert second.status_code == 200 - assert second.json()["output"][0]["content"][0]["text"] == "OK" - assert connect_count == 2 + assert [event["type"] for event in first_events] == ["response.created", "response.completed"] + assert [event["type"] for event in second_events] == ["response.created", "response.failed"] + assert [event["type"] for event in third_events] == ["response.created", "response.completed"] + assert second_events[-1]["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(second_events[-1]) + assert third_events[-1]["response"]["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 1 + assert upstream.closed is False @pytest.mark.asyncio -async def test_v1_responses_http_bridge_rebinds_after_upstream_invalid_request_previous_response_not_found_param( - async_client, +async def test_v1_responses_http_bridge_stream_keeps_session_alive_after_anonymous_prev_nf_created_followup( app_instance, monkeypatch, ): _install_bridge_settings(monkeypatch, enabled=True) - account_id = await _import_account( - async_client, - "acc_http_bridge_invalid_request_rebind", - "http-bridge-invalid-request-rebind@example.com", - ) - account = await _get_account(account_id) - first_upstream = _FakeBridgeUpstreamWebSocket() - recovered_upstream = _FakeBridgeUpstreamWebSocket() + upstream = _AnonymousPreviousResponseNotFoundAfterCreatedUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -6963,58 +7938,91 @@ async def fake_connect_responses_websocket( del headers, access_token, account_id_header, base_url, session nonlocal connect_count connect_count += 1 - if connect_count == 1: - return first_upstream - return recovered_upstream + return upstream monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) - first = await async_client.post( - "/v1/responses", - json={ - "model": "gpt-5.1", - "instructions": "Return exactly OK.", - "input": "hello", - "prompt_cache_key": "invalid-request-rebind", - }, - ) - assert first.status_code == 200 - first_body = first.json() + async with app_instance.router.lifespan_context(app_instance): + async with ( + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as admin_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as first_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as second_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as third_client, + ): + account_id = await _import_account( + admin_client, + "acc_http_bridge_anonymous_prev_nf_created_followup", + "http-bridge-anonymous-prev-nf-created-followup@example.com", + ) + account = await _get_account(account_id) - service = get_proxy_service_for_app(app_instance) - async with service._http_bridge_lock: - session = next(iter(service._http_bridge_sessions.values())) - await _replace_http_bridge_upstream_reader( - service, - session, - cast(proxy_module.UpstreamResponsesWebSocket, _InvalidRequestPreviousResponseUpstreamWebSocket()), - ) + first = asyncio.create_task( + _collect_sse_events( + first_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "hello", + "prompt_cache_key": "anonymous-created-followup-stream", + "stream": True, + }, + ) + ) + await _wait_for_event(upstream.first_request_created) - second = await async_client.post( - "/v1/responses", - json={ - "model": "gpt-5.1", - "instructions": "Return exactly OK.", - "input": "hello-again", - "prompt_cache_key": "invalid-request-rebind", - "previous_response_id": first_body["id"], - }, - ) + second = asyncio.create_task( + _collect_sse_events( + second_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "continue", + "prompt_cache_key": "anonymous-created-followup-stream", + "previous_response_id": "resp_bridge_prev_anchor", + "stream": True, + }, + ) + ) - assert second.status_code == 200 - assert second.json()["output"][0]["content"][0]["text"] == "OK" - assert connect_count == 2 + first_events, second_events = await asyncio.wait_for( + asyncio.gather(first, second), + timeout=_TEST_SYNC_TIMEOUT_SECONDS, + ) + + third_events = await _collect_sse_events( + third_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "after error", + "prompt_cache_key": "anonymous-created-followup-stream", + "stream": True, + }, + ) + + assert [event["type"] for event in first_events] == ["response.created", "response.completed"] + assert [event["type"] for event in second_events] == ["response.created", "response.failed"] + assert second_events[0]["response"]["id"] == "resp_bridge_followup_created" + assert second_events[1]["response"]["id"] == "resp_bridge_followup_created" + assert second_events[1]["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(second_events[1]) + assert [event["type"] for event in third_events] == ["response.created", "response.completed"] + assert third_events[-1]["response"]["output"][0]["content"][0]["text"] == "OK" + assert connect_count == 1 @pytest.mark.asyncio -async def test_v1_responses_http_bridge_masks_anonymous_previous_response_not_found_with_inflight_request( +async def test_v1_responses_http_bridge_stream_matches_previous_response_error_to_anchor_with_two_followups( app_instance, monkeypatch, ): _install_bridge_settings(monkeypatch, enabled=True) - upstream = _AnonymousPreviousResponseNotFoundWithInflightUpstreamWebSocket() + upstream = _TwoFollowupsPreviousResponseNotFoundUpstreamWebSocket() connect_count = 0 async def fake_select_account_with_budget( @@ -7078,51 +8086,98 @@ async def fake_connect_responses_websocket( AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as admin_client, AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as first_client, AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as second_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as third_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as fourth_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as fifth_client, ): account_id = await _import_account( admin_client, - "acc_http_bridge_prev_nf_inflight", - "http-bridge-prev-nf-inflight@example.com", + "acc_http_bridge_two_followups_prev_nf", + "http-bridge-two-followups-prev-nf@example.com", ) account = await _get_account(account_id) - first = asyncio.create_task( - first_client.post( + first_response = await first_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "anchor-a", + "prompt_cache_key": "two-followups-prev-nf-stream", + }, + ) + assert first_response.status_code == 200 + + second_response = await second_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "anchor-b", + "prompt_cache_key": "two-followups-prev-nf-stream", + }, + ) + assert second_response.status_code == 200 + + third = asyncio.create_task( + _collect_sse_events( + third_client, "/v1/responses", - json={ + json_body={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "hello", - "prompt_cache_key": "previous-response-inflight-mixed", + "input": "continue-a", + "prompt_cache_key": "two-followups-prev-nf-stream", + "previous_response_id": first_response.json()["id"], + "stream": True, }, ) ) - await _wait_for_event(upstream.first_request_created) + await _wait_for_event(upstream.first_followup_created) - second = asyncio.create_task( - second_client.post( + fourth = asyncio.create_task( + _collect_sse_events( + fourth_client, "/v1/responses", - json={ + json_body={ "model": "gpt-5.1", "instructions": "Return exactly OK.", - "input": "hello-again", - "prompt_cache_key": "previous-response-inflight-mixed", - "previous_response_id": "resp_bridge_prev_anchor", + "input": "continue-b", + "prompt_cache_key": "two-followups-prev-nf-stream", + "previous_response_id": second_response.json()["id"], + "stream": True, }, ) ) - first_response, second_response = await asyncio.wait_for( - asyncio.gather(first, second), + third_events, fourth_events = await asyncio.wait_for( + asyncio.gather(third, fourth), timeout=_TEST_SYNC_TIMEOUT_SECONDS, ) - assert first_response.status_code == 200 - assert first_response.json()["output"][0]["content"][0]["text"] == "OK" - assert second_response.status_code >= 400 - assert second_response.json()["error"]["code"] == "stream_incomplete" - assert "previous_response_not_found" not in second_response.json()["error"].get("code", "") - assert "previous_response_not_found" not in second_response.json()["error"].get("message", "") + fifth_events = await _collect_sse_events( + fifth_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "after error", + "prompt_cache_key": "two-followups-prev-nf-stream", + "stream": True, + }, + ) + + assert [event["type"] for event in third_events] == ["response.created", "response.failed"] + assert [event["type"] for event in fourth_events] == ["response.created", "response.completed"] + assert third_events[0]["response"]["id"] == "resp_bridge_followup_a" + assert third_events[1]["response"]["id"] == "resp_bridge_followup_a" + assert third_events[1]["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(third_events[1]) + assert fourth_events[0]["response"]["id"] == "resp_bridge_followup_b" + assert fourth_events[1]["response"]["id"] == "resp_bridge_followup_b" + assert fourth_events[1]["response"]["output"][0]["content"][0]["text"] == "OK" + assert [event["type"] for event in fifth_events] == ["response.created", "response.completed"] + assert fifth_events[-1]["response"]["output"][0]["content"][0]["text"] == "OK" assert connect_count == 1 diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 1fe7d0ec..3d30aef1 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1514,6 +1514,194 @@ async def fake_try_open_websocket_connect_attempt( assert event["error"]["message"] == "Upstream websocket closed before response.completed" +def test_backend_responses_websocket_masks_previous_response_not_found_and_recovers_on_retry( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_ws_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_retry", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_retry", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + captured_preferred_accounts: list[str | None] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + captured_preferred_accounts.append(request_state.preferred_account_id) + if connect_count == 1: + return SimpleNamespace(id="acct_ws_prev_mask"), first_upstream + return SimpleNamespace(id="acct_ws_prev_mask"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id == "resp_ws_prev_anchor" + return "acct_ws_prev_mask" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "hello", + "stream": True, + } + ) + ) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert completed_2["type"] == "response.completed" + assert created_2["response"]["id"] == "resp_ws_prev_retry" + assert completed_2["response"]["id"] == "resp_ws_prev_retry" + assert connect_count == 2 + assert captured_preferred_accounts == [None, "acct_ws_prev_mask"] + assert first_upstream.closed is True + + def test_backend_responses_websocket_masks_anonymous_previous_response_not_found_with_inflight_request( app_instance, monkeypatch, @@ -1679,6 +1867,740 @@ async def fake_resolve_previous_response_owner( assert fake_upstream.closed is True +def test_backend_responses_websocket_keeps_session_alive_after_foreign_previous_response_not_found( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_created", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.failed", + "response": { + "id": "resp_ws_foreign_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_ws_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_after_error", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_after_error", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_followup_prev_nf"), first_upstream + return SimpleNamespace(id="acct_ws_followup_prev_nf"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id == "resp_ws_prev_anchor" + return "acct_ws_followup_prev_nf" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "hello", + "stream": True, + } + ) + ) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + failed_2 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "after error", + "stream": True, + } + ) + ) + created_3 = json.loads(websocket.receive_text()) + completed_3 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert created_2["response"]["id"] == "resp_ws_followup_created" + assert failed_2["type"] == "response.failed" + assert failed_2["response"]["id"] == "resp_ws_followup_created" + assert failed_2["response"]["error"]["code"] == "stream_incomplete" + assert failed_2["response"]["error"]["message"] == "Upstream websocket closed before response.completed" + assert "previous_response_not_found" not in json.dumps(failed_2) + assert created_3["type"] == "response.created" + assert completed_3["type"] == "response.completed" + assert created_3["response"]["id"] == "resp_ws_after_error" + assert completed_3["response"]["id"] == "resp_ws_after_error" + assert connect_count == 2 + assert first_upstream.closed is True + + +def test_backend_responses_websocket_keeps_session_alive_after_anonymous_prev_nf_created_followup( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_inflight", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_created", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_ws_prev_anchor' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_inflight", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_after_error", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_after_error", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_followup_prev_nf"), first_upstream + return SimpleNamespace(id="acct_ws_followup_prev_nf"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id == "resp_ws_prev_anchor" + return "acct_ws_followup_prev_nf" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "hello", + "stream": True, + } + ) + ) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "first inflight", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_3 = json.loads(websocket.receive_text()) + failed_3 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "after error", + "stream": True, + } + ) + ) + created_4 = json.loads(websocket.receive_text()) + completed_4 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert created_2["response"]["id"] == "resp_ws_inflight" + assert created_3["type"] == "response.created" + assert created_3["response"]["id"] == "resp_ws_followup_created" + assert failed_3["type"] == "response.failed" + assert failed_3["response"]["id"] == "resp_ws_followup_created" + assert failed_3["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(failed_3) + assert completed_2["type"] == "response.completed" + assert completed_2["response"]["id"] == "resp_ws_inflight" + assert created_4["type"] == "response.created" + assert created_4["response"]["id"] == "resp_ws_after_error" + assert completed_4["type"] == "response.completed" + assert completed_4["response"]["id"] == "resp_ws_after_error" + assert connect_count == 2 + assert first_upstream.closed is True + + +def test_backend_responses_websocket_matches_previous_response_error_to_anchor_with_two_followups( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor_a", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor_a", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor_b", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor_b", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_a", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_b", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_ws_prev_anchor_a.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_followup_b", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_after_error", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_after_error", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_followup_prev_nf"), first_upstream + return SimpleNamespace(id="acct_ws_followup_prev_nf"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id in {"resp_ws_prev_anchor_a", "resp_ws_prev_anchor_b"} + return "acct_ws_followup_prev_nf" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "anchor-a"})) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "anchor-b"})) + created_2 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-a", + "previous_response_id": "resp_ws_prev_anchor_a", + "stream": True, + } + ) + ) + created_3 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-b", + "previous_response_id": "resp_ws_prev_anchor_b", + "stream": True, + } + ) + ) + created_4 = json.loads(websocket.receive_text()) + failed_3 = json.loads(websocket.receive_text()) + completed_4 = json.loads(websocket.receive_text()) + + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "after-error"})) + created_5 = json.loads(websocket.receive_text()) + completed_5 = json.loads(websocket.receive_text()) + + assert created_1["response"]["id"] == "resp_ws_prev_anchor_a" + assert completed_1["response"]["id"] == "resp_ws_prev_anchor_a" + assert created_2["response"]["id"] == "resp_ws_prev_anchor_b" + assert completed_2["response"]["id"] == "resp_ws_prev_anchor_b" + assert created_3["response"]["id"] == "resp_ws_followup_a" + assert created_4["response"]["id"] == "resp_ws_followup_b" + assert failed_3["type"] == "response.failed" + assert failed_3["response"]["id"] == "resp_ws_followup_a" + assert failed_3["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(failed_3) + assert completed_4["type"] == "response.completed" + assert completed_4["response"]["id"] == "resp_ws_followup_b" + assert created_5["response"]["id"] == "resp_ws_after_error" + assert completed_5["response"]["id"] == "resp_ws_after_error" + assert connect_count == 2 + assert first_upstream.closed is True + + @pytest.mark.parametrize("frame", ['{"type":"response.create"', "[]"]) def test_backend_responses_websocket_rejects_malformed_first_frame_as_invalid_payload(app_instance, monkeypatch, frame): called = {"connect": False} diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index f1c5531b..e16d6ccc 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -5207,6 +5207,678 @@ async def test_process_upstream_websocket_text_does_not_match_foreign_completed_ assert list(pending_requests) == [pending_request] +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_a.", + "param": "previous_response_id", + }, + "response": {"id": "resp_ws_foreign_prev_nf"}, + }, + { + "type": "response.failed", + "response": { + "id": "resp_ws_foreign_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_a.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_skips_foreign_prev_nf_for_mismatched_created_followup( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_foreign_prev_nf_created_mismatch") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_prev_nf_mismatch", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_b", + previous_response_id="resp_anchor_b", + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert downstream_text == json.dumps(payload, separators=(",", ":")) + finalize_request_state.assert_not_awaited() + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is False + assert list(pending_requests) == [pending_request] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + "response": {"id": "resp_ws_foreign_prev_nf"}, + }, + { + "type": "response.failed", + "response": { + "id": "resp_ws_foreign_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_masks_foreign_previous_response_not_found_for_only_created_followup( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_foreign_prev_nf_created") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + pending_request = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_prev_nf", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created", + previous_response_id="resp_anchor", + ) + pending_requests = deque([pending_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is pending_request + assert finalize_call.kwargs["event_type"] == "response.failed" + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is False + assert list(pending_requests) == [] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor_a' not found.", + "param": "previous_response_id", + }, + "response": {"id": "resp_ws_foreign_prev_nf"}, + }, + { + "type": "response.failed", + "response": { + "id": "resp_ws_foreign_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor_a' not found.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_matches_foreign_prev_nf_to_anchor_with_two_followups( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_foreign_prev_nf_multiple_followups") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request_a = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_prev_nf_a", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_a", + previous_response_id="resp_anchor_a", + ) + followup_request_b = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_prev_nf_b", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_b", + previous_response_id="resp_anchor_b", + ) + pending_requests = deque([followup_request_a, followup_request_b]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert '"id":"resp_ws_followup_created_a"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request_a + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert list(pending_requests) == [followup_request_b] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_1234.", + "param": "previous_response_id", + }, + "response": {"id": "resp_ws_foreign_prev_nf"}, + }, + { + "type": "response.failed", + "response": { + "id": "resp_ws_foreign_prev_nf", + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_1234.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_matches_foreign_prev_nf_with_overlapping_anchors( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_foreign_prev_nf_overlap_followups") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request_a = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_overlap_a", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_overlap_a", + previous_response_id="resp_anchor_123", + ) + followup_request_b = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_overlap_b", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_overlap_b", + previous_response_id="resp_anchor_1234", + ) + pending_requests = deque([followup_request_a, followup_request_b]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert '"id":"resp_ws_followup_overlap_b"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request_b + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert list(pending_requests) == [followup_request_a] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_a.", + "param": "previous_response_id", + }, + }, + { + "type": "response.failed", + "response": { + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_a.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_skips_anonymous_prev_nf_for_mismatched_created_followup( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_anonymous_prev_nf_created_followup_mismatch") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_anonymous_prev_nf_mismatch", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_b", + previous_response_id="resp_anchor_b", + ) + pending_requests = deque([followup_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert downstream_text == json.dumps(payload, separators=(",", ":")) + finalize_request_state.assert_not_awaited() + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is False + assert list(pending_requests) == [followup_request] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor_a' not found.", + "param": "previous_response_id", + }, + }, + { + "type": "response.failed", + "response": { + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor_a' not found.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_matches_anonymous_prev_nf_to_anchor_with_two_followups( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_anonymous_prev_nf_multiple_followups") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request_a = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_anonymous_prev_nf_a", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_a", + previous_response_id="resp_anchor_a", + ) + followup_request_b = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_anonymous_prev_nf_b", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created_b", + previous_response_id="resp_anchor_b", + ) + pending_requests = deque([followup_request_a, followup_request_b]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert '"id":"resp_ws_followup_created_a"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request_a + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert list(pending_requests) == [followup_request_b] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_1234.", + "param": "previous_response_id", + }, + }, + { + "type": "response.failed", + "response": { + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Cannot continue conversation because upstream lost resp_anchor_1234.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_matches_anonymous_prev_nf_with_overlapping_anchors( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_anonymous_prev_nf_overlap_followups") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request_a = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_overlap_anonymous_a", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_overlap_anonymous_a", + previous_response_id="resp_anchor_123", + ) + followup_request_b = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_overlap_anonymous_b", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_overlap_anonymous_b", + previous_response_id="resp_anchor_1234", + ) + pending_requests = deque([followup_request_a, followup_request_b]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert '"id":"resp_ws_followup_overlap_anonymous_b"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request_b + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert list(pending_requests) == [followup_request_a] + + +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + }, + { + "type": "response.failed", + "response": { + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_masks_anonymous_previous_response_not_found_for_created_followup( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_anonymous_prev_nf_created_followup") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + inflight_request = proxy_service._WebSocketRequestState( + request_id="ws_req_inflight_created_followup_prev_nf", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_inflight", + ) + followup_request = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_created_anonymous_prev_nf", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + response_id="resp_ws_followup_created", + previous_response_id="resp_anchor", + ) + pending_requests = deque([inflight_request, followup_request]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(1), + ) + + assert '"type":"response.failed"' in downstream_text + assert '"code":"stream_incomplete"' in downstream_text + assert '"id":"resp_ws_followup_created"' in downstream_text + assert "previous_response_not_found" not in downstream_text + finalize_request_state.assert_awaited_once() + finalize_call = finalize_request_state.await_args + assert finalize_call is not None + assert finalize_call.args[0] is followup_request + assert finalize_call.kwargs["event_type"] == "response.failed" + handle_stream_error.assert_not_awaited() + assert upstream_control.reconnect_requested is True + assert upstream_control.suppress_downstream_event is False + assert list(pending_requests) == [inflight_request] + + @pytest.mark.asyncio async def test_process_upstream_websocket_text_transparently_retries_precreated_usage_limit_failure( monkeypatch, From 090da753218d987ff1e052201db88061c9dc3586 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Tue, 21 Apr 2026 11:08:21 +0200 Subject: [PATCH 15/18] fix(proxy): fail-closed previous_response_not_found and keep WS/HTTP bridge run continuity --- app/modules/proxy/service.py | 195 ++++++++- .../specs/responses-api-compat/spec.md | 15 + .../integration/test_http_responses_bridge.py | 310 ++++++++++++++ .../test_proxy_websocket_responses.py | 404 ++++++++++++++++++ tests/unit/test_proxy_utils.py | 130 ++++++ 5 files changed, 1041 insertions(+), 13 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index dbea2752..ba737168 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -4797,6 +4797,7 @@ async def _process_http_bridge_upstream_text( matched_request_state = None created_request_state = None has_other_pending_requests = False + grouped_previous_response_request_states: list[_WebSocketRequestState] = [] if event_type == "response.created": matched_request_state = _assign_websocket_response_id(session.pending_requests, response_id) created_request_state = matched_request_state @@ -4842,9 +4843,57 @@ async def _process_http_bridge_upstream_text( ) if terminal_request_state is not None: session.queued_request_count = max(0, session.queued_request_count - 1) + elif is_previous_response_not_found_event: + grouped_previous_response_request_states = _pop_matching_websocket_request_states( + session.pending_requests, + _matching_websocket_request_states_for_previous_response_error( + session.pending_requests, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, + ), + ) + if grouped_previous_response_request_states: + session.queued_request_count = max( + 0, + session.queued_request_count - len(grouped_previous_response_request_states), + ) has_other_pending_requests = bool(session.pending_requests) + if len(grouped_previous_response_request_states) > 1: + session.upstream_control.reconnect_requested = True + for grouped_request_state in grouped_previous_response_request_states: + grouped_request_state.error_http_status_override = 502 + ( + _grouped_downstream_text, + grouped_event_block, + grouped_event, + grouped_payload, + grouped_event_type, + ) = _build_stream_incomplete_terminal_event_for_request(grouped_request_state) + if grouped_request_state.event_queue is not None: + await grouped_request_state.event_queue.put(grouped_event_block) + await grouped_request_state.event_queue.put(None) + await self._finalize_websocket_request_state( + grouped_request_state, + account=session.account, + account_id_value=session.account.id, + event=grouped_event, + event_type=grouped_event_type, + payload=grouped_payload, + api_key=grouped_request_state.api_key, + upstream_control=session.upstream_control, + response_create_gate=session.response_create_gate, + ) + return + + if len(grouped_previous_response_request_states) == 1 and terminal_request_state is None: + terminal_request_state = grouped_previous_response_request_states[0] + status_request_state = terminal_request_state or matched_request_state + if status_request_state is None and is_previous_response_not_found_event: + session.upstream_control.reconnect_requested = True + return + if ( status_request_state is not None and status_request_state.previous_response_id is not None @@ -4953,9 +5002,13 @@ def _remember_websocket_previous_response_owner( account_id_value = account_id.strip() if not account_id_value: return - cache_key = (response_id, api_key_id, _normalize_session_id(session_id)) - self._websocket_previous_response_account_index.pop(cache_key, None) - self._websocket_previous_response_account_index[cache_key] = account_id_value + cache_keys = [(response_id, api_key_id, None)] + normalized_session_id = _normalize_session_id(session_id) + if normalized_session_id is not None: + cache_keys.append((response_id, api_key_id, normalized_session_id)) + for cache_key in cache_keys: + self._websocket_previous_response_account_index.pop(cache_key, None) + self._websocket_previous_response_account_index[cache_key] = account_id_value while len(self._websocket_previous_response_account_index) > _WEBSOCKET_PREVIOUS_RESPONSE_ACCOUNT_CACHE_LIMIT: self._websocket_previous_response_account_index.pop( next(iter(self._websocket_previous_response_account_index)) @@ -5167,8 +5220,18 @@ async def _relay_upstream_websocket_messages( response_create_gate=response_create_gate, ) suppress_downstream_event = upstream_control.suppress_downstream_event + downstream_texts = upstream_control.downstream_texts upstream_control.suppress_downstream_event = False - if not suppress_downstream_event: + upstream_control.downstream_texts = None + if downstream_texts is not None: + for emitted_text in downstream_texts: + await self._send_downstream_websocket_text( + websocket, + client_send_lock=client_send_lock, + text=emitted_text, + downstream_activity=downstream_activity, + ) + elif not suppress_downstream_event: await self._send_downstream_websocket_text( websocket, client_send_lock=client_send_lock, @@ -5287,6 +5350,7 @@ async def _process_upstream_websocket_text( request_state = None created_request_state = None has_other_pending_requests = False + grouped_previous_response_request_states: list[_WebSocketRequestState] = [] if event_type == "response.created": request_state = _assign_websocket_response_id(pending_requests, response_id) created_request_state = request_state @@ -5327,6 +5391,15 @@ async def _process_upstream_websocket_text( "error", }, ) + if request_state is None and is_previous_response_not_found_event: + grouped_previous_response_request_states = _pop_matching_websocket_request_states( + pending_requests, + _matching_websocket_request_states_for_previous_response_error( + pending_requests, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, + ), + ) has_other_pending_requests = bool(pending_requests) else: request_state = None @@ -5334,7 +5407,39 @@ async def _process_upstream_websocket_text( if event_type == "response.created" and release_create_gate and created_request_state is not None: _release_websocket_response_create_gate(created_request_state, response_create_gate) + if len(grouped_previous_response_request_states) > 1: + upstream_control.reconnect_requested = True + downstream_texts: list[str] = [] + for grouped_request_state in grouped_previous_response_request_states: + ( + grouped_downstream_text, + _grouped_event_block, + grouped_event, + grouped_payload, + grouped_event_type, + ) = _build_stream_incomplete_terminal_event_for_request(grouped_request_state) + downstream_texts.append(grouped_downstream_text) + await self._finalize_websocket_request_state( + grouped_request_state, + account=account, + account_id_value=account_id_value, + event=grouped_event, + event_type=grouped_event_type, + payload=grouped_payload, + api_key=api_key, + upstream_control=upstream_control, + response_create_gate=response_create_gate, + ) + upstream_control.suppress_downstream_event = True + upstream_control.downstream_texts = downstream_texts + return downstream_texts[0] + + if len(grouped_previous_response_request_states) == 1 and request_state is None: + request_state = grouped_previous_response_request_states[0] + if request_state is None: + if is_previous_response_not_found_event: + upstream_control.suppress_downstream_event = True return text retry_is_previous_response_not_found = is_previous_response_not_found_event @@ -7535,6 +7640,7 @@ class _WebSocketUpstreamControl: reconnect_requested: bool = False suppress_downstream_event: bool = False replay_request_state: _WebSocketRequestState | None = None + downstream_texts: list[str] | None = None @dataclass(slots=True) @@ -8139,30 +8245,93 @@ def _match_websocket_request_state_for_previous_response_error( previous_response_id_hint: str | None = None, error_message: str | None = None, ) -> _WebSocketRequestState | None: + matching_requests = _matching_websocket_request_states_for_previous_response_error( + pending_requests, + previous_response_id_hint=previous_response_id_hint, + error_message=error_message, + ) + if len(matching_requests) == 1: + return matching_requests[0] + return None + + +def _matching_websocket_request_states_for_previous_response_error( + pending_requests: deque[_WebSocketRequestState], + *, + previous_response_id_hint: str | None = None, + error_message: str | None = None, +) -> list[_WebSocketRequestState]: followup_requests = [ request_state for request_state in pending_requests if request_state.previous_response_id is not None ] + if not followup_requests: + return [] if previous_response_id_hint is not None: matching_requests = [ request_state for request_state in followup_requests if request_state.previous_response_id == previous_response_id_hint ] - if len(matching_requests) == 1: - return matching_requests[0] - return None + if matching_requests: + return matching_requests if error_message is not None: matching_requests = [ request_state for request_state in followup_requests if _message_mentions_previous_response_id(error_message, request_state.previous_response_id) ] - if len(matching_requests) == 1: - return matching_requests[0] - return None - if len(followup_requests) == 1: - return followup_requests[0] - return None + if matching_requests: + return matching_requests + unresolved_followups = [request_state for request_state in followup_requests if request_state.response_id is None] + if len(unresolved_followups) == 1: + return unresolved_followups + if len(unresolved_followups) > 1: + unique_previous_response_ids = { + request_state.previous_response_id + for request_state in unresolved_followups + if request_state.previous_response_id + } + if len(unique_previous_response_ids) == 1: + return unresolved_followups + return [] + + +def _pop_matching_websocket_request_states( + pending_requests: deque[_WebSocketRequestState], + matching_requests: list[_WebSocketRequestState], +) -> list[_WebSocketRequestState]: + popped_requests: list[_WebSocketRequestState] = [] + for request_state in matching_requests: + try: + pending_requests.remove(request_state) + except ValueError: + continue + popped_requests.append(request_state) + return popped_requests + + +def _build_stream_incomplete_terminal_event_for_request( + request_state: _WebSocketRequestState, +) -> tuple[str, str, OpenAIEvent | None, dict[str, JsonValue] | None, str | None]: + event_block, event, payload, event_type = _build_rewritten_stream_response_failed_event( + response_id=request_state.response_id or request_state.request_id, + error_code="stream_incomplete", + error_message="Upstream websocket closed before response.completed", + ) + downstream_text = json.dumps( + cast( + dict[str, JsonValue], + response_failed_event( + "stream_incomplete", + "Upstream websocket closed before response.completed", + error_type="server_error", + response_id=request_state.response_id or request_state.request_id, + ), + ), + ensure_ascii=True, + separators=(",", ":"), + ) + return downstream_text, event_block, event, payload, event_type def _release_websocket_response_create_gate( diff --git a/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md index c94a2b24..e21b8f7d 100644 --- a/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md +++ b/openspec/changes/harden-continuity-fail-closed-edges/specs/responses-api-compat/spec.md @@ -22,6 +22,21 @@ When a Responses follow-up depends on previously established continuity state, t - **AND** it does not expose raw `previous_response_not_found` - **AND** unrelated pending requests continue on their own response lifecycle +#### Scenario: multiplexed follow-ups sharing one anchor fail closed together without leaking raw continuity errors +- **WHEN** a websocket or HTTP bridge session has multiple pending follow-up requests that share the same `previous_response_id` anchor +- **AND** upstream emits an anonymous continuity loss event such as `previous_response_not_found` for that shared anchor +- **THEN** the service rewrites each affected follow-up into a retryable continuity error +- **AND** no affected follow-up exposes raw `previous_response_not_found` +- **AND** the run remains usable for subsequent requests after the rewritten failures + +#### Scenario: single pre-created follow-up still fails closed when continuity loss omits explicit response id in message +- **WHEN** a websocket follow-up request is pending with `previous_response_id` and has not received a stable upstream `response.id` yet +- **AND** upstream emits `previous_response_not_found` with `param=previous_response_id` +- **AND** the upstream error message omits the literal previous response identifier +- **THEN** the service still maps that continuity loss to the pending follow-up +- **AND** it rewrites the downstream terminal event to a retryable continuity error +- **AND** it does not surface raw `previous_response_not_found` to the client + ### Requirement: Hard continuity owner lookup fails closed When a request depends on hard continuity ownership, the service MUST fail closed if owner or ring lookup errors prevent safe pinning. The service MUST NOT continue with local recovery or account selection that bypasses hard owner enforcement. diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index 9cbbd30f..36c9cfed 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -1086,6 +1086,164 @@ async def send_text(self, text: str) -> None: ) +class _TwoSameAnchorFollowupsPreviousResponseNotFoundUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): + def __init__(self) -> None: + super().__init__() + self.first_followup_created = asyncio.Event() + + async def send_text(self, text: str) -> None: + self.sent_text.append(text) + if len(self.sent_text) == 1: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_prev_anchor_shared", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_bridge_prev_anchor_shared", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + if len(self.sent_text) == 2: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_same_anchor_a", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + self.first_followup_created.set() + return + + if len(self.sent_text) == 3: + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": { + "id": "resp_bridge_followup_same_anchor_b", + "object": "response", + "status": "in_progress", + }, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_bridge_prev_anchor_shared' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ) + ) + return + + response_id = "resp_bridge_after_same_anchor_error" + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": response_id, "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ) + ) + await self._messages.put( + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": response_id, + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "OK"}], + } + ], + "usage": { + "input_tokens": 24, + "output_tokens": 2, + "total_tokens": 26, + "input_tokens_details": {"cached_tokens": 20}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + }, + }, + separators=(",", ":"), + ), + ) + ) + + class _FailingSendThenCloseUpstreamWebSocket(_FakeBridgeUpstreamWebSocket): async def send_text(self, text: str) -> None: self.sent_text.append(text) @@ -8181,6 +8339,158 @@ async def fake_connect_responses_websocket( assert connect_count == 1 +@pytest.mark.asyncio +async def test_v1_responses_http_bridge_stream_masks_anonymous_previous_response_not_found_for_same_anchor_followups( + app_instance, + monkeypatch, +): + _install_bridge_settings(monkeypatch, enabled=True) + upstream = _TwoSameAnchorFollowupsPreviousResponseNotFoundUpstreamWebSocket() + connect_count = 0 + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids=None, + additional_limit_name=None, + api_key=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + exclude_account_ids, + additional_limit_name, + api_key, + ) + return AccountSelection(account=account, error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, target, *, force=False, timeout_seconds): + del self, force, timeout_seconds + return target + + async def fake_connect_responses_websocket( + headers, + access_token, + account_id_header, + *, + base_url=None, + session=None, + ): + del headers, access_token, account_id_header, base_url, session + nonlocal connect_count + connect_count += 1 + return upstream + + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + + async with app_instance.router.lifespan_context(app_instance): + async with ( + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as admin_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as first_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as second_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as third_client, + AsyncClient(transport=ASGITransport(app=app_instance), base_url="http://testserver") as fourth_client, + ): + account_id = await _import_account( + admin_client, + "acc_http_bridge_two_same_anchor_followups_prev_nf", + "http-bridge-two-same-anchor-followups-prev-nf@example.com", + ) + account = await _get_account(account_id) + + first_response = await first_client.post( + "/v1/responses", + json={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "anchor", + "prompt_cache_key": "two-same-anchor-followups-prev-nf-stream", + }, + ) + assert first_response.status_code == 200 + + second = asyncio.create_task( + _collect_sse_events( + second_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "continue-a", + "prompt_cache_key": "two-same-anchor-followups-prev-nf-stream", + "previous_response_id": first_response.json()["id"], + "stream": True, + }, + ) + ) + await _wait_for_event(upstream.first_followup_created) + + third = asyncio.create_task( + _collect_sse_events( + third_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "continue-b", + "prompt_cache_key": "two-same-anchor-followups-prev-nf-stream", + "previous_response_id": first_response.json()["id"], + "stream": True, + }, + ) + ) + + second_events, third_events = await asyncio.wait_for( + asyncio.gather(second, third), + timeout=_TEST_SYNC_TIMEOUT_SECONDS, + ) + + fourth_events = await _collect_sse_events( + fourth_client, + "/v1/responses", + json_body={ + "model": "gpt-5.1", + "instructions": "Return exactly OK.", + "input": "after-error", + "prompt_cache_key": "two-same-anchor-followups-prev-nf-stream", + "stream": True, + }, + ) + + assert [event["type"] for event in second_events] == ["response.created", "response.failed"] + assert [event["type"] for event in third_events] == ["response.created", "response.failed"] + assert second_events[0]["response"]["id"] == "resp_bridge_followup_same_anchor_a" + assert third_events[0]["response"]["id"] == "resp_bridge_followup_same_anchor_b" + assert second_events[1]["response"]["error"]["code"] == "stream_incomplete" + assert third_events[1]["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(second_events[1]) + assert "previous_response_not_found" not in json.dumps(third_events[1]) + assert [event["type"] for event in fourth_events] == ["response.created", "response.completed"] + assert fourth_events[-1]["response"]["id"] == "resp_bridge_after_same_anchor_error" + assert connect_count == 1 + + @pytest.mark.asyncio async def test_v1_responses_http_bridge_send_retry_keeps_session_open_for_followup_request( async_client, diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 3d30aef1..99761be5 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1867,6 +1867,185 @@ async def fake_resolve_previous_response_owner( assert fake_upstream.closed is True +def test_backend_responses_websocket_recovers_when_previous_response_not_found_message_omits_response_id( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_replayed", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_followup_replayed", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_prev_nf_omitted_id"), first_upstream + return SimpleNamespace(id="acct_ws_prev_nf_omitted_id"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id == "resp_ws_prev_anchor" + return "acct_ws_prev_nf_omitted_id" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "hello"})) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + assert created_1["type"] == "response.created" + assert completed_1["type"] == "response.completed" + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue", + "previous_response_id": "resp_ws_prev_anchor", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + completed_2 = json.loads(websocket.receive_text()) + + assert created_2["type"] == "response.created" + assert completed_2["type"] == "response.completed" + assert created_2["response"]["id"] == "resp_ws_followup_replayed" + assert completed_2["response"]["id"] == "resp_ws_followup_replayed" + assert "previous_response_not_found" not in json.dumps(created_2) + assert "previous_response_not_found" not in json.dumps(completed_2) + assert connect_count == 2 + assert first_upstream.closed is True + + def test_backend_responses_websocket_keeps_session_alive_after_foreign_previous_response_not_found( app_instance, monkeypatch, @@ -2601,6 +2780,231 @@ async def fake_resolve_previous_response_owner( assert first_upstream.closed is True +def test_backend_responses_websocket_masks_anonymous_previous_response_not_found_for_same_anchor_followups_and_recovers( + app_instance, + monkeypatch, +): + first_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev_anchor_shared", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_prev_anchor_shared", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_same_anchor_a", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + ], + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_followup_same_anchor_b", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_ws_prev_anchor_shared' not found.", + "param": "previous_response_id", + }, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + recovered_upstream = _SequencedUpstreamWebSocket( + [], + deferred_message_batches=[ + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_after_same_anchor_error", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_after_same_anchor_error", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ], + ], + ) + connect_count = 0 + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ): + del ( + self, + headers, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + ) + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + return SimpleNamespace(id="acct_ws_same_anchor"), first_upstream + return SimpleNamespace(id="acct_ws_same_anchor"), recovered_upstream + + async def fake_resolve_previous_response_owner( + self, + *, + previous_response_id, + api_key, + session_id=None, + surface, + ): + del self, api_key, session_id, surface + assert previous_response_id == "resp_ws_prev_anchor_shared" + return "acct_ws_same_anchor" + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr( + proxy_module.ProxyService, + "_resolve_websocket_previous_response_owner", + fake_resolve_previous_response_owner, + ) + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "anchor"})) + created_1 = json.loads(websocket.receive_text()) + completed_1 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-a", + "previous_response_id": "resp_ws_prev_anchor_shared", + "stream": True, + } + ) + ) + created_2 = json.loads(websocket.receive_text()) + + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.4", + "input": "continue-b", + "previous_response_id": "resp_ws_prev_anchor_shared", + "stream": True, + } + ) + ) + created_3 = json.loads(websocket.receive_text()) + failed_2 = json.loads(websocket.receive_text()) + failed_3 = json.loads(websocket.receive_text()) + + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.4", "input": "after-error"})) + created_4 = json.loads(websocket.receive_text()) + completed_4 = json.loads(websocket.receive_text()) + + assert created_1["response"]["id"] == "resp_ws_prev_anchor_shared" + assert completed_1["response"]["id"] == "resp_ws_prev_anchor_shared" + assert created_2["response"]["id"] == "resp_ws_followup_same_anchor_a" + assert created_3["response"]["id"] == "resp_ws_followup_same_anchor_b" + assert failed_2["type"] == "response.failed" + assert failed_3["type"] == "response.failed" + assert failed_2["response"]["id"] == "resp_ws_followup_same_anchor_a" + assert failed_3["response"]["id"] == "resp_ws_followup_same_anchor_b" + assert failed_2["response"]["error"]["code"] == "stream_incomplete" + assert failed_3["response"]["error"]["code"] == "stream_incomplete" + assert "previous_response_not_found" not in json.dumps(failed_2) + assert "previous_response_not_found" not in json.dumps(failed_3) + assert created_4["response"]["id"] == "resp_ws_after_same_anchor_error" + assert completed_4["response"]["id"] == "resp_ws_after_same_anchor_error" + assert connect_count == 2 + assert first_upstream.closed is True + + @pytest.mark.parametrize("frame", ['{"type":"response.create"', "[]"]) def test_backend_responses_websocket_rejects_malformed_first_frame_as_invalid_payload(app_instance, monkeypatch, frame): called = {"connect": False} diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index e16d6ccc..e359539f 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -5879,6 +5879,99 @@ async def test_process_upstream_websocket_text_masks_anonymous_previous_response assert list(pending_requests) == [inflight_request] +@pytest.mark.parametrize( + "payload", + [ + { + "type": "error", + "status": 400, + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + }, + { + "type": "response.failed", + "response": { + "status": "failed", + "error": { + "type": "invalid_request_error", + "code": "previous_response_not_found", + "message": "Previous response with id 'resp_anchor' not found.", + "param": "previous_response_id", + }, + }, + }, + ], +) +@pytest.mark.asyncio +async def test_process_upstream_websocket_text_masks_anonymous_previous_response_not_found_for_same_anchor_followups( + monkeypatch, + payload, +): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + finalize_request_state = AsyncMock() + handle_stream_error = AsyncMock() + account = _make_account("acc_ws_anonymous_prev_nf_same_anchor") + + monkeypatch.setattr(service, "_finalize_websocket_request_state", finalize_request_state) + monkeypatch.setattr(service, "_handle_stream_error", handle_stream_error) + + followup_request_a = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_same_anchor_a", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_anchor", + request_text='{"type":"response.create","previous_response_id":"resp_anchor"}', + ) + followup_request_b = proxy_service._WebSocketRequestState( + request_id="ws_req_followup_same_anchor_b", + model="gpt-5.1", + service_tier=None, + reasoning_effort=None, + api_key_reservation=None, + started_at=0.0, + previous_response_id="resp_anchor", + request_text='{"type":"response.create","previous_response_id":"resp_anchor"}', + ) + pending_requests = deque([followup_request_a, followup_request_b]) + upstream_control = proxy_service._WebSocketUpstreamControl() + + downstream_text = await service._process_upstream_websocket_text( + json.dumps(payload, separators=(",", ":")), + account=account, + account_id_value=account.id, + pending_requests=pending_requests, + pending_lock=anyio.Lock(), + api_key=None, + upstream_control=upstream_control, + response_create_gate=asyncio.Semaphore(2), + ) + + assert "previous_response_not_found" not in downstream_text + assert upstream_control.suppress_downstream_event is True + assert upstream_control.reconnect_requested is True + assert upstream_control.downstream_texts is not None + assert len(upstream_control.downstream_texts) == 2 + for emitted_text in upstream_control.downstream_texts: + assert '"type":"response.failed"' in emitted_text + assert '"code":"stream_incomplete"' in emitted_text + assert "previous_response_not_found" not in emitted_text + assert finalize_request_state.await_count == 2 + finalized_requests = [call.args[0] for call in finalize_request_state.await_args_list] + assert finalized_requests == [followup_request_a, followup_request_b] + for call in finalize_request_state.await_args_list: + assert call.kwargs["event_type"] == "response.failed" + handle_stream_error.assert_not_awaited() + assert list(pending_requests) == [] + + @pytest.mark.asyncio async def test_process_upstream_websocket_text_transparently_retries_precreated_usage_limit_failure( monkeypatch, @@ -7181,6 +7274,43 @@ async def test_resolve_websocket_previous_response_owner_prefers_scoped_lookup_o ) +@pytest.mark.asyncio +async def test_resolve_websocket_previous_response_owner_uses_unique_scoped_cache_fallback_on_lookup_failure() -> None: + request_logs = _RequestLogsRecorder() + request_logs.lookup_error = RuntimeError("request log lookup unavailable") + service = proxy_service.ProxyService(_repo_factory(request_logs)) + api_key = ApiKeyData( + id="key_shared", + name="shared-key", + key_prefix="sk-shared", + allowed_models=None, + enforced_model=None, + enforced_reasoning_effort=None, + enforced_service_tier=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + + service._remember_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key_id=api_key.id, + account_id="acc_owner_scoped", + session_id="turn_scope_a", + ) + + owner = await service._resolve_websocket_previous_response_owner( + previous_response_id="resp_prev_shared", + api_key=api_key, + session_id="turn_scope_b", + surface="websocket", + ) + + assert owner == "acc_owner_scoped" + assert request_logs.lookup_calls == [("resp_prev_shared", api_key.id, "turn_scope_b")] + + def test_remember_websocket_previous_response_owner_eviction_keeps_latest_entries(): request_logs = _RequestLogsRecorder() service = proxy_service.ProxyService(_repo_factory(request_logs)) From 53e24384710f89a3b920eaa3a09a602a202db1eb Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Tue, 21 Apr 2026 11:31:54 +0200 Subject: [PATCH 16/18] fix(db): linearize request_logs migration chain after main merge --- .../versions/20260417_000000_add_request_log_plan_type.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py b/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py index a6a143f7..ca592a75 100644 --- a/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py +++ b/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py @@ -1,7 +1,7 @@ """add plan_type snapshot to request_logs Revision ID: 20260417_000000_add_request_log_plan_type -Revises: 20260413_000000_add_accounts_blocked_at +Revises: 20260415_160000_add_request_logs_response_lookup_index Create Date: 2026-04-17 """ @@ -11,7 +11,7 @@ from alembic import op revision = "20260417_000000_add_request_log_plan_type" -down_revision = "20260413_000000_add_accounts_blocked_at" +down_revision = "20260415_160000_add_request_logs_response_lookup_index" branch_labels = None depends_on = None From 5b1367c4acc3112fc627b7c76c2435355e846b5d Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Tue, 21 Apr 2026 12:13:09 +0200 Subject: [PATCH 17/18] fix(db): add alembic merge revision for request_logs heads --- ...260417_000000_add_request_log_plan_type.py | 4 +-- ..._request_log_lookup_and_plan_type_heads.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 app/db/alembic/versions/20260421_120000_merge_request_log_lookup_and_plan_type_heads.py diff --git a/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py b/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py index ca592a75..a6a143f7 100644 --- a/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py +++ b/app/db/alembic/versions/20260417_000000_add_request_log_plan_type.py @@ -1,7 +1,7 @@ """add plan_type snapshot to request_logs Revision ID: 20260417_000000_add_request_log_plan_type -Revises: 20260415_160000_add_request_logs_response_lookup_index +Revises: 20260413_000000_add_accounts_blocked_at Create Date: 2026-04-17 """ @@ -11,7 +11,7 @@ from alembic import op revision = "20260417_000000_add_request_log_plan_type" -down_revision = "20260415_160000_add_request_logs_response_lookup_index" +down_revision = "20260413_000000_add_accounts_blocked_at" branch_labels = None depends_on = None diff --git a/app/db/alembic/versions/20260421_120000_merge_request_log_lookup_and_plan_type_heads.py b/app/db/alembic/versions/20260421_120000_merge_request_log_lookup_and_plan_type_heads.py new file mode 100644 index 00000000..f5d4bb30 --- /dev/null +++ b/app/db/alembic/versions/20260421_120000_merge_request_log_lookup_and_plan_type_heads.py @@ -0,0 +1,25 @@ +"""merge request log lookup and plan type heads + +Revision ID: 20260421_120000_merge_request_log_lookup_and_plan_type_heads +Revises: 20260415_160000_add_request_logs_response_lookup_index, +20260417_000000_add_request_log_plan_type +Create Date: 2026-04-21 +""" + +from __future__ import annotations + +revision = "20260421_120000_merge_request_log_lookup_and_plan_type_heads" +down_revision = ( + "20260415_160000_add_request_logs_response_lookup_index", + "20260417_000000_add_request_log_plan_type", +) +branch_labels = None +depends_on = None + + +def upgrade() -> None: + return + + +def downgrade() -> None: + return From 712cb571c4c1d94b6675db92107277537ed4e427 Mon Sep 17 00:00:00 2001 From: Kazet111 Date: Tue, 21 Apr 2026 14:00:58 +0200 Subject: [PATCH 18/18] test(proxy): make owner-lookup reservation-release regression test ty-compatible --- tests/unit/test_proxy_utils.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index 3f352b49..f15e1fda 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -7244,11 +7244,19 @@ async def fail_connect_proxy_websocket(*args, **kwargs): @pytest.mark.asyncio async def test_stream_with_retry_releases_api_key_reservation_when_owner_lookup_fails(monkeypatch): request_logs = _RequestLogsRecorder() - api_keys_repo = cast(ApiKeysRepository, AsyncMock()) - api_keys_repo.get_usage_reservation = AsyncMock(return_value=SimpleNamespace(status="reserved", items=[])) - api_keys_repo.transition_usage_reservation_status = AsyncMock(return_value=True) - api_keys_repo.settle_usage_reservation = AsyncMock() - api_keys_repo.commit = AsyncMock() + get_usage_reservation_mock = AsyncMock(return_value=SimpleNamespace(status="reserved", items=[])) + transition_usage_reservation_status_mock = AsyncMock(return_value=True) + settle_usage_reservation_mock = AsyncMock() + commit_mock = AsyncMock() + api_keys_repo = cast( + ApiKeysRepository, + SimpleNamespace( + get_usage_reservation=get_usage_reservation_mock, + transition_usage_reservation_status=transition_usage_reservation_status_mock, + settle_usage_reservation=settle_usage_reservation_mock, + commit=commit_mock, + ), + ) class _RepoContextWithApiKeys: def __init__(self) -> None: @@ -7330,14 +7338,14 @@ async def __aexit__(self, exc_type, exc, tb) -> bool: assert _proxy_error_code(exc_info.value) == "upstream_unavailable" owner_lookup.assert_awaited_once() select_account.assert_not_called() - api_keys_repo.get_usage_reservation.assert_awaited_once_with(reservation.reservation_id) - api_keys_repo.transition_usage_reservation_status.assert_awaited_once_with( + get_usage_reservation_mock.assert_awaited_once_with(reservation.reservation_id) + transition_usage_reservation_status_mock.assert_awaited_once_with( reservation.reservation_id, expected_status="reserved", new_status="released", ) - api_keys_repo.settle_usage_reservation.assert_awaited_once() - api_keys_repo.commit.assert_awaited_once() + settle_usage_reservation_mock.assert_awaited_once() + commit_mock.assert_awaited_once() @pytest.mark.asyncio