Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agents/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _per_process(
step, _agent_cfg, eval_cfgs, episodes, n_processes, nth_gpu = args
logging.info(f"Starting evaluation for step {step}")
os.environ["CUDA_VISIBLE_DEVICES"] = str(nth_gpu)
os.environ["CAM_PATH"] = f"{os.environ['RUN_PATH']}/videos/{step}"
agent_cfg = copy.deepcopy(_agent_cfg)
agent_cfg.agent_kwargs["checkpoint_step"] = step

Expand Down
37 changes: 32 additions & 5 deletions src/agents/evaluator_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from abc import ABC
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from time import sleep
from typing import Any

import gymnasium as gym
import numpy as np
from PIL import Image
from simple_slurm import Slurm
from tqdm import tqdm

Expand Down Expand Up @@ -62,17 +64,22 @@ def do_import():

class RCSPickUpCubeEval(EvaluatorEnv):
INSTRUCTIONS = {
"rcs/FR3SimplePickUpSim-v0": "pick up the red cube",
"rcs/FR3SimplePickUpSim-v0": "pick the green box",
"rcs/FR3LabPickUpSimDigitHand-v0": "pick the green box",
}

def translate_obs(self, obs: dict[str, Any]) -> Obs:
# does not include history

# side = obs["frames"]["arro"]["rgb"]["data"]
side = obs["frames"]["side"]["rgb"]["data"]
wrist = obs["frames"]["wrist"]["rgb"]["data"]
# depth_side = obs["frames"]["side"]["depth"]["data"],
return Obs(
cameras=dict(rgb_side=side),
cameras=dict(rgb_side=side, rgb_wrist=wrist),
# cameras=dict(rgb_side=side),
gripper=obs["gripper"],
info=dict(joints=obs["joints"]),
)

def step(self, action: Act) -> tuple[Obs, float, bool, bool, dict]:
Expand All @@ -99,9 +106,11 @@ def language_instruction(self) -> str:
@staticmethod
def do_import():
import rcs
import rcs_toolbox


EvaluatorEnv.register("rcs/FR3SimplePickUpSim-v0", RCSPickUpCubeEval)
EvaluatorEnv.register("rcs/FR3LabPickUpSimDigitHand-v0", RCSPickUpCubeEval)


class ManiSkill(EvaluatorEnv):
Expand Down Expand Up @@ -212,7 +221,7 @@ class AgentConfig:
port: int = 8080


def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int) -> tuple[list[float], list[float], list[float]]:
def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int, i) -> tuple[list[float], list[float], list[float]]:
logging.debug(f"Starting evaluation of {env.env.unwrapped.spec.id}")
obs, _ = env.reset(options={})
logging.debug(f"Reset env {env.env.unwrapped.spec.id}")
Expand All @@ -222,13 +231,30 @@ def single_eval(env: EvaluatorEnv, agent: Agent, max_steps: int) -> tuple[list[f
truncated = False
step = 0.0
rewards = []
im = []
while not done and not truncated and max_steps > step:
action = agent.act(obs)
obs, reward, done, truncated, _ = env.step(action)
reward = float(reward)
done, truncated = bool(done), bool(truncated)
step += 1
rewards.append(reward)
im.append(obs.cameras)

Path(f"{os.environ['CAM_PATH']}").mkdir(exist_ok=True, parents=True)
for camera in im[0].keys():
imgs = []
for img in im:
# skip images that have timestamps closer together than 0.5s
imgs.append(Image.fromarray(img[camera]))

imgs[0].save(
f"{os.environ['CAM_PATH']}/{i}_{camera}_{str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))}.gif",
save_all=True,
append_images=imgs[1:],
duration=0.2 * 1000,
loop=0,
)

env.reset(options={})
logging.debug(
Expand Down Expand Up @@ -262,7 +288,7 @@ def run_episode(args: tuple[int, list[EvalConfig], int, AgentConfig]) -> tuple[f
while not agent.is_initialized():
logging.info("Waiting for agent to initialize...")
sleep(5)
return single_eval(env, agent, cfg.max_steps_per_episode)
return single_eval(env, agent, cfg.max_steps_per_episode, i)


def multi_eval(
Expand All @@ -277,6 +303,7 @@ def multi_eval(
# single_results = p.map(run_episode, args)

# without process
np.random.seed(42)
args = [(i, cfgs, episodes, agent_cfg) for i in range(len(cfgs) * episodes)]
single_results = [run_episode(arg) for arg in tqdm(args)]

Expand Down Expand Up @@ -321,7 +348,6 @@ def start_server(
]
logging.info("Server starting: %s", " ".join(cmd))
p = subprocess.Popen(cmd)
sleep(5)
try:
yield p
finally:
Expand Down Expand Up @@ -352,6 +378,7 @@ def evaluation(
with start_server(
agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path
):
sleep(30)
res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes)
except Exception:
# Ensures you SEE the client's stack trace and any logged errors.
Expand Down