diff --git a/AGENTS.md b/AGENTS.md index 981e4ac..8b9a78b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,26 +1,31 @@ # Repository Guidelines ## Project Structure & Module Organization -`sumo/` is the active package. Put task code in `sumo/tasks/g1/` or `sumo/tasks/spot/`, controller wrappers in `sumo/controller/`, CLI and MPC entrypoints in `sumo/cli.py` and `sumo/run_mpc/`, and MuJoCo XML/mesh assets under `sumo/models/`. `g1_extensions/` contains the optional pybind11 extension for G1 backends. `tests/` holds the pytest suite. Treat `judo-private/` as an old fork used only as migration reference; match new code to the public `judo-rai` repo instead. Use `.judo-src/` as the local reference checkout when porting behavior. +`sumo/` is the main Python package. Task implementations live in `sumo/tasks/g1/` and `sumo/tasks/spot/`; shared controller code is in `sumo/controller/`; CLI and headless MPC entry points are `sumo/cli.py` and `sumo/run_mpc/`. MuJoCo XML and mesh assets are packaged under `sumo/models/`. Native G1 rollout bindings live in `g1_extensions/` with CMake/pybind11 sources. Tests belong in `tests/`, while generated run artifacts should stay in `out/`, `outputs/`, or `run_mpc/results/`. ## Build, Test, and Development Commands -Use `pixi` by default: +Use Pixi as the default environment manager. ```bash -pixi install -pixi run build -pixi run build-judo-ext -pixi run pytest tests/ -v -pixi run sumo --init-task g1_box --num-episodes 1 +pixi install # create the default dev environment +pixi run build # build g1_extensions and Judo MuJoCo extensions +pixi run sumo task=spot_box_push # launch the interactive app +pixi run python -m sumo.run_mpc --init-task=g1_box --num-episodes=2 +pixi run pytest tests/ -v # run the test suite +pixi run pre-commit run --all-files # run Ruff, formatting, and whitespace hooks +pixi run pyright sumo/ # run static type checks ``` -`pixi run build` compiles `g1_extensions`; `pixi run build-judo-ext` builds Judo's `mujoco_extensions` backend for Spot tasks. `pixi install -e dev` remains equivalent to `pixi install` for this repo, but `pixi run` without `-e` already targets the default full app environment. Use `pip install -e .` only as a lightweight fallback when you do not need the managed `pixi` environment. +Run `pixi run build` before G1 or Spot simulations that require compiled extension backends. ## Coding Style & Naming Conventions -Write Python with 4-space indentation, explicit imports, and type-aware dataclass configs. Follow existing naming: `snake_case` for modules/functions, `PascalCase` for classes, and `*Config` for task/config dataclasses such as `G1BoxConfig`. New task modules should follow patterns like `g1_box.py` or `spot_table_drag.py`. Prefer `ruff check .` before submitting changes. When porting code, preserve `judo-rai` APIs and conventions first, even if `judo-private/` differs. +Python uses 4-space indentation, explicit imports, and Ruff formatting with a 120-character line length. Keep module and function names in `snake_case`, classes in `PascalCase`, and task configs named with a `Config` suffix, for example `G1BoxConfig`. New task files should follow existing patterns such as `g1_table_push.py` or `spot_tire_roll.py`. Prefer dataclass-style configs and match public `judo-rai` APIs rather than introducing local compatibility shims. ## Testing Guidelines -Add tests in `tests/` and name files `test_*.py`. Cover imports, task registration, and task-specific behavior such as `reset()`, `reward()`, and control shape. Mark extension-only tests with `@pytest.mark.g1_extensions`; `tests/conftest.py` skips them when `g1_extensions` is unavailable. When adding a task, update registration in `sumo/tasks/__init__.py` and add at least one smoke test. +Use pytest. Name files `test_*.py` and tests `test_*`. Add coverage for imports, task registration, reset behavior, reward shape, and control dimensions when adding or changing tasks. Mark tests that require the optional native G1 extension with `@pytest.mark.g1_extensions`; `tests/conftest.py` skips those when the extension is not built. Run `pixi run pytest -rsx` to mirror CI output. ## Commit & Pull Request Guidelines -The repo history is minimal, so use short imperative commit subjects. Scoped messages are preferred, for example `feat(tasks): port spot_table_drag`. PRs should summarize the port or feature, note what was aligned to public `judo-rai`, list commands run, and include screenshots or logs for simulator-facing changes. Avoid mixing migration cleanup with unrelated refactors. +The existing history uses short, imperative subjects such as `Update README.md` and `clean up paper tasks to match paper rewards`. Keep commits focused and avoid mixing task behavior changes with cleanup. Before pushing or requesting review, run `pixi run pre-commit run --all-files` and `pixi run pyright sumo/` locally. PRs should include a concise summary, commands run, and screenshots or logs for simulator-facing changes. + +## Security & Configuration Tips +Do not commit generated build directories, local `.judo-src/` checkouts, cache folders, or large simulation outputs. Keep dependency and build changes in `pyproject.toml`, `pixi.lock`, and the relevant CMake files together so CI can reproduce them. diff --git a/pixi.lock b/pixi.lock index 4341469..9dfe0dc 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5,8 +5,6 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple - options: - pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda @@ -160,7 +158,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/d0/34/9e591954939276bb679b73773836c6684c22e56d05980e31d52a9a8deb18/lxml-6.0.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/44/29/c89ab58a20965d630df347999b55306b146473cadd0fb8b7879c85ca7a54/manifold3d-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/03/30/e54ececd0403a5495c340b693075abec92a6d17dc44283b6cb059534f7ed/mapbox_earcut-2.0.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl @@ -353,7 +351,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/5d/f4/2a94a3d3dfd6c6b433501b8d470a1960a20ecce93245cf2db1706adf6c19/lxml-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/68/d0/b1b5a1dc4db9e4ce844d025a72777a5a2251a33c5a84cff7e64ed9425373/manifold3d-3.4.0-cp313-cp313-macosx_10_14_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8b/7c/c5dd5b255b9828ba5df729e62fdd470a322c938f07ef392ca03c0592bb3a/mapbox_earcut-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl @@ -545,7 +543,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/53/fd/4e8f0540608977aea078bf6d79f128e0e2c2bba8af1acf775c30baa70460/lxml-6.0.2-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/f2/b3/47473d3f3bda127569e9615416ff24b1049b74cc08540572a9747e00bbe0/manifold3d-3.4.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/1a/3f/03f23eac9831e7d0d8da3d6993695a9a3724659c94e9997f6b7aaccc199d/mapbox_earcut-2.0.0-cp313-cp313-macosx_11_0_arm64.whl @@ -585,8 +583,6 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple - options: - pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda @@ -740,7 +736,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/d0/34/9e591954939276bb679b73773836c6684c22e56d05980e31d52a9a8deb18/lxml-6.0.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/44/29/c89ab58a20965d630df347999b55306b146473cadd0fb8b7879c85ca7a54/manifold3d-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/03/30/e54ececd0403a5495c340b693075abec92a6d17dc44283b6cb059534f7ed/mapbox_earcut-2.0.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl @@ -933,7 +929,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/5d/f4/2a94a3d3dfd6c6b433501b8d470a1960a20ecce93245cf2db1706adf6c19/lxml-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/68/d0/b1b5a1dc4db9e4ce844d025a72777a5a2251a33c5a84cff7e64ed9425373/manifold3d-3.4.0-cp313-cp313-macosx_10_14_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8b/7c/c5dd5b255b9828ba5df729e62fdd470a322c938f07ef392ca03c0592bb3a/mapbox_earcut-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl @@ -1125,7 +1121,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl - - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 + - pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 - pypi: https://files.pythonhosted.org/packages/53/fd/4e8f0540608977aea078bf6d79f128e0e2c2bba8af1acf775c30baa70460/lxml-6.0.2-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/f2/b3/47473d3f3bda127569e9615416ff24b1049b74cc08540572a9747e00bbe0/manifold3d-3.4.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/1a/3f/03f23eac9831e7d0d8da3d6993695a9a3724659c94e9997f6b7aaccc199d/mapbox_earcut-2.0.0-cp313-cp313-macosx_11_0_arm64.whl @@ -2295,6 +2291,7 @@ packages: requires_dist: - ninja ; extra == 'dev' - pybind11 ; extra == 'dev' + editable: true - conda: https://conda.anaconda.org/conda-forge/linux-64/gcc-14.3.0-h0dff253_18.conda sha256: 9b34b57b06b485e33a40d430f71ac88c8f381673592507cf7161c50ff0832772 md5: 52d6457abc42e320787ada5f9033fa99 @@ -2831,12 +2828,12 @@ packages: requires_dist: - referencing>=0.31.0 requires_python: '>=3.9' -- pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#e2605e510c29e191053c7514018409b3944bcc97 +- pypi: git+https://github.com/bdaiinstitute/judo.git?rev=dta%2Ffix_for_sumo#9acaa753131e982a87e5912bcd638a1281e48ca4 name: judo-rai version: 0.0.7 requires_dist: - dora-utils - - mujoco>=3.5.0,<3.6 + - mujoco>=3.5.0 - numpy - pillow - pycparser @@ -5953,17 +5950,18 @@ packages: - pypi: ./ name: sumo version: 0.0.1 - sha256: e3102fcdc1710b3c41ffa38bac188a3122735afa8156f8668fcc5420ed4799dc + sha256: 6e64d863aeaaf5de30159112a5dbb8576b2965de95fc9e2e1d5f20d31243c91f requires_dist: - judo-rai @ git+https://github.com/bdaiinstitute/judo.git@dta/fix_for_sumo - numpy - - mujoco + - mujoco>=3.5.0,<3.6.0 - h5py - tyro - tqdm - scipy - pytest ; extra == 'dev' - ruff ; extra == 'dev' + editable: true - pypi: https://files.pythonhosted.org/packages/3a/83/4f5b250220e1a5acd31345a5ec1c95a7769725d0d8135276f399f44062f8/svg_path-7.0-py2.py3-none-any.whl name: svg-path version: '7.0' diff --git a/pyproject.toml b/pyproject.toml index ebece41..496aafe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ version = "0.0.1" dependencies = [ "judo-rai @ git+https://github.com/bdaiinstitute/judo.git@dta/fix_for_sumo", "numpy", - "mujoco", + "mujoco>=3.5.0, <3.6.0", "h5py", "tyro", "tqdm", diff --git a/sumo/run_mpc/run_mpc.py b/sumo/run_mpc/run_mpc.py index bce2fbb..5eddcb0 100644 --- a/sumo/run_mpc/run_mpc.py +++ b/sumo/run_mpc/run_mpc.py @@ -154,8 +154,6 @@ def run_single_episode(config, task, controller, sim, viser_model=None, episode_ time=curr_time, qpos=np.array(task.data.qpos), qvel=np.array(task.data.qvel), - xpos=np.array(task.data.xpos), - xquat=np.array(task.data.xquat), mocap_pos=np.array(task.data.mocap_pos), mocap_quat=np.array(task.data.mocap_quat), sim_metadata={}, diff --git a/sumo/tasks/__init__.py b/sumo/tasks/__init__.py index 6749feb..b834161 100644 --- a/sumo/tasks/__init__.py +++ b/sumo/tasks/__init__.py @@ -29,12 +29,15 @@ "spot_cone_push", "spot_rack_push", "spot_tire_push", + "spot_box_upright", + "spot_chair_upright", "spot_cone_upright", + "spot_rack_upright", + "spot_tire_upright", "spot_chair_ramp", "spot_barrier_upright", "spot_barrier_drag", "spot_tire_roll", - "spot_tire_upright", "spot_tire_stack", "spot_tire_rack_drag", "spot_rugged_box_push", @@ -45,11 +48,14 @@ from sumo.tasks.spot.spot_barrier_upright import SpotBarrierUpright, SpotBarrierUprightConfig from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig from sumo.tasks.spot.spot_box_push import SpotBoxPush, SpotBoxPushConfig +from sumo.tasks.spot.spot_box_upright import SpotBoxUpright, SpotBoxUprightConfig from sumo.tasks.spot.spot_chair_push import SpotChairPush, SpotChairPushConfig from sumo.tasks.spot.spot_chair_ramp import SpotChairRamp, SpotChairRampConfig +from sumo.tasks.spot.spot_chair_upright import SpotChairUpright, SpotChairUprightConfig from sumo.tasks.spot.spot_cone_push import SpotConePush, SpotConePushConfig from sumo.tasks.spot.spot_cone_upright import SpotConeUpright, SpotConeUprightConfig from sumo.tasks.spot.spot_rack_push import SpotRackPush, SpotRackPushConfig +from sumo.tasks.spot.spot_rack_upright import SpotRackUpright, SpotRackUprightConfig from sumo.tasks.spot.spot_rugged_box_push import SpotRuggedBoxPush, SpotRuggedBoxPushConfig from sumo.tasks.spot.spot_tire_push import SpotTirePush, SpotTirePushConfig from sumo.tasks.spot.spot_tire_rack_drag import SpotTireRackDrag, SpotTireRackDragConfig @@ -63,12 +69,15 @@ register_task("spot_cone_push", SpotConePush, SpotConePushConfig) register_task("spot_rack_push", SpotRackPush, SpotRackPushConfig) register_task("spot_tire_push", SpotTirePush, SpotTirePushConfig) +register_task("spot_box_upright", SpotBoxUpright, SpotBoxUprightConfig) +register_task("spot_chair_upright", SpotChairUpright, SpotChairUprightConfig) register_task("spot_cone_upright", SpotConeUpright, SpotConeUprightConfig) +register_task("spot_rack_upright", SpotRackUpright, SpotRackUprightConfig) +register_task("spot_tire_upright", SpotTireUpright, SpotTireUprightConfig) register_task("spot_chair_ramp", SpotChairRamp, SpotChairRampConfig) register_task("spot_barrier_upright", SpotBarrierUpright, SpotBarrierUprightConfig) register_task("spot_barrier_drag", SpotBarrierDrag, SpotBarrierDragConfig) register_task("spot_tire_roll", SpotTireRoll, SpotTireRollConfig) -register_task("spot_tire_upright", SpotTireUpright, SpotTireUprightConfig) register_task("spot_tire_stack", SpotTireStack, SpotTireStackConfig) register_task("spot_tire_rack_drag", SpotTireRackDrag, SpotTireRackDragConfig) register_task("spot_rugged_box_push", SpotRuggedBoxPush, SpotRuggedBoxPushConfig) diff --git a/sumo/tasks/spot/spot_base.py b/sumo/tasks/spot/spot_base.py index bc773e9..73e915c 100644 --- a/sumo/tasks/spot/spot_base.py +++ b/sumo/tasks/spot/spot_base.py @@ -6,9 +6,10 @@ import re import tempfile from pathlib import Path -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, cast from judo import MODEL_PATH as JUDO_MODEL_PATH +from judo.tasks.base import TaskConfig from judo.tasks.spot.spot_base import SpotBase as _JudoSpotBase from judo.tasks.spot.spot_base import SpotBaseConfig @@ -16,7 +17,7 @@ XML_PATH = str(JUDO_MODEL_PATH / "xml" / "spot_primitive" / "robot.xml") -ConfigT = TypeVar("ConfigT", bound=SpotBaseConfig) +ConfigT = TypeVar("ConfigT", bound=TaskConfig) _INCLUDE_RE = re.compile(r'( None: class SpotBase(SpotAssetMixin, _JudoSpotBase, Generic[ConfigT]): """Sumo SpotBase wrapper that composes local task XML with public Spot definitions.""" + config_t: type[ConfigT] # pyright: ignore[reportIncompatibleVariableOverride] + config: ConfigT + @staticmethod def _is_relative_to(path: Path, parent: Path) -> bool: try: @@ -158,7 +162,7 @@ def __init__( use_gripper: bool = False, use_legs: bool = False, use_torso: bool = False, - config: SpotBaseConfig | None = None, + config: ConfigT | None = None, ) -> None: super().__init__( model_path=str(self._materialize_model_path(model_path)), @@ -166,5 +170,8 @@ def __init__( use_gripper=use_gripper, use_legs=use_legs, use_torso=use_torso, - config=config, + config=cast(Any, config), ) + + +__all__ = ["SpotAssetMixin", "SpotBase", "SpotBaseConfig"] diff --git a/sumo/tasks/spot/spot_box_push.py b/sumo/tasks/spot/spot_box_push.py index c1be174..04c4f4f 100644 --- a/sumo/tasks/spot/spot_box_push.py +++ b/sumo/tasks/spot/spot_box_push.py @@ -1,13 +1,71 @@ # Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np from judo.tasks.spot.spot_box_push import SpotBoxPush as _JudoSpotBoxPush -from judo.tasks.spot.spot_box_push import SpotBoxPushConfig +from judo.tasks.spot.spot_constants import BOX_HALF_LENGTH +from judo.utils.fields import np_1d_field from sumo.tasks.spot.spot_base import SpotAssetMixin +from sumo.tasks.spot.spot_push import ( + SpotPushConfig, + goal_distance_reward, + gripper_distance_reward, + object_linear_velocity_reward, +) + + +@dataclass +class SpotBoxPushConfig(SpotPushConfig): + """Configuration for Sumo's simplified Spot box pushing analysis task.""" + + goal_position: np.ndarray = np_1d_field( + np.array([0.0, 0.0, BOX_HALF_LENGTH]), + names=["x", "y", "z"], + mins=[-5.0, -5.0, 0.0], + maxs=[5.0, 5.0, 3.0], + vis_name="goal_position", + xyz_vis_indices=[0, 1, None], + ) class SpotBoxPush(SpotAssetMixin, _JudoSpotBoxPush): - """Public judo-rai SpotBoxPush with Sumo asset resolution.""" + """Spot box pushing with Sumo's simplified analysis reward.""" + + config_t: type[SpotBoxPushConfig] = SpotBoxPushConfig # type: ignore[assignment] + config: Any + + def __init__(self, config: SpotBoxPushConfig | None = None) -> None: + super().__init__(config=cast(Any, config)) + self.object_vel_idx = self.get_joint_velocity_start_index("box_joint") + + def reward( + self, + states: np.ndarray, + sensors: np.ndarray, + controls: np.ndarray, + system_metadata: dict[str, Any] | None = None, + ) -> np.ndarray: + """Reward using only goal distance, gripper distance, and object velocity.""" + batch_size = states.shape[0] + qpos = states[..., : self.model.nq] + + object_pos = qpos[..., self.object_pose_idx : self.object_pose_idx + 3] + gripper_pos = sensors[..., self.gripper_pos_idx : self.gripper_pos_idx + 3] + object_linear_velocity = states[..., self.object_vel_idx : self.object_vel_idx + 3] + + goal_reward = goal_distance_reward(self.config, object_pos) + gripper_proximity_reward = gripper_distance_reward( + self.config, np.linalg.norm(gripper_pos - object_pos, axis=-1) + ) + object_linear_velocity_penalty = object_linear_velocity_reward(self.config, object_linear_velocity) + + assert goal_reward.shape == (batch_size,) + assert gripper_proximity_reward.shape == (batch_size,) + assert object_linear_velocity_penalty.shape == (batch_size,) + return goal_reward + gripper_proximity_reward + object_linear_velocity_penalty __all__ = ["SpotBoxPush", "SpotBoxPushConfig"] diff --git a/sumo/tasks/spot/spot_box_upright.py b/sumo/tasks/spot/spot_box_upright.py new file mode 100644 index 0000000..a49a124 --- /dev/null +++ b/sumo/tasks/spot/spot_box_upright.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. + +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np +from mujoco import MjData, MjModel + +from sumo.tasks.spot.spot_box_push import SpotBoxPush +from sumo.tasks.spot.spot_constants import LEGS_STANDING_POS, STANDING_HEIGHT +from sumo.tasks.spot.spot_upright import ( + Z_AXIS, + SpotUprightConfig, + gripper_distance_reward, + random_object_pose, + sample_annulus_xy, + z_axis_is_upright, + z_axis_orientation_reward, +) + +RADIUS_MIN = 1.0 +RADIUS_MAX = 2.0 +ORIENTATION_TOLERANCE = 0.1 + + +@dataclass +class SpotBoxUprightConfig(SpotUprightConfig): + """Configuration for Sumo's simplified Spot box upright analysis task.""" + + +class SpotBoxUpright(SpotBoxPush): + """Task getting Spot to upright a randomly oriented box.""" + + name: str = "spot_box_upright" + config_t: type[SpotBoxUprightConfig] = SpotBoxUprightConfig # type: ignore[assignment] + config: SpotBoxUprightConfig + + def __init__(self, config: SpotBoxUprightConfig | None = None) -> None: + super().__init__(config=cast(Any, config)) + self.object_z_axis_idx = self.get_sensor_start_index("object_z_axis") + + def reward( + self, + states: np.ndarray, + sensors: np.ndarray, + controls: np.ndarray, + system_metadata: dict[str, Any] | None = None, + ) -> np.ndarray: + """Reward using only object orientation and gripper distance.""" + batch_size = states.shape[0] + qpos = states[..., : self.model.nq] + + object_pos = qpos[..., self.object_pose_idx : self.object_pose_idx + 3] + object_z_axis = sensors[..., self.object_z_axis_idx : self.object_z_axis_idx + 3] + gripper_pos = sensors[..., self.gripper_pos_idx : self.gripper_pos_idx + 3] + gripper_distance = np.linalg.norm(gripper_pos - object_pos, axis=-1) + + orientation_reward = z_axis_orientation_reward(self.config, object_z_axis) + proximity_reward = gripper_distance_reward(self.config, gripper_distance) + + assert orientation_reward.shape == (batch_size,) + assert proximity_reward.shape == (batch_size,) + return orientation_reward + proximity_reward + + @property + def reset_pose(self) -> np.ndarray: + """Reset pose with a random box attitude that clears the ground.""" + object_pose = random_object_pose( + self.model, + "box_body", + sample_annulus_xy(RADIUS_MIN, RADIUS_MAX), + reject_orientation=lambda quat: z_axis_is_upright(quat, ORIENTATION_TOLERANCE), + ) + robot_pose = np.array([0.0, 0.0, STANDING_HEIGHT, 1.0, 0.0, 0.0, 0.0]) + return np.array([*robot_pose, *LEGS_STANDING_POS, *self.reset_arm_pos, *object_pose]) + + def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: + """Check if the box is upright.""" + object_z_axis = data.sensordata[self.object_z_axis_idx : self.object_z_axis_idx + 3] + return bool(np.dot(object_z_axis, Z_AXIS) >= 1.0 - ORIENTATION_TOLERANCE) + + +__all__ = ["SpotBoxUpright", "SpotBoxUprightConfig"] diff --git a/sumo/tasks/spot/spot_chair_push.py b/sumo/tasks/spot/spot_chair_push.py index de8d2f6..92933a8 100644 --- a/sumo/tasks/spot/spot_chair_push.py +++ b/sumo/tasks/spot/spot_chair_push.py @@ -9,37 +9,31 @@ from typing import Any import numpy as np -from judo.tasks.spot.spot_constants import Z_AXIS from judo.utils.fields import np_1d_field from mujoco import MjData, MjModel from sumo import MODEL_PATH -from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig +from sumo.tasks.spot.spot_base import SpotBase from sumo.tasks.spot.spot_constants import LEGS_STANDING_POS, STANDING_HEIGHT +from sumo.tasks.spot.spot_push import ( + SpotPushConfig, + goal_distance_reward, + gripper_distance_reward, + object_linear_velocity_reward, +) XML_PATH = str(MODEL_PATH / "xml/spot_tasks/spot_yellow_chair.xml") RADIUS_MIN = 1.0 RADIUS_MAX = 2.0 -HARDWARE_FENCE_X = (-2.0, 3.0) -HARDWARE_FENCE_Y = (-3.0, 2.5) +POSITION_TOLERANCE = 0.2 +VELOCITY_TOLERANCE = 0.05 @dataclass -class SpotChairPushConfig(SpotBaseConfig): - """Configuration for the SpotChairPush task.""" - - w_fence: float = 1000.0 - w_goal: float = 60.0 - w_orientation: float = 50.0 - orientation_sparsity: float = 5.0 - w_torso_proximity: float = 50.0 - torso_proximity_threshold: float = 1.0 - w_gripper_proximity: float = 8.0 - orientation_threshold: float = 0.7 - w_controls: float = 2.0 - w_object_velocity: float = 64.0 - fall_penalty: float = 2500.0 +class SpotChairPushConfig(SpotPushConfig): + """Configuration for Sumo's simplified Spot chair pushing analysis task.""" + goal_position: np.ndarray = np_1d_field( np.array([0.0, 0.0, 0.0]), names=["x", "y", "z"], @@ -49,16 +43,13 @@ class SpotChairPushConfig(SpotBaseConfig): xyz_vis_indices=[0, 1, None], ) - # Success criteria (sumo-specific) - orientation_tolerance: float = 0.1 - class SpotChairPush(SpotBase[SpotChairPushConfig]): """Task getting Spot to push a yellow chair to a desired goal location.""" name: str = "spot_chair_push" config_t: type[SpotChairPushConfig] = SpotChairPushConfig # type: ignore[assignment] - config: SpotChairPushConfig + config: Any def __init__( self, @@ -69,8 +60,7 @@ def __init__( self.body_pose_idx = self.get_joint_position_start_index("base") self.object_pose_idx = self.get_joint_position_start_index("yellow_chair_joint") self.gripper_pos_idx = self.get_sensor_start_index("trace_fngr_site") - self.object_z_axis_idx = self.get_sensor_start_index("object_z_axis") - self.object_vel_idx = self.model.jnt_dofadr[self.model.joint("yellow_chair_joint").id] + self.object_vel_idx = self.get_joint_velocity_start_index("yellow_chair_joint") def reward( self, @@ -79,59 +69,24 @@ def reward( controls: np.ndarray, system_metadata: dict[str, Any] | None = None, ) -> np.ndarray: - """Reward function for the chair push task.""" + """Reward using only goal distance, gripper distance, and object velocity.""" batch_size = states.shape[0] qpos = states[..., : self.model.nq] - qvel = states[..., self.model.nq :] - body_height = qpos[..., self.body_pose_idx + 2] - body_pos = qpos[..., self.body_pose_idx : self.body_pose_idx + 3] object_pos = qpos[..., self.object_pose_idx : self.object_pose_idx + 3] - object_linear_velocity = qvel[..., self.object_vel_idx : self.object_vel_idx + 3] - object_z_axis = sensors[..., self.object_z_axis_idx : self.object_z_axis_idx + 3] + object_linear_velocity = states[..., self.object_vel_idx : self.object_vel_idx + 3] gripper_pos = sensors[..., self.gripper_pos_idx : self.gripper_pos_idx + 3] - # Fence penalty - fence_violated_x = (body_pos[..., 0] < HARDWARE_FENCE_X[0]) | (body_pos[..., 0] > HARDWARE_FENCE_X[1]) - fence_violated_y = (body_pos[..., 1] < HARDWARE_FENCE_Y[0]) | (body_pos[..., 1] > HARDWARE_FENCE_Y[1]) - spot_fence_reward = -self.config.w_fence * (fence_violated_x | fence_violated_y).any(axis=-1) - - spot_fallen_reward = -self.config.fall_penalty * (body_height <= self.config.spot_fallen_threshold).any(axis=-1) - - goal_reward = -self.config.w_goal * np.linalg.norm( - object_pos - self.config.goal_position[None, None], axis=-1 - ).mean(-1) - - orientation_alignment = np.minimum(np.dot(object_z_axis, Z_AXIS) - 1, self.config.orientation_threshold) - object_orientation_reward = +self.config.w_orientation * np.exp( - self.config.orientation_sparsity * orientation_alignment - ).sum(axis=-1) - - torso_proximity_reward = self.config.w_torso_proximity * np.minimum( - self.config.torso_proximity_threshold, np.linalg.norm(body_pos - object_pos, axis=-1) - ).mean(-1) - - gripper_proximity_reward = -self.config.w_gripper_proximity * np.linalg.norm( - gripper_pos - object_pos, axis=-1 - ).mean(-1) - - object_linear_velocity_reward = -self.config.w_object_velocity * np.square( - np.linalg.norm(object_linear_velocity, axis=-1).mean(-1) + goal_reward = goal_distance_reward(self.config, object_pos) + gripper_proximity_reward = gripper_distance_reward( + self.config, np.linalg.norm(gripper_pos - object_pos, axis=-1) ) + velocity_reward = object_linear_velocity_reward(self.config, object_linear_velocity) - controls_reward = -self.config.w_controls * np.linalg.norm(controls[..., :3], axis=-1).mean(-1) - - assert spot_fence_reward.shape == (batch_size,) - return ( - spot_fence_reward - + spot_fallen_reward - + goal_reward - + object_orientation_reward - + torso_proximity_reward - + gripper_proximity_reward - + object_linear_velocity_reward - + controls_reward - ) + assert goal_reward.shape == (batch_size,) + assert gripper_proximity_reward.shape == (batch_size,) + assert velocity_reward.shape == (batch_size,) + return goal_reward + gripper_proximity_reward + velocity_reward @property def reset_pose(self) -> np.ndarray: @@ -155,8 +110,11 @@ def reset_pose(self) -> np.ndarray: ) def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: - """Check if the yellow chair is upright, regardless of position.""" - object_z_axis = data.sensordata[self.object_z_axis_idx : self.object_z_axis_idx + 3] - orientation_alignment = np.dot(object_z_axis, Z_AXIS) - orientation_success = orientation_alignment >= (1.0 - self.config.orientation_tolerance) - return bool(orientation_success) + """Check if the yellow chair is in the goal position.""" + object_pos = data.qpos[self.object_pose_idx : self.object_pose_idx + 3] + object_vel = data.qvel[self.object_vel_idx - self.model.nq : self.object_vel_idx - self.model.nq + 3] + position_check = ( + np.linalg.norm(object_pos - self.config.goal_position, axis=-1, ord=np.inf) < POSITION_TOLERANCE + ) + velocity_check = np.linalg.norm(object_vel, axis=-1) < VELOCITY_TOLERANCE + return bool(position_check and velocity_check) diff --git a/sumo/tasks/spot/spot_chair_upright.py b/sumo/tasks/spot/spot_chair_upright.py new file mode 100644 index 0000000..99cb3af --- /dev/null +++ b/sumo/tasks/spot/spot_chair_upright.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. + +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np +from mujoco import MjData, MjModel + +from sumo.tasks.spot.spot_chair_push import SpotChairPush +from sumo.tasks.spot.spot_constants import LEGS_STANDING_POS, STANDING_HEIGHT +from sumo.tasks.spot.spot_upright import ( + Z_AXIS, + SpotUprightConfig, + gripper_distance_reward, + random_object_pose, + sample_annulus_xy, + z_axis_is_upright, + z_axis_orientation_reward, +) + +RADIUS_MIN = 1.0 +RADIUS_MAX = 2.0 +ORIENTATION_TOLERANCE = 0.1 + + +@dataclass +class SpotChairUprightConfig(SpotUprightConfig): + """Configuration for Sumo's simplified Spot chair upright analysis task.""" + + +class SpotChairUpright(SpotChairPush): + """Task getting Spot to upright a randomly oriented chair.""" + + name: str = "spot_chair_upright" + config_t: type[SpotChairUprightConfig] = SpotChairUprightConfig # type: ignore[assignment] + config: SpotChairUprightConfig + + def __init__(self, config: SpotChairUprightConfig | None = None) -> None: + super().__init__(config=cast(Any, config)) + self.object_z_axis_idx = self.get_sensor_start_index("object_z_axis") + + def reward( + self, + states: np.ndarray, + sensors: np.ndarray, + controls: np.ndarray, + system_metadata: dict[str, Any] | None = None, + ) -> np.ndarray: + """Reward using only object orientation and gripper distance.""" + batch_size = states.shape[0] + qpos = states[..., : self.model.nq] + + object_pos = qpos[..., self.object_pose_idx : self.object_pose_idx + 3] + object_z_axis = sensors[..., self.object_z_axis_idx : self.object_z_axis_idx + 3] + gripper_pos = sensors[..., self.gripper_pos_idx : self.gripper_pos_idx + 3] + gripper_distance = np.linalg.norm(gripper_pos - object_pos, axis=-1) + + orientation_reward = z_axis_orientation_reward(self.config, object_z_axis) + proximity_reward = gripper_distance_reward(self.config, gripper_distance) + + assert orientation_reward.shape == (batch_size,) + assert proximity_reward.shape == (batch_size,) + return orientation_reward + proximity_reward + + @property + def reset_pose(self) -> np.ndarray: + """Reset pose with a random chair attitude that clears the ground.""" + object_pose = random_object_pose( + self.model, + "yellow_chair", + sample_annulus_xy(RADIUS_MIN, RADIUS_MAX), + reject_orientation=lambda quat: z_axis_is_upright(quat, ORIENTATION_TOLERANCE), + ) + robot_pose = np.array([0.0, 0.0, STANDING_HEIGHT, 1.0, 0.0, 0.0, 0.0]) + return np.array([*robot_pose, *LEGS_STANDING_POS, *self.reset_arm_pos, *object_pose]) + + def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: + """Check if the chair is upright.""" + object_z_axis = data.sensordata[self.object_z_axis_idx : self.object_z_axis_idx + 3] + return bool(np.dot(object_z_axis, Z_AXIS) >= 1.0 - ORIENTATION_TOLERANCE) + + +__all__ = ["SpotChairUpright", "SpotChairUprightConfig"] diff --git a/sumo/tasks/spot/spot_cone_push.py b/sumo/tasks/spot/spot_cone_push.py index 4b195e1..77a6e50 100644 --- a/sumo/tasks/spot/spot_cone_push.py +++ b/sumo/tasks/spot/spot_cone_push.py @@ -8,29 +8,33 @@ from mujoco import MjData, MjModel from sumo import MODEL_PATH -from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig +from sumo.tasks.spot.spot_base import SpotBase from sumo.tasks.spot.spot_constants import ( LEGS_STANDING_POS, STANDING_HEIGHT, ) +from sumo.tasks.spot.spot_push import ( + SpotPushConfig, + goal_distance_reward, + gripper_distance_reward, + object_linear_velocity_reward, +) XML_PATH = str(MODEL_PATH / "xml/spot_tasks/spot_traffic_cone.xml") -USE_LEGS = False RADIUS_MIN = 1.0 RADIUS_MAX = 2.0 DEFAULT_CONE_HEIGHT = 0.0 -DEFAULT_TORSO_POSITION = np.array([-1.75, 0, STANDING_HEIGHT]) -Z_AXIS = np.array([0.0, 0.0, 1.0]) # Success condition tolerances POSITION_TOLERANCE = 0.2 VELOCITY_TOLERANCE = 0.05 +SPOT_FALLEN_THRESHOLD = 0.35 @dataclass -class SpotConePushConfig(SpotBaseConfig): - """Config for the spot cone pushing task.""" +class SpotConePushConfig(SpotPushConfig): + """Config for Sumo's simplified Spot cone pushing analysis task.""" goal_position: np.ndarray = np_1d_field( np.array([0.0, 0.0, DEFAULT_CONE_HEIGHT], dtype=np.float64), @@ -43,20 +47,12 @@ class SpotConePushConfig(SpotBaseConfig): xyz_vis_defaults=[0.0, 0.0, DEFAULT_CONE_HEIGHT], ) - w_object_orientation: float = 100.0 - w_upright: float = 200.0 - upright_sparsity: float = 5.0 - w_gripper_proximity: float = 4.0 - w_torso_proximity: float = 0.1 - orientation_threshold: float = 0.7 - w_object_velocity: float = 20.0 - -class SpotConePush(SpotBase): +class SpotConePush(SpotBase[SpotConePushConfig]): """Task getting Spot to push a cone to a goal location.""" name = "spot_cone_push" - config_t = SpotConePushConfig + config_t: type[SpotConePushConfig] = SpotConePushConfig config: SpotConePushConfig def __init__(self, config: SpotConePushConfig | None = None) -> None: @@ -65,8 +61,6 @@ def __init__(self, config: SpotConePushConfig | None = None) -> None: self.body_pose_start = self.get_joint_position_start_index("base") self.object_pose_start = self.get_joint_position_start_index("traffic_cone_joint") self.object_vel_start = self.get_joint_velocity_start_index("traffic_cone_joint") - self.object_y_axis_start = self.get_sensor_start_index("object_y_axis") - self.object_z_axis_start = self.get_sensor_start_index("object_z_axis") self.end_effector_to_object_start = self.get_sensor_start_index("sensor_arm_link_fngr") def reward( @@ -76,67 +70,25 @@ def reward( controls: np.ndarray, system_metadata: dict[str, Any] | None = None, ) -> np.ndarray: - """Reward function for the Spot cone pushing task.""" + """Reward using only goal distance, gripper distance, and object velocity.""" batch_size = states.shape[0] - # (batch, horizon, size) - # or (batch, horizon) if scalar qpos = states[..., : self.model.nq] - body_pos = qpos[..., self.body_pose_start : self.body_pose_start + 3] object_pos = qpos[..., self.object_pose_start : self.object_pose_start + 3] - object_y_axis = sensors[..., self.object_y_axis_start : self.object_y_axis_start + 3] - object_z_axis = sensors[..., self.object_z_axis_start : self.object_z_axis_start + 3] object_linear_velocity = states[..., self.object_vel_start : self.object_vel_start + 3] end_effector_to_object = sensors[..., self.end_effector_to_object_start : self.end_effector_to_object_start + 3] - gripper_proximity_reward = -self.config.w_gripper_proximity * np.linalg.norm( - end_effector_to_object, axis=-1 - ).mean(axis=-1) - - object_orientation_reward = -self.config.w_object_orientation * np.abs( - np.dot(object_y_axis, Z_AXIS) > self.config.orientation_threshold - ).sum(axis=-1) - - # Upright reward: incentivize cone z-axis aligned with world z-axis throughout the rollout - upright_alignment = np.minimum(np.dot(object_z_axis, Z_AXIS) - 1, 0.0) - upright_reward = self.config.w_upright * np.exp(self.config.upright_sparsity * upright_alignment).mean(axis=-1) - - goal_reward = -self.config.w_goal * np.linalg.norm( - object_pos - np.array(self.config.goal_position)[None, None], axis=-1 - ).mean(-1) - - torso_proximity_reward = self.config.w_torso_proximity * np.linalg.norm(body_pos - object_pos, axis=-1).mean(-1) - - object_linear_velocity_penalty = -self.config.w_object_velocity * np.square( - np.linalg.norm(object_linear_velocity, axis=-1).mean(-1) + gripper_proximity_reward = gripper_distance_reward( + self.config, + np.linalg.norm(end_effector_to_object, axis=-1), ) - # Check if any state in the rollout has spot fallen - body_height = qpos[..., self.body_pose_start + 2] - spot_fallen_reward = -self.config.fall_penalty * (body_height <= self.config.spot_fallen_threshold).any(axis=-1) + goal_reward = goal_distance_reward(self.config, object_pos) + object_linear_velocity_penalty = object_linear_velocity_reward(self.config, object_linear_velocity) - # Compute a penalty to prefer small commands. - controls_reward = -self.config.w_controls * np.linalg.norm(controls, axis=-1).mean(-1) - - assert object_orientation_reward.shape == (batch_size,) - assert upright_reward.shape == (batch_size,) assert gripper_proximity_reward.shape == (batch_size,) - assert torso_proximity_reward.shape == (batch_size,) assert object_linear_velocity_penalty.shape == (batch_size,) assert goal_reward.shape == (batch_size,) - assert spot_fallen_reward.shape == (batch_size,) - assert controls_reward.shape == (batch_size,) - - reward = ( - +spot_fallen_reward - + goal_reward - + object_orientation_reward - + upright_reward - + torso_proximity_reward - + gripper_proximity_reward - + object_linear_velocity_penalty - + controls_reward - ) - return reward + return goal_reward + gripper_proximity_reward + object_linear_velocity_penalty @property def reset_pose(self) -> np.ndarray: @@ -168,4 +120,4 @@ def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: """Check if Spot has fallen.""" body_height = data.qpos[self.body_pose_start + 2] - return body_height <= self.config.spot_fallen_threshold + return body_height <= SPOT_FALLEN_THRESHOLD diff --git a/sumo/tasks/spot/spot_cone_upright.py b/sumo/tasks/spot/spot_cone_upright.py index c90b1cd..6fe0ba3 100644 --- a/sumo/tasks/spot/spot_cone_upright.py +++ b/sumo/tasks/spot/spot_cone_upright.py @@ -5,49 +5,40 @@ The robot must pick up or nudge the cone back to an upright position. """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any import numpy as np from mujoco import MjData, MjModel from sumo import MODEL_PATH -from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig +from sumo.tasks.spot.spot_base import SpotBase from sumo.tasks.spot.spot_constants import ( LEGS_STANDING_POS, STANDING_HEIGHT, ) +from sumo.tasks.spot.spot_upright import ( + Z_AXIS, + SpotUprightConfig, + gripper_distance_reward, + random_object_pose, + sample_annulus_xy, + z_axis_is_upright, + z_axis_orientation_reward, +) XML_PATH = str(MODEL_PATH / "xml/spot_tasks/spot_traffic_cone.xml") -Z_AXIS = np.array([0.0, 0.0, 1.0]) RADIUS_MIN = 0.2 RADIUS_MAX = 0.5 -HARDWARE_FENCE_X = (-2.0, 3.0) -HARDWARE_FENCE_Y = (-3.0, 2.5) - DEFAULT_SPOT_POS = np.array([-1.5, 0.0]) -DEFAULT_OBJECT_POS = np.array([0.0, 0.0]) +ORIENTATION_TOLERANCE = 0.1 @dataclass -class SpotConeUprightConfig(SpotBaseConfig): - """Config for the spot cone upright task.""" - - goal_position: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0])) - w_fence: float = 1000.0 - w_orientation: float = 50.0 - orientation_sparsity: float = 5.0 - w_torso_proximity: float = 50.0 - torso_proximity_threshold: float = 1.0 - w_gripper_proximity: float = 8.0 - orientation_threshold: float = 0.7 - w_controls: float = 2.0 - w_object_velocity: float = 64.0 - position_tolerance: float = 0.2 - orientation_tolerance: float = 0.1 - velocity_tolerance: float = 0.1 +class SpotConeUprightConfig(SpotUprightConfig): + """Config for Sumo's simplified Spot cone upright analysis task.""" class SpotConeUpright(SpotBase[SpotConeUprightConfig]): @@ -62,8 +53,6 @@ def __init__(self, config: SpotConeUprightConfig | None = None) -> None: self.body_pose_idx = self.get_joint_position_start_index("base") self.object_pose_idx = self.get_joint_position_start_index("traffic_cone_joint") - self.object_vel_idx = self.get_joint_velocity_start_index("traffic_cone_joint") - self.object_y_axis_idx = self.get_sensor_start_index("object_y_axis") self.object_z_axis_idx = self.get_sensor_start_index("object_z_axis") self.end_effector_to_object_idx = self.get_sensor_start_index("sensor_arm_link_fngr") @@ -74,69 +63,29 @@ def reward( controls: np.ndarray, system_metadata: dict[str, Any] | None = None, ) -> np.ndarray: - """Reward function for the cone upright task.""" + """Reward using only object orientation and gripper distance.""" batch_size = states.shape[0] - qpos = states[..., : self.model.nq] - - body_height = qpos[..., self.body_pose_idx + 2] - body_pos = qpos[..., self.body_pose_idx : self.body_pose_idx + 3] - object_pos = qpos[..., self.object_pose_idx : self.object_pose_idx + 3] - object_linear_velocity = states[..., self.object_vel_idx : self.object_vel_idx + 3] - object_z_axis = sensors[..., self.object_z_axis_idx : self.object_z_axis_idx + 3] gripper_to_object = sensors[..., self.end_effector_to_object_idx : self.end_effector_to_object_idx + 3] + gripper_distance = np.linalg.norm(gripper_to_object, axis=-1) - fence_violated_x = (body_pos[..., 0] < HARDWARE_FENCE_X[0]) | (body_pos[..., 0] > HARDWARE_FENCE_X[1]) - fence_violated_y = (body_pos[..., 1] < HARDWARE_FENCE_Y[0]) | (body_pos[..., 1] > HARDWARE_FENCE_Y[1]) - spot_fence_reward = -self.config.w_fence * (fence_violated_x | fence_violated_y).any(axis=-1) - - spot_fallen_reward = -self.config.fall_penalty * (body_height <= self.config.spot_fallen_threshold).any(axis=-1) - - goal_reward = -self.config.w_goal * np.linalg.norm( - object_pos - np.array(self.config.goal_position)[None, None], axis=-1 - ).mean(-1) - - orientation_alignment = np.minimum(np.dot(object_z_axis, Z_AXIS) - 1, self.config.orientation_threshold) - object_orientation_reward = +self.config.w_orientation * np.exp( - self.config.orientation_sparsity * orientation_alignment - ).sum(axis=-1) - - torso_proximity_reward = self.config.w_torso_proximity * np.minimum( - self.config.torso_proximity_threshold, np.linalg.norm(body_pos - object_pos, axis=-1) - ).mean(-1) + orientation_reward = z_axis_orientation_reward(self.config, object_z_axis) + proximity_reward = gripper_distance_reward(self.config, gripper_distance) - gripper_proximity_reward = -self.config.w_gripper_proximity * np.linalg.norm( - gripper_to_object, - axis=-1, - ).mean(-1) - - object_linear_velocity_reward = -self.config.w_object_velocity * np.square( - np.linalg.norm(object_linear_velocity, axis=-1).mean(-1) - ) - - controls_reward = -self.config.w_controls * np.linalg.norm(controls[..., :3], axis=-1).mean(-1) - - assert spot_fence_reward.shape == (batch_size,) - return ( - spot_fence_reward - + spot_fallen_reward - + goal_reward - + object_orientation_reward - + torso_proximity_reward - + gripper_proximity_reward - + object_linear_velocity_reward - + controls_reward - ) + assert orientation_reward.shape == (batch_size,) + assert proximity_reward.shape == (batch_size,) + return orientation_reward + proximity_reward @property def reset_pose(self) -> np.ndarray: - """Reset pose — cone starts fallen on its side.""" - radius = RADIUS_MIN + (RADIUS_MAX - RADIUS_MIN) * np.random.rand() - theta = 2 * np.pi * np.random.rand() - object_pos = np.array([radius * np.cos(theta), radius * np.sin(theta)]) + np.random.randn(2) - # Cone on its side: 90-degree roll about x-axis - reset_object_pose = np.array([*object_pos, 0.275, np.cos(np.pi / 4), -np.sin(np.pi / 4), 0, 0]) + """Reset pose with a random cone attitude that clears the ground.""" + reset_object_pose = random_object_pose( + self.model, + "traffic_cone", + sample_annulus_xy(RADIUS_MIN, RADIUS_MAX), + reject_orientation=lambda quat: z_axis_is_upright(quat, ORIENTATION_TOLERANCE), + ) spot_pos = DEFAULT_SPOT_POS + np.random.randn(2) * 0.001 return np.array( [ @@ -156,10 +105,7 @@ def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None """Check if the traffic cone is upright.""" object_z_axis = data.sensordata[self.object_z_axis_idx : self.object_z_axis_idx + 3] orientation_alignment = np.dot(object_z_axis, Z_AXIS) - orientation_success = orientation_alignment >= (1.0 - self.config.orientation_tolerance) + return bool(orientation_alignment >= (1.0 - ORIENTATION_TOLERANCE)) - vel_offset = self.object_vel_idx - self.model.nq - velocity_success = ( - np.linalg.norm(data.qvel[vel_offset : vel_offset + 3], axis=-1) < self.config.velocity_tolerance - ) - return bool(orientation_success and velocity_success) + +__all__ = ["SpotConeUpright", "SpotConeUprightConfig"] diff --git a/sumo/tasks/spot/spot_push.py b/sumo/tasks/spot/spot_push.py new file mode 100644 index 0000000..a155e38 --- /dev/null +++ b/sumo/tasks/spot/spot_push.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. + +from dataclasses import dataclass +from typing import Protocol + +import numpy as np +from judo.tasks.base import TaskConfig + + +@dataclass +class SpotPushConfig(TaskConfig): + """Configuration for Sumo's simplified Spot pushing analysis tasks.""" + + w_goal: float = 60.0 + w_gripper_proximity: float = 4.0 + w_object_velocity: float = 20.0 + + +class SpotPushRewardConfig(Protocol): + w_goal: float + w_gripper_proximity: float + w_object_velocity: float + goal_position: np.ndarray + + +def goal_distance_reward(config: SpotPushRewardConfig, object_pos: np.ndarray) -> np.ndarray: + """Reward object proximity to the goal position.""" + return -config.w_goal * np.linalg.norm( + object_pos - np.asarray(config.goal_position)[None, None], + axis=-1, + ).mean(axis=-1) + + +def gripper_distance_reward(config: SpotPushRewardConfig, gripper_distance: np.ndarray) -> np.ndarray: + """Reward gripper proximity to the pushed object.""" + return -config.w_gripper_proximity * gripper_distance.mean(axis=-1) + + +def object_linear_velocity_reward(config: SpotPushRewardConfig, object_linear_velocity: np.ndarray) -> np.ndarray: + """Penalize object linear velocity.""" + return -config.w_object_velocity * np.square(np.linalg.norm(object_linear_velocity, axis=-1).mean(axis=-1)) diff --git a/sumo/tasks/spot/spot_rack_push.py b/sumo/tasks/spot/spot_rack_push.py index f9c4156..d21f417 100644 --- a/sumo/tasks/spot/spot_rack_push.py +++ b/sumo/tasks/spot/spot_rack_push.py @@ -8,29 +8,33 @@ from mujoco import MjData, MjModel from sumo import MODEL_PATH -from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig +from sumo.tasks.spot.spot_base import SpotBase from sumo.tasks.spot.spot_constants import ( LEGS_STANDING_POS, STANDING_HEIGHT, ) +from sumo.tasks.spot.spot_push import ( + SpotPushConfig, + goal_distance_reward, + gripper_distance_reward, + object_linear_velocity_reward, +) XML_PATH = str(MODEL_PATH / "xml/spot_tasks/spot_tire_rack.xml") -USE_LEGS = False RADIUS_MIN = 1.0 RADIUS_MAX = 2.0 DEFAULT_RACK_HEIGHT = 0.3 -DEFAULT_TORSO_POSITION = np.array([-1.75, 0, STANDING_HEIGHT]) -Z_AXIS = np.array([0.0, 0.0, 1.0]) # Success condition tolerances POSITION_TOLERANCE = 0.2 VELOCITY_TOLERANCE = 0.05 +SPOT_FALLEN_THRESHOLD = 0.35 @dataclass -class SpotRackPushConfig(SpotBaseConfig): - """Config for the spot rack pushing task.""" +class SpotRackPushConfig(SpotPushConfig): + """Config for Sumo's simplified Spot rack pushing analysis task.""" goal_position: np.ndarray = np_1d_field( np.array([0.0, 0.0, DEFAULT_RACK_HEIGHT], dtype=np.float64), @@ -43,19 +47,13 @@ class SpotRackPushConfig(SpotBaseConfig): xyz_vis_defaults=[0.0, 0.0, DEFAULT_RACK_HEIGHT], ) - w_object_orientation: float = 100.0 - w_gripper_proximity: float = 4.0 - w_torso_proximity: float = 0.1 - orientation_threshold: float = 0.7 - w_object_velocity: float = 20.0 - -class SpotRackPush(SpotBase): +class SpotRackPush(SpotBase[SpotRackPushConfig]): """Task getting Spot to push a rack to a goal location.""" name = "spot_rack_push" - config_t = SpotRackPushConfig - config: SpotRackPushConfig + config_t: type[SpotRackPushConfig] = SpotRackPushConfig + config: Any def __init__(self, config: SpotRackPushConfig | None = None) -> None: super().__init__(model_path=XML_PATH, use_arm=True, config=config) @@ -63,7 +61,6 @@ def __init__(self, config: SpotRackPushConfig | None = None) -> None: self.body_pose_start = self.get_joint_position_start_index("base") self.object_pose_start = self.get_joint_position_start_index("tire_rack_joint") self.object_vel_start = self.get_joint_velocity_start_index("tire_rack_joint") - self.object_y_axis_start = self.get_sensor_start_index("object_y_axis") self.end_effector_to_object_start = self.get_sensor_start_index("sensor_arm_link_fngr") def reward( @@ -73,61 +70,25 @@ def reward( controls: np.ndarray, system_metadata: dict[str, Any] | None = None, ) -> np.ndarray: - """Reward function for the Spot rack pushing task.""" + """Reward using only goal distance, gripper distance, and object velocity.""" batch_size = states.shape[0] - # (batch, horizon, size) - # or (batch, horizon) if scalar qpos = states[..., : self.model.nq] - body_pos = qpos[..., self.body_pose_start : self.body_pose_start + 3] object_pos = qpos[..., self.object_pose_start : self.object_pose_start + 3] - object_y_axis = sensors[..., self.object_y_axis_start : self.object_y_axis_start + 3] object_linear_velocity = states[..., self.object_vel_start : self.object_vel_start + 3] - # Compute unit vector pointing from tire to torso end_effector_to_object = sensors[..., self.end_effector_to_object_start : self.end_effector_to_object_start + 3] - gripper_proximity_reward = -self.config.w_gripper_proximity * np.linalg.norm( - end_effector_to_object, axis=-1 - ).mean(axis=-1) - - object_orientation_reward = -self.config.w_object_orientation * np.abs( - np.dot(object_y_axis, Z_AXIS) > self.config.orientation_threshold - ).sum(axis=-1) - - goal_reward = -self.config.w_goal * np.linalg.norm( - object_pos - np.array(self.config.goal_position)[None, None], axis=-1 - ).mean(-1) - - torso_proximity_reward = self.config.w_torso_proximity * np.linalg.norm(body_pos - object_pos, axis=-1).mean(-1) - - object_linear_velocity_penalty = -self.config.w_object_velocity * np.square( - np.linalg.norm(object_linear_velocity, axis=-1).mean(-1) + gripper_proximity_reward = gripper_distance_reward( + self.config, + np.linalg.norm(end_effector_to_object, axis=-1), ) - # Check if any state in the rollout has spot fallen - body_height = qpos[..., self.body_pose_start + 2] - spot_fallen_reward = -self.config.fall_penalty * (body_height <= self.config.spot_fallen_threshold).any(axis=-1) + goal_reward = goal_distance_reward(self.config, object_pos) + object_linear_velocity_penalty = object_linear_velocity_reward(self.config, object_linear_velocity) - # Compute a penalty to prefer small commands. - controls_reward = -self.config.w_controls * np.linalg.norm(controls, axis=-1).mean(-1) - - assert object_orientation_reward.shape == (batch_size,) assert gripper_proximity_reward.shape == (batch_size,) - assert torso_proximity_reward.shape == (batch_size,) assert object_linear_velocity_penalty.shape == (batch_size,) assert goal_reward.shape == (batch_size,) - assert spot_fallen_reward.shape == (batch_size,) - assert controls_reward.shape == (batch_size,) - - reward = ( - +spot_fallen_reward - + goal_reward - + object_orientation_reward - + torso_proximity_reward - + gripper_proximity_reward - + object_linear_velocity_penalty - + controls_reward - ) - return reward + return goal_reward + gripper_proximity_reward + object_linear_velocity_penalty @property def reset_pose(self) -> np.ndarray: @@ -159,4 +120,4 @@ def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: """Check if Spot has fallen.""" body_height = data.qpos[self.body_pose_start + 2] - return body_height <= self.config.spot_fallen_threshold + return body_height <= SPOT_FALLEN_THRESHOLD diff --git a/sumo/tasks/spot/spot_rack_upright.py b/sumo/tasks/spot/spot_rack_upright.py new file mode 100644 index 0000000..fe4ecb2 --- /dev/null +++ b/sumo/tasks/spot/spot_rack_upright.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. + +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np +from mujoco import MjData, MjModel + +from sumo.tasks.spot.spot_constants import LEGS_STANDING_POS, STANDING_HEIGHT +from sumo.tasks.spot.spot_rack_push import SpotRackPush +from sumo.tasks.spot.spot_upright import ( + Z_AXIS, + SpotUprightConfig, + gripper_distance_reward, + random_object_pose, + sample_annulus_xy, + z_axis_is_upright, + z_axis_orientation_reward, +) + +RADIUS_MIN = 1.0 +RADIUS_MAX = 2.0 +ORIENTATION_TOLERANCE = 0.1 + + +@dataclass +class SpotRackUprightConfig(SpotUprightConfig): + """Configuration for Sumo's simplified Spot rack upright analysis task.""" + + +class SpotRackUpright(SpotRackPush): + """Task getting Spot to upright a randomly oriented tire rack.""" + + name: str = "spot_rack_upright" + config_t: type[SpotRackUprightConfig] = SpotRackUprightConfig # type: ignore[assignment] + config: SpotRackUprightConfig + + def __init__(self, config: SpotRackUprightConfig | None = None) -> None: + super().__init__(config=cast(Any, config)) + self.object_z_axis_idx = self.get_sensor_start_index("object_z_axis") + + def reward( + self, + states: np.ndarray, + sensors: np.ndarray, + controls: np.ndarray, + system_metadata: dict[str, Any] | None = None, + ) -> np.ndarray: + """Reward using only object orientation and gripper distance.""" + batch_size = states.shape[0] + + object_z_axis = sensors[..., self.object_z_axis_idx : self.object_z_axis_idx + 3] + gripper_to_object = sensors[..., self.end_effector_to_object_start : self.end_effector_to_object_start + 3] + gripper_distance = np.linalg.norm(gripper_to_object, axis=-1) + + orientation_reward = z_axis_orientation_reward(self.config, object_z_axis) + proximity_reward = gripper_distance_reward(self.config, gripper_distance) + + assert orientation_reward.shape == (batch_size,) + assert proximity_reward.shape == (batch_size,) + return orientation_reward + proximity_reward + + @property + def reset_pose(self) -> np.ndarray: + """Reset pose with a random rack attitude that clears the ground.""" + object_pose = random_object_pose( + self.model, + "tire_rack", + sample_annulus_xy(RADIUS_MIN, RADIUS_MAX), + reject_orientation=lambda quat: z_axis_is_upright(quat, ORIENTATION_TOLERANCE), + ) + robot_pose = np.array([0.0, 0.0, STANDING_HEIGHT, 1.0, 0.0, 0.0, 0.0]) + return np.array([*robot_pose, *LEGS_STANDING_POS, *self.reset_arm_pos, *object_pose]) + + def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: + """Check if the tire rack is upright.""" + object_z_axis = data.sensordata[self.object_z_axis_idx : self.object_z_axis_idx + 3] + return bool(np.dot(object_z_axis, Z_AXIS) >= 1.0 - ORIENTATION_TOLERANCE) + + +__all__ = ["SpotRackUpright", "SpotRackUprightConfig"] diff --git a/sumo/tasks/spot/spot_tire_push.py b/sumo/tasks/spot/spot_tire_push.py index 8209417..f3ac33f 100644 --- a/sumo/tasks/spot/spot_tire_push.py +++ b/sumo/tasks/spot/spot_tire_push.py @@ -8,30 +8,34 @@ from mujoco import MjData, MjModel from sumo import MODEL_PATH -from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig +from sumo.tasks.spot.spot_base import SpotBase from sumo.tasks.spot.spot_constants import ( LEGS_STANDING_POS, STANDING_HEIGHT, ) +from sumo.tasks.spot.spot_push import ( + SpotPushConfig, + goal_distance_reward, + gripper_distance_reward, + object_linear_velocity_reward, +) XML_PATH = str(MODEL_PATH / "xml/spot_tasks/spot_tire.xml") TIRE_RADIUS: float = 0.339 TIRE_WIDTH: float = 0.175 -USE_LEGS = False RADIUS_MIN = 1.0 RADIUS_MAX = 2.0 -DEFAULT_TORSO_POSITION = np.array([-1.75, 0, STANDING_HEIGHT]) -Z_AXIS = np.array([0.0, 0.0, 1.0]) # Success condition tolerances POSITION_TOLERANCE = 0.1 VELOCITY_TOLERANCE = 0.05 +SPOT_FALLEN_THRESHOLD = 0.35 @dataclass -class SpotTirePushConfig(SpotBaseConfig): - """Config for the spot tire pushing task.""" +class SpotTirePushConfig(SpotPushConfig): + """Config for Sumo's simplified Spot tire pushing analysis task.""" goal_position: np.ndarray = np_1d_field( np.array([0.0, 0.0, TIRE_RADIUS], dtype=np.float64), @@ -44,19 +48,13 @@ class SpotTirePushConfig(SpotBaseConfig): xyz_vis_defaults=[0.0, 0.0, TIRE_RADIUS], ) - w_tire_orientation: float = 100.0 - w_gripper_proximity: float = 4.0 - w_torso_proximity: float = 0.1 - orientation_threshold: float = 0.7 - w_object_velocity: float = 20.0 - -class SpotTirePush(SpotBase): +class SpotTirePush(SpotBase[SpotTirePushConfig]): """Task getting Spot to push a tire to a goal location.""" name = "spot_tire_push" - config_t = SpotTirePushConfig - config: SpotTirePushConfig + config_t: type[SpotTirePushConfig] = SpotTirePushConfig + config: Any def __init__(self, config: SpotTirePushConfig | None = None) -> None: super().__init__(model_path=XML_PATH, use_arm=True, config=config) @@ -64,7 +62,6 @@ def __init__(self, config: SpotTirePushConfig | None = None) -> None: self.body_pose_start = self.get_joint_position_start_index("base") self.object_pose_start = self.get_joint_position_start_index("tire_joint") self.object_vel_start = self.get_joint_velocity_start_index("tire_joint") - self.tire_y_axis_start = self.get_sensor_start_index("object_y_axis") self.end_effector_to_object_start = self.get_sensor_start_index("sensor_arm_link_fngr") def reward( @@ -74,61 +71,25 @@ def reward( controls: np.ndarray, system_metadata: dict[str, Any] | None = None, ) -> np.ndarray: - """Reward function for the Spot tire pushing task.""" + """Reward using only goal distance, gripper distance, and object velocity.""" batch_size = states.shape[0] - # (batch, horizon, size) - # or (batch, horizon) if scalar qpos = states[..., : self.model.nq] - body_pos = qpos[..., self.body_pose_start : self.body_pose_start + 3] object_pos = qpos[..., self.object_pose_start : self.object_pose_start + 3] - tire_y_axis = sensors[..., self.tire_y_axis_start : self.tire_y_axis_start + 3] object_linear_velocity = states[..., self.object_vel_start : self.object_vel_start + 3] - # Compute unit vector pointing from tire to torso end_effector_to_object = sensors[..., self.end_effector_to_object_start : self.end_effector_to_object_start + 3] - gripper_proximity_reward = -self.config.w_gripper_proximity * np.linalg.norm( - end_effector_to_object, axis=-1 - ).mean(axis=-1) - - object_orientation_reward = -self.config.w_tire_orientation * np.abs( - np.dot(tire_y_axis, Z_AXIS) > self.config.orientation_threshold - ).sum(axis=-1) - - goal_reward = -self.config.w_goal * np.linalg.norm( - object_pos - np.array(self.config.goal_position)[None, None], axis=-1 - ).mean(-1) - - torso_proximity_reward = self.config.w_torso_proximity * np.linalg.norm(body_pos - object_pos, axis=-1).mean(-1) - - object_linear_velocity_penalty = -self.config.w_object_velocity * np.square( - np.linalg.norm(object_linear_velocity, axis=-1).mean(-1) + gripper_proximity_reward = gripper_distance_reward( + self.config, + np.linalg.norm(end_effector_to_object, axis=-1), ) - # Check if any state in the rollout has spot fallen - body_height = qpos[..., self.body_pose_start + 2] - spot_fallen_reward = -self.config.fall_penalty * (body_height <= self.config.spot_fallen_threshold).any(axis=-1) + goal_reward = goal_distance_reward(self.config, object_pos) + object_linear_velocity_penalty = object_linear_velocity_reward(self.config, object_linear_velocity) - # Compute a penalty to prefer small commands. - controls_reward = -self.config.w_controls * np.linalg.norm(controls, axis=-1).mean(-1) - - assert object_orientation_reward.shape == (batch_size,) assert gripper_proximity_reward.shape == (batch_size,) - assert torso_proximity_reward.shape == (batch_size,) assert object_linear_velocity_penalty.shape == (batch_size,) assert goal_reward.shape == (batch_size,) - assert spot_fallen_reward.shape == (batch_size,) - assert controls_reward.shape == (batch_size,) - - reward = ( - +spot_fallen_reward - + goal_reward - + object_orientation_reward - + torso_proximity_reward - + gripper_proximity_reward - + object_linear_velocity_penalty - + controls_reward - ) - return reward + return goal_reward + gripper_proximity_reward + object_linear_velocity_penalty @property def reset_pose(self) -> np.ndarray: @@ -160,4 +121,4 @@ def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: """Check if Spot has fallen.""" body_height = data.qpos[self.body_pose_start + 2] - return body_height <= self.config.spot_fallen_threshold + return body_height <= SPOT_FALLEN_THRESHOLD diff --git a/sumo/tasks/spot/spot_tire_upright.py b/sumo/tasks/spot/spot_tire_upright.py index 4bef9c7..b688179 100644 --- a/sumo/tasks/spot/spot_tire_upright.py +++ b/sumo/tasks/spot/spot_tire_upright.py @@ -1,13 +1,78 @@ # Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. -from judo.tasks.spot.spot_tire_upright import SpotTireUpright as _JudoSpotTireUpright -from judo.tasks.spot.spot_tire_upright import SpotTireUprightConfig +from dataclasses import dataclass +from typing import Any, cast -from sumo.tasks.spot.spot_base import SpotAssetMixin +import numpy as np +from mujoco import MjData, MjModel +from sumo.tasks.spot.spot_constants import LEGS_STANDING_POS, STANDING_HEIGHT +from sumo.tasks.spot.spot_tire_push import RADIUS_MAX, RADIUS_MIN, SpotTirePush +from sumo.tasks.spot.spot_upright import ( + SpotUprightConfig, + gripper_distance_reward, + horizontal_axis_orientation_reward, + random_object_pose, + sample_annulus_xy, + y_axis_is_horizontal, +) -class SpotTireUpright(SpotAssetMixin, _JudoSpotTireUpright): - """Public judo-rai SpotTireUpright with Sumo asset resolution.""" +ORIENTATION_TOLERANCE = 0.1 + + +@dataclass +class SpotTireUprightConfig(SpotUprightConfig): + """Configuration for Sumo's simplified Spot tire upright analysis task.""" + + +class SpotTireUpright(SpotTirePush): + """Task getting Spot to upright a randomly oriented tire.""" + + name: str = "spot_tire_upright" + config_t: type[SpotTireUprightConfig] = SpotTireUprightConfig # type: ignore[assignment] + config: SpotTireUprightConfig + + def __init__(self, config: SpotTireUprightConfig | None = None) -> None: + super().__init__(config=cast(Any, config)) + self.object_y_axis_idx = self.get_sensor_start_index("object_y_axis") + + def reward( + self, + states: np.ndarray, + sensors: np.ndarray, + controls: np.ndarray, + system_metadata: dict[str, Any] | None = None, + ) -> np.ndarray: + """Reward using only tire orientation and gripper distance.""" + batch_size = states.shape[0] + + tire_y_axis = sensors[..., self.object_y_axis_idx : self.object_y_axis_idx + 3] + gripper_to_object = sensors[..., self.end_effector_to_object_start : self.end_effector_to_object_start + 3] + gripper_distance = np.linalg.norm(gripper_to_object, axis=-1) + + orientation_reward = horizontal_axis_orientation_reward(self.config, tire_y_axis) + proximity_reward = gripper_distance_reward(self.config, gripper_distance) + + assert orientation_reward.shape == (batch_size,) + assert proximity_reward.shape == (batch_size,) + return orientation_reward + proximity_reward + + @property + def reset_pose(self) -> np.ndarray: + """Reset pose with a random tire attitude that clears the ground.""" + object_pose = random_object_pose( + self.model, + "tire", + sample_annulus_xy(RADIUS_MIN, RADIUS_MAX), + reject_orientation=lambda quat: y_axis_is_horizontal(quat, ORIENTATION_TOLERANCE), + ) + robot_pose = np.array([0.0, 0.0, STANDING_HEIGHT, 1.0, 0.0, 0.0, 0.0]) + return np.array([*robot_pose, *LEGS_STANDING_POS, *self.reset_arm_pos, *object_pose]) + + def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool: + """Check if the tire is upright.""" + tire_y_axis = data.sensordata[self.object_y_axis_idx : self.object_y_axis_idx + 3] + return bool(abs(tire_y_axis[2]) <= ORIENTATION_TOLERANCE) __all__ = ["SpotTireUpright", "SpotTireUprightConfig"] diff --git a/sumo/tasks/spot/spot_upright.py b/sumo/tasks/spot/spot_upright.py new file mode 100644 index 0000000..f7f1538 --- /dev/null +++ b/sumo/tasks/spot/spot_upright.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. + +from dataclasses import dataclass +from typing import Callable + +import numpy as np +from judo.tasks.base import TaskConfig +from mujoco import MjModel, mjtGeom + +Z_AXIS = np.array([0.0, 0.0, 1.0]) +Y_AXIS = np.array([0.0, 1.0, 0.0]) +GROUND_CLEARANCE_MARGIN = 0.02 +MAX_RESET_ORIENTATION_ATTEMPTS = 1000 + + +@dataclass +class SpotUprightConfig(TaskConfig): + """Configuration for Sumo's simplified Spot upright analysis tasks.""" + + w_orientation: float = 100.0 + w_gripper_proximity: float = 0.5 + + +def random_unit_quat() -> np.ndarray: + """Sample a random unit quaternion in MuJoCo's wxyz convention.""" + quat = np.random.normal(size=4) + quat /= np.linalg.norm(quat) + if quat[0] < 0: + quat = -quat + return quat + + +def z_axis_is_upright(quat: np.ndarray, tolerance: float) -> bool: + """Return whether the body z-axis already satisfies an upright task.""" + object_z_axis = quat_to_mat(quat) @ Z_AXIS + return bool(object_z_axis[2] >= 1.0 - tolerance) + + +def y_axis_is_horizontal(quat: np.ndarray, tolerance: float) -> bool: + """Return whether the body y-axis already satisfies a tire upright task.""" + object_y_axis = quat_to_mat(quat) @ Y_AXIS + return bool(abs(object_y_axis[2]) <= tolerance) + + +def quat_to_mat(quat: np.ndarray) -> np.ndarray: + """Convert a wxyz quaternion to a rotation matrix.""" + w, x, y, z = quat / np.linalg.norm(quat) + return np.array( + [ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ] + ) + + +def _geom_min_z_in_body_frame(model: MjModel, geom_id: int, body_rot: np.ndarray) -> float: + geom_pos = np.asarray(model.geom_pos[geom_id]) + geom_rot = quat_to_mat(np.asarray(model.geom_quat[geom_id])) + center_z = float((body_rot @ geom_pos)[2]) + z_in_geom = geom_rot.T @ body_rot.T @ Z_AXIS + geom_type = model.geom_type[geom_id] + geom_size = model.geom_size[geom_id] + + if geom_type == mjtGeom.mjGEOM_MESH: + mesh_id = model.geom_dataid[geom_id] + vert_start = model.mesh_vertadr[mesh_id] + vert_end = vert_start + model.mesh_vertnum[mesh_id] + vertices = np.asarray(model.mesh_vert[vert_start:vert_end]) + geom_points = geom_pos[:, None] + geom_rot @ vertices.T + return float((body_rot @ geom_points)[2].min()) + + if geom_type == mjtGeom.mjGEOM_BOX: + support = np.abs(geom_size * z_in_geom).sum() + elif geom_type == mjtGeom.mjGEOM_SPHERE: + support = geom_size[0] + elif geom_type == mjtGeom.mjGEOM_CAPSULE: + support = geom_size[0] + geom_size[1] * abs(z_in_geom[2]) + elif geom_type == mjtGeom.mjGEOM_CYLINDER: + support = geom_size[0] * np.linalg.norm(z_in_geom[:2]) + geom_size[1] * abs(z_in_geom[2]) + elif geom_type == mjtGeom.mjGEOM_ELLIPSOID: + support = np.linalg.norm(geom_size * z_in_geom) + else: + support = model.geom_rbound[geom_id] + + return center_z - float(support) + + +def ground_clearance_height(model: MjModel, body_name: str, quat: np.ndarray) -> float: + """Return a conservative free-body z value that keeps all body geoms above the ground.""" + body_id = model.body(body_name).id + rot = quat_to_mat(quat) + min_z = np.inf + + for geom_id in range(model.ngeom): + if model.geom_bodyid[geom_id] != body_id: + continue + min_z = min(min_z, _geom_min_z_in_body_frame(model, geom_id, rot)) + + if not np.isfinite(min_z): + return GROUND_CLEARANCE_MARGIN + return max(GROUND_CLEARANCE_MARGIN, -min_z + GROUND_CLEARANCE_MARGIN) + + +def random_object_pose( + model: MjModel, + body_name: str, + xy: np.ndarray, + reject_orientation: Callable[[np.ndarray], bool] | None = None, +) -> np.ndarray: + """Build a free-joint pose with random attitude and ground clearance.""" + for _ in range(MAX_RESET_ORIENTATION_ATTEMPTS): + quat = random_unit_quat() + if reject_orientation is None or not reject_orientation(quat): + break + else: + msg = f"Failed to sample a valid reset orientation for {body_name}" + raise RuntimeError(msg) + + z = ground_clearance_height(model, body_name, quat) + return np.array([xy[0], xy[1], z, *quat]) + + +def sample_annulus_xy(radius_min: float, radius_max: float, noise_scale: float = 0.1) -> np.ndarray: + """Sample an object position in an annulus around the origin.""" + radius = radius_min + (radius_max - radius_min) * np.random.rand() + theta = 2 * np.pi * np.random.rand() + return np.array([radius * np.cos(theta), radius * np.sin(theta)]) + noise_scale * np.random.randn(2) + + +def z_axis_orientation_reward(config: SpotUprightConfig, object_z_axis: np.ndarray) -> np.ndarray: + """Reward object z-axis alignment with world z.""" + alignment = np.sum(object_z_axis * Z_AXIS, axis=-1) + return -config.w_orientation * (1.0 - alignment).mean(axis=-1) + + +def horizontal_axis_orientation_reward(config: SpotUprightConfig, object_axis: np.ndarray) -> np.ndarray: + """Reward an object axis being horizontal.""" + return -config.w_orientation * np.abs(np.sum(object_axis * Z_AXIS, axis=-1)).mean(axis=-1) + + +def gripper_distance_reward(config: SpotUprightConfig, gripper_distance: np.ndarray) -> np.ndarray: + """Reward gripper proximity to the object.""" + return -config.w_gripper_proximity * gripper_distance.mean(axis=-1) diff --git a/tests/test_imports.py b/tests/test_imports.py index d772cb5..ee2483a 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -11,7 +11,10 @@ "sumo.tasks.spot.spot_cone_push", "sumo.tasks.spot.spot_rack_push", "sumo.tasks.spot.spot_tire_push", + "sumo.tasks.spot.spot_box_upright", + "sumo.tasks.spot.spot_chair_upright", "sumo.tasks.spot.spot_cone_upright", + "sumo.tasks.spot.spot_rack_upright", "sumo.tasks.spot.spot_chair_ramp", "sumo.tasks.spot.spot_barrier_upright", "sumo.tasks.spot.spot_barrier_drag", @@ -37,12 +40,15 @@ "spot_cone_push", "spot_rack_push", "spot_tire_push", + "spot_box_upright", + "spot_chair_upright", "spot_cone_upright", + "spot_rack_upright", + "spot_tire_upright", "spot_chair_ramp", "spot_barrier_upright", "spot_barrier_drag", "spot_tire_roll", - "spot_tire_upright", "spot_tire_stack", "spot_tire_rack_drag", "spot_rugged_box_push", diff --git a/tests/test_tasks.py b/tests/test_tasks.py index cfd170b..ff2a05f 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,9 @@ +from dataclasses import fields + import numpy as np import pytest from judo.tasks import get_registered_tasks +from mujoco import mj_forward from sumo.tasks.g1.g1_base import G1Base, G1BaseConfig from sumo.tasks.g1.g1_box import G1Box, G1BoxConfig @@ -11,17 +14,21 @@ from sumo.tasks.spot.spot_barrier_upright import SpotBarrierUpright, SpotBarrierUprightConfig from sumo.tasks.spot.spot_base import SpotBase, SpotBaseConfig from sumo.tasks.spot.spot_box_push import SpotBoxPush, SpotBoxPushConfig +from sumo.tasks.spot.spot_box_upright import SpotBoxUpright, SpotBoxUprightConfig from sumo.tasks.spot.spot_chair_push import SpotChairPush, SpotChairPushConfig from sumo.tasks.spot.spot_chair_ramp import SpotChairRamp, SpotChairRampConfig +from sumo.tasks.spot.spot_chair_upright import SpotChairUpright, SpotChairUprightConfig from sumo.tasks.spot.spot_cone_push import SpotConePush, SpotConePushConfig from sumo.tasks.spot.spot_cone_upright import SpotConeUpright, SpotConeUprightConfig from sumo.tasks.spot.spot_rack_push import SpotRackPush, SpotRackPushConfig +from sumo.tasks.spot.spot_rack_upright import SpotRackUpright, SpotRackUprightConfig from sumo.tasks.spot.spot_rugged_box_push import SpotRuggedBoxPush, SpotRuggedBoxPushConfig from sumo.tasks.spot.spot_tire_push import SpotTirePush, SpotTirePushConfig from sumo.tasks.spot.spot_tire_rack_drag import SpotTireRackDrag, SpotTireRackDragConfig from sumo.tasks.spot.spot_tire_roll import SpotTireRoll, SpotTireRollConfig from sumo.tasks.spot.spot_tire_stack import SpotTireStack, SpotTireStackConfig from sumo.tasks.spot.spot_tire_upright import SpotTireUpright, SpotTireUprightConfig +from sumo.tasks.spot.spot_upright import ground_clearance_height G1_TASK_CONFIGS = [ (G1Base, G1BaseConfig, 3), @@ -38,17 +45,36 @@ (SpotConePush, SpotConePushConfig), (SpotRackPush, SpotRackPushConfig), (SpotTirePush, SpotTirePushConfig), + (SpotBoxUpright, SpotBoxUprightConfig), + (SpotChairUpright, SpotChairUprightConfig), (SpotConeUpright, SpotConeUprightConfig), + (SpotRackUpright, SpotRackUprightConfig), + (SpotTireUpright, SpotTireUprightConfig), (SpotChairRamp, SpotChairRampConfig), (SpotBarrierUpright, SpotBarrierUprightConfig), (SpotBarrierDrag, SpotBarrierDragConfig), (SpotTireRoll, SpotTireRollConfig), - (SpotTireUpright, SpotTireUprightConfig), (SpotTireStack, SpotTireStackConfig), (SpotTireRackDrag, SpotTireRackDragConfig), (SpotRuggedBoxPush, SpotRuggedBoxPushConfig), ] +SIMPLIFIED_SPOT_PUSH_TASK_CONFIGS = [ + (SpotBoxPush, SpotBoxPushConfig), + (SpotChairPush, SpotChairPushConfig), + (SpotConePush, SpotConePushConfig), + (SpotRackPush, SpotRackPushConfig), + (SpotTirePush, SpotTirePushConfig), +] + +SIMPLIFIED_SPOT_UPRIGHT_TASK_CONFIGS = [ + (SpotBoxUpright, SpotBoxUprightConfig, "box_body", "z"), + (SpotChairUpright, SpotChairUprightConfig, "yellow_chair", "z"), + (SpotConeUpright, SpotConeUprightConfig, "traffic_cone", "z"), + (SpotRackUpright, SpotRackUprightConfig, "tire_rack", "z"), + (SpotTireUpright, SpotTireUprightConfig, "tire", "horizontal"), +] + REGISTERED_SPOT_TASK_NAMES = [ "spot_base", "spot_box_push", @@ -56,12 +82,15 @@ "spot_cone_push", "spot_rack_push", "spot_tire_push", + "spot_box_upright", + "spot_chair_upright", "spot_cone_upright", + "spot_rack_upright", + "spot_tire_upright", "spot_chair_ramp", "spot_barrier_upright", "spot_barrier_drag", "spot_tire_roll", - "spot_tire_upright", "spot_tire_stack", "spot_tire_rack_drag", "spot_rugged_box_push", @@ -219,6 +248,179 @@ def test_spot_task_reward_shape(task_cls, config_cls): assert reward.shape == (states.shape[0],) +def _get_task_attr(task, *names: str) -> int: + for name in names: + if hasattr(task, name): + return getattr(task, name) + raise AttributeError(f"{type(task).__name__} has none of {names}") + + +def _gripper_distance(task, sensors: np.ndarray, object_pos: np.ndarray) -> np.ndarray: + if hasattr(task, "gripper_pos_idx"): + gripper_pos_idx = task.gripper_pos_idx + gripper_pos = sensors[..., gripper_pos_idx : gripper_pos_idx + 3] + return np.linalg.norm(gripper_pos - object_pos, axis=-1) + + end_effector_to_object_start = _get_task_attr( + task, + "end_effector_to_object_idx", + "end_effector_to_object_start", + ) + end_effector_to_object = sensors[..., end_effector_to_object_start : end_effector_to_object_start + 3] + return np.linalg.norm(end_effector_to_object, axis=-1) + + +@pytest.mark.parametrize( + "task_cls,config_cls", + SIMPLIFIED_SPOT_PUSH_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_push_reward_uses_analysis_terms(task_cls, config_cls): + task = task_cls() + states, sensors, controls = _make_spot_rollout_inputs(task) + + qpos = states[..., : task.model.nq] + object_pose_idx = _get_task_attr(task, "object_pose_idx", "object_pose_start") + object_vel_idx = _get_task_attr(task, "object_vel_idx", "object_vel_start") + object_pos = qpos[..., object_pose_idx : object_pose_idx + 3] + object_linear_velocity = states[..., object_vel_idx : object_vel_idx + 3] + + expected = -task.config.w_goal * np.linalg.norm( + object_pos - np.asarray(task.config.goal_position)[None, None], axis=-1 + ).mean(-1) + expected += -task.config.w_gripper_proximity * _gripper_distance(task, sensors, object_pos).mean(-1) + expected += -task.config.w_object_velocity * np.square(np.linalg.norm(object_linear_velocity, axis=-1).mean(-1)) + + np.testing.assert_allclose(task.reward(states, sensors, controls), expected) + np.testing.assert_allclose(task.reward(states, sensors, controls + 100.0), expected) + + +@pytest.mark.parametrize( + "task_cls,config_cls", + SIMPLIFIED_SPOT_PUSH_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_push_configs_only_expose_analysis_terms(task_cls, config_cls): + config_fields = {field.name: field for field in fields(config_cls)} + assert set(config_fields) == { + "goal_position", + "w_goal", + "w_gripper_proximity", + "w_object_velocity", + } + assert config_fields["w_goal"].type is float + assert config_fields["w_gripper_proximity"].type is float + assert config_fields["w_object_velocity"].type is float + assert config_fields["goal_position"].type is np.ndarray + + +@pytest.mark.parametrize( + "task_cls,config_cls,body_name,orientation_kind", + SIMPLIFIED_SPOT_UPRIGHT_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_upright_reward_uses_analysis_terms(task_cls, config_cls, body_name, orientation_kind): + task = task_cls() + states, sensors, controls = _make_spot_rollout_inputs(task) + + qpos = states[..., : task.model.nq] + object_pose_idx = _get_task_attr(task, "object_pose_idx", "object_pose_start") + object_pos = qpos[..., object_pose_idx : object_pose_idx + 3] + gripper_distance = _gripper_distance(task, sensors, object_pos) + + if orientation_kind == "horizontal": + axis_idx = _get_task_attr(task, "object_y_axis_idx", "tire_y_axis_idx") + object_axis = sensors[..., axis_idx : axis_idx + 3] + orientation_error = np.abs(object_axis[..., 2]) + else: + axis_idx = task.object_z_axis_idx + object_axis = sensors[..., axis_idx : axis_idx + 3] + orientation_error = 1.0 - object_axis[..., 2] + + expected = -task.config.w_orientation * orientation_error.mean(axis=-1) + expected += -task.config.w_gripper_proximity * gripper_distance.mean(axis=-1) + + np.testing.assert_allclose(task.reward(states, sensors, controls), expected) + np.testing.assert_allclose(task.reward(states, sensors, controls + 100.0), expected) + + +@pytest.mark.parametrize( + "task_cls,config_cls,body_name,orientation_kind", + SIMPLIFIED_SPOT_UPRIGHT_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_upright_configs_only_expose_analysis_terms( + task_cls, + config_cls, + body_name, + orientation_kind, +): + config_fields = {field.name: field for field in fields(config_cls)} + assert set(config_fields) == { + "w_orientation", + "w_gripper_proximity", + } + assert config_fields["w_orientation"].type is float + assert config_fields["w_gripper_proximity"].type is float + + +@pytest.mark.parametrize( + "task_cls,config_cls,body_name,orientation_kind", + SIMPLIFIED_SPOT_UPRIGHT_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_upright_reset_randomizes_attitude_without_ground_penetration( + task_cls, + config_cls, + body_name, + orientation_kind, +): + task = task_cls() + object_pose_idx = _get_task_attr(task, "object_pose_idx", "object_pose_start") + + np.random.seed(0) + reset_pose_a = task.reset_pose + np.random.seed(1) + reset_pose_b = task.reset_pose + + quat_a = reset_pose_a[object_pose_idx + 3 : object_pose_idx + 7] + quat_b = reset_pose_b[object_pose_idx + 3 : object_pose_idx + 7] + z_a = reset_pose_a[object_pose_idx + 2] + + np.testing.assert_allclose(np.linalg.norm(quat_a), 1.0) + assert not np.allclose(quat_a, quat_b) + assert z_a >= ground_clearance_height(task.model, body_name, quat_a) + + +@pytest.mark.parametrize( + "task_cls,config_cls,body_name,orientation_kind", + SIMPLIFIED_SPOT_UPRIGHT_TASK_CONFIGS, + ids=lambda x: x.__name__ if hasattr(x, "__name__") else str(x), +) +def test_simplified_spot_upright_reset_rejects_already_successful_orientation( + task_cls, + config_cls, + body_name, + orientation_kind, + monkeypatch, +): + task = task_cls() + object_pose_idx = _get_task_attr(task, "object_pose_idx", "object_pose_start") + already_successful_quat = np.array([1.0, 0.0, 0.0, 0.0]) + fallen_quat = np.array([np.sqrt(0.5), np.sqrt(0.5), 0.0, 0.0]) + quat_samples = iter([already_successful_quat, fallen_quat]) + + monkeypatch.setattr("sumo.tasks.spot.spot_upright.random_unit_quat", lambda: next(quat_samples)) + + reset_pose = task.reset_pose + + np.testing.assert_allclose(reset_pose[object_pose_idx + 3 : object_pose_idx + 7], fallen_quat) + task.data.qpos[:] = reset_pose + task.data.qvel[:] = 0.0 + mj_forward(task.model, task.data) + assert not task.success(task.model, task.data) + + @pytest.mark.parametrize( "task_cls,config_cls", SPOT_TASK_CONFIGS,