diff --git a/src/agents/__main__.py b/src/agents/__main__.py index fabb3f6..5ae29f8 100644 --- a/src/agents/__main__.py +++ b/src/agents/__main__.py @@ -67,6 +67,7 @@ def _per_process( step, _agent_cfg, eval_cfgs, episodes, n_processes, nth_gpu = args logging.info(f"Starting evaluation for step {step}") os.environ["CUDA_VISIBLE_DEVICES"] = str(nth_gpu) + os.environ["CAM_PATH"] = f"{os.environ['RUN_PATH']}/videos/{step}" agent_cfg = copy.deepcopy(_agent_cfg) agent_cfg.agent_kwargs["checkpoint_step"] = step diff --git a/src/agents/evaluator_envs.py b/src/agents/evaluator_envs.py index 5b878dd..739ff12 100644 --- a/src/agents/evaluator_envs.py +++ b/src/agents/evaluator_envs.py @@ -8,11 +8,13 @@ from abc import ABC from contextlib import contextmanager from dataclasses import asdict, dataclass +from pathlib import Path from time import sleep from typing import Any import gymnasium as gym import numpy as np +from PIL import Image from simple_slurm import Slurm from tqdm import tqdm @@ -62,17 +64,22 @@ def do_import(): class RCSPickUpCubeEval(EvaluatorEnv): INSTRUCTIONS = { - "rcs/FR3SimplePickUpSim-v0": "pick up the red cube", + "rcs/FR3SimplePickUpSim-v0": "pick the green box", + "rcs/FR3LabPickUpSimDigitHand-v0": "pick the green box", } def translate_obs(self, obs: dict[str, Any]) -> Obs: # does not include history + # side = obs["frames"]["arro"]["rgb"]["data"] side = obs["frames"]["side"]["rgb"]["data"] + wrist = obs["frames"]["wrist"]["rgb"]["data"] # depth_side = obs["frames"]["side"]["depth"]["data"], return Obs( - cameras=dict(rgb_side=side), + cameras=dict(rgb_side=side, rgb_wrist=wrist), + # cameras=dict(rgb_side=side), gripper=obs["gripper"], + info=dict(joints=obs["joints"]), ) def step(self, action: Act) -> tuple[Obs, float, bool, bool, dict]: @@ -99,9 +106,11 @@ def language_instruction(self) -> str: @staticmethod def do_import(): import rcs + import rcs_toolbox EvaluatorEnv.register("rcs/FR3SimplePickUpSim-v0", RCSPickUpCubeEval) +EvaluatorEnv.register("rcs/FR3LabPickUpSimDigitHand-v0", RCSPickUpCubeEval) class ManiSkill(EvaluatorEnv): @@ -212,7 +221,7 @@ class AgentConfig: port: int = 8080 -def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int) -> tuple[list[float], list[float], list[float]]: +def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int, i) -> tuple[list[float], list[float], list[float]]: logging.debug(f"Starting evaluation of {env.env.unwrapped.spec.id}") obs, _ = env.reset(options={}) logging.debug(f"Reset env {env.env.unwrapped.spec.id}") @@ -222,6 +231,7 @@ def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int) -> tuple[list[f truncated = False step = 0.0 rewards = [] + im = [] while not done and not truncated and max_steps > step: action = agent.act(obs) obs, reward, done, truncated, _ = env.step(action) @@ -229,6 +239,22 @@ def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int) -> tuple[list[f done, truncated = bool(done), bool(truncated) step += 1 rewards.append(reward) + im.append(obs.cameras) + + Path(f"{os.environ['CAM_PATH']}").mkdir(exist_ok=True, parents=True) + for camera in im[0].keys(): + imgs = [] + for img in im: + # skip images that have timestamps closer together than 0.5s + imgs.append(Image.fromarray(img[camera])) + + imgs[0].save( + f"{os.environ['CAM_PATH']}/{i}_{camera}_{str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))}.gif", + save_all=True, + append_images=imgs[1:], + duration=0.2 * 1000, + loop=0, + ) env.reset(options={}) logging.debug( @@ -262,7 +288,7 @@ def run_episode(args: tuple[int, list[EvalConfig], int, AgentConfig]) -> tuple[f while not agent.is_initialized(): logging.info("Waiting for agent to initialize...") sleep(5) - return single_eval(env, agent, cfg.max_steps_per_episode) + return single_eval(env, agent, cfg.max_steps_per_episode, i) def multi_eval( @@ -277,6 +303,7 @@ def multi_eval( # single_results = p.map(run_episode, args) # without process + np.random.seed(42) args = [(i, cfgs, episodes, agent_cfg) for i in range(len(cfgs) * episodes)] single_results = [run_episode(arg) for arg in tqdm(args)] @@ -321,7 +348,6 @@ def start_server( ] logging.info("Server starting: %s", " ".join(cmd)) p = subprocess.Popen(cmd) - sleep(5) try: yield p finally: @@ -352,6 +378,7 @@ def evaluation( with start_server( agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path ): + sleep(30) res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes) except Exception: # Ensures you SEE the client's stack trace and any logged errors.