diff --git a/src/agents/evaluator_envs.py b/src/agents/evaluator_envs.py index 02d4a12..5b878dd 100644 --- a/src/agents/evaluator_envs.py +++ b/src/agents/evaluator_envs.py @@ -6,6 +6,7 @@ import shlex import subprocess from abc import ABC +from contextlib import contextmanager from dataclasses import asdict, dataclass from time import sleep from typing import Any @@ -290,10 +291,14 @@ def multi_eval( return per_env_results_last_reward, per_env_results_rewards +@contextmanager def start_server( - agent_name: str, kwargs: dict[str, Any], port=8080, host="localhost", python_path: str = "python" -) -> subprocess.Popen: - """Start the agent server in a subprocess. + agent_name: str, kwargs: dict[str, Any], port: int = 8080, host: str = "localhost", python_path: str = "python" +): + """Start the agent server in a subprocess as a context manager. + + This ensures that the server is properly stopped when exiting the context and + that all logs are printed to the console. Args: agent_name (str): Name of the agent to start. @@ -303,27 +308,36 @@ def start_server( python_path (str): Path to the Python interpreter to use. If you use conda you can look up the path with `conda info --envs`. It can also be a format string that will be formatted with the agent_name, e.g. "conda run -n {agent_name} python". Defaults to "python". - Returns: - subprocess.Popen: The process running the server. """ - - logging.info( - f"Server starting with command: {python_path.format(agent_name=agent_name)} -m agents start-server {agent_name} --port={port} --host={host} --kwargs={json.dumps(kwargs)}" - ) - p = subprocess.Popen( - [ - python_path.format(agent_name=agent_name), - "-m", - "agents", - "start-server", - f"{agent_name}", - f"--port={port}", - f"--host={host}", - f"--kwargs={json.dumps(kwargs)}", - ] - ) - logging.info("successfully started") - return p + cmd = [ + python_path.format(agent_name=agent_name), + "-m", + "agents", + "start-server", + f"{agent_name}", + f"--port={port}", + f"--host={host}", + f"--kwargs={json.dumps(kwargs)}", + ] + logging.info("Server starting: %s", " ".join(cmd)) + p = subprocess.Popen(cmd) + sleep(5) + try: + yield p + finally: + # Stop the server no matter how we exit the with-block (success or exception). + try: + p.send_signal(subprocess.signal.SIGINT) + p.wait(timeout=5) + except Exception: + pass + if p.poll() is None: + p.terminate() + try: + p.wait(timeout=3) + except subprocess.TimeoutExpired: + p.kill() + logging.info("Server stopped") def evaluation( @@ -334,13 +348,15 @@ def evaluation( ): per_process_cache.clear() logging.info(f"Starting evaluation with {agent_cfg.agent_name} and {agent_cfg.agent_kwargs}") - with start_server( - agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path - ) as p: - res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes) - logging.info("Evaluation finished") - # send ctrl c signal - p.send_signal(subprocess.signal.SIGINT) + try: + with start_server( + agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path + ): + res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes) + except Exception: + # Ensures you SEE the client's stack trace and any logged errors. + logging.exception("Client failed") + raise logging.info(f"Results (success, reward, steps) for all envs: {res[0].mean(axis=1)}") logging.info(