Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,40 @@ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
```
For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi).

### vjEPA2-ac
To use VJEPA2-AC, create a new conda environment:
```shell
conda create -n vjepa2 python=3.12
conda activate vjepa2
```
Clone the repo and install it.
```shell
git clone git@github.com:facebookresearch/vjepa2.git
cd vjepa2
pip install -e .

pip install git+https://github.com/juelg/agents.git
pip install -ve .

```

## Usage
To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g.
```shell
# octo
python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}'

# openvla
python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}'

# openpi
python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "<path to checkpoint>/{checkpoint_step}", "train_config_name": "pi0_rcs", "checkpoint_step": <checkpoint_step>}' # leave "{checkpoint_step}" it will be replaced, "train_config_name" is the key for the training config
python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "<path to checkpoint>/{checkpoint_step}", "model_name": "pi0_rcs", "checkpoint_step": <checkpoint_step>}' # leave "{checkpoint_step}" it will be replaced, "model_name" is the key for the training config

# vjepa2-ac
python -m agents start-server vjepa --port=20997 --host=0.0.0.0 --kwargs='{"cfg_path": "configs/inference/vjepa2-ac-vitg/<your_config>.yaml", "model_name": "vjepa2_ac_vit_giant", "default_checkpoint_path": "../.cache/torch/hub/checkpoints/vjepa2-ac-vitg.pt"}'
```


There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint.
The `run-eval-post-training` command evaluates a range of checkpoints in parallel.
In both cases environment and arguments as well as policy and arguments and wandb config for logging can be passed as CLI arguments.
Expand Down
159 changes: 159 additions & 0 deletions src/agents/policies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import copy
import json
import logging
Expand Down Expand Up @@ -124,6 +125,163 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
return info


class VjepaAC(Agent):

def __init__(
self,
cfg_path: str,
model_name: str = "vjepa2_ac_vit_giant",
default_checkpoint_path: str = "",
**kwargs,
) -> None:
super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs)
import yaml

self.cfg_path = cfg_path
with open(self.cfg_path, "r") as f:
self.cfg = yaml.safe_load(f)

self.model_name = model_name

def initialize(self):
# torch import
import torch

# VJEPA imports
from app.vjepa_droid.transforms import make_transforms
from notebooks.utils.world_model_wrapper import WorldModel

self.device = self.cfg.get("device", "cuda")
self.goal_img = self.cfg.get("goal_img", "exp_1.png")

# data config
cfgs_data = self.cfg.get("data")
crop_size = cfgs_data.get("crop_size", 256)

# data augs
cfgs_data_aug = self.cfg.get("data_aug")
use_aa = cfgs_data_aug.get("auto_augment", False)
horizontal_flip = cfgs_data_aug.get("horizontal_flip", False)
motion_shift = cfgs_data_aug.get("motion_shift", False)
ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
reprob = cfgs_data_aug.get("reprob", 0.0)

# cfgs_mpc_args config
cfgs_mpc_args = self.cfg.get("mpc_args")
self.rollout_horizon = cfgs_mpc_args.get("rollout_horizon", 2)
samples = cfgs_mpc_args.get("samples", 25)
topk = cfgs_mpc_args.get("topk", 10)
cem_steps = cfgs_mpc_args.get("cem_steps", 1)
momentum_mean = cfgs_mpc_args.get("momentum_mean", 0.15)
momentum_mean_gripper = cfgs_mpc_args.get("momentum_mean_gripper", 0.15)
momentum_std = cfgs_mpc_args.get("momentum_std", 0.75)
momentum_std_gripper = cfgs_mpc_args.get("momentum_std_gripper", 0.15)
maxnorm = cfgs_mpc_args.get("maxnorm", 0.075)
verbose = cfgs_mpc_args.get("verbose", True)

# Initialize transform (random-resize-crop augmentations)
self.transform = make_transforms(
random_horizontal_flip=horizontal_flip,
random_resize_aspect_ratio=ar_range,
random_resize_scale=rr_scale,
reprob=reprob,
auto_augment=use_aa,
motion_shift=motion_shift,
crop_size=crop_size,
)

# load model
encoder, predictor = torch.hub.load(
"./", self.model_name, source="local", pretrained=True # root of the vjepa source code # model type
)

# load model to cuda
encoder.to(self.device)
predictor.to(self.device)

# World model wrapper initialization
tokens_per_frame = int((crop_size // encoder.patch_size) ** 2)
self.world_model = WorldModel(
encoder=encoder,
predictor=predictor,
tokens_per_frame=tokens_per_frame,
mpc_args={
"rollout": self.rollout_horizon,
"samples": samples,
"topk": topk,
"cem_steps": cem_steps,
"momentum_mean": momentum_mean,
"momentum_mean_gripper": momentum_mean_gripper,
"momentum_std": momentum_std,
"momentum_std_gripper": momentum_std_gripper,
"maxnorm": maxnorm,
"verbose": verbose,
},
normalize_reps=True,
device=self.device,
)

def act(self, obs: Obs) -> Act:
# torch imports
import torch
from torchvision.io import decode_jpeg

with torch.no_grad():

# read from camera-stream
side = base64.urlsafe_b64decode(obs.cameras["rgb_side"])
side = torch.frombuffer(bytearray(side), dtype=torch.uint8)
side = decode_jpeg(side)

# [3, 720, 1280] -> [1, 720, 1280, 3] i.e, [T, C, Patches, dim]
side = torch.permute(side, (1, 2, 0)).unsqueeze(0)

# [1, 720, 1280, 3] -> [1, 3, 1, 256, 1408] i.e, [B, C, T, Patches, dim]
input_image_tensor = (self.transform(side)[None, :]).to(
device=self.device, dtype=torch.float, non_blocking=True
)
# Pre-trained VJEPA 2 ENCODER: [1, 3, 1, 256, 1408] -> [1, 256, 1408]
z_n = self.world_model.encode(input_image_tensor)

# [1, 7] -> [B, state_dim]
# TODO: check gripper state convention
# in DROID: 0: is close to 0.86: is open?
# In rcs 0: is close and 1: is open
s_n = (
torch.tensor((np.concatenate(([obs.info["xyzrpy"], [1 - obs.gripper]]), axis=0))) # [1-obs.gripper]
.unsqueeze(0)
.to(self.device, dtype=torch.float, non_blocking=True)
)

# Action conditioned predictor and zero-shot action inference with CEM
actions = self.world_model.infer_next_action(z_n, s_n, self.goal_rep) # [rollout_horizon, 7]

first_action = actions[0].cpu()
first_action[-1] = 1 - first_action[-1]

return Act(action=np.array(first_action))

def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
super().reset(obs, instruction, **kwargs)
# imports
import torch

img = Image.open(self.goal_img)

# time dim exp
goal_image = np.expand_dims(np.array(img), axis=0)
# batch dim exp
goal_image_tensor = torch.tensor(self.transform(goal_image)[None, :]).to(
device=self.device, dtype=torch.float, non_blocking=True
)

with torch.no_grad():
self.goal_rep = self.world_model.encode(goal_image_tensor)

return {}


class OpenPiModel(Agent):

def __init__(
Expand Down Expand Up @@ -518,4 +676,5 @@ def act(self, obs: Obs) -> Act:
octodist=OctoActionDistribution,
openvladist=OpenVLADistribution,
openpi=OpenPiModel,
vjepa=VjepaAC,
)