diff --git a/README.md b/README.md index 5700aac..58b3a3d 100644 --- a/README.md +++ b/README.md @@ -117,18 +117,40 @@ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . ``` For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi). +### vjEPA2-ac +To use VJEPA2-AC, create a new conda environment: +```shell +conda create -n vjepa2 python=3.12 +conda activate vjepa2 +``` +Clone the repo and install it. +```shell +git clone git@github.com:facebookresearch/vjepa2.git +cd vjepa2 +pip install -e . + +pip install git+https://github.com/juelg/agents.git +pip install -ve . + +``` ## Usage To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g. ```shell # octo python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}' + # openvla python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}' + # openpi -python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "train_config_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "train_config_name" is the key for the training config +python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "model_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config + +# vjepa2-ac +python -m agents start-server vjepa --port=20997 --host=0.0.0.0 --kwargs='{"cfg_path": "configs/inference/vjepa2-ac-vitg/.yaml", "model_name": "vjepa2_ac_vit_giant", "default_checkpoint_path": "../.cache/torch/hub/checkpoints/vjepa2-ac-vitg.pt"}' ``` + There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint. The `run-eval-post-training` command evaluates a range of checkpoints in parallel. In both cases environment and arguments as well as policy and arguments and wandb config for logging can be passed as CLI arguments. diff --git a/src/agents/policies.py b/src/agents/policies.py index be8fdb4..a56a52e 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -1,3 +1,4 @@ +import base64 import copy import json import logging @@ -124,6 +125,163 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: return info +class VjepaAC(Agent): + + def __init__( + self, + cfg_path: str, + model_name: str = "vjepa2_ac_vit_giant", + default_checkpoint_path: str = "", + **kwargs, + ) -> None: + super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) + import yaml + + self.cfg_path = cfg_path + with open(self.cfg_path, "r") as f: + self.cfg = yaml.safe_load(f) + + self.model_name = model_name + + def initialize(self): + # torch import + import torch + + # VJEPA imports + from app.vjepa_droid.transforms import make_transforms + from notebooks.utils.world_model_wrapper import WorldModel + + self.device = self.cfg.get("device", "cuda") + self.goal_img = self.cfg.get("goal_img", "exp_1.png") + + # data config + cfgs_data = self.cfg.get("data") + crop_size = cfgs_data.get("crop_size", 256) + + # data augs + cfgs_data_aug = self.cfg.get("data_aug") + use_aa = cfgs_data_aug.get("auto_augment", False) + horizontal_flip = cfgs_data_aug.get("horizontal_flip", False) + motion_shift = cfgs_data_aug.get("motion_shift", False) + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + reprob = cfgs_data_aug.get("reprob", 0.0) + + # cfgs_mpc_args config + cfgs_mpc_args = self.cfg.get("mpc_args") + self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2) + samples = cfgs_mpc_args.get("samples", 25) + topk = cfgs_mpc_args.get("topk", 10) + cem_steps = cfgs_mpc_args.get("cem_steps", 1) + momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15) + momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15) + momentum_std = cfgs_mpc_args.get("momentum_std", 0.75) + momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", 0.15) + maxnorm = cfgs_mpc_args.get("maxnorm", 0.075) + verbose = cfgs_mpc_args.get("verbose", True) + + # Initialize transform (random-resize-crop augmentations) + self.transform = make_transforms( + random_horizontal_flip=horizontal_flip, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size, + ) + + # load model + encoder, predictor = torch.hub.load( + "./", self.model_name, source="local", pretrained=True # root of the vjepa source code # model type + ) + + # load model to cuda + encoder.to(self.device) + predictor.to(self.device) + + # World model wrapper initialization + tokens_per_frame = int((crop_size // encoder.patch_size) ** 2) + self.world_model = WorldModel( + encoder=encoder, + predictor=predictor, + tokens_per_frame=tokens_per_frame, + mpc_args={ + "rollout": self.rollout_horizon, + "samples": samples, + "topk": topk, + "cem_steps": cem_steps, + "momentum_mean": momentum_mean, + "momentum_mean_gripper": momentum_mean_gripper, + "momentum_std": momentum_std, + "momentum_std_gripper": momentum_std_gripper, + "maxnorm": maxnorm, + "verbose": verbose, + }, + normalize_reps=True, + device=self.device, + ) + + def act(self, obs: Obs) -> Act: + # torch imports + import torch + from torchvision.io import decode_jpeg + + with torch.no_grad(): + + # read from camera-stream + side = base64.urlsafe_b64decode(obs.cameras["rgb_side"]) + side = torch.frombuffer(bytearray(side), dtype=torch.uint8) + side = decode_jpeg(side) + + # [3, 720, 1280] -> [1, 720, 1280, 3] i.e, [T, C, Patches, dim] + side = torch.permute(side, (1, 2, 0)).unsqueeze(0) + + # [1, 720, 1280, 3] -> [1, 3, 1, 256, 1408] i.e, [B, C, T, Patches, dim] + input_image_tensor = (self.transform(side)[None, :]).to( + device=self.device, dtype=torch.float, non_blocking=True + ) + # Pre-trained VJEPA 2 ENCODER: [1, 3, 1, 256, 1408] -> [1, 256, 1408] + z_n = self.world_model.encode(input_image_tensor) + + # [1, 7] -> [B, state_dim] + # TODO: check gripper state convention + # in DROID: 0: is close to 0.86: is open? + # In rcs 0: is close and 1: is open + s_n = ( + torch.tensor((np.concatenate(([obs.info["xyzrpy"], [1 - obs.gripper]]), axis=0))) # [1-obs.gripper] + .unsqueeze(0) + .to(self.device, dtype=torch.float, non_blocking=True) + ) + + # Action conditioned predictor and zero-shot action inference with CEM + actions = self.world_model.infer_next_action(z_n, s_n, self.goal_rep) # [rollout_horizon, 7] + + first_action = actions[0].cpu() + first_action[-1] = 1 - first_action[-1] + + return Act(action=np.array(first_action)) + + def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: + super().reset(obs, instruction, **kwargs) + # imports + import torch + + img = Image.open(self.goal_img) + + # time dim exp + goal_image = np.expand_dims(np.array(img), axis=0) + # batch dim exp + goal_image_tensor = torch.tensor(self.transform(goal_image)[None, :]).to( + device=self.device, dtype=torch.float, non_blocking=True + ) + + with torch.no_grad(): + self.goal_rep = self.world_model.encode(goal_image_tensor) + + return {} + + class OpenPiModel(Agent): def __init__( @@ -518,4 +676,5 @@ def act(self, obs: Obs) -> Act: octodist=OctoActionDistribution, openvladist=OpenVLADistribution, openpi=OpenPiModel, + vjepa=VjepaAC, )