diff --git a/README.md b/README.md index 6f1acd0..5700aac 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,24 @@ pip install git+https://github.com/juelg/agents.git For more details, see the [OpenVLA github page](https://github.com/openvla/openvla). +### OpenPi / Pi0 +To use OpenPi, create a new conda environment: +```shell +conda create -n openpi python=3.11 -y +conda activate openpi +``` +Clone the repo and install it. +```shell +git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git +# Or if you already cloned the repo: +git submodule update --init --recursive +# install dependencies +GIT_LFS_SKIP_SMUDGE=1 uv sync +GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . +``` +For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi). + + ## Usage To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g. ```shell @@ -107,6 +125,8 @@ To start an agents server use the `start-server` command where `kwargs` is a dic python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}' # openvla python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}' +# openpi +python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "/{checkpoint_step}", "train_config_name": "pi0_rcs", "checkpoint_step": }' # leave "{checkpoint_step}" it will be replaced, "train_config_name" is the key for the training config ``` There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint. diff --git a/src/agents/policies.py b/src/agents/policies.py index e88ccd7..be8fdb4 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -124,6 +124,66 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: return info +class OpenPiModel(Agent): + + def __init__( + self, + train_config_name: str = "pi0_droid", + default_checkpoint_path: str = "gs://openpi-assets/checkpoints/pi0_droid", + execution_horizon=20, + **kwargs, + ) -> None: + super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs) + from openpi.training import config + + logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}") + self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step) + + self.cfg = config.get_config(train_config_name) + self.execution_horizon = execution_horizon + + self.chunk_counter = self.execution_horizon + self._cached_action_chunk = None + + def initialize(self): + from openpi.policies import policy_config + from openpi.shared import download + + checkpoint_dir = download.maybe_download(self.openpi_path) + + # Create a trained policy. + self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir) + + def act(self, obs: Obs) -> Act: + if self.chunk_counter < self.execution_horizon: + self.chunk_counter += 1 + return Act(action=self._cached_action_chunk[self.chunk_counter]) + + else: + self.chunk_counter = 0 + observation = {f"observation/{k}": np.copy(v).transpose(2, 0, 1) for k, v in obs.cameras.items()} + observation.update( + { + # openpi expects 0 as gripper open and 1 as closed + "observation/state": np.concatenate([obs.info["joints"], [1 - obs.gripper]]), + "prompt": self.instruction, + } + ) + action_chunk = self.policy.infer(observation)["actions"] + + # convert gripper action into agents format + action_chunk[:, -1] = 1 - action_chunk[:, -1] + self._cached_action_chunk = action_chunk + + return Act(action=action_chunk[0]) + + def reset(self, obs: Obs, instruction: Any): + super().reset(obs, instruction) + self.chunk_counter = self.execution_horizon + self._cached_action_chunk = None + return {} + + class OpenVLAModel(Agent): # === Utilities === SYSTEM_PROMPT = ( @@ -457,4 +517,5 @@ def act(self, obs: Obs) -> Act: openvla=OpenVLAModel, octodist=OctoActionDistribution, openvladist=OpenVLADistribution, + openpi=OpenPiModel, )