This repository was archived by the owner on Mar 31, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathselection_timeout_error_reporting.patch
More file actions
190 lines (180 loc) · 7.73 KB
/
selection_timeout_error_reporting.patch
File metadata and controls
190 lines (180 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
diff --git a/research/world_model/src/gitcg_world_model/env.py b/research/world_model/src/gitcg_world_model/env.py
index fc65220..patched 100644
--- a/research/world_model/src/gitcg_world_model/env.py
+++ b/research/world_model/src/gitcg_world_model/env.py
@@ -1,8 +1,10 @@
from __future__ import annotations
+import os
+import traceback
from dataclasses import dataclass, replace
from importlib import import_module
-from queue import Queue
+from queue import Empty, Queue
from threading import Event, Thread
from typing import Any
@@ -74,6 +76,23 @@ class _Shutdown:
pass
+_GITCG_EVENT_TIMEOUT_SECONDS_ENV = "GITCG_ENV_EVENT_TIMEOUT_SECONDS"
+_GITCG_EVENT_TIMEOUT_SECONDS_DEFAULT = 90.0
+
+
+def _event_timeout_seconds() -> float | None:
+ raw = os.environ.get(_GITCG_EVENT_TIMEOUT_SECONDS_ENV)
+ if raw is None:
+ return _GITCG_EVENT_TIMEOUT_SECONDS_DEFAULT
+ text = raw.strip()
+ if not text:
+ return _GITCG_EVENT_TIMEOUT_SECONDS_DEFAULT
+ try:
+ value = float(text)
+ except ValueError:
+ return _GITCG_EVENT_TIMEOUT_SECONDS_DEFAULT
+ return value if value > 0.0 else None
+
+
@dataclass(frozen=True)
class _DecisionEvent:
context: DecisionContext
@@ -171,7 +190,17 @@ class GitcgDecisionEnv:
self._worker = None
def _next_event(self) -> _DecisionEvent | _TerminalEvent:
- event = self._event_queue.get()
+ timeout = _event_timeout_seconds()
+ try:
+ event = (
+ self._event_queue.get(timeout=timeout)
+ if timeout is not None
+ else self._event_queue.get()
+ )
+ except Empty as exc:
+ self._shutdown_requested.set()
+ raise TimeoutError(
+ f"world-model event timeout matchup={self._matchup.key} seed={self._seed} timeout_s={timeout}"
+ ) from exc
if isinstance(event, _CrashEvent):
raise RuntimeError("world-model worker failed") from event.error
return event
@@ -182,6 +211,14 @@ class GitcgDecisionEnv:
except _EnvironmentClosed:
pass
except BaseException as exc:
+ print(
+ "[env-worker-crash] "
+ f"matchup={self._matchup.key} "
+ f"seed={self._seed} "
+ f"error={type(exc).__name__}: {exc}\n"
+ f"{traceback.format_exc()}",
+ flush=True,
+ )
self._event_queue.put(_CrashEvent(exc))
def _run_worker_inner(self, seed: int | None, initial_state_json: str | None):
@@ -240,8 +277,19 @@ class GitcgDecisionEnv:
def _await_payload(self, request_type: DecisionType, request: Any) -> dict[str, Any]:
state_ref = game.state()
try:
- raw_state_json = state_ref.json()
- full_state = snapshot_state(state_ref, state_json=raw_state_json)
+ try:
+ raw_state_json = state_ref.json()
+ full_state = snapshot_state(state_ref, state_json=raw_state_json)
+ except BaseException as exc:
+ print(
+ "[env-rpc-error] "
+ f"matchup={env._matchup.key} "
+ f"seed={env._seed} "
+ f"player={self.who} "
+ f"request_type={request_type.name} "
+ f"step={step_index_ref['value']} "
+ f"io_error={self.last_error!r} "
+ f"error={type(exc).__name__}: {exc}",
+ flush=True,
+ )
+ raise
finally:
_safe_release(state_ref)
full_state_json = raw_state_json if env._config.record_full_state_json else None
@@ -311,6 +359,7 @@ class GitcgDecisionEnv:
raw_winner,
status=status_name,
error=error_message,
+ max_rounds=self._config.max_rounds,
)
if sanitized_winner != final_state.winner:
final_state = replace(final_state, winner=sanitized_winner)
@@ -345,7 +394,13 @@ class GitcgDecisionEnv:
_safe_release(final_state_ref)
self._event_queue.put(_TerminalEvent(terminal_context))
finally:
- if final_state_ref is not None:
+ if final_state_ref is not None:
_safe_release(final_state_ref)
if game is not None:
_safe_release(game)
@@ -359,7 +414,13 @@ class GitcgDecisionEnv:
gitcg.thread_cleanup()
-def _validate_terminal_outcome(full_state, winner: int | None, *, status: str | None, error: str | None) -> tuple[int | None, str | None]:
+def _validate_terminal_outcome(
+ full_state,
+ winner: int | None,
+ *,
+ status: str | None,
+ error: str | None,
+ max_rounds: int | None = None,
+) -> tuple[int | None, str | None]:
players = tuple(getattr(full_state, "players", ()) or ())
if len(players) != 2:
if status and status != "FINISHED":
@@ -372,6 +433,7 @@ def _validate_terminal_outcome(full_state, winner: int | None, *, status: str |
player0_all_defeated = _all_defeated(players[0])
player1_all_defeated = _all_defeated(players[1])
+ final_round = int(getattr(full_state, "round_number", 0) or 0)
expected_winner: int | None
if player1_all_defeated and not player0_all_defeated:
expected_winner = 0
@@ -386,6 +448,15 @@ def _validate_terminal_outcome(full_state, winner: int | None, *, status: str |
if error:
return None, f"terminal_error={error}"
+ if winner is None:
+ if player0_all_defeated and player1_all_defeated:
+ return None, None
+ if max_rounds is not None and final_round >= int(max_rounds):
+ return None, None
+ if player0_all_defeated != player1_all_defeated:
+ return None, "invalid_missing_winner_with_single_ko"
+ return None, "invalid_draw_without_dual_ko_or_round_cap"
if winner in (0, 1) and expected_winner is None:
return None, f"inconsistent_terminal_winner={winner}"
if winner in (0, 1) and expected_winner is not None and int(winner) != int(expected_winner):
diff --git a/research/world_model/src/gitcg_world_model/ppo_pipeline.py b/research/world_model/src/gitcg_world_model/ppo_pipeline.py
index 0adccc2..patched 100644
--- a/research/world_model/src/gitcg_world_model/ppo_pipeline.py
+++ b/research/world_model/src/gitcg_world_model/ppo_pipeline.py
@@ -6,6 +6,7 @@ import gc
import json
import math
import os
+import traceback
import random
import shutil
import time
@@ -281,12 +282,20 @@ def _run_episode_job(
if attempt >= retry_limit:
raise
print(
+ "[episode-error] "
+ f"matchup={job.matchup.key} seed={requested_seed} actual_seed={effective_seed} "
+ f"attempt={attempt + 1}/{retry_limit} "
+ f"error={type(exc).__name__}: {exc}\n"
+ f"{traceback.format_exc()}",
+ flush=True,
+ )
+ print(
"[episode-retry] "
f"matchup={job.matchup.key} seed={requested_seed} actual_seed={effective_seed} "
f"attempt={attempt + 1}/{retry_limit} reason=exception:{type(exc).__name__}",
flush=True,
)
gc.collect()
+ time.sleep(0.25)
continue
merged_metadata = {**episode.metadata, **job.metadata} if job.metadata else dict(episode.metadata)