From e55a928415895e0f7b475d37291d1909efe7b8b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Wed, 3 Sep 2025 20:19:43 +0200 Subject: [PATCH 1/8] feat: wrist image , inverted gripper action, jpeg encoding --- README.md | 7 +++++++ src/agents/policies.py | 42 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9e01f69..e49ff86 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,13 @@ python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"c python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "model_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config ``` +### RCS run commands +```shell +# openpi +/home/juelg/miniconda3/envs/rcs_openpi/bin/python -m agents start-server openpi --port=20997 --host=0.0.0.0 --kwargs='{"checkpoint_path": "/mnt/dataset_drive/juelg/checkpoints/rcs_paper/pi0/pi0_rcs_utn/openpi_utn_wrist/{checkpoint_step}", "model_name": "pi0_rcs_utn", "checkpoint_step": 29999}' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config +``` + + There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint. The `run-eval-post-training` command evaluates a range of checkpoints in parallel. In both cases environment and arguments as well as policy and arguments and wandb config for logging can be passed as CLI arguments. diff --git a/src/agents/policies.py b/src/agents/policies.py index 78220f4..1428dc7 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -1,3 +1,4 @@ +import base64 import copy import json import logging @@ -9,9 +10,12 @@ from operator import getitem from pathlib import Path from typing import Any, Union +from torchvision.io import decode_jpeg +from torchvision.transforms import v2 import numpy as np from PIL import Image +import torch @dataclass(kw_only=True) @@ -140,6 +144,9 @@ def __init__( self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) self.cfg = config.get_config(model_name) + self.chunks = 20 + self.s = self.chunks + self.a = None def initialize(self): from openpi.policies import policy_config @@ -153,18 +160,49 @@ def initialize(self): def act(self, obs: Obs) -> Act: # Run inference on a dummy example. # observation = {f"observation/{k}": v for k, v in obs.cameras.items()} + + if self.s < self.chunks: + self.s += 1 + return Act(action=self.a[self.s]) + + else: + self.s = 0 + + side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) + side = torch.frombuffer(bytearray(side), dtype=torch.uint8) + side = decode_jpeg(side) + side = v2.Resize((256, 256))(side) + + wrist = base64.urlsafe_b64decode(obs.cameras["rgb_wrist"]) + wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) + wrist = decode_jpeg(wrist) + wrist = v2.Resize((256, 256))(wrist) + + + + + # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) + # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) + # return Act(action=np.array([])) observation = {} observation.update( { - "observation/image": np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1), - "observation/state": np.concatenate([obs.info["joints"], [obs.gripper]]), + "observation/image": side, + "observation/wrist_image": wrist, + "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), "prompt": self.instruction, } ) action_chunk = self.policy.infer(observation)["actions"] + # convert gripper action + action_chunk[:,-1] = 1 - action_chunk[:,-1] + self.a = action_chunk + + # return Act(action=action_chunk[0]) return Act(action=action_chunk[0]) + class OpenVLAModel(Agent): # === Utilities === SYSTEM_PROMPT = ( From 281f17710618ad23b9e6239cf4a60756c5c681ae Mon Sep 17 00:00:00 2001 From: nisarganc Date: Tue, 16 Dec 2025 15:50:05 +0100 Subject: [PATCH 2/8] initial commit --- src/agents/policies.py | 74 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/agents/policies.py b/src/agents/policies.py index 1428dc7..a440408 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -127,6 +127,80 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: } return info +class VjepaAC(Agent): + + def __init__( + self, + model_name: str = "vjepa2_ac_vit_giant", + default_checkpoint_path: str = "vjepa2_ac_vit_giant", + **kwargs, + ) -> None: + super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) + + from vjepa2.configs.inference import vjepa2-ac-vitg.utn-robot.yaml + + logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}") + self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) + + self.cfg = config.get_config(model_name) + self.chunks = 20 + self.s = self.chunks + self.a = None + + def initialize(self): + from openpi.policies import policy_config + from openpi.shared import download + + encoder, predictor = torch.hub.load("~/vjepa2", # root of the vjepa source code + "vjepa2_ac_vit_giant", # model type + source="local", + pretrained=True) + + checkpoint_dir = download.maybe_download(self.openpi_path) + + # Create a trained policy. + self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir) + + def act(self, obs: Obs) -> Act: + # Run inference on a dummy example. + # observation = {f"observation/{k}": v for k, v in obs.cameras.items()} + + if self.s < self.chunks: + self.s += 1 + return Act(action=self.a[self.s]) + + else: + self.s = 0 + + side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) + side = torch.frombuffer(bytearray(side), dtype=torch.uint8) + side = decode_jpeg(side) + side = v2.Resize((256, 256))(side) + + wrist = base64.urlsafe_b64decode(obs.cameras["rgb_wrist"]) + wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) + wrist = decode_jpeg(wrist) + wrist = v2.Resize((256, 256))(wrist) + + # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) + # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) + # return Act(action=np.array([])) + observation = {} + observation.update( + { + "observation/image": side, + "observation/wrist_image": wrist, + "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), + "prompt": self.instruction, + } + ) + action_chunk = self.policy.infer(observation)["actions"] + # convert gripper action + action_chunk[:,-1] = 1 - action_chunk[:,-1] + self.a = action_chunk + + # return Act(action=action_chunk[0]) + return Act(action=action_chunk[0]) class OpenPiModel(Agent): From a2c56398c9a92d4627675ef369ccc5f4abf69859 Mon Sep 17 00:00:00 2001 From: nisarganc Date: Tue, 16 Dec 2025 15:51:11 +0100 Subject: [PATCH 3/8] initial commit --- mydiff.diff | 199 +++++++++++++++++++++++++++++++++++++++++ src/agents/policies.py | 160 +++++++++++++++++++++++---------- 2 files changed, 310 insertions(+), 49 deletions(-) create mode 100644 mydiff.diff diff --git a/mydiff.diff b/mydiff.diff new file mode 100644 index 0000000..e06945d --- /dev/null +++ b/mydiff.diff @@ -0,0 +1,199 @@ +diff --git a/src/agents/policies.py b/src/agents/policies.py +index a440408..827ec58 100644 +--- a/src/agents/policies.py ++++ b/src/agents/policies.py +@@ -131,76 +131,147 @@ class VjepaAC(Agent): + + def __init__( + self, ++ cfg_path: str, + model_name: str = "vjepa2_ac_vit_giant", + default_checkpoint_path: str = "vjepa2_ac_vit_giant", + **kwargs, + ) -> None: + super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) ++ import yaml + +- from vjepa2.configs.inference import vjepa2-ac-vitg.utn-robot.yaml ++ self.cfg_path = cfg_path ++ with open(self.cfg_path, "r") as f: ++ self.cfg = yaml.safe_load(f) + +- logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}") +- self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) +- +- self.cfg = config.get_config(model_name) +- self.chunks = 20 +- self.s = self.chunks +- self.a = None + + def initialize(self): +- from openpi.policies import policy_config +- from openpi.shared import download ++ # VJEPA imports ++ from vjepa2.app.vjepa_droid.transforms import make_transforms ++ from vjepa2.notebooks.utils.world_model_wrapper import WorldModel ++ ++ device = self.cfg.get("device", 'cuda') ++ save_path = self.cfg.get("save_path", 'exp_1.png') ++ ++ # data config ++ cfgs_data = self.cfg.get("data") ++ fps = cfgs_data.get("fps", 4) ++ crop_size = cfgs_data.get("crop_size", 256) ++ patch_size = cfgs_data.get("patch_size") ++ pin_mem = cfgs_data.get("pin_mem", False) ++ num_workers = cfgs_data.get("num_workers", 1) ++ persistent_workers = cfgs_data.get("persistent_workers", True) ++ ++ # data augs ++ cfgs_data_aug = self.cfg.get("data_aug") ++ horizontal_flip = cfgs_data_aug.get("horizontal_flip", False) ++ ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) ++ rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) ++ motion_shift = cfgs_data_aug.get("motion_shift", False) ++ reprob = cfgs_data_aug.get("reprob", 0.0) ++ use_aa = cfgs_data_aug.get("auto_augment", False) ++ ++ # exp config ++ cfgs_mpc_args= self.cfg.get("mpc_args") ++ self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2) ++ samples = cfgs_mpc_args.get("samples", 25) ++ topk = cfgs_mpc_args.get("topk", 10) ++ cem_steps = cfgs_mpc_args.get("cem_steps", 1) ++ momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15) ++ momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15) ++ momentum_std = cfgs_mpc_args.get("momentum_std", 0.75) ++ momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", .15) ++ maxnorm = cfgs_mpc_args.get("maxnorm", 0.075) ++ verbose = cfgs_mpc_args.get("verbose", True) ++ ++ ++ ++ # Initialize transform (random-resize-crop augmentations) ++ self.transform = make_transforms( ++ random_horizontal_flip=horizontal_flip, ++ random_resize_aspect_ratio=ar_range, ++ random_resize_scale=rr_scale, ++ reprob=reprob, ++ auto_augment=use_aa, ++ motion_shift=motion_shift, ++ crop_size=crop_size, ++ ) + + encoder, predictor = torch.hub.load("~/vjepa2", # root of the vjepa source code + "vjepa2_ac_vit_giant", # model type + source="local", + pretrained=True) + +- checkpoint_dir = download.maybe_download(self.openpi_path) ++ # check if model weights are loaded on cuda ++ encoder.to(device) ++ predictor.to(device) ++ ++ # World model wrapper initialization ++ ++ tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) ++ self.world_model = WorldModel( ++ encoder=encoder, ++ predictor=predictor, ++ tokens_per_frame=tokens_per_frame, ++ mpc_args={ ++ "rollout": self.rollout_horizon, ++ "samples": samples, ++ "topk": topk, ++ "cem_steps": cem_steps, ++ "momentum_mean": momentum_mean, ++ "momentum_mean_gripper": momentum_mean_gripper, ++ "momentum_std": momentum_std, ++ "momentum_std_gripper": momentum_std_gripper, ++ "maxnorm": maxnorm, ++ "verbose": verbose, ++ }, ++ normalize_reps=True, ++ device=device ++ ) + +- # Create a trained policy. +- self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir) + + def act(self, obs: Obs) -> Act: +- # Run inference on a dummy example. +- # observation = {f"observation/{k}": v for k, v in obs.cameras.items()} + +- if self.s < self.chunks: +- self.s += 1 +- return Act(action=self.a[self.s]) +- +- else: +- self.s = 0 ++ with torch.no_grad(): + +- side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) +- side = torch.frombuffer(bytearray(side), dtype=torch.uint8) +- side = decode_jpeg(side) +- side = v2.Resize((256, 256))(side) ++ # Pre-trained VJEPA 2 ENCODER: # [1, 3, 1, 256, 256] -> [1, 3, 1, 256, 1408] i.e, [B, C, Time, Patches, dim] ++ z_n = self.world_model.encode(self.transform(obs.cameras["rgb_side"])) ++ B = z_n.shape[0] ++ T = z_n.shape[2] + +- wrist = base64.urlsafe_b64decode(obs.cameras["rgb_wrist"]) +- wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) +- wrist = decode_jpeg(wrist) +- wrist = v2.Resize((256, 256))(wrist) ++ # [1, 1, 76] -> [B, Time, state] ++ # TODO: gripper state in DROID? In rcs 0: is close and 1: is open ++ s_n = np.concatenate(([obs.info["xyzrpy"], [1-obs.gripper]]), axis=0).reshape(B, T, -1) + +- # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) +- # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) +- # return Act(action=np.array([])) +- observation = {} +- observation.update( +- { +- "observation/image": side, +- "observation/wrist_image": wrist, +- "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), +- "prompt": self.instruction, +- } +- ) +- action_chunk = self.policy.infer(observation)["actions"] +- # convert gripper action +- action_chunk[:,-1] = 1 - action_chunk[:,-1] +- self.a = action_chunk ++ # Action conditioned predictor and zero-shot action inference with CEM ++ actions = self.world_model.infer_next_action( ++ z_n, ++ s_n, ++ z_n ++ ) # [4, 7] ++ ++ # compute predicted next states ++ s_n_k = s_n # [1, 1, 7] ++ predicted_states = s_n # [1, 1, 7] ++ for i in range(self.rollout_horizon): ++ a_n = actions[i].unsqueeze(0).unsqueeze(1) # [1, 1, 7] ++ s_next = compute_new_pose(s_n_k, a_n) # [1, 1, 7] ++ predicted_states = torch.cat((predicted_states, s_next), dim=1) # [1, i+2, 7] ++ s_n_k = s_next ++ ++ predicted_state_trajs.append((k, predicted_states.cpu())) ++ ++ print(f"Predicted new state: {predicted_states.cpu()}") # [1, rollout_horizon+1, 7] ++ print(f"Ground truth new state: {ground_truth_traj}") # [1, rollout_horizon+1, 7] ++ ++ ++ return Act(action=np.array([])) ++ ++ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: ++ super().reset(obs, instruction, **kwargs) ++ # TODO: actually calculate goal representation ++ self.goal_rep = instruction ++ return {} + +- # return Act(action=action_chunk[0]) +- return Act(action=action_chunk[0]) + + class OpenPiModel(Agent): + diff --git a/src/agents/policies.py b/src/agents/policies.py index a440408..b27812e 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -131,76 +131,138 @@ class VjepaAC(Agent): def __init__( self, + cfg_path: str, model_name: str = "vjepa2_ac_vit_giant", - default_checkpoint_path: str = "vjepa2_ac_vit_giant", + default_checkpoint_path: str = "~/.cache/torch/hub/checkpoints/", **kwargs, ) -> None: super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) + import yaml - from vjepa2.configs.inference import vjepa2-ac-vitg.utn-robot.yaml + self.cfg_path = cfg_path + with open(self.cfg_path, "r") as f: + self.cfg = yaml.safe_load(f) - logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}") - self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) + self.model_name = model_name - self.cfg = config.get_config(model_name) - self.chunks = 20 - self.s = self.chunks - self.a = None def initialize(self): - from openpi.policies import policy_config - from openpi.shared import download + # VJEPA imports + from vjepa2.app.vjepa_droid.transforms import make_transforms + from vjepa2.notebooks.utils.world_model_wrapper import WorldModel + + device = self.cfg.get("device", 'cuda') + save_path = self.cfg.get("save_path", 'exp_1.png') + # data config + cfgs_data = self.cfg.get("data") + fps = cfgs_data.get("fps", 4) + crop_size = cfgs_data.get("crop_size", 256) + + # data augs + cfgs_data_aug = self.cfg.get("data_aug") + use_aa = cfgs_data_aug.get("auto_augment", False) + horizontal_flip = cfgs_data_aug.get("horizontal_flip", False) + motion_shift = cfgs_data_aug.get("motion_shift", False) + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + reprob = cfgs_data_aug.get("reprob", 0.0) + + # cfgs_mpc_args config + cfgs_mpc_args= self.cfg.get("mpc_args") + self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2) + samples = cfgs_mpc_args.get("samples", 25) + topk = cfgs_mpc_args.get("topk", 10) + cem_steps = cfgs_mpc_args.get("cem_steps", 1) + momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15) + momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15) + momentum_std = cfgs_mpc_args.get("momentum_std", 0.75) + momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", .15) + maxnorm = cfgs_mpc_args.get("maxnorm", 0.075) + verbose = cfgs_mpc_args.get("verbose", True) + + + + # Initialize transform (random-resize-crop augmentations) + self.transform = make_transforms( + random_horizontal_flip=horizontal_flip, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size, + ) + + # load model encoder, predictor = torch.hub.load("~/vjepa2", # root of the vjepa source code - "vjepa2_ac_vit_giant", # model type + self.model_name, # model type source="local", pretrained=True) - checkpoint_dir = download.maybe_download(self.openpi_path) + # load model to cuda + encoder.to(device) + predictor.to(device) + + # World model wrapper initialization + tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) + self.world_model = WorldModel( + encoder=encoder, + predictor=predictor, + tokens_per_frame=tokens_per_frame, + mpc_args={ + "rollout": self.rollout_horizon, + "samples": samples, + "topk": topk, + "cem_steps": cem_steps, + "momentum_mean": momentum_mean, + "momentum_mean_gripper": momentum_mean_gripper, + "momentum_std": momentum_std, + "momentum_std_gripper": momentum_std_gripper, + "maxnorm": maxnorm, + "verbose": verbose, + }, + normalize_reps=True, + device=device + ) - # Create a trained policy. - self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir) def act(self, obs: Obs) -> Act: - # Run inference on a dummy example. - # observation = {f"observation/{k}": v for k, v in obs.cameras.items()} - - if self.s < self.chunks: - self.s += 1 - return Act(action=self.a[self.s]) - - else: - self.s = 0 - side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) - side = torch.frombuffer(bytearray(side), dtype=torch.uint8) - side = decode_jpeg(side) - side = v2.Resize((256, 256))(side) + with torch.no_grad(): + + # Pre-trained VJEPA 2 ENCODER + # [256, 256, 3] -> [1, 3, 1, 256, 256] -> [1, 3, 1, 256, 1408] i.e, + # [B, C, Time, Patches, dim] + z_n = self.world_model.encode(self.transform(obs.cameras["rgb_side"])) + B = z_n.shape[0] + T = z_n.shape[2] + + # [1, 1, 76] -> [B, Time, state] + # TODO: gripper state in DROID? In rcs 0: is close and 1: is open + s_n = torch.tensor(np.concatenate(([obs.info["xyzrpy"], + [1-obs.gripper]]), + axis=0).reshape(B, T, -1)).to(self.device, + dtype=torch.float, + non_blocking=True) + + # Action conditioned predictor and zero-shot action inference with CEM + actions = self.world_model.infer_next_action( + z_n, + s_n, + self.goal_rep + ) # [rollout_horizon, 7] + + return Act(action=np.array(actions[0])) - wrist = base64.urlsafe_b64decode(obs.cameras["rgb_wrist"]) - wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) - wrist = decode_jpeg(wrist) - wrist = v2.Resize((256, 256))(wrist) + def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: + super().reset(obs, instruction, **kwargs) + + with torch.no_grad(): + self.goal_rep = self.world_model.encode(self.transform(instruction)) - # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) - # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) - # return Act(action=np.array([])) - observation = {} - observation.update( - { - "observation/image": side, - "observation/wrist_image": wrist, - "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), - "prompt": self.instruction, - } - ) - action_chunk = self.policy.infer(observation)["actions"] - # convert gripper action - action_chunk[:,-1] = 1 - action_chunk[:,-1] - self.a = action_chunk + return {} - # return Act(action=action_chunk[0]) - return Act(action=action_chunk[0]) class OpenPiModel(Agent): From ba88daba5912d85378b79243a4244b4ef70d8fb0 Mon Sep 17 00:00:00 2001 From: nisarganc Date: Mon, 22 Dec 2025 12:45:01 +0100 Subject: [PATCH 4/8] vjepa2 integration code --- src/agents/policies.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/agents/policies.py b/src/agents/policies.py index b27812e..136ae30 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -133,7 +133,7 @@ def __init__( self, cfg_path: str, model_name: str = "vjepa2_ac_vit_giant", - default_checkpoint_path: str = "~/.cache/torch/hub/checkpoints/", + default_checkpoint_path: str = "", **kwargs, ) -> None: super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) @@ -148,8 +148,8 @@ def __init__( def initialize(self): # VJEPA imports - from vjepa2.app.vjepa_droid.transforms import make_transforms - from vjepa2.notebooks.utils.world_model_wrapper import WorldModel + from app.vjepa_droid.transforms import make_transforms + from inference.utils.world_model_wrapper import WorldModel device = self.cfg.get("device", 'cuda') save_path = self.cfg.get("save_path", 'exp_1.png') @@ -195,7 +195,7 @@ def initialize(self): ) # load model - encoder, predictor = torch.hub.load("~/vjepa2", # root of the vjepa source code + encoder, predictor = torch.hub.load("./", # root of the vjepa source code self.model_name, # model type source="local", pretrained=True) @@ -673,4 +673,5 @@ def act(self, obs: Obs) -> Act: octodist=OctoActionDistribution, openvladist=OpenVLADistribution, openpi=OpenPiModel, + vjepa=VjepaAC, ) From 3a45187dbaadf7f314fbc742303876c82ff50fe3 Mon Sep 17 00:00:00 2001 From: nisarganc Date: Mon, 22 Dec 2025 12:47:23 +0100 Subject: [PATCH 5/8] vjepa2 integration code --- README.md | 22 +++++++++++++++ src/agents/policies.py | 61 ++++++++++++++++++++++++++++++------------ 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index e49ff86..7e15508 100644 --- a/README.md +++ b/README.md @@ -117,16 +117,38 @@ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . ``` For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi). +### vjEPA2-ac +To use VJEPA2-AC, create a new conda environment: +```shell +conda create -n vjepa2 python=3.12 +conda activate vjepa2 +``` +Clone the repo and install it. +```shell +git clone git@github.com:nisarganc/vjepa2.git +cd vjepa2 +pip install -e . + +pip install git+https://github.com/juelg/agents.git +git checkout nilavadi/vjepa-ac +pip install -ve . + +``` ## Usage To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g. ```shell # octo python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}' + # openvla python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}' + # openpi python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "model_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config + +# vjepa2-ac +python -m agents start-server vjepa --port=20997 --host=0.0.0.0 --kwargs='{"cfg_path": "configs/inference/vjepa2-ac-vitg/utn-robot.yaml", "model_name": "vjepa2_ac_vit_giant", "default_checkpoint_path": "../.cache/torch/hub/checkpoints/vjepa2-ac-vitg.pt"}' ``` ### RCS run commands diff --git a/src/agents/policies.py b/src/agents/policies.py index 136ae30..22f7d32 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -151,8 +151,9 @@ def initialize(self): from app.vjepa_droid.transforms import make_transforms from inference.utils.world_model_wrapper import WorldModel - device = self.cfg.get("device", 'cuda') - save_path = self.cfg.get("save_path", 'exp_1.png') + self.device = self.cfg.get("device", 'cuda') + self.save_path = self.cfg.get("save_path", 'exp_1.png') + self.goal_img = self.cfg.get("goal_img", 'exp_1.png') # data config cfgs_data = self.cfg.get("data") @@ -201,8 +202,8 @@ def initialize(self): pretrained=True) # load model to cuda - encoder.to(device) - predictor.to(device) + encoder.to(self.device) + predictor.to(self.device) # World model wrapper initialization tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) @@ -223,7 +224,7 @@ def initialize(self): "verbose": verbose, }, normalize_reps=True, - device=device + device=self.device ) @@ -231,18 +232,29 @@ def act(self, obs: Obs) -> Act: with torch.no_grad(): - # Pre-trained VJEPA 2 ENCODER - # [256, 256, 3] -> [1, 3, 1, 256, 256] -> [1, 3, 1, 256, 1408] i.e, - # [B, C, Time, Patches, dim] - z_n = self.world_model.encode(self.transform(obs.cameras["rgb_side"])) - B = z_n.shape[0] - T = z_n.shape[2] - - # [1, 1, 76] -> [B, Time, state] + # read from camera-stream + side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) + side = torch.frombuffer(bytearray(side), + dtype=torch.uint8 + ) + side = decode_jpeg(side) + + # [3, 720, 1280] -> [1, 720, 1280, 3] i.e, [T, C, Patches, dim] + side = torch.permute(side, (1, 2, 0)).unsqueeze(0) + + # [1, 720, 1280, 3] -> [1, 3, 1, 256, 1408] i.e, [B, C, T, Patches, dim] + input_image_tensor = (self.transform(side)[None, :]).to(device=self.device, + dtype=torch.float, + non_blocking=True + ) + # Pre-trained VJEPA 2 ENCODER: [1, 3, 1, 256, 1408] -> [1, 256, 1408] + z_n = self.world_model.encode(input_image_tensor) + + # [1, 7] -> [B, state_dim] # TODO: gripper state in DROID? In rcs 0: is close and 1: is open - s_n = torch.tensor(np.concatenate(([obs.info["xyzrpy"], + s_n = torch.tensor((np.concatenate(([obs.info["xyzrpy"], [1-obs.gripper]]), - axis=0).reshape(B, T, -1)).to(self.device, + axis=0))).unsqueeze(0).to(self.device, dtype=torch.float, non_blocking=True) @@ -252,14 +264,29 @@ def act(self, obs: Obs) -> Act: s_n, self.goal_rep ) # [rollout_horizon, 7] + + first_action = actions[0].cpu() + first_action[-1] = 1 - first_action[-1] + - return Act(action=np.array(actions[0])) + return Act(action=np.array(first_action)) def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: super().reset(obs, instruction, **kwargs) + from PIL import Image + img = Image.open(self.goal_img) + + # TODO: change goal image resolution + # time dim exp + goal_image = np.expand_dims(np.array(img), axis=0) + # batch dim exp + goal_image_tensor = torch.tensor(self.transform(goal_image)[None, :]).to(device=self.device, + dtype=torch.float, + non_blocking=True) + with torch.no_grad(): - self.goal_rep = self.world_model.encode(self.transform(instruction)) + self.goal_rep = self.world_model.encode(goal_image_tensor) return {} From 7384d4a74cf26276ed4645cdb405f11f8c877afd Mon Sep 17 00:00:00 2001 From: nisarganc Date: Wed, 14 Jan 2026 16:17:05 +0100 Subject: [PATCH 6/8] chore: fromatting and clean up --- mydiff.diff | 199 ----------------------------------------- src/agents/policies.py | 98 +++++++++----------- 2 files changed, 41 insertions(+), 256 deletions(-) delete mode 100644 mydiff.diff diff --git a/mydiff.diff b/mydiff.diff deleted file mode 100644 index e06945d..0000000 --- a/mydiff.diff +++ /dev/null @@ -1,199 +0,0 @@ -diff --git a/src/agents/policies.py b/src/agents/policies.py -index a440408..827ec58 100644 ---- a/src/agents/policies.py -+++ b/src/agents/policies.py -@@ -131,76 +131,147 @@ class VjepaAC(Agent): - - def __init__( - self, -+ cfg_path: str, - model_name: str = "vjepa2_ac_vit_giant", - default_checkpoint_path: str = "vjepa2_ac_vit_giant", - **kwargs, - ) -> None: - super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) -+ import yaml - -- from vjepa2.configs.inference import vjepa2-ac-vitg.utn-robot.yaml -+ self.cfg_path = cfg_path -+ with open(self.cfg_path, "r") as f: -+ self.cfg = yaml.safe_load(f) - -- logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}") -- self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) -- -- self.cfg = config.get_config(model_name) -- self.chunks = 20 -- self.s = self.chunks -- self.a = None - - def initialize(self): -- from openpi.policies import policy_config -- from openpi.shared import download -+ # VJEPA imports -+ from vjepa2.app.vjepa_droid.transforms import make_transforms -+ from vjepa2.notebooks.utils.world_model_wrapper import WorldModel -+ -+ device = self.cfg.get("device", 'cuda') -+ save_path = self.cfg.get("save_path", 'exp_1.png') -+ -+ # data config -+ cfgs_data = self.cfg.get("data") -+ fps = cfgs_data.get("fps", 4) -+ crop_size = cfgs_data.get("crop_size", 256) -+ patch_size = cfgs_data.get("patch_size") -+ pin_mem = cfgs_data.get("pin_mem", False) -+ num_workers = cfgs_data.get("num_workers", 1) -+ persistent_workers = cfgs_data.get("persistent_workers", True) -+ -+ # data augs -+ cfgs_data_aug = self.cfg.get("data_aug") -+ horizontal_flip = cfgs_data_aug.get("horizontal_flip", False) -+ ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) -+ rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) -+ motion_shift = cfgs_data_aug.get("motion_shift", False) -+ reprob = cfgs_data_aug.get("reprob", 0.0) -+ use_aa = cfgs_data_aug.get("auto_augment", False) -+ -+ # exp config -+ cfgs_mpc_args= self.cfg.get("mpc_args") -+ self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2) -+ samples = cfgs_mpc_args.get("samples", 25) -+ topk = cfgs_mpc_args.get("topk", 10) -+ cem_steps = cfgs_mpc_args.get("cem_steps", 1) -+ momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15) -+ momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15) -+ momentum_std = cfgs_mpc_args.get("momentum_std", 0.75) -+ momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", .15) -+ maxnorm = cfgs_mpc_args.get("maxnorm", 0.075) -+ verbose = cfgs_mpc_args.get("verbose", True) -+ -+ -+ -+ # Initialize transform (random-resize-crop augmentations) -+ self.transform = make_transforms( -+ random_horizontal_flip=horizontal_flip, -+ random_resize_aspect_ratio=ar_range, -+ random_resize_scale=rr_scale, -+ reprob=reprob, -+ auto_augment=use_aa, -+ motion_shift=motion_shift, -+ crop_size=crop_size, -+ ) - - encoder, predictor = torch.hub.load("~/vjepa2", # root of the vjepa source code - "vjepa2_ac_vit_giant", # model type - source="local", - pretrained=True) - -- checkpoint_dir = download.maybe_download(self.openpi_path) -+ # check if model weights are loaded on cuda -+ encoder.to(device) -+ predictor.to(device) -+ -+ # World model wrapper initialization -+ -+ tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) -+ self.world_model = WorldModel( -+ encoder=encoder, -+ predictor=predictor, -+ tokens_per_frame=tokens_per_frame, -+ mpc_args={ -+ "rollout": self.rollout_horizon, -+ "samples": samples, -+ "topk": topk, -+ "cem_steps": cem_steps, -+ "momentum_mean": momentum_mean, -+ "momentum_mean_gripper": momentum_mean_gripper, -+ "momentum_std": momentum_std, -+ "momentum_std_gripper": momentum_std_gripper, -+ "maxnorm": maxnorm, -+ "verbose": verbose, -+ }, -+ normalize_reps=True, -+ device=device -+ ) - -- # Create a trained policy. -- self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir) - - def act(self, obs: Obs) -> Act: -- # Run inference on a dummy example. -- # observation = {f"observation/{k}": v for k, v in obs.cameras.items()} - -- if self.s < self.chunks: -- self.s += 1 -- return Act(action=self.a[self.s]) -- -- else: -- self.s = 0 -+ with torch.no_grad(): - -- side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) -- side = torch.frombuffer(bytearray(side), dtype=torch.uint8) -- side = decode_jpeg(side) -- side = v2.Resize((256, 256))(side) -+ # Pre-trained VJEPA 2 ENCODER: # [1, 3, 1, 256, 256] -> [1, 3, 1, 256, 1408] i.e, [B, C, Time, Patches, dim] -+ z_n = self.world_model.encode(self.transform(obs.cameras["rgb_side"])) -+ B = z_n.shape[0] -+ T = z_n.shape[2] - -- wrist = base64.urlsafe_b64decode(obs.cameras["rgb_wrist"]) -- wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) -- wrist = decode_jpeg(wrist) -- wrist = v2.Resize((256, 256))(wrist) -+ # [1, 1, 76] -> [B, Time, state] -+ # TODO: gripper state in DROID? In rcs 0: is close and 1: is open -+ s_n = np.concatenate(([obs.info["xyzrpy"], [1-obs.gripper]]), axis=0).reshape(B, T, -1) - -- # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) -- # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) -- # return Act(action=np.array([])) -- observation = {} -- observation.update( -- { -- "observation/image": side, -- "observation/wrist_image": wrist, -- "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), -- "prompt": self.instruction, -- } -- ) -- action_chunk = self.policy.infer(observation)["actions"] -- # convert gripper action -- action_chunk[:,-1] = 1 - action_chunk[:,-1] -- self.a = action_chunk -+ # Action conditioned predictor and zero-shot action inference with CEM -+ actions = self.world_model.infer_next_action( -+ z_n, -+ s_n, -+ z_n -+ ) # [4, 7] -+ -+ # compute predicted next states -+ s_n_k = s_n # [1, 1, 7] -+ predicted_states = s_n # [1, 1, 7] -+ for i in range(self.rollout_horizon): -+ a_n = actions[i].unsqueeze(0).unsqueeze(1) # [1, 1, 7] -+ s_next = compute_new_pose(s_n_k, a_n) # [1, 1, 7] -+ predicted_states = torch.cat((predicted_states, s_next), dim=1) # [1, i+2, 7] -+ s_n_k = s_next -+ -+ predicted_state_trajs.append((k, predicted_states.cpu())) -+ -+ print(f"Predicted new state: {predicted_states.cpu()}") # [1, rollout_horizon+1, 7] -+ print(f"Ground truth new state: {ground_truth_traj}") # [1, rollout_horizon+1, 7] -+ -+ -+ return Act(action=np.array([])) -+ -+ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: -+ super().reset(obs, instruction, **kwargs) -+ # TODO: actually calculate goal representation -+ self.goal_rep = instruction -+ return {} - -- # return Act(action=action_chunk[0]) -- return Act(action=action_chunk[0]) - - class OpenPiModel(Agent): - diff --git a/src/agents/policies.py b/src/agents/policies.py index 22f7d32..0e3be15 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -10,12 +10,12 @@ from operator import getitem from pathlib import Path from typing import Any, Union -from torchvision.io import decode_jpeg -from torchvision.transforms import v2 import numpy as np -from PIL import Image import torch +from PIL import Image +from torchvision.io import decode_jpeg +from torchvision.transforms import v2 @dataclass(kw_only=True) @@ -127,6 +127,7 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: } return info + class VjepaAC(Agent): def __init__( @@ -137,7 +138,7 @@ def __init__( **kwargs, ) -> None: super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) - import yaml + import yaml self.cfg_path = cfg_path with open(self.cfg_path, "r") as f: @@ -145,21 +146,19 @@ def __init__( self.model_name = model_name - def initialize(self): # VJEPA imports from app.vjepa_droid.transforms import make_transforms from inference.utils.world_model_wrapper import WorldModel - self.device = self.cfg.get("device", 'cuda') - self.save_path = self.cfg.get("save_path", 'exp_1.png') - self.goal_img = self.cfg.get("goal_img", 'exp_1.png') + self.device = self.cfg.get("device", "cuda") + self.goal_img = self.cfg.get("goal_img", "exp_1.png") # data config cfgs_data = self.cfg.get("data") fps = cfgs_data.get("fps", 4) crop_size = cfgs_data.get("crop_size", 256) - + # data augs cfgs_data_aug = self.cfg.get("data_aug") use_aa = cfgs_data_aug.get("auto_augment", False) @@ -168,9 +167,9 @@ def initialize(self): ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) reprob = cfgs_data_aug.get("reprob", 0.0) - + # cfgs_mpc_args config - cfgs_mpc_args= self.cfg.get("mpc_args") + cfgs_mpc_args = self.cfg.get("mpc_args") self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2) samples = cfgs_mpc_args.get("samples", 25) topk = cfgs_mpc_args.get("topk", 10) @@ -178,12 +177,10 @@ def initialize(self): momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15) momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15) momentum_std = cfgs_mpc_args.get("momentum_std", 0.75) - momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", .15) + momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", 0.15) maxnorm = cfgs_mpc_args.get("maxnorm", 0.075) verbose = cfgs_mpc_args.get("verbose", True) - - # Initialize transform (random-resize-crop augmentations) self.transform = make_transforms( random_horizontal_flip=horizontal_flip, @@ -196,23 +193,22 @@ def initialize(self): ) # load model - encoder, predictor = torch.hub.load("./", # root of the vjepa source code - self.model_name, # model type - source="local", - pretrained=True) + encoder, predictor = torch.hub.load( + "./", self.model_name, source="local", pretrained=True # root of the vjepa source code # model type + ) # load model to cuda encoder.to(self.device) - predictor.to(self.device) + predictor.to(self.device) # World model wrapper initialization - tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) + tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) self.world_model = WorldModel( encoder=encoder, predictor=predictor, tokens_per_frame=tokens_per_frame, mpc_args={ - "rollout": self.rollout_horizon, + "rollout": self.rollout_horizon, "samples": samples, "topk": topk, "cem_steps": cem_steps, @@ -224,69 +220,61 @@ def initialize(self): "verbose": verbose, }, normalize_reps=True, - device=self.device + device=self.device, ) - def act(self, obs: Obs) -> Act: with torch.no_grad(): # read from camera-stream side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) - side = torch.frombuffer(bytearray(side), - dtype=torch.uint8 - ) + side = torch.frombuffer(bytearray(side), dtype=torch.uint8) side = decode_jpeg(side) # [3, 720, 1280] -> [1, 720, 1280, 3] i.e, [T, C, Patches, dim] side = torch.permute(side, (1, 2, 0)).unsqueeze(0) # [1, 720, 1280, 3] -> [1, 3, 1, 256, 1408] i.e, [B, C, T, Patches, dim] - input_image_tensor = (self.transform(side)[None, :]).to(device=self.device, - dtype=torch.float, - non_blocking=True - ) + input_image_tensor = (self.transform(side)[None, :]).to( + device=self.device, dtype=torch.float, non_blocking=True + ) # Pre-trained VJEPA 2 ENCODER: [1, 3, 1, 256, 1408] -> [1, 256, 1408] - z_n = self.world_model.encode(input_image_tensor) + z_n = self.world_model.encode(input_image_tensor) # [1, 7] -> [B, state_dim] - # TODO: gripper state in DROID? In rcs 0: is close and 1: is open - s_n = torch.tensor((np.concatenate(([obs.info["xyzrpy"], - [1-obs.gripper]]), - axis=0))).unsqueeze(0).to(self.device, - dtype=torch.float, - non_blocking=True) + # TODO: check gripper state convention + # in DROID: 0: is close to 0.86: is open? + # In rcs 0: is close and 1: is open + s_n = ( + torch.tensor((np.concatenate(([obs.info["xyzrpy"], [1 - obs.gripper]]), axis=0))) # [1-obs.gripper] + .unsqueeze(0) + .to(self.device, dtype=torch.float, non_blocking=True) + ) # Action conditioned predictor and zero-shot action inference with CEM - actions = self.world_model.infer_next_action( - z_n, - s_n, - self.goal_rep - ) # [rollout_horizon, 7] + actions = self.world_model.infer_next_action(z_n, s_n, self.goal_rep) # [rollout_horizon, 7] first_action = actions[0].cpu() first_action[-1] = 1 - first_action[-1] - return Act(action=np.array(first_action)) def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: super().reset(obs, instruction, **kwargs) from PIL import Image + img = Image.open(self.goal_img) - # TODO: change goal image resolution # time dim exp - goal_image = np.expand_dims(np.array(img), axis=0) + goal_image = np.expand_dims(np.array(img), axis=0) # batch dim exp - goal_image_tensor = torch.tensor(self.transform(goal_image)[None, :]).to(device=self.device, - dtype=torch.float, - non_blocking=True) + goal_image_tensor = torch.tensor(self.transform(goal_image)[None, :]).to( + device=self.device, dtype=torch.float, non_blocking=True + ) - with torch.no_grad(): - self.goal_rep = self.world_model.encode(goal_image_tensor) + self.goal_rep = self.world_model.encode(goal_image_tensor) return {} @@ -327,7 +315,7 @@ def act(self, obs: Obs) -> Act: if self.s < self.chunks: self.s += 1 return Act(action=self.a[self.s]) - + else: self.s = 0 @@ -340,9 +328,6 @@ def act(self, obs: Obs) -> Act: wrist = torch.frombuffer(bytearray(wrist), dtype=torch.uint8) wrist = decode_jpeg(wrist) wrist = v2.Resize((256, 256))(wrist) - - - # side = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) # wrist = np.copy(obs.cameras["rgb_side"]).transpose(2, 0, 1) @@ -352,20 +337,19 @@ def act(self, obs: Obs) -> Act: { "observation/image": side, "observation/wrist_image": wrist, - "observation/state": np.concatenate([obs.info["joints"], [1-obs.gripper]]), + "observation/state": np.concatenate([obs.info["joints"], [1 - obs.gripper]]), "prompt": self.instruction, } ) action_chunk = self.policy.infer(observation)["actions"] # convert gripper action - action_chunk[:,-1] = 1 - action_chunk[:,-1] + action_chunk[:, -1] = 1 - action_chunk[:, -1] self.a = action_chunk # return Act(action=action_chunk[0]) return Act(action=action_chunk[0]) - class OpenVLAModel(Agent): # === Utilities === SYSTEM_PROMPT = ( From 2dfc6a35f530bec9335f1d8b0f7abe3812225f06 Mon Sep 17 00:00:00 2001 From: nisarganc Date: Wed, 14 Jan 2026 16:27:56 +0100 Subject: [PATCH 7/8] chores: readme cleanup --- README.md | 11 ++--------- src/agents/policies.py | 2 +- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7e15508..58b3a3d 100644 --- a/README.md +++ b/README.md @@ -125,12 +125,11 @@ conda activate vjepa2 ``` Clone the repo and install it. ```shell -git clone git@github.com:nisarganc/vjepa2.git +git clone git@github.com:facebookresearch/vjepa2.git cd vjepa2 pip install -e . pip install git+https://github.com/juelg/agents.git -git checkout nilavadi/vjepa-ac pip install -ve . ``` @@ -148,13 +147,7 @@ python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"c python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "model_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config # vjepa2-ac -python -m agents start-server vjepa --port=20997 --host=0.0.0.0 --kwargs='{"cfg_path": "configs/inference/vjepa2-ac-vitg/utn-robot.yaml", "model_name": "vjepa2_ac_vit_giant", "default_checkpoint_path": "../.cache/torch/hub/checkpoints/vjepa2-ac-vitg.pt"}' -``` - -### RCS run commands -```shell -# openpi -/home/juelg/miniconda3/envs/rcs_openpi/bin/python -m agents start-server openpi --port=20997 --host=0.0.0.0 --kwargs='{"checkpoint_path": "/mnt/dataset_drive/juelg/checkpoints/rcs_paper/pi0/pi0_rcs_utn/openpi_utn_wrist/{checkpoint_step}", "model_name": "pi0_rcs_utn", "checkpoint_step": 29999}' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config +python -m agents start-server vjepa --port=20997 --host=0.0.0.0 --kwargs='{"cfg_path": "configs/inference/vjepa2-ac-vitg/.yaml", "model_name": "vjepa2_ac_vit_giant", "default_checkpoint_path": "../.cache/torch/hub/checkpoints/vjepa2-ac-vitg.pt"}' ``` diff --git a/src/agents/policies.py b/src/agents/policies.py index 0e3be15..c949ea8 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -149,7 +149,7 @@ def __init__( def initialize(self): # VJEPA imports from app.vjepa_droid.transforms import make_transforms - from inference.utils.world_model_wrapper import WorldModel + from notebooks.utils.world_model_wrapper import WorldModel self.device = self.cfg.get("device", "cuda") self.goal_img = self.cfg.get("goal_img", "exp_1.png") From eece4bd5534c3b52d3252a70a9ca139d64aa7bca Mon Sep 17 00:00:00 2001 From: nisarganc Date: Wed, 14 Jan 2026 16:41:12 +0100 Subject: [PATCH 8/8] chore: clean imports --- src/agents/policies.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/agents/policies.py b/src/agents/policies.py index 68a1c0e..a56a52e 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -12,10 +12,7 @@ from typing import Any, Union import numpy as np -import torch from PIL import Image -from torchvision.io import decode_jpeg -from torchvision.transforms import v2 @dataclass(kw_only=True) @@ -147,6 +144,9 @@ def __init__( self.model_name = model_name def initialize(self): + # torch import + import torch + # VJEPA imports from app.vjepa_droid.transforms import make_transforms from notebooks.utils.world_model_wrapper import WorldModel @@ -156,7 +156,6 @@ def initialize(self): # data config cfgs_data = self.cfg.get("data") - fps = cfgs_data.get("fps", 4) crop_size = cfgs_data.get("crop_size", 256) # data augs @@ -224,6 +223,9 @@ def initialize(self): ) def act(self, obs: Obs) -> Act: + # torch imports + import torch + from torchvision.io import decode_jpeg with torch.no_grad(): @@ -262,7 +264,8 @@ def act(self, obs: Obs) -> Act: def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: super().reset(obs, instruction, **kwargs) - from PIL import Image + # imports + import torch img = Image.open(self.goal_img)