From 6aeb5d3d6b531ef5c5a213c649e607b337096b1b Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Sun, 22 Feb 2026 02:34:15 -0800 Subject: [PATCH 1/5] Add ManagerCallSwitch max_mode cap, Scene capture config, and manager updates - Add manager_call_max_mode field for per-env capture ceiling (min(mode, cap)) - Support dict input for manager_call_config (in addition to JSON string) - Add "Scene" to MANAGER_NAMES for configurable Scene_write_data_to_sim mode - Remove hardcoded WARP_NOT_CAPTURED override from Scene_write_data_to_sim - Add warp_capturable decorator and is_warp_capturable check for mode=2 fallback - Update managers: action, observation, event with warp-first improvements - Update scene_entity_cfg with body_ids_wp resolution - Update train.py CLI arg handling --- .../reinforcement_learning/rsl_rl/train.py | 18 +++ .../envs/manager_based_env_warp.py | 125 ++++++++++++------ .../envs/manager_based_rl_env_warp.py | 12 +- .../envs/utils/io_descriptors.py | 35 ++++- .../managers/__init__.py | 4 +- .../managers/action_manager.py | 18 +-- .../managers/event_manager.py | 14 +- .../managers/manager_base.py | 7 + .../managers/observation_manager.py | 88 ++++++++++-- .../managers/scene_entity_cfg.py | 35 ++--- .../utils/warp/__init__.py | 2 +- .../isaaclab_experimental/utils/warp/utils.py | 39 ++++++ 12 files changed, 297 insertions(+), 100 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/train.py b/scripts/reinforcement_learning/rsl_rl/train.py index 4bff6bafe86..48e9f3759df 100644 --- a/scripts/reinforcement_learning/rsl_rl/train.py +++ b/scripts/reinforcement_learning/rsl_rl/train.py @@ -38,6 +38,18 @@ default=False, help="Enable Isaac Lab timers (use --no-timer to disable).", ) +parser.add_argument( + "--step-timer", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable granular timer sections in environment step().", +) +parser.add_argument( + "--reset-timer", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable granular timer sections in environment reset().", +) parser.add_argument( "--manager_call_config", type=str, @@ -91,12 +103,18 @@ import torch from datetime import datetime +import isaaclab_experimental.envs.manager_based_rl_env_warp as manager_based_rl_env_warp from rsl_rl.runners import DistillationRunner, OnPolicyRunner +import isaaclab.envs.manager_based_rl_env as manager_based_rl_env from isaaclab.utils.timer import Timer Timer.enable = args_cli.timer Timer.enable_display_output = args_cli.timer +manager_based_rl_env.TIMER_ENABLED_STEP = args_cli.step_timer +manager_based_rl_env.TIMER_ENABLED_RESET_IDX = args_cli.reset_timer +manager_based_rl_env_warp.TIMER_ENABLED_STEP = args_cli.step_timer +manager_based_rl_env_warp.TIMER_ENABLED_RESET_IDX = args_cli.reset_timer import isaaclab_tasks_experimental # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py index 73dd79401b6..89006dd8386 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -83,32 +83,47 @@ class ManagerCallSwitch: "TerminationManager", "RewardManager", "CurriculumManager", + "Scene", ) - def __init__(self, cfg_source: str | None = None): + def __init__(self, cfg_source: dict | str | None = None, max_modes: dict[str, int] | None = None): self._wp_graphs: dict[str, Any] = {} + self._non_capturable_managers: set[str] = set() self._cfg = self._load_cfg(cfg_source) + self._max_modes = self._validate_max_modes(max_modes) print("[INFO] ManagerCallSwitch configuration:") print(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}") for manager_name in self.MANAGER_NAMES: - print(f" - {manager_name}: {int(self.get_mode_for_manager(manager_name))}") + mode = int(self.get_mode_for_manager(manager_name)) + cap = self._max_modes.get(manager_name) + cap_str = f" (cap={cap})" if cap is not None else "" + print(f" - {manager_name}: {mode}{cap_str}") def invalidate_graphs(self) -> None: """Invalidate cached capture graphs.""" self._wp_graphs.clear() + def register_manager_capturability(self, manager_name: str, all_terms_capturable: bool) -> None: + """Register whether a manager's terms are all CUDA-graph-capturable.""" + if not all_terms_capturable: + self._non_capturable_managers.add(manager_name) + logger.warning(f"{manager_name} has non-capturable terms — mode=2 requests will fall back to mode=1.") + def call_stage( self, *, stage: str, stable_calls: Sequence[dict[str, Any]], warp_calls: Sequence[dict[str, Any]], + mode_override: ManagerCallMode | int | None = None, ) -> Any: - """Run the stage according to configured mode.""" + """Run the stage according to configured mode (or explicit override).""" manager_name = self._manager_name_from_stage(stage) - mode = self.get_mode_for_manager(manager_name) + mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override) if mode == ManagerCallMode.STABLE: return self._run_calls(stable_calls) + if mode == ManagerCallMode.WARP_CAPTURED and manager_name in self._non_capturable_managers: + mode = ManagerCallMode.WARP_NOT_CAPTURED if mode == ManagerCallMode.WARP_NOT_CAPTURED: return self._run_calls(warp_calls) self._wp_capture_or_launch(stage=stage, calls=warp_calls) @@ -122,14 +137,14 @@ def _manager_name_from_stage(self, stage: str) -> str: def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode: default_key = next(iter(self.DEFAULT_CONFIG)) mode_value = self._cfg.get(manager_name, self._cfg[default_key]) + cap = self._max_modes.get(manager_name) + if cap is not None: + mode_value = min(mode_value, cap) return ManagerCallMode(mode_value) - def resolve_manager_class(self, manager_name: str) -> type: - module_name = ( - "isaaclab.managers" - if self.get_mode_for_manager(manager_name) == ManagerCallMode.STABLE - else "isaaclab_experimental.managers" - ) + def resolve_manager_class(self, manager_name: str, mode_override: ManagerCallMode | int | None = None) -> type: + mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override) + module_name = "isaaclab.managers" if mode == ManagerCallMode.STABLE else "isaaclab_experimental.managers" module = importlib.import_module(module_name) if not hasattr(module, manager_name): raise AttributeError(f"Manager '{manager_name}' not found in module '{module_name}'.") @@ -138,35 +153,41 @@ def resolve_manager_class(self, manager_name: str) -> type: def _run_calls(self, calls: Sequence[dict[str, Any]]) -> Any: result = None for spec in calls: - fn = spec["fn"] - fn_args = spec.get("args", ()) - fn_kwargs = spec.get("kwargs", {}) - result = fn(*fn_args, **fn_kwargs) + result = self._run_call(call=spec) return result + def _run_call(self, call: dict[str, Any]) -> Any: + fn = call["fn"] + fn_args = call.get("args", ()) + fn_kwargs = call.get("kwargs", {}) + return fn(*fn_args, **fn_kwargs) + def _wp_capture_or_launch(self, stage: str, calls: Sequence[dict[str, Any]]) -> None: """Capture Warp CUDA graph for a stage on first call, then replay.""" graph = self._wp_graphs.get(stage) - if graph is None: - with wp.ScopedCapture() as capture: - for spec in calls: - fn = spec["fn"] - fn_args = spec.get("args", ()) - fn_kwargs = spec.get("kwargs", {}) - fn(*fn_args, **fn_kwargs) - graph = capture.graph - self._wp_graphs[stage] = graph - wp.capture_launch(graph) - - def _load_cfg(self, cfg_source: str | None) -> dict[str, int]: - if cfg_source is not None and not isinstance(cfg_source, str): - raise TypeError(f"cfg_source must be a string or None, got: {type(cfg_source)}") - if cfg_source is None or cfg_source.strip() == "": + if graph is not None: + wp.capture_launch(graph) + return + # Warmup: run eagerly to trigger first-call allocations (hasattr guards, wp.zeros, etc.) + self._run_calls(calls) + # Capture: allocations already done, only wp.launch calls are recorded + with wp.ScopedCapture() as capture: + self._run_calls(calls) + self._wp_graphs[stage] = capture.graph + + def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]: + if cfg_source is None: return dict(self.DEFAULT_CONFIG) - - parsed = json.loads(cfg_source) - if not isinstance(parsed, dict): - raise TypeError("manager_call_config must decode to a dict.") + if isinstance(cfg_source, dict): + parsed = cfg_source + elif isinstance(cfg_source, str): + if cfg_source.strip() == "": + return dict(self.DEFAULT_CONFIG) + parsed = json.loads(cfg_source) + if not isinstance(parsed, dict): + raise TypeError("manager_call_config must decode to a dict.") + else: + raise TypeError(f"cfg_source must be a dict, string, or None, got: {type(cfg_source)}") cfg = dict(parsed) if self.DEFAULT_KEY not in cfg: @@ -186,6 +207,20 @@ def _load_cfg(self, cfg_source: str | None) -> dict[str, int]: ) from exc return cfg + def _validate_max_modes(self, max_modes: dict[str, int] | None) -> dict[str, int]: + if max_modes is None: + return {} + if not isinstance(max_modes, dict): + raise TypeError(f"max_modes must be a dict or None, got: {type(max_modes)}") + for name, cap in max_modes.items(): + if not isinstance(cap, int): + raise TypeError(f"manager_call_max_mode value for '{name}' must be int (0/1/2), got: {type(cap)}") + try: + ManagerCallMode(cap) + except ValueError as exc: + raise ValueError(f"Invalid manager_call_max_mode value for '{name}': {cap}. Expected 0/1/2.") from exc + return dict(max_modes) + class ManagerBasedEnvWarp: """The base environment for the manager-based workflow (experimental fork). @@ -211,16 +246,9 @@ def __init__(self, cfg: ManagerBasedEnvCfg): # initialize internal variables self._is_closed = False # temporary debug runtime config for manager source/call switching. - cfg_source: str | None = getattr(self.cfg, "manager_call_config", None) - # if cfg_source is None: - # try: - # import __main__ - - # args_cli = getattr(__main__, "args_cli", None) - # cfg_source = getattr(args_cli, "manager_call_config", None) - # except Exception: - # cfg_source = None - self._manager_call_switch = ManagerCallSwitch(cfg_source) + cfg_source: dict | str | None = getattr(self.cfg, "manager_call_config", None) + max_modes: dict[str, int] | None = getattr(self.cfg, "manager_call_max_mode", None) + self._manager_call_switch = ManagerCallSwitch(cfg_source, max_modes=max_modes) self._apply_manager_term_cfg_profile() # set the seed for the environment @@ -409,6 +437,17 @@ def device(self): """The device on which the environment is running.""" return self.sim.device + @property + def env_origins_wp(self) -> wp.array: + """Scene env origins as a warp ``vec3f`` array. Cached on first access.""" + if not hasattr(self, "_env_origins_wp"): + origins = self.scene.env_origins + if isinstance(origins, wp.array): + self._env_origins_wp = origins + else: + self._env_origins_wp = wp.from_torch(origins, dtype=wp.vec3f) + return self._env_origins_wp + def resolve_env_mask( self, *, diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 184d65b8eda..a98c160527f 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -135,9 +135,10 @@ def load_managers(self): # note: this order is important since observation manager needs to know the command and action managers # and the reward manager needs to know the termination manager # -- command manager - self.command_manager = self._manager_call_switch.resolve_manager_class("CommandManager")( - self.cfg.commands, self - ) + # TODO(jichuanh): switch to experimental command manager once command-term isolation is complete. + self.command_manager = self._manager_call_switch.resolve_manager_class( + "CommandManager", mode_override=ManagerCallMode.STABLE + )(self.cfg.commands, self) print("[INFO] Command Manager: ", self.command_manager) # call the parent class to load the managers for observations and actions. @@ -379,6 +380,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: stage="CommandManager_compute", stable_calls=[{"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}], warp_calls=[{"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}], + mode_override=ManagerCallMode.STABLE, ) # -- step interval events @@ -669,10 +671,12 @@ def _reset_idx( enable=TIMER_ENABLED_RESET_IDX, format="us", ): + command_mode = ManagerCallMode.STABLE command_info = self._manager_call_switch.call_stage( stage="CommandManager_reset", stable_calls=[{"fn": self.command_manager.reset, "kwargs": {"env_ids": env_ids}}], warp_calls=[{"fn": self.command_manager.reset, "kwargs": {"env_mask": env_mask}}], + mode_override=command_mode, ) event_info = self._manager_call_switch.call_stage( stage="EventManager_reset", @@ -684,7 +688,7 @@ def _reset_idx( stable_calls=[{"fn": self.termination_manager.reset, "kwargs": {"env_ids": env_ids}}], warp_calls=[{"fn": self.termination_manager.reset, "kwargs": {"env_mask": env_mask}}], ) - if self._manager_call_switch.get_mode_for_manager("CommandManager") == ManagerCallMode.WARP_CAPTURED: + if command_mode == ManagerCallMode.WARP_CAPTURED: command_info = self.command_manager.reset_extras if self._manager_call_switch.get_mode_for_manager("EventManager") == ManagerCallMode.WARP_CAPTURED: event_info = {} diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py index b454ccb13de..0b26a2205f1 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py @@ -59,6 +59,9 @@ def _make_descriptor(**kwargs: Any) -> GenericObservationIODescriptor: desc = GenericObservationIODescriptor(**known) # User defined extras are stored in the descriptor under the `extras` field desc.extras = extras + # ``out_dim`` is kept as a top-level attribute (not in extras) so the + # observation manager can read it without inspecting extras. + desc.out_dim = extras.pop("out_dim", None) return desc @@ -187,18 +190,40 @@ def wrapper(env: ManagerBasedEnv, *args: P.args, **kwargs: P.kwargs) -> R: def record_shape(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: """Record the shape of the output buffer. - No-op when ``output`` is ``None`` (the typical case during Warp-first - inspection). Use a type-specific hook such as :func:`record_joint_shape` - to derive shape from config instead. + When ``output`` is not ``None`` (eager path), shape is read directly. + When ``output`` is ``None`` (Warp-first inspection), shape is derived from: + - ``descriptor.extras["axes"]`` for RootState observations, or + - ``asset_cfg.joint_ids`` for JointState observations. + + BodyState shape cannot be derived without calling the function (the per-body + feature size varies). In that case shape is left unset. Args: output: The pre-allocated output buffer, or ``None`` during inspection. descriptor: The descriptor to record the shape to. **kwargs: Additional keyword arguments. """ - if output is None: + if output is not None: + descriptor.shape = (output.shape[-1],) + return + # --- Warp-first fallback: derive shape without output --- + # 1) From axes metadata (RootState) + axes = descriptor.extras.get("axes") if descriptor.extras else None + if axes: + descriptor.shape = (len(axes),) return - descriptor.shape = (output.shape[-1],) + # 2) From asset_cfg for JointState + if descriptor.observation_type == "JointState": + asset_cfg = kwargs.get("asset_cfg") + if asset_cfg is not None: + from isaaclab.assets import Articulation + + asset: Articulation = kwargs["env"].scene[asset_cfg.name] + joint_ids = asset_cfg.joint_ids + if joint_ids == slice(None): + descriptor.shape = (len(asset.joint_names),) + else: + descriptor.shape = (len(joint_ids),) def record_dtype(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py index 62d3171d32a..b4521b98434 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py @@ -11,10 +11,12 @@ from isaaclab.managers import * # noqa: F401,F403 -# Override the stable implementation with the experimental fork. from .action_manager import ActionManager # noqa: F401 from .command_manager import CommandManager # noqa: F401 from .event_manager import EventManager # noqa: F401 + +# Override the stable implementation with the experimental fork. +from .manager_base import ManagerTermBase # noqa: F401 from .manager_term_cfg import ObservationTermCfg, RewardTermCfg, TerminationTermCfg # noqa: F401 from .observation_manager import ObservationManager # noqa: F401 from .reward_manager import RewardManager # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py index 986d67e1a49..8b3bf61e91c 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py @@ -27,22 +27,8 @@ if TYPE_CHECKING: from isaaclab.envs import ManagerBasedEnv - -@wp.kernel -def _zero_masked_2d( - # input - mask: wp.array(dtype=wp.bool), - # input/output - data: wp.array(dtype=wp.float32, ndim=2), -): - """Zero rows of a 2D buffer where ``mask`` is True. - - Launched with dim = (num_envs, data.shape[1]). - """ - - env_id, j = wp.tid() - if mask[env_id]: - data[env_id, j] = 0.0 +# Shared kernel – imported from utils to avoid duplication. +from isaaclab_experimental.utils.warp.utils import zero_masked_2d as _zero_masked_2d class ActionTerm(ManagerTermBase): diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py index 52be53262fe..a560f664255 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py @@ -311,21 +311,21 @@ def apply( self._apply_interval(float(dt)) return + # resolve the environment mask + if env_mask_wp is None: + if wp.get_device().is_capturing: + raise ValueError(f"Event mode '{mode}' requires the environment mask to be provided when capturing.") + env_mask_wp = self._env.resolve_env_mask(env_ids=env_ids) + if mode == "reset": if global_env_step_count is None: raise ValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.") - if env_mask_wp is None: - if wp.get_device().is_capturing: - raise ValueError( - f"Event mode '{mode}' requires the environment mask to be provided when capturing." - ) - env_mask_wp = self._env.resolve_env_mask(env_ids=env_ids) self._apply_reset(env_mask_wp, global_env_step_count) return # other modes keep the stable convention (env_ids forwarded) for term_cfg in self._mode_term_cfgs[mode]: - term_cfg.func(self._env, env_ids, **term_cfg.params) + term_cfg.func(self._env, env_mask_wp, **term_cfg.params) def _apply_interval(self, dt: float) -> None: assert self._env.rng_state_wp is not None diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py index c16bebfa274..af4f5e02d2f 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any import warp as wp +from isaaclab_experimental.utils.warp import is_warp_capturable import isaaclab.utils.string as string_utils from isaaclab.utils import class_to_dict, string_to_callable @@ -401,6 +402,12 @@ def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, f" and optional parameters: {args_with_defaults}, but received: {term_params}." ) + # register non-capturable terms with the call switch for mode=2 fallback + if not is_warp_capturable(term_cfg.func): + switch = getattr(self._env, "_manager_call_switch", None) + if switch is not None: + switch.register_manager_capturability(type(self).__name__, False) + # process attributes at runtime # these properties are only resolvable once the simulation starts playing if self._env.sim.is_playing(): diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py index e9d7dd34208..3bfd0fd8965 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py @@ -701,9 +701,10 @@ def _prepare_terms(self): # noqa: C901 f" but received: {len(term_cfg.scale)}." ) - # cast the scale into torch tensor - term_cfg.scale = torch.tensor(term_cfg.scale, dtype=torch.float, device=self._env.device) - term_cfg.scale_wp = wp.from_torch(term_cfg.scale, dtype=wp.float32) + scale_vals = ( + term_cfg.scale if isinstance(term_cfg.scale, tuple) else [float(term_cfg.scale)] * obs_dims[1] + ) + term_cfg.scale_wp = wp.array(scale_vals, dtype=wp.float32, device=self._env.device) # prepare modifiers for each observation if term_cfg.modifiers is not None: @@ -840,16 +841,43 @@ def _prepare_terms(self): # noqa: C901 self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer def _infer_term_dim_scalar(self, term_cfg: ObservationTermCfg) -> int: - """Infer (D,) using scalar scene info (no term execution).""" - # allow explicit override + """Infer observation output dimension (D,) using decorator metadata, scene info, or manager state. + + Resolution order: + 1. ``out_dim`` on the function's ``@generic_io_descriptor_warp`` decorator. + 2. ``axes`` on the decorator (e.g. ``axes=["X","Y","Z"]`` → dim 3). + 3. Explicit ``term_dim`` / ``out_dim`` / ``obs_dim`` in ``term_cfg.params`` (legacy). + 4. ``asset_cfg.joint_ids`` count (joint-based observations). + """ + # --- 1-2. Decorator metadata (preferred) --- + func = term_cfg.func + # Check for descriptor on the (possibly wrapped) function first, + # then fall back to unwrapping for class-based terms. + descriptor = getattr(func, "_descriptor", None) + if descriptor is None and hasattr(func, "__wrapped__"): + descriptor = getattr(func.__wrapped__, "_descriptor", None) + if descriptor is not None: + # 1. Explicit out_dim on decorator + out_dim = getattr(descriptor, "out_dim", None) + if out_dim is not None: + return self._resolve_out_dim(out_dim, term_cfg) + # 2. Derive from axes metadata + axes = descriptor.extras.get("axes") if descriptor.extras else None + if axes is not None: + return len(axes) + + # --- 3. Legacy explicit override in params --- for k in ("term_dim", "out_dim", "obs_dim"): if k in term_cfg.params: return int(term_cfg.params[k]) - # try explicit param first + + # --- 3. Joint-based fallback via asset_cfg --- asset_cfg = term_cfg.params.get("asset_cfg") if asset_cfg is None: - raise ValueError(f"Observation term '{term_cfg.params}' has no asset_cfg parameter.") - # resolve selection + raise ValueError( + f"Cannot infer output dimension for observation term '{getattr(func, '__name__', func)}'. " + "Add `out_dim=` to its @generic_io_descriptor_warp decorator." + ) asset = self._env.scene[asset_cfg.name] joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) if joint_ids_wp is not None: @@ -858,3 +886,47 @@ def _infer_term_dim_scalar(self, term_cfg: ObservationTermCfg) -> int: if isinstance(joint_ids, slice): return int(getattr(asset, "num_joints", wp.to_torch(asset.data.joint_pos).shape[1])) return int(len(joint_ids)) + + def _resolve_out_dim(self, out_dim: int | str, term_cfg: ObservationTermCfg) -> int: + """Resolve an ``out_dim`` value from a decorator into a concrete integer. + + Supports: + - ``int``: returned as-is (fixed dimension). + - ``"joint"``: number of selected joints from ``asset_cfg``. + - ``"body:N"``: ``N`` components per selected body from ``asset_cfg``. + - ``"command"``: query ``command_manager.get_command(name).shape[-1]``. + - ``"action"``: query ``action_manager.action.shape[-1]``. + """ + if isinstance(out_dim, int): + return out_dim + + if out_dim == "joint": + asset_cfg = term_cfg.params.get("asset_cfg") + asset = self._env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is not None: + return int(joint_ids_wp.shape[0]) + joint_ids = getattr(asset_cfg, "joint_ids", slice(None)) + if isinstance(joint_ids, slice): + return int(getattr(asset, "num_joints", wp.to_torch(asset.data.joint_pos).shape[1])) + return int(len(joint_ids)) + + if isinstance(out_dim, str) and out_dim.startswith("body:"): + per_body = int(out_dim.split(":")[1]) + asset_cfg = term_cfg.params.get("asset_cfg") + body_ids = getattr(asset_cfg, "body_ids", None) + if body_ids is None or body_ids == slice(None): + asset = self._env.scene[asset_cfg.name] + return per_body * len(asset.body_names) + return per_body * len(body_ids) + + if out_dim == "command": + command_name = term_cfg.params.get("command_name") + cmd = self._env.command_manager.get_command(command_name) + return int(cmd.shape[-1]) + + if out_dim == "action": + action = self._env.action_manager.action + return int(action.shape[-1]) + + raise ValueError(f"Unknown out_dim sentinel: {out_dim!r}") diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py index 5c160a44732..554f667bd02 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -30,24 +30,29 @@ class SceneEntityCfg(_SceneEntityCfg): joint_ids_wp: wp.array | None = ( None # Needed for subset-sized outputs/gathers (len(selected)); mask can't map k→joint/order. ) + body_ids_wp: wp.array | None = None def resolve(self, scene: InteractiveScene): # run the stable resolution first (fills joint_ids/body_ids from names/regex) super().resolve(scene) - # Build a Warp joint mask for articulations only. entity = scene[self.name] - if not isinstance(entity, BaseArticulation): - return - - # Pre-allocate a full-length mask (all True for default selection). - if self.joint_ids == slice(None): - joint_ids_list = list(range(entity.num_joints)) - mask_list = [True] * entity.num_joints - else: - joint_ids_list = list(self.joint_ids) - mask_list = [False] * entity.num_joints - for idx in joint_ids_list: - mask_list[idx] = True - self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) - self.joint_ids_wp = wp.array(joint_ids_list, dtype=wp.int32, device=scene.device) + + # -- Warp joint mask / ids for articulations + if isinstance(entity, BaseArticulation): + if self.joint_ids == slice(None): + joint_ids_list = list(range(entity.num_joints)) + mask_list = [True] * entity.num_joints + else: + joint_ids_list = list(self.joint_ids) + mask_list = [False] * entity.num_joints + for idx in joint_ids_list: + mask_list[idx] = True + self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) + self.joint_ids_wp = wp.array(joint_ids_list, dtype=wp.int32, device=scene.device) + + # -- Warp body ids + if self.body_ids is not None and self.body_ids != slice(None): + self.body_ids_wp = wp.array(list(self.body_ids), dtype=wp.int32, device=scene.device) + elif hasattr(entity, "num_bodies"): + self.body_ids_wp = wp.array(list(range(entity.num_bodies)), dtype=wp.int32, device=scene.device) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py index 9767a31ccc9..1c3b3497a6c 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py @@ -6,4 +6,4 @@ """Warp utility functions and shared kernels for isaaclab_experimental.""" from .kernels import compute_reset_scale, count_masked -from .utils import resolve_asset_cfg, wrap_to_pi +from .utils import is_warp_capturable, resolve_asset_cfg, warp_capturable, wrap_to_pi diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py index c47c4361886..7a7b0cbf2b4 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py @@ -14,6 +14,37 @@ from isaaclab.envs import ManagerBasedEnv +def warp_capturable(capturable: bool): + """Annotate an MDP term's CUDA-graph capturability. + + No-wrapper decorator: sets ``_warp_capturable`` directly on the function + and returns it unchanged. Safe to stack with any other decorator in any order. + + By default all MDP terms are assumed capturable (True). Use + ``@warp_capturable(False)`` on terms that call non-capturable external APIs. + """ + + def decorator(func): + func._warp_capturable = capturable + return func + + return decorator + + +def is_warp_capturable(func) -> bool: + """Check if a term function is CUDA-graph-capturable. + + Checks ``_warp_capturable`` on the function and its ``__wrapped__`` target. + Returns True (capturable) by default if no annotation is found. + """ + for f in (func, getattr(func, "__wrapped__", None)): + if f is not None: + val = getattr(f, "_warp_capturable", None) + if val is not None: + return val + return True + + @wp.func def wrap_to_pi(angle: float) -> float: """Wrap input angle (in radians) to the range [-pi, pi).""" @@ -24,6 +55,14 @@ def wrap_to_pi(angle: float) -> float: return wp.where((wrapped_angle == 0) and (angle > 0), wp.pi, wrapped_angle - wp.pi) +@wp.kernel +def zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): + """Zero out rows of a 2D float32 array where mask is True.""" + env_id, j = wp.tid() + if mask[env_id]: + values[env_id, j] = 0.0 + + def resolve_asset_cfg(cfg: dict, env: ManagerBasedEnv) -> SceneEntityCfg: asset_cfg = None From b224b9a29d86b6532bcca96efccbb465480271fd Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Sun, 22 Feb 2026 02:44:32 -0800 Subject: [PATCH 2/5] Add warp-first MDP terms used by tested envs Warp-first observation, reward, termination, event, and action terms referenced by the 14 verified training-parity envs. Observations: base_pos_z, base_lin_vel, base_ang_vel, projected_gravity, joint_pos, joint_pos_rel, joint_pos_limit_normalized, joint_vel, joint_vel_rel, last_action, generated_commands Rewards: is_alive, is_terminated, lin_vel_z_l2, ang_vel_xy_l2, flat_orientation_l2, joint_torques_l2, joint_vel_l1, joint_vel_l2, joint_acc_l2, joint_deviation_l1, joint_pos_limits, action_rate_l2, action_l2, undesired_contacts, track_lin_vel_xy_exp, track_ang_vel_z_exp Terminations: time_out, root_height_below_minimum, joint_pos_out_of_manual_limit, illegal_contact Events: randomize_rigid_body_com, apply_external_force_torque, reset_root_state_uniform, reset_joints_by_scale, reset_joints_by_offset, push_by_setting_velocity Actions: JointPositionAction, JointEffortAction Terms accessing lazy TimestampedWarpBuffer properties (Tier 2) are marked @warp_capturable(False) to prevent stale data under CUDA graph capture. --- .../contact_sensor/base_contact_sensor.py | 2 +- .../envs/mdp/actions/__init__.py | 6 +- .../envs/mdp/actions/actions_cfg.py | 41 +- .../envs/mdp/actions/joint_actions.py | 57 +-- .../isaaclab_experimental/envs/mdp/events.py | 428 ++++++++++++++++-- .../envs/mdp/observations.py | 256 ++++++++++- .../isaaclab_experimental/envs/mdp/rewards.py | 375 ++++++++++++++- .../envs/mdp/terminations.py | 103 ++++- 8 files changed, 1177 insertions(+), 91 deletions(-) diff --git a/source/isaaclab/isaaclab/sensors/contact_sensor/base_contact_sensor.py b/source/isaaclab/isaaclab/sensors/contact_sensor/base_contact_sensor.py index 15854725322..c87ec4fc86d 100644 --- a/source/isaaclab/isaaclab/sensors/contact_sensor/base_contact_sensor.py +++ b/source/isaaclab/isaaclab/sensors/contact_sensor/base_contact_sensor.py @@ -127,7 +127,7 @@ def reset(self, env_ids: Sequence[int] | None = None, env_mask: wp.array(dtype=w env_mask: The masks of the environments to reset. Defaults to None: all the environments are reset. """ # reset the timers and counters - super().reset(env_ids) + super().reset(env_ids, env_mask) @abstractmethod def find_bodies( diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py index 283805a279f..d295384149d 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py @@ -3,10 +3,10 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Experimental action terms (minimal). +"""Experimental action terms (Warp-first). -Only the action configs/terms currently required by the experimental manager-based Cartpole task -are provided here. +Provides Warp-first action term implementations overriding the stable +:mod:`isaaclab.envs.mdp.actions` module. """ from .actions_cfg import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py index fa75f69d045..6635bcb47db 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py @@ -3,12 +3,6 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Action term configuration (experimental, minimal). - -This module mirrors the stable :mod:`isaaclab.envs.mdp.actions.actions_cfg` but only keeps what -the experimental Cartpole task needs. -""" - from dataclasses import MISSING from isaaclab_experimental.managers.action_manager import ActionTerm, ActionTermCfg @@ -17,26 +11,51 @@ from . import joint_actions +## +# Joint actions. +## + @configclass class JointActionCfg(ActionTermCfg): - """Configuration for the base joint action term.""" + """Configuration for the base joint action term. + + See :class:`JointAction` for more details. + """ joint_names: list[str] = MISSING """List of joint names or regex expressions that the action will be mapped to.""" - scale: float | dict[str, float] = 1.0 """Scale factor for the action (float or dict of regex expressions). Defaults to 1.0.""" - offset: float | dict[str, float] = 0.0 """Offset factor for the action (float or dict of regex expressions). Defaults to 0.0.""" - preserve_order: bool = False """Whether to preserve the order of the joint names in the action output. Defaults to False.""" +@configclass +class JointPositionActionCfg(JointActionCfg): + """Configuration for the joint position action term. + + See :class:`JointPositionAction` for more details. + """ + + class_type: type[ActionTerm] = joint_actions.JointPositionAction + + use_default_offset: bool = True + """Whether to use default joint positions configured in the articulation asset as offset. + Defaults to True. + + If True, this flag results in overwriting the values of :attr:`offset` to the default joint positions + from the articulation asset. + """ + + @configclass class JointEffortActionCfg(JointActionCfg): - """Configuration for the joint effort action term.""" + """Configuration for the joint effort action term. + + See :class:`JointEffortAction` for more details. + """ class_type: type[ActionTerm] = joint_actions.JointEffortAction diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py index 78e0fd5b63d..1215625cef4 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py @@ -6,10 +6,12 @@ from __future__ import annotations import logging +import numpy as np from typing import TYPE_CHECKING import warp as wp from isaaclab_experimental.managers.action_manager import ActionTerm +from isaaclab_experimental.utils.warp.utils import zero_masked_2d import isaaclab.utils.string as string_utils from isaaclab.assets.articulation import Articulation @@ -53,24 +55,6 @@ def _process_joint_actions_kernel( processed_out[env_id, j] = x -@wp.kernel -def _set_clip_1d_to_2d( - clip_low: wp.array(dtype=wp.float32), - clip_high: wp.array(dtype=wp.float32), - out: wp.array(dtype=wp.float32, ndim=2), -): - j = wp.tid() - out[j, 0] = clip_low[j] - out[j, 1] = clip_high[j] - - -@wp.kernel -def _zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): - env_id, j = wp.tid() - if mask[env_id]: - values[env_id, j] = 0.0 - - class JointAction(ActionTerm): r"""Base class for joint actions. @@ -171,15 +155,8 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non else: raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") - clip_low_vec = wp.array(clip_low, dtype=wp.float32, device=self.device) - clip_high_vec = wp.array(clip_high, dtype=wp.float32, device=self.device) - self._clip = wp.zeros((self.action_dim, 2), dtype=wp.float32, device=self.device) - # TODO(jichuanh): use np.stack([a, b], axis=0) - wp.launch( - kernel=_set_clip_1d_to_2d, - dim=self.action_dim, - inputs=[clip_low_vec, clip_high_vec, self._clip], - device=self.device, + self._clip = wp.array( + np.column_stack([clip_low, clip_high]).astype(np.float32), dtype=wp.float32, device=self.device ) """ @@ -259,13 +236,37 @@ def reset(self, env_mask: wp.array | None = None) -> None: self._raw_actions.fill_(0.0) return wp.launch( - kernel=_zero_masked_2d, + kernel=zero_masked_2d, dim=(self.num_envs, self.action_dim), inputs=[env_mask, self._raw_actions], device=self.device, ) +class JointPositionAction(JointAction): + """Joint action term that applies the processed actions to the articulation's joints as position commands. + + Warp-first override of :class:`isaaclab.envs.mdp.actions.JointPositionAction`. + """ + + cfg: actions_cfg.JointPositionActionCfg + """The configuration of the action term.""" + + def __init__(self, cfg: actions_cfg.JointPositionActionCfg, env: ManagerBasedEnv): + super().__init__(cfg, env) + # use default joint positions as offset + if cfg.use_default_offset: + defaults_np = self._asset.data.default_joint_pos.numpy() + if isinstance(self._joint_ids, slice): + offset_vals = defaults_np[0, :].tolist() + else: + offset_vals = [float(defaults_np[0, jid]) for jid in self._joint_ids] + self._offset = wp.array(offset_vals, dtype=wp.float32, device=self.device) + + def apply_actions(self): + self._asset.set_joint_position_target(self.processed_actions, joint_mask=self._joint_mask) + + class JointEffortAction(JointAction): """Joint action term that applies the processed actions to the articulation's joints as effort commands.""" diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py index 3429128b570..ad60100ea44 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py @@ -24,14 +24,381 @@ from __future__ import annotations +import logging +from typing import TYPE_CHECKING + import warp as wp from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp import warp_capturable from isaaclab.assets import Articulation +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Randomize rigid body center of mass +# --------------------------------------------------------------------------- + @wp.kernel -def _reset_joints_by_offset_kernel( +def _randomize_com_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + body_com_pos_b: wp.array(dtype=wp.vec3f, ndim=2), + body_ids: wp.array(dtype=wp.int32), + com_lo: wp.vec3f, + com_hi: wp.vec3f, +): + """Add random offset to center of mass positions for selected bodies.""" + env_id = wp.tid() + if not env_mask[env_id]: + return + + state = rng_state[env_id] + for k in range(body_ids.shape[0]): + b = body_ids[k] + v = body_com_pos_b[env_id, b] + dx = wp.randf(state, com_lo[0], com_hi[0]) + dy = wp.randf(state, com_lo[1], com_hi[1]) + dz = wp.randf(state, com_lo[2], com_hi[2]) + body_com_pos_b[env_id, b] = wp.vec3f(v[0] + dx, v[1] + dy, v[2] + dz) + rng_state[env_id] = state + + +def randomize_rigid_body_com( + env, + env_mask: wp.array, + com_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Randomize the center of mass (CoM) of rigid bodies by adding random offsets. + + Warp-first override of :func:`isaaclab.envs.mdp.events.randomize_rigid_body_com`. + Writes directly into the sim-bound ``body_com_pos_b`` buffer. + """ + asset: Articulation = env.scene[asset_cfg.name] + + fn = randomize_rigid_body_com + if not hasattr(fn, "_com_lo") or fn._asset_name != asset_cfg.name: + fn._asset_name = asset_cfg.name + r = [com_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z"]] + fn._com_lo = wp.vec3f(r[0][0], r[1][0], r[2][0]) + fn._com_hi = wp.vec3f(r[0][1], r[1][1], r[2][1]) + + wp.launch( + kernel=_randomize_com_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.body_com_pos_b, + asset_cfg.body_ids_wp, + fn._com_lo, + fn._com_hi, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# Apply external force and torque +# --------------------------------------------------------------------------- + + +@wp.kernel +def _apply_external_force_torque_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + force_out: wp.array(dtype=wp.vec3f, ndim=2), + torque_out: wp.array(dtype=wp.vec3f, ndim=2), + force_lo: float, + force_hi: float, + torque_lo: float, + torque_hi: float, +): + env_id = wp.tid() + if not env_mask[env_id]: + # zero out unmasked envs so they don't accumulate stale forces + for b in range(force_out.shape[1]): + force_out[env_id, b] = wp.vec3f(0.0, 0.0, 0.0) + torque_out[env_id, b] = wp.vec3f(0.0, 0.0, 0.0) + return + + state = rng_state[env_id] + for b in range(force_out.shape[1]): + force_out[env_id, b] = wp.vec3f( + wp.randf(state, force_lo, force_hi), + wp.randf(state, force_lo, force_hi), + wp.randf(state, force_lo, force_hi), + ) + torque_out[env_id, b] = wp.vec3f( + wp.randf(state, torque_lo, torque_hi), + wp.randf(state, torque_lo, torque_hi), + wp.randf(state, torque_lo, torque_hi), + ) + rng_state[env_id] = state + + +@warp_capturable(False) +def apply_external_force_torque( + env, + env_mask: wp.array, + force_range: tuple[float, float], + torque_range: tuple[float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Randomize external forces and torques applied to the asset's bodies. + + Warp-first override of :func:`isaaclab.envs.mdp.events.apply_external_force_torque`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-convert constant arguments. + if not hasattr(apply_external_force_torque, "_scratch_forces"): + apply_external_force_torque._scratch_forces = wp.zeros( + (env.num_envs, asset.num_bodies), dtype=wp.vec3f, device=env.device + ) + apply_external_force_torque._scratch_torques = wp.zeros( + (env.num_envs, asset.num_bodies), dtype=wp.vec3f, device=env.device + ) + + wp.launch( + kernel=_apply_external_force_torque_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + apply_external_force_torque._scratch_forces, + apply_external_force_torque._scratch_torques, + force_range[0], + force_range[1], + torque_range[0], + torque_range[1], + ], + device=env.device, + ) + + asset.set_external_force_and_torque( + apply_external_force_torque._scratch_forces, + apply_external_force_torque._scratch_torques, + env_mask=env_mask, + ) + + +# --------------------------------------------------------------------------- +# Push by velocity +# --------------------------------------------------------------------------- + + +@wp.kernel +def _push_by_setting_velocity_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + vel_out: wp.array(dtype=wp.spatial_vectorf), + lin_lo: wp.vec3f, + lin_hi: wp.vec3f, + ang_lo: wp.vec3f, + ang_hi: wp.vec3f, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + vel = root_vel_w[env_id] + state = rng_state[env_id] + + vel_out[env_id] = wp.spatial_vectorf( + vel[0] + wp.randf(state, lin_lo[0], lin_hi[0]), + vel[1] + wp.randf(state, lin_lo[1], lin_hi[1]), + vel[2] + wp.randf(state, lin_lo[2], lin_hi[2]), + vel[3] + wp.randf(state, ang_lo[0], ang_hi[0]), + vel[4] + wp.randf(state, ang_lo[1], ang_hi[1]), + vel[5] + wp.randf(state, ang_lo[2], ang_hi[2]), + ) + + rng_state[env_id] = state + + +@warp_capturable(False) +def push_by_setting_velocity( + env, + env_mask: wp.array, + velocity_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Push the asset by setting the root velocity to a random value within the given ranges. + + Warp-first override of :func:`isaaclab.envs.mdp.events.push_by_setting_velocity`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-parse constant range arguments. + if not hasattr(push_by_setting_velocity, "_scratch_vel"): + push_by_setting_velocity._scratch_vel = wp.zeros((env.num_envs,), dtype=wp.spatial_vectorf, device=env.device) + r = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + push_by_setting_velocity._lin_lo = wp.vec3f(r[0][0], r[1][0], r[2][0]) + push_by_setting_velocity._lin_hi = wp.vec3f(r[0][1], r[1][1], r[2][1]) + push_by_setting_velocity._ang_lo = wp.vec3f(r[3][0], r[4][0], r[5][0]) + push_by_setting_velocity._ang_hi = wp.vec3f(r[3][1], r[4][1], r[5][1]) + + wp.launch( + kernel=_push_by_setting_velocity_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.root_vel_w, + push_by_setting_velocity._scratch_vel, + push_by_setting_velocity._lin_lo, + push_by_setting_velocity._lin_hi, + push_by_setting_velocity._ang_lo, + push_by_setting_velocity._ang_hi, + ], + device=env.device, + ) + + asset.write_root_velocity_to_sim(push_by_setting_velocity._scratch_vel, env_mask=env_mask) + + +# --------------------------------------------------------------------------- +# Reset root state uniform +# --------------------------------------------------------------------------- + + +@wp.kernel +def _reset_root_state_uniform_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + default_root_pose: wp.array(dtype=wp.transformf), + default_root_vel: wp.array(dtype=wp.spatial_vectorf), + env_origins: wp.array(dtype=wp.vec3f), + pose_out: wp.array(dtype=wp.transformf), + vel_out: wp.array(dtype=wp.spatial_vectorf), + pos_lo: wp.vec3f, + pos_hi: wp.vec3f, + rot_lo: wp.vec3f, + rot_hi: wp.vec3f, + vel_lin_lo: wp.vec3f, + vel_lin_hi: wp.vec3f, + vel_ang_lo: wp.vec3f, + vel_ang_hi: wp.vec3f, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + state = rng_state[env_id] + + # --- Pose --- + default_pose = default_root_pose[env_id] + default_pos = wp.transform_get_translation(default_pose) + default_q = wp.transform_get_rotation(default_pose) + origin = env_origins[env_id] + + # position = default + env_origin + random offset + pos = wp.vec3f( + default_pos[0] + origin[0] + wp.randf(state, pos_lo[0], pos_hi[0]), + default_pos[1] + origin[1] + wp.randf(state, pos_lo[1], pos_hi[1]), + default_pos[2] + origin[2] + wp.randf(state, pos_lo[2], pos_hi[2]), + ) + + # orientation = default * delta(euler_xyz) + roll = wp.randf(state, rot_lo[0], rot_hi[0]) + pitch = wp.randf(state, rot_lo[1], rot_hi[1]) + yaw = wp.randf(state, rot_lo[2], rot_hi[2]) + qx = wp.quat_from_axis_angle(wp.vec3f(1.0, 0.0, 0.0), roll) + qy = wp.quat_from_axis_angle(wp.vec3f(0.0, 1.0, 0.0), pitch) + qz = wp.quat_from_axis_angle(wp.vec3f(0.0, 0.0, 1.0), yaw) + # ZYX extrinsic = XYZ intrinsic: delta = qz * qy * qx + delta_q = wp.mul(wp.mul(qz, qy), qx) + final_q = wp.mul(default_q, delta_q) + + pose_out[env_id] = wp.transformf(pos, final_q) + + # --- Velocity --- + default_vel = default_root_vel[env_id] + vel_out[env_id] = wp.spatial_vectorf( + default_vel[0] + wp.randf(state, vel_lin_lo[0], vel_lin_hi[0]), + default_vel[1] + wp.randf(state, vel_lin_lo[1], vel_lin_hi[1]), + default_vel[2] + wp.randf(state, vel_lin_lo[2], vel_lin_hi[2]), + default_vel[3] + wp.randf(state, vel_ang_lo[0], vel_ang_hi[0]), + default_vel[4] + wp.randf(state, vel_ang_lo[1], vel_ang_hi[1]), + default_vel[5] + wp.randf(state, vel_ang_lo[2], vel_ang_hi[2]), + ) + + rng_state[env_id] = state + + +@warp_capturable(False) +def reset_root_state_uniform( + env, + env_mask: wp.array, + pose_range: dict[str, tuple[float, float]], + velocity_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Reset the asset root state to a random position and velocity uniformly within the given ranges. + + Warp-first override of :func:`isaaclab.envs.mdp.events.reset_root_state_uniform`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-parse range dicts. + if not hasattr(reset_root_state_uniform, "_scratch_pose"): + reset_root_state_uniform._scratch_pose = wp.zeros((env.num_envs,), dtype=wp.transformf, device=env.device) + reset_root_state_uniform._scratch_vel = wp.zeros((env.num_envs,), dtype=wp.spatial_vectorf, device=env.device) + # Pre-parse pose_range dict + p = [pose_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + reset_root_state_uniform._pos_lo = wp.vec3f(p[0][0], p[1][0], p[2][0]) + reset_root_state_uniform._pos_hi = wp.vec3f(p[0][1], p[1][1], p[2][1]) + reset_root_state_uniform._rot_lo = wp.vec3f(p[3][0], p[4][0], p[5][0]) + reset_root_state_uniform._rot_hi = wp.vec3f(p[3][1], p[4][1], p[5][1]) + # Pre-parse velocity_range dict + v = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + reset_root_state_uniform._vel_lin_lo = wp.vec3f(v[0][0], v[1][0], v[2][0]) + reset_root_state_uniform._vel_lin_hi = wp.vec3f(v[0][1], v[1][1], v[2][1]) + reset_root_state_uniform._vel_ang_lo = wp.vec3f(v[3][0], v[4][0], v[5][0]) + reset_root_state_uniform._vel_ang_hi = wp.vec3f(v[3][1], v[4][1], v[5][1]) + + wp.launch( + kernel=_reset_root_state_uniform_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.default_root_pose, + asset.data.default_root_vel, + env.env_origins_wp, + reset_root_state_uniform._scratch_pose, + reset_root_state_uniform._scratch_vel, + reset_root_state_uniform._pos_lo, + reset_root_state_uniform._pos_hi, + reset_root_state_uniform._rot_lo, + reset_root_state_uniform._rot_hi, + reset_root_state_uniform._vel_lin_lo, + reset_root_state_uniform._vel_lin_hi, + reset_root_state_uniform._vel_ang_lo, + reset_root_state_uniform._vel_ang_hi, + ], + device=env.device, + ) + + asset.write_root_pose_to_sim(reset_root_state_uniform._scratch_pose, env_mask=env_mask) + asset.write_root_velocity_to_sim(reset_root_state_uniform._scratch_vel, env_mask=env_mask) + + +# --------------------------------------------------------------------------- +# Reset joints by scale +# --------------------------------------------------------------------------- + + +@wp.kernel +def _reset_joints_by_scale_kernel( env_mask: wp.array(dtype=wp.bool), joint_ids: wp.array(dtype=wp.int32), rng_state: wp.array(dtype=wp.uint32), @@ -50,51 +417,44 @@ def _reset_joints_by_offset_kernel( if not env_mask[env_id]: return - # 1 thread per env so per-env RNG state updates are race-free. state = rng_state[env_id] for joint_i in range(joint_ids.shape[0]): joint_id = joint_ids[joint_i] - # offset samples in the provided ranges (Warp RNG state pattern) - pos_off = wp.randf(state, pos_lo, pos_hi) - vel_off = wp.randf(state, vel_lo, vel_hi) + # scale samples in the provided ranges + pos_scale = wp.randf(state, pos_lo, pos_hi) + vel_scale = wp.randf(state, vel_lo, vel_hi) - pos = default_joint_pos[env_id, joint_id] + pos_off - vel = default_joint_vel[env_id, joint_id] + vel_off + pos = default_joint_pos[env_id, joint_id] * pos_scale + vel = default_joint_vel[env_id, joint_id] * vel_scale - # clamp to soft limits lim = soft_joint_pos_limits[env_id, joint_id] pos = wp.clamp(pos, lim.x, lim.y) vmax = soft_joint_vel_limits[env_id, joint_id] vel = wp.clamp(vel, -vmax, vmax) - # write into sim-bound state buffers + # write into sim joint_pos[env_id, joint_id] = pos joint_vel[env_id, joint_id] = vel rng_state[env_id] = state -def reset_joints_by_offset( +def reset_joints_by_scale( env, env_mask: wp.array, position_range: tuple[float, float], velocity_range: tuple[float, float], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ): - """Warp-first reset of joint state by random offsets around defaults. - - This overrides the stable `isaaclab.envs.mdp.events.reset_joints_by_offset` when importing - via `isaaclab_experimental.envs.mdp`. - """ + """Warp-first reset of joint state by scaling defaults with random factors.""" asset: Articulation = env.scene[asset_cfg.name] - # Assume cfg params are already resolved by the manager stack (Warp-first workflow). assert asset_cfg.joint_ids_wp is not None assert env.rng_state_wp is not None wp.launch( - kernel=_reset_joints_by_offset_kernel, + kernel=_reset_joints_by_scale_kernel, dim=env.num_envs, inputs=[ env_mask, @@ -115,8 +475,13 @@ def reset_joints_by_offset( ) +# --------------------------------------------------------------------------- +# Reset joints by offset +# --------------------------------------------------------------------------- + + @wp.kernel -def _reset_joints_by_scale_kernel( +def _reset_joints_by_offset_kernel( env_mask: wp.array(dtype=wp.bool), joint_ids: wp.array(dtype=wp.int32), rng_state: wp.array(dtype=wp.uint32), @@ -135,44 +500,51 @@ def _reset_joints_by_scale_kernel( if not env_mask[env_id]: return + # 1 thread per env so per-env RNG state updates are race-free. state = rng_state[env_id] for joint_i in range(joint_ids.shape[0]): joint_id = joint_ids[joint_i] - # scale samples in the provided ranges - pos_scale = wp.randf(state, pos_lo, pos_hi) - vel_scale = wp.randf(state, vel_lo, vel_hi) + # offset samples in the provided ranges (Warp RNG state pattern) + pos_off = wp.randf(state, pos_lo, pos_hi) + vel_off = wp.randf(state, vel_lo, vel_hi) - pos = default_joint_pos[env_id, joint_id] * pos_scale - vel = default_joint_vel[env_id, joint_id] * vel_scale + pos = default_joint_pos[env_id, joint_id] + pos_off + vel = default_joint_vel[env_id, joint_id] + vel_off + # clamp to soft limits lim = soft_joint_pos_limits[env_id, joint_id] pos = wp.clamp(pos, lim.x, lim.y) vmax = soft_joint_vel_limits[env_id, joint_id] vel = wp.clamp(vel, -vmax, vmax) - # write into sim + # write into sim-bound state buffers joint_pos[env_id, joint_id] = pos joint_vel[env_id, joint_id] = vel rng_state[env_id] = state -def reset_joints_by_scale( +def reset_joints_by_offset( env, env_mask: wp.array, position_range: tuple[float, float], velocity_range: tuple[float, float], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ): - """Warp-first reset of joint state by scaling defaults with random factors.""" + """Warp-first reset of joint state by random offsets around defaults. + + This overrides the stable `isaaclab.envs.mdp.events.reset_joints_by_offset` when importing + via `isaaclab_experimental.envs.mdp`. + """ asset: Articulation = env.scene[asset_cfg.name] + # Assume cfg params are already resolved by the manager stack (Warp-first workflow). assert asset_cfg.joint_ids_wp is not None assert env.rng_state_wp is not None wp.launch( - kernel=_reset_joints_by_scale_kernel, + kernel=_reset_joints_by_offset_kernel, dim=env.num_envs, inputs=[ env_mask, @@ -191,3 +563,5 @@ def reset_joints_by_scale( ], device=env.device, ) + + diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py index 84acf612c44..49f46a70587 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py @@ -3,29 +3,35 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Warp-first observation terms (experimental, Cartpole-focused). +"""Warp-first observation terms (experimental). All functions in this file follow the Warp-compatible observation signature expected by the experimental Warp-first observation manager: - ``func(env, out, **params) -> None`` -where ``out`` is a pre-allocated Warp array with float32 dtype and shape ``(num_envs, term_dim)``. +where ``out`` is a pre-allocated Warp array with float32 dtype and shape ``(num_envs, D)``. +Output dimension ``D`` is inferred from decorator metadata: ``axes`` for root-state terms, +``out_dim`` for body/command/action/time terms, or ``joint_ids`` count for joint terms. """ from __future__ import annotations +import torch from typing import TYPE_CHECKING import warp as wp from isaaclab_experimental.envs.utils.io_descriptors import ( generic_io_descriptor_warp, + record_body_names, + record_dtype, record_joint_names, record_joint_pos_offsets, - record_joint_shape, record_joint_vel_offsets, + record_shape, ) from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp import warp_capturable from isaaclab.assets import Articulation @@ -33,6 +39,136 @@ from isaaclab.envs import ManagerBasedEnv +# --------------------------------------------------------------------------- +# Shared kernels +# --------------------------------------------------------------------------- + + +@wp.kernel +def _vec3_to_out3_kernel( + src: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id = wp.tid() + v = src[env_id] + out[env_id, 0] = v[0] + out[env_id, 1] = v[1] + out[env_id, 2] = v[2] + + +@wp.kernel +def _joint_gather_kernel( + src: wp.array(dtype=wp.float32, ndim=2), + joint_ids: wp.array(dtype=wp.int32), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, k = wp.tid() + j = joint_ids[k] + out[env_id, k] = src[env_id, j] + + +""" +Root state. +""" + + +@wp.kernel +def _base_pos_z_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id = wp.tid() + out[env_id, 0] = root_pos_w[env_id][2] + + +# Reviewed(jichuanh): good +@generic_io_descriptor_warp( + units="m", axes=["Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_pos_z(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root height in the simulation world frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_pos_z_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, out], + device=env.device, + ) + + +# Reviewed(jichuanh): good +@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) +@generic_io_descriptor_warp( + units="m/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_lin_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root linear velocity in the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_vec3_to_out3_kernel, + dim=env.num_envs, + inputs=[asset.data.root_lin_vel_b, out], + device=env.device, + ) + + +# Reviewed(jichuanh): good +@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) +@generic_io_descriptor_warp( + units="rad/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_ang_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root angular velocity in the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_vec3_to_out3_kernel, + dim=env.num_envs, + inputs=[asset.data.root_ang_vel_b, out], + device=env.device, + ) + + +# Reviewed(jichuanh): good +@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) +@generic_io_descriptor_warp( + units="m/s^2", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def projected_gravity(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Gravity projection on the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_vec3_to_out3_kernel, + dim=env.num_envs, + inputs=[asset.data.projected_gravity_b, out], + device=env.device, + ) + + +""" +Joint state. +""" + + +@generic_io_descriptor_warp( + observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape], units="rad" +) +def joint_pos(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint positions of the asset.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_pos, joint_ids_wp, out], + device=env.device, + ) + + @wp.kernel def _joint_pos_rel_gather_kernel( joint_pos: wp.array(dtype=wp.float32, ndim=2), @@ -45,9 +181,10 @@ def _joint_pos_rel_gather_kernel( out[env_id, k] = joint_pos[env_id, j] - default_joint_pos[env_id, j] +# Reviewed(jichuanh): good @generic_io_descriptor_warp( observation_type="JointState", - on_inspect=[record_joint_names, record_joint_shape, record_joint_pos_offsets], + on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_pos_offsets], units="rad", ) def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: @@ -69,6 +206,69 @@ def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn ) +# Reviewed(jichuanh): logic is different from stable version. Even upper and lower are flipped, stable +# logic should work, fix this. +@wp.kernel +def _joint_pos_limit_normalized_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), + joint_ids: wp.array(dtype=wp.int32), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, k = wp.tid() + j = joint_ids[k] + pos = joint_pos[env_id, j] + lim = soft_joint_pos_limits[env_id, j] + lower = lim.x + upper = lim.y + mid = (lower + upper) * 0.5 + half_range = (upper - lower) * 0.5 + if half_range > 0.0: + out[env_id, k] = (pos - mid) / half_range + else: + out[env_id, k] = 0.0 + + +@generic_io_descriptor_warp(observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape]) +def joint_pos_limit_normalized(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint positions of the asset normalized with the asset's joint limits.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_pos_limit_normalized_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_pos, asset.data.soft_joint_pos_limits, joint_ids_wp, out], + device=env.device, + ) + + +# Reviewed(jichuanh): good +@generic_io_descriptor_warp( + observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape], units="rad/s" +) +def joint_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint velocities of the asset.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_vel, joint_ids_wp, out], + device=env.device, + ) + + +# Reviewed(jichuanh): kernel impl seems duplicate, rel_gather kernel could be shared. @wp.kernel def _joint_vel_rel_gather_kernel( joint_vel: wp.array(dtype=wp.float32, ndim=2), @@ -83,7 +283,7 @@ def _joint_vel_rel_gather_kernel( @generic_io_descriptor_warp( observation_type="JointState", - on_inspect=[record_joint_names, record_joint_shape, record_joint_vel_offsets], + on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_vel_offsets], units="rad/s", ) def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: @@ -103,3 +303,49 @@ def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn inputs=[asset.data.joint_vel, asset.data.default_joint_vel, joint_ids_wp, out], device=env.device, ) + + +""" +Actions. +""" +# Reviewed(jichuanh): good + + +@generic_io_descriptor_warp(out_dim="action", dtype=torch.float32, observation_type="Action", on_inspect=[record_shape]) +def last_action(env: ManagerBasedEnv, out, action_name: str | None = None) -> None: + """The last input action to the environment.""" + # TODO(warp-migration): Cross-manager access (observation → action). Currently works + # because experimental ActionManager.action is already a warp array. No from_torch needed. + if action_name is not None: + raise NotImplementedError("Named action support is not yet implemented for Warp-first last_action observation.") + wp.copy(out, env.action_manager.action) + + +""" +Commands. +""" + + +# Reviewed(jichuanh): good +@generic_io_descriptor_warp( + out_dim="command", dtype=torch.float32, observation_type="Command", on_inspect=[record_shape] +) +def generated_commands(env: ManagerBasedEnv, out, command_name: str) -> None: + """The generated command from the command manager. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.observations.generated_commands`. + Uses ``wp.from_torch`` to create a zero-copy warp view of the command tensor on first call. + """ + # TODO(warp-migration): Cross-manager access (observation → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + fn = generated_commands + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + fn._cmd_wp = cmd + else: + fn._cmd_wp = wp.from_torch(cmd) + fn._cmd_name = command_name + wp.copy(out, fn._cmd_wp) + + diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py index ba0086eda71..cea77832c31 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py @@ -5,9 +5,6 @@ """Common functions that can be used to enable reward functions (experimental). -This module is intentionally minimal: it only contains reward terms that are currently -used by the experimental manager-based Cartpole task. - All functions in this file follow the Warp-compatible reward signature expected by `isaaclab_experimental.managers.RewardManager`: @@ -22,6 +19,7 @@ import warp as wp from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp import warp_capturable from isaaclab.assets import Articulation @@ -42,6 +40,8 @@ def _is_alive_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dtype=wp def is_alive(env: ManagerBasedRLEnv, out: wp.array(dtype=wp.float32)) -> None: """Reward for being alive. Writes into ``out`` (shape: (num_envs,)).""" + # TODO(warp-migration): Cross-manager access (reward → termination). Replace with direct + # warp property once all managers are guaranteed to be warp-native. terminated_wp = wp.from_torch(env.termination_manager.terminated, dtype=wp.bool) wp.launch(kernel=_is_alive_kernel, dim=env.num_envs, inputs=[terminated_wp, out], device=env.device) @@ -54,15 +54,109 @@ def _is_terminated_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dty def is_terminated(env: ManagerBasedRLEnv, out) -> None: """Penalize terminated episodes. Writes into ``out``.""" + # TODO(warp-migration): Cross-manager access (reward → termination). Replace with direct + # warp property once all managers are guaranteed to be warp-native. terminated_wp = wp.from_torch(env.termination_manager.terminated, dtype=wp.bool) wp.launch(kernel=_is_terminated_kernel, dim=env.num_envs, inputs=[terminated_wp, out], device=env.device) +""" +Root penalties. +""" + + +# Reviewed(jichuanh): opportunity to share kernel should be explored, e.g. a square_index kernel with +# pre-allocated warp-ids array could be used. +@wp.kernel +def _lin_vel_z_l2_kernel(root_lin_vel_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): + i = wp.tid() + vz = root_lin_vel_b[i][2] + out[i] = vz * vz + + +@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) +def lin_vel_z_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize z-axis base linear velocity using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_lin_vel_z_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.root_lin_vel_b, out], + device=env.device, + ) + + +# Reviewed(jichuanh): same as previous +@wp.kernel +def _ang_vel_xy_l2_kernel(root_ang_vel_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): + i = wp.tid() + v = root_ang_vel_b[i] + out[i] = v[0] * v[0] + v[1] * v[1] + + +@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) +def ang_vel_xy_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize xy-axis base angular velocity using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_ang_vel_xy_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.root_ang_vel_b, out], + device=env.device, + ) + + +# Reviewed(jichuanh): same as previous +@wp.kernel +def _flat_orientation_l2_kernel(projected_gravity_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): + i = wp.tid() + g = projected_gravity_b[i] + out[i] = g[0] * g[0] + g[1] * g[1] + + +@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) +def flat_orientation_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize non-flat base orientation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_flat_orientation_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.projected_gravity_b, out], + device=env.device, + ) + + """ Joint penalties. """ +# TODO(warp-migration): Revisit whether 2D kernel + wp.atomic_add is faster than 1D with inner loop +# for the following masked reduction kernels. Profile with typical joint counts (12-30). +@wp.kernel +def _sum_sq_masked_kernel( + x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) +): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + if joint_mask[j]: + s += x[i, j] * x[i, j] + out[i] = s + + +def joint_torques_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint torques applied on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.applied_torque, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. @wp.kernel def _sum_abs_masked_kernel( x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) @@ -84,3 +178,278 @@ def joint_vel_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg) -> None inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], device=env.device, ) + + +def joint_vel_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint velocities on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], + device=env.device, + ) + + +def joint_acc_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint accelerations on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_acc, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_abs_diff_masked_kernel( + a: wp.array(dtype=wp.float32, ndim=2), + b: wp.array(dtype=wp.float32, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(a.shape[1]): + if joint_mask[j]: + s += wp.abs(a[i, j] - b[i, j]) + out[i] = s + + +def joint_deviation_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint positions that deviate from the default one.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_abs_diff_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.default_joint_pos, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _joint_pos_limits_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(joint_pos.shape[1]): + if joint_mask[j]: + pos = joint_pos[i, j] + lim = soft_joint_pos_limits[i, j] + lower = lim.x + upper = lim.y + # penalty for exceeding lower limit + below = lower - pos + if below > 0.0: + s += below + # penalty for exceeding upper limit + above = pos - upper + if above > 0.0: + s += above + out[i] = s + + +def joint_pos_limits(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint positions if they cross the soft limits.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_joint_pos_limits_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.soft_joint_pos_limits, asset_cfg.joint_mask, out], + device=env.device, + ) + + +""" +Action penalties. +""" + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_sq_diff_2d_kernel( + a: wp.array(dtype=wp.float32, ndim=2), + b: wp.array(dtype=wp.float32, ndim=2), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(a.shape[1]): + d = a[i, j] - b[i, j] + s += d * d + out[i] = s + + +def action_rate_l2(env: ManagerBasedRLEnv, out) -> None: + """Penalize the rate of change of the actions using L2 squared kernel.""" + wp.launch( + kernel=_sum_sq_diff_2d_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, env.action_manager.prev_action, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_sq_2d_kernel(x: wp.array(dtype=wp.float32, ndim=2), out: wp.array(dtype=wp.float32)): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + s += x[i, j] * x[i, j] + out[i] = s + + +def action_l2(env: ManagerBasedRLEnv, out) -> None: + """Penalize the actions using L2 squared kernel.""" + wp.launch( + kernel=_sum_sq_2d_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, out], + device=env.device, + ) + + +""" +Contact sensor. +""" + + +# Reviewed(jichuanh): good +@wp.kernel +def _undesired_contacts_kernel( + forces: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + threshold: float, + out: wp.array(dtype=wp.float32), +): + """Count bodies where max-over-history contact force norm exceeds threshold.""" + i = wp.tid() + count = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + max_force = float(0.0) + for h in range(forces.shape[1]): + f = forces[i, h, b] + norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + if norm > max_force: + max_force = norm + if max_force > threshold: + count += 1.0 + out[i] = count + + +def undesired_contacts(env: ManagerBasedRLEnv, out, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Penalize undesired contacts as the number of violations above a threshold. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.undesired_contacts`. + """ + contact_sensor = env.scene.sensors[sensor_cfg.name] + wp.launch( + kernel=_undesired_contacts_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.net_forces_w_history, sensor_cfg.body_ids_wp, threshold, out], + device=env.device, + ) + + +""" +Velocity-tracking rewards. +""" + + +@wp.kernel +def _track_lin_vel_xy_exp_kernel( + root_lin_vel_b: wp.array(dtype=wp.vec3f), + command: wp.array(dtype=wp.float32, ndim=2), + std_sq_inv: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + v = root_lin_vel_b[i] + dx = command[i, 0] - v[0] + dy = command[i, 1] - v[1] + error = dx * dx + dy * dy + out[i] = wp.exp(-error * std_sq_inv) + + +# Reviewed(jichuanh): Review if there's any gap to make term provide warp type by default. +@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) +def track_lin_vel_xy_exp( + env: ManagerBasedRLEnv, + out, + std: float, + command_name: str, + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward tracking of linear velocity commands (xy axes) using exponential kernel. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.track_lin_vel_xy_exp`. + """ + asset: Articulation = env.scene[asset_cfg.name] + # cache the warp view of the command tensor on first call (zero-copy) + # TODO(warp-migration): Cross-manager access (reward → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + if not hasattr(track_lin_vel_xy_exp, "_cmd_wp") or track_lin_vel_xy_exp._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + track_lin_vel_xy_exp._cmd_wp = cmd + else: + track_lin_vel_xy_exp._cmd_wp = wp.from_torch(cmd) + track_lin_vel_xy_exp._cmd_name = command_name + wp.launch( + kernel=_track_lin_vel_xy_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_lin_vel_b, track_lin_vel_xy_exp._cmd_wp, 1.0 / (std * std), out], + device=env.device, + ) + + +@wp.kernel +def _track_ang_vel_z_exp_kernel( + root_ang_vel_b: wp.array(dtype=wp.vec3f), + command: wp.array(dtype=wp.float32, ndim=2), + cmd_col: int, + std_sq_inv: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + dz = command[i, cmd_col] - root_ang_vel_b[i][2] + out[i] = wp.exp(-dz * dz * std_sq_inv) + + +@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) +def track_ang_vel_z_exp( + env: ManagerBasedRLEnv, + out, + std: float, + command_name: str, + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward tracking of angular velocity commands (yaw) using exponential kernel. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.track_ang_vel_z_exp`. + """ + asset: Articulation = env.scene[asset_cfg.name] + # TODO(warp-migration): Cross-manager access (reward → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + if not hasattr(track_ang_vel_z_exp, "_cmd_wp") or track_ang_vel_z_exp._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + track_ang_vel_z_exp._cmd_wp = cmd + else: + track_ang_vel_z_exp._cmd_wp = wp.from_torch(cmd) + track_ang_vel_z_exp._cmd_name = command_name + wp.launch( + kernel=_track_ang_vel_z_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_ang_vel_b, track_ang_vel_z_exp._cmd_wp, 2, 1.0 / (std * std), out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py index 44500244128..721bbc26711 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py @@ -5,9 +5,6 @@ """Common functions that can be used to activate terminations (experimental). -This module is intentionally minimal: it only contains termination terms that are currently -used by the experimental manager-based Cartpole task. - All functions in this file follow the Warp-compatible termination signature expected by `isaaclab_experimental.managers.TerminationManager`: @@ -29,6 +26,12 @@ from isaaclab.envs import ManagerBasedRLEnv +""" +MDP terminations. +""" + + +# Reviewed(jichuanh) @wp.kernel def _time_out_kernel(episode_length: wp.array(dtype=wp.int64), max_episode_length: int, out: wp.array(dtype=wp.bool)): i = wp.tid() @@ -37,6 +40,8 @@ def _time_out_kernel(episode_length: wp.array(dtype=wp.int64), max_episode_lengt def time_out(env: ManagerBasedRLEnv, out) -> None: """Terminate the episode when episode length exceeds the maximum episode length.""" + # TODO(warp-migration): env.episode_length_buf is a torch.Tensor (torch.long). Replace + # once ManagerBasedRLEnv provides a native warp property. episode_length_wp = wp.from_torch(env.episode_length_buf, dtype=wp.int64) wp.launch( kernel=_time_out_kernel, @@ -46,6 +51,41 @@ def time_out(env: ManagerBasedRLEnv, out) -> None: ) +""" +Root terminations. +""" + + +# Reviewed(jichuanh): good. +@wp.kernel +def _root_height_below_min_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + minimum_height: float, + out: wp.array(dtype=wp.bool), +): + i = wp.tid() + out[i] = root_pos_w[i][2] < minimum_height + + +def root_height_below_minimum( + env: ManagerBasedRLEnv, out, minimum_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Terminate when the asset's root height is below the minimum height.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_root_height_below_min_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, minimum_height, out], + device=env.device, + ) + + +""" +Joint terminations. +""" + + +# Reviewed(jichuanh): good @wp.kernel def _joint_pos_out_of_manual_limit_kernel( joint_pos: wp.array(dtype=wp.float32, ndim=2), @@ -54,15 +94,12 @@ def _joint_pos_out_of_manual_limit_kernel( upper: float, out: wp.array(dtype=wp.bool), ): - i = wp.tid() - violated = bool(False) - for j in range(joint_pos.shape[1]): - if joint_mask[j]: - v = joint_pos[i, j] - if v < lower or v > upper: - violated = True - break - out[i] = violated + """2D kernel (num_envs, num_joints). ``out`` is pre-zeroed; only writes True.""" + i, j = wp.tid() + if joint_mask[j]: + v = joint_pos[i, j] + if v < lower or v > upper: + out[i] = True def joint_pos_out_of_manual_limit( @@ -74,7 +111,47 @@ def joint_pos_out_of_manual_limit( assert asset.data.joint_pos.shape[1] == asset_cfg.joint_mask.shape[0] wp.launch( kernel=_joint_pos_out_of_manual_limit_kernel, - dim=env.num_envs, + dim=(env.num_envs, asset.data.joint_pos.shape[1]), inputs=[asset.data.joint_pos, asset_cfg.joint_mask, bounds[0], bounds[1], out], device=env.device, ) + + +""" +Contact sensor. +""" + + +# Reviewed(jichuanh): good +@wp.kernel +def _illegal_contact_kernel( + forces: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + threshold: float, + out: wp.array(dtype=wp.bool), +): + """Terminate when any selected body's max-over-history contact force norm exceeds threshold.""" + i = wp.tid() + violated = bool(False) + for k in range(body_ids.shape[0]): + b = body_ids[k] + for h in range(forces.shape[1]): + f = forces[i, h, b] + norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + if norm > threshold: + violated = True + out[i] = violated + + +def illegal_contact(env: ManagerBasedRLEnv, out, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Terminate when the contact force on the sensor exceeds the force threshold. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.terminations.illegal_contact`. + """ + contact_sensor = env.scene.sensors[sensor_cfg.name] + wp.launch( + kernel=_illegal_contact_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.net_forces_w_history, sensor_cfg.body_ids_wp, threshold, out], + device=env.device, + ) From 953654430951b122fb8d3a9c4fb3f4e1583e49ef Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Sun, 22 Feb 2026 02:45:06 -0800 Subject: [PATCH 3/5] Add tested warp env configs, task-specific MDP terms, tests, and docs Env configs and task-local MDP terms for 14 training-parity verified envs: - Classic: Cartpole, Humanoid, Ant - Locomotion velocity (flat): Anymal-B/C/D, G1-v0/v1, H1, Cassie, Unitree A1/Go1/Go2 - Manipulation: Reach-Franka Per-robot config registrations (gym IDs) and flat env cfgs for all tested locomotion and reach variants. Task-specific MDP terms: - Humanoid: base_yaw_roll, base_up_proj, base_heading_proj, base_angle_to_target, progress_reward, upright_posture_bonus, move_to_target_bonus, power_consumption, joint_pos_limits_penalty_ratio - Velocity: feet_air_time, feet_air_time_positive_biped, feet_slide, track_lin_vel_xy_yaw_frame_exp, track_ang_vel_z_world_exp, stand_still_joint_deviation_l1, terrain_out_of_bounds, terrain_levels_vel - Reach: position_command_error, position_command_error_tanh, orientation_command_error Also includes: - Warp parity tests (3 test files) - WARP_MIGRATION_GAP_ANALYSIS.md (MDP term catalog and per-task usage) - MANAGER_TEST_COVERAGE.md (capturability analysis) - GRAPH_CAPTURE_MIGRATION.md (ArticulationData Tier 1/2/3 property analysis) --- .../envs/mdp/MANAGER_TEST_COVERAGE.md | 355 +++++ .../envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md | 922 +++++++++++ .../test/envs/mdp/test_action_warp_parity.py | 521 +++++++ .../test/envs/mdp/test_mdp_warp_parity.py | 1378 +++++++++++++++++ .../mdp/test_mdp_warp_parity_new_terms.py | 920 +++++++++++ .../articulation/GRAPH_CAPTURE_MIGRATION.md | 200 +++ .../manager_based/classic/__init__.py | 8 +- .../manager_based/classic/ant/__init__.py | 30 + .../manager_based/classic/ant/ant_env_cfg.py | 196 +++ .../classic/cartpole/__init__.py | 23 +- .../classic/humanoid/__init__.py | 30 + .../classic/humanoid/humanoid_env_cfg.py | 231 +++ .../classic/humanoid/mdp/__init__.py | 11 + .../classic/humanoid/mdp/observations.py | 173 +++ .../classic/humanoid/mdp/rewards.py | 309 ++++ .../manager_based/locomotion/__init__.py | 6 + .../locomotion/velocity/__init__.py | 6 + .../locomotion/velocity/config/__init__.py | 9 + .../locomotion/velocity/config/a1/__init__.py | 37 + .../velocity/config/a1/flat_env_cfg.py | 61 + .../velocity/config/a1/rough_env_cfg.py | 92 ++ .../velocity/config/anymal_b/__init__.py | 37 + .../velocity/config/anymal_b/flat_env_cfg.py | 61 + .../velocity/config/anymal_b/rough_env_cfg.py | 34 + .../velocity/config/anymal_c/__init__.py | 39 + .../velocity/config/anymal_c/flat_env_cfg.py | 53 + .../velocity/config/anymal_c/rough_env_cfg.py | 39 + .../velocity/config/anymal_d/__init__.py | 60 + .../velocity/config/anymal_d/flat_env_cfg.py | 47 + .../velocity/config/anymal_d/rough_env_cfg.py | 37 + .../velocity/config/cassie/__init__.py | 35 + .../velocity/config/cassie/flat_env_cfg.py | 46 + .../velocity/config/cassie/rough_env_cfg.py | 95 ++ .../locomotion/velocity/config/g1/__init__.py | 58 + .../velocity/config/g1/flat_env_cfg.py | 60 + .../velocity/config/g1/rough_env_cfg.py | 177 +++ .../velocity/config/g1_29_dofs/__init__.py | 35 + .../config/g1_29_dofs/flat_env_cfg.py | 60 + .../config/g1_29_dofs/rough_env_cfg.py | 130 ++ .../velocity/config/go1/__init__.py | 35 + .../velocity/config/go1/flat_env_cfg.py | 47 + .../velocity/config/go1/rough_env_cfg.py | 61 + .../velocity/config/go2/__init__.py | 35 + .../velocity/config/go2/flat_env_cfg.py | 47 + .../velocity/config/go2/rough_env_cfg.py | 60 + .../locomotion/velocity/config/h1/__init__.py | 57 + .../velocity/config/h1/flat_env_cfg.py | 47 + .../velocity/config/h1/rough_env_cfg.py | 130 ++ .../locomotion/velocity/mdp/__init__.py | 12 + .../locomotion/velocity/mdp/curriculums.py | 40 + .../locomotion/velocity/mdp/rewards.py | 309 ++++ .../locomotion/velocity/mdp/terminations.py | 67 + .../locomotion/velocity/velocity_env_cfg.py | 291 ++++ .../manager_based/manipulation/__init__.py | 6 + .../manipulation/reach/__init__.py | 6 + .../manipulation/reach/config/__init__.py | 4 + .../reach/config/franka/__init__.py | 41 + .../reach/config/franka/joint_pos_env_cfg.py | 76 + .../reach/config/ur_10/__init__.py | 36 + .../reach/config/ur_10/joint_pos_env_cfg.py | 76 + .../manipulation/reach/mdp/__init__.py | 10 + .../manipulation/reach/mdp/rewards.py | 168 ++ .../manipulation/reach/reach_env_cfg.py | 205 +++ 63 files changed, 8475 insertions(+), 12 deletions(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/MANAGER_TEST_COVERAGE.md create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py create mode 100644 source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/MANAGER_TEST_COVERAGE.md b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/MANAGER_TEST_COVERAGE.md new file mode 100644 index 00000000000..c45f59948c3 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/MANAGER_TEST_COVERAGE.md @@ -0,0 +1,355 @@ +# Manager Implementation Test Coverage + +> Which tasks from `run_single_manager_warp_sweep.sh` are needed to exercise every +> reachable code path in the warp manager implementations? + +--- + +## Minimal Test Set + +**4 tasks (env-ids `1,2,6,3`) cover every manager code path that existing tasks can reach.** + +```bash +# Sweep command example +./run_single_manager_warp_sweep.sh default=0 target=2 env-ids=1,2,6,3 +``` + +| env-id | Gym ID | Role | +|:------:|--------|------| +| 1 | `Isaac-Cartpole-Warp-v0` | Simplest baseline; `JointEffortAction`; `corruption=False`; no commands | +| 2 | `Isaac-Humanoid-Warp-v0` | Obs `scale`; per-joint action scale dict; class-based rewards | +| 6 | `Isaac-Velocity-Flat-Anymal-C-Warp-v0` | All 3 event modes; obs `noise`; sensor deps; velocity commands; terrain curriculum | +| 3 | `Isaac-Reach-Franka-Warp-v0` | Pose commands; `modify_reward_weight` curriculum | + +### Why every other task is redundant + +| Dropped | Reason | +|---------|--------| +| Ant (0) | Strict subset of Humanoid: fewer obs `scale` terms, no per-joint action dict, subset of class-based rewards | +| Velocity quadrupeds (5,7-8,12-14) | Identical manager structure to Anymal-C; differ only in hyperparameters (action scale, body names, joint names) | +| Velocity bipeds (9-11) | Add more reward/termination terms of the same types; no new manager code paths | +| Velocity rough (15) | Same manager structure as flat; terrain config is scene-level, not manager-level | +| Reach-UR10 (4) | Identical manager structure to Reach-Franka; differs only in robot asset and body names | + +--- + +## Manager Code Path Coverage Matrix + +### Observation Manager + +Source: `isaaclab_experimental/managers/observation_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| Basic term compute (`func(env, out)`) | `compute_group` L494 | Y | Y | Y | Y | YES | +| `scale` kernel (`_apply_scale`) | `compute_group` L517-523 | | Y | | | YES | +| `clip` kernel (`_apply_clip`) | `compute_group` L510-516 | | | | | **NO** | +| `noise` application | `compute_group` L503-508 | | | Y | Y | YES | +| `modifiers` pipeline | `compute_group` L498-501 | | | | | **NO** | +| `enable_corruption=False` (skip noise) | `compute_group` L496 | Y | Y | | | YES | +| `enable_corruption=True` | `compute_group` L496 | | | Y | Y | YES | +| `concatenate_terms=True` (contiguous buf) | `_prepare_terms` L637 | Y | Y | Y | Y | YES | +| `concatenate_terms=False` (separate bufs) | `_prepare_terms` L653 | | | | | **NO** | +| Dim inference: `axes` (root-state obs) | `_infer_term_dim_scalar` | | Y | Y | | YES | +| Dim inference: `out_dim` int (custom obs) | `_infer_term_dim_scalar` | | Y | | | YES | +| Dim inference: `"joint"` sentinel | `_infer_term_dim_scalar` | Y | Y | Y | Y | YES | +| Dim inference: `"command"` sentinel | `_infer_term_dim_scalar` | | | Y | Y | YES | +| Dim inference: `"action"` sentinel | `_infer_term_dim_scalar` | | Y | Y | Y | YES | +| Dim inference: `"body:N"` sentinel | `_infer_term_dim_scalar` | | | | | **NO** | +| Cross-manager obs (commands) | `generated_commands` | | | Y | Y | YES | +| Cross-manager obs (last_action) | `last_action` | | Y | Y | Y | YES | +| Class-based modifier `.reset()` | `reset` L402-404 | | | | | **NO** | + +### Action Manager + +Source: `isaaclab_experimental/managers/action_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| Single-term action processing | `process_action` L410-431 | Y | Y | Y | Y | YES | +| Multi-term offset splitting | `process_action` L425-427 | | | | | **NO** | +| `prev_action` copy kernel | `process_action` L421-422 | Y | Y | Y | Y | YES | +| `JointPositionAction` | `joint_actions.py` | | | Y | Y | YES | +| `JointEffortAction` | `joint_actions.py` | Y | Y | | | YES | +| `RelativeJointPositionAction` | `joint_actions.py` | | | | | **NO** | +| `JointVelocityAction` | `joint_actions.py` | | | | | **NO** | +| `BinaryJointPositionAction` | `binary_joint_actions.py` | | | | | **NO** | +| `BinaryJointVelocityAction` | `binary_joint_actions.py` | | | | | **NO** | +| `JointPositionToLimitsAction` | `joint_actions_to_limits.py` | | | | | **NO** | +| `EMAJointPositionToLimitsAction` | `joint_actions_to_limits.py` | | | | | **NO** | +| `NonHolonomicAction` | `non_holonomic_actions.py` | | | | | **NO** | +| Per-joint scale dict | `JointEffortAction.__init__` | | Y | | | YES | +| `use_default_offset=True` | `JointPositionAction.__init__` | | | Y | | YES | + +### Event Manager + +Source: `isaaclab_experimental/managers/event_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| `startup` mode dispatch | `_apply_startup` | | | Y | | YES | +| `reset` mode dispatch | `_apply_reset` L369-388 | Y | Y | Y | Y | YES | +| `interval` mode (per-env timer) | `_apply_interval` L353-367 | | | Y | | YES | +| `interval` mode (global timer) | `_apply_interval` L336-352 | | | | | **NO** | +| `_interval_step_per_env` kernel | L65-82 | | | Y | | YES | +| `_interval_step_global` kernel | L86-102 | | | | | **NO** | +| `_interval_reset_selected` (re-sample on env reset) | `reset` L262-278 | | | Y | | YES | +| `min_step_count_between_reset` logic | `_reset_compute_valid_mask` L128-158 | | | | | **NO** | +| Class-based event terms | `_prepare_terms` | | | | | **NO** | +| Function-based event terms | `_prepare_terms` | Y | Y | Y | Y | YES | + +### Reward Manager + +Source: `isaaclab_experimental/managers/reward_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| Function-based rewards | `compute` | Y | Y | Y | Y | YES | +| Class-based rewards (init/reset/call) | `compute`, `reset` | | Y | | | YES | +| `_reward_finalize` kernel (weighted sum) | `compute` | Y | Y | Y | Y | YES | +| `_reward_pre_compute_reset` (zero per step) | `compute` | Y | Y | Y | Y | YES | +| Episode sum tracking + reset logging | `reset` | Y | Y | Y | Y | YES | +| Sensor-dependent rewards (`wp.from_torch`) | via `undesired_contacts` etc. | | | Y | | YES | +| Command-dependent rewards (`wp.from_torch`) | via `track_lin_vel_xy_exp` etc. | | | Y | Y | YES | + +### Termination Manager + +Source: `isaaclab_experimental/managers/termination_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| `_termination_finalize` (reduce to dones) | `compute` | Y | Y | Y | Y | YES | +| `time_out=True` flag handling | `_term_is_time_out_wp` | Y | Y | Y | Y | YES | +| `time_out=False` (real termination) | `_termination_finalize` | Y | Y | Y | | YES | +| Sensor-based termination | via `illegal_contact` | | | Y | | YES | +| Reset mean logging | `_termination_reset_mean_all_2d` | Y | Y | Y | Y | YES | + +### Command Manager + +Source: `isaaclab_experimental/managers/command_manager.py` + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| No commands (empty manager) | `__init__` | Y | Y | | | YES | +| `UniformVelocityCommand` | `command_manager.py` | | | Y | | YES | +| `UniformPoseCommand` | `command_manager.py` | | | | Y | YES | +| Resample on reset | `_step_time_left_and_build_resample_mask` | | | Y | Y | YES | + +### Curriculum Manager + +Note: Curriculum manager is not warp-migrated (runs at reset, not per-step). Included for completeness. + +| Code Path | Location | Cartpole | Humanoid | Velocity | Reach | Covered? | +|-----------|----------|:--------:|:--------:|:--------:|:-----:|:--------:| +| No curriculum (empty manager) | | Y | Y | | | YES | +| Custom curriculum (`terrain_levels_vel`) | torch-based, reset-only | | | Y | | YES | +| `modify_reward_weight` (stable) | forwarded from stable | | | | Y | YES | + +--- + +## Uncovered Manager Code Paths + +These code paths exist in the manager implementations but **no existing migrated task exercises them**. + +### Observation Manager (5 gaps) + +| Gap | Code Location | What triggers it | +|-----|---------------|------------------| +| `clip` kernel | `_apply_clip` L71-74, launched at L510-516 | Any `ObsTermCfg` with `clip=(lo, hi)` | +| `modifiers` pipeline | L498-501 compute, L402-404 reset | Any `ObsTermCfg` with `modifiers=[...]` | +| `concatenate_terms=False` | `_prepare_terms` L653+ | An obs group with `concatenate_terms=False` | +| `scale` as tuple | `_prepare_terms` L691-705 | `ObsTermCfg(scale=(1.0, 2.0, ...))` with per-element values | +| `out_dim="body:N"` inference | `_infer_term_dim_scalar` | An obs using `body_pose_w` or `body_projected_gravity_b` | + +### Action Manager (8 gaps) + +| Gap | Code Location | What triggers it | +|-----|---------------|------------------| +| Multi-term offset splitting | `process_action` L425-427 | A task with 2+ non-None action terms | +| `RelativeJointPositionAction` | `joint_actions.py` | `RelativeJointPositionActionCfg` | +| `JointVelocityAction` | `joint_actions.py` | `JointVelocityActionCfg` | +| `BinaryJointPositionAction` | `binary_joint_actions.py` | `BinaryJointPositionActionCfg` | +| `BinaryJointVelocityAction` | `binary_joint_actions.py` | `BinaryJointVelocityActionCfg` | +| `JointPositionToLimitsAction` | `joint_actions_to_limits.py` | `JointPositionToLimitsActionCfg` | +| `EMAJointPositionToLimitsAction` | `joint_actions_to_limits.py` | `EMAJointPositionToLimitsActionCfg` | +| `NonHolonomicAction` | `non_holonomic_actions.py` | `NonHolonomicActionCfg` | + +### Event Manager (3 gaps) + +| Gap | Code Location | What triggers it | +|-----|---------------|------------------| +| `interval` with `is_global_time=True` | `_apply_interval` L336-352, `_interval_step_global` L86-102 | Any `EventTermCfg(mode="interval", is_global_time=True)` | +| `min_step_count_between_reset` | `_reset_compute_valid_mask` L128-158 | Any `EventTermCfg(mode="reset", min_step_count_between_reset=N)` where N > 0 | +| Class-based event terms | `_prepare_terms` class instantiation path | `randomize_rigid_body_material`, `randomize_rigid_body_mass`, `randomize_actuator_gains`, `randomize_joint_parameters` | + +--- + +## Coverage Summary + +| Manager | Total Paths | Covered | Uncovered | +|---------|:-----------:|:-------:|:---------:| +| Observation | 18 | 13 | **5** | +| Action | 14 | 6 | **8** | +| Event | 10 | 7 | **3** | +| Reward | 7 | 7 | 0 | +| Termination | 5 | 5 | 0 | +| Command | 4 | 4 | 0 | +| Curriculum | 3 | 3 | 0 | +| **Total** | **61** | **45 (74%)** | **16 (26%)** | + +The 16 uncovered paths break down as: +- **8 action term types** — no migrated task uses these action classes +- **5 obs post-processing features** — `clip`, `modifiers`, tuple `scale`, `concatenate_terms=False`, `body:N` dim inference +- **3 event features** — global interval, reset rate-limiting, class-based randomization events + +--- + +## Repo-Wide Search for Gap Coverage + +Searched the entire IsaacLab repo (stable tasks, direct-workflow tasks, test configs, +examples) for any usage of the 16 uncovered features. + +### Already covered by existing unit tests (not end-to-end) + +These features have **no task config usage** anywhere in the repo but are exercised by +dedicated unit tests. They do NOT need new task coverage — the unit tests validate the +warp implementation in isolation. + +| Gap | Unit Test | File | +|-----|-----------|------| +| `JointVelocityAction` | `TestJointActions` | `isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py:231` | +| `BinaryJointPositionAction` | `TestBinaryJointActions` | `test_action_warp_parity.py:300` | +| `BinaryJointVelocityAction` | `TestBinaryJointActions` | `test_action_warp_parity.py:300` | +| `JointPositionToLimitsAction` | `TestJointPositionToLimitsActions` | `test_action_warp_parity.py:374` | +| `EMAJointPositionToLimitsAction` | `TestJointPositionToLimitsActions` | `test_action_warp_parity.py:374` | +| `NonHolonomicAction` | `TestNonHolonomicAction` | `test_action_warp_parity.py:448` | +| `body_pose_w` (`out_dim="body:N"`) | `test_body_pose_w` | `test_mdp_warp_parity_new_terms.py:445,815` | +| `body_projected_gravity_b` | `test_body_projected_gravity_b` | `test_mdp_warp_parity_new_terms.py:435,803` | + +### Closable by uncommenting existing config (1 gap) + +The velocity env config has a complete, ready-to-use `randomize_rigid_body_material` +event term that is commented out. Uncommenting it would close the class-based event gap. + +``` +# source/isaaclab_tasks/.../locomotion/velocity/velocity_env_cfg.py L184-194 +# physics_material = EventTerm( +# func=mdp.randomize_rigid_body_material, +# mode="startup", +# params={ +# "asset_cfg": SceneEntityCfg("robot", body_names=".*"), +# "static_friction_range": (0.8, 0.8), +# "dynamic_friction_range": (0.6, 0.6), +# "restitution_range": (0.0, 0.0), +# "num_buckets": 64, +# }, +# ) +``` + +Note: `randomize_rigid_body_mass` is also commented out but marked as causing NaNs. + +| Gap | Stable Config Location | Status | +|-----|----------------------|--------| +| Class-based events (`randomize_rigid_body_material`) | `velocity_env_cfg.py:184-194` | Complete config, degenerate ranges (no actual randomization) — safe to uncomment | +| Class-based events (`randomize_rigid_body_mass`) | `velocity_env_cfg.py:196-205` | Known broken (NaN) — do NOT uncomment | + +### Used only in Dexsuite — complex, not practical to migrate for coverage alone (2 gaps) + +| Gap | Stable Task | Why impractical | +|-----|------------|-----------------| +| `clip` on ObsTerm | Dexsuite (`dexsuite_env_cfg.py:157,176`) | 15+ custom MDP terms, ADR curriculum, multi-obs-group — not migrated | +| `RelativeJointPositionAction` | Dexsuite Kuka-Allegro (`dexsuite_kuka_allegro_env_cfg.py:40`) | Same complexity | + +### True dead code — zero usage anywhere in the repo (5 gaps) + +These manager code paths are implemented but have **no usage in any task config, test, +or example** across the entire codebase. They are forward-looking infrastructure. + +| Gap | Description | +|-----|-------------| +| `modifiers` pipeline | `ObsTermCfg.modifiers` — defined but never set in any config | +| `scale` as tuple | Per-element varying scale — only float `scale=` is ever used | +| `concatenate_terms=False` | Every obs group in every config sets `True` | +| `is_global_time=True` | Tested in stable unit test only (`test_event_manager.py:276`); no task config | +| `min_step_count_between_reset` | Tested in stable unit test only (`test_event_manager.py:331`); no task config | +| Multi-term action splitting | Reach declares `gripper_action` slot but never populates it | + +--- + +## Revised Gap Classification + +| Category | Count | Gaps | +|----------|:-----:|------| +| Covered by unit tests (no task needed) | 8 | 6 action types + 2 body obs (`body:N` inference) | +| Closable by uncommenting stable config | 1 | Class-based events (`randomize_rigid_body_material`) | +| Blocked behind complex task migration | 2 | `clip`, `RelativeJointPositionAction` (Dexsuite) | +| True dead code (no usage anywhere) | 5 | `modifiers`, tuple `scale`, `concatenate_terms=False`, `is_global_time`, `min_step_count_between_reset` | +| **Total** | **16** | | + +### Effective end-to-end gap after accounting for unit tests: 8 + +Of those 8: +- **1 is actionable now** (uncomment DR event in velocity config) +- **2 require Dexsuite migration** (large effort, low priority) +- **5 have zero usage anywhere** (cannot be tested without writing new configs) + +--- + +## Open: Per-MDP Capturability Tracking + +### Problem + +Manager mode=2 (WARP_CAPTURED) assumes all MDP terms are CUDA-graph-capturable. +Some terms call non-capturable external APIs (e.g., `write_root_pose_to_sim`, +`set_external_force_and_torque`). If any term is non-capturable, the manager +should fall back to mode=1. + +Not capturability issues: +- `wp.from_torch` — stable pointers, fine in graphs +- `wp.zeros` in `hasattr` guards — solvable via warmup in `_wp_capture_or_launch` + +### Proposed: `@warp_capturable` decorator + +By default all MDP terms are assumed capturable (True). Only non-capturable +terms need `@warp_capturable(False)`. The decorator sets an attribute directly +on the function (no wrapper), so it composes safely with any other decorator +in any order. + +```python +def warp_capturable(capturable: bool): + """Annotate an MDP term's CUDA-graph capturability. Default assumption: True.""" + def decorator(func): + func._warp_capturable = capturable + return func # no wrapper + return decorator + +def is_warp_capturable(func) -> bool: + """Check capturability. Default: True. Checks __wrapped__ for decorated fns.""" + for f in (func, getattr(func, '__wrapped__', None)): + if f is not None: + val = getattr(f, '_warp_capturable', None) + if val is not None: + return val + return True +``` + +Usage: +```python +@warp_capturable(False) +def apply_external_force_torque(env, env_mask, ...): + ... +``` + +Manager integration: during `_prepare_terms`, check all terms. If any returns +`is_warp_capturable(func) == False`, fall back to mode=1 with a warning. + +### Terms requiring `@warp_capturable(False)` + +| Term | Non-capturable dependency | +|------|--------------------------| +| `apply_external_force_torque` | `wrench_composer.set_forces_and_torques()` | +| `reset_root_state_uniform` | `write_root_pose_to_sim()` / `write_root_velocity_to_sim()` | +| `reset_root_state_with_random_orientation` | Same | +| `reset_root_state_from_terrain` | Same | +| `reset_scene_to_default` | Same | +| `push_by_setting_velocity` | `write_root_velocity_to_sim()` | diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md new file mode 100644 index 00000000000..07e6e446b7c --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md @@ -0,0 +1,922 @@ +# MDP Warp-First Migration: Gap Analysis + +> Updated 2026-02-18. Tracks the stable -> experimental (warp-first) conversion status of every +> public MDP term and every manager-based task. + +--- + +## Table of Contents + +1. [Testing Requirements](#testing-requirements) +2. [Shared MDP Term Catalog](#shared-mdp-term-catalog) +3. [Per-Task Gym ID Migration Table](#per-task-gym-id-migration-table) +4. [Per-Task MDP Term Usage](#per-task-mdp-term-usage) +5. [Custom (Task-Local) MDP Terms](#custom-task-local-mdp-terms) +6. [Shared Terms Not Used by Any Migrated Task](#shared-terms-not-used-by-any-migrated-task) +7. [Cross-Cutting Notes](#cross-cutting-notes) +8. [Key Warp Conversion Patterns](#key-warp-conversion-patterns) + +--- + +## Testing Requirements + +Every migrated MDP term **must** pass the following checks before it is considered complete: + +### a) Parity Check: `stable == warp == warp-captured` + +For each converted term, verify numerical equivalence across three execution modes: + +1. **Stable** -- original torch-based term from `isaaclab.envs.mdp` +2. **Warp** -- experimental warp-first term from `isaaclab_experimental.envs.mdp` (eager launch) +3. **Warp Captured** -- same warp term executed inside a `wp.ScopedCapture` / `wp.capture_launch` graph + +All three modes must produce identical results (within floating-point tolerance) for the same inputs. + +### b) Dynamic Dependency Update Check + +For the **warp-captured** execution path, verify that the captured graph produces correct results +even after upstream data changes between replay invocations. Specifically: + +- Sim state buffers (joint_pos, root_vel, etc.) update between steps -- the captured graph must + read the latest values from persistent pointers, not stale data baked into the graph. +- When a dependency (e.g., action buffer, sensor data, command output) is updated externally, + the next `wp.capture_launch` must reflect the change. +- Resetting a subset of environments (via `env_mask`) must not corrupt state of non-reset environments. + +--- + +## Shared MDP Term Catalog + +Legend: **S** = Shared library (`isaaclab.envs.mdp`), **W** = Warp override exists in `isaaclab_experimental.envs.mdp` + +### Observations (22 stable terms) + +| # | Function/Class | Warp | Notes | +|---|---|---|---| +| 1 | `base_pos_z` | YES | Pure warp kernel | +| 2 | `base_lin_vel` | YES | Pure warp kernel | +| 3 | `base_ang_vel` | YES | Pure warp kernel | +| 4 | `projected_gravity` | YES | Pure warp kernel | +| 5 | `root_pos_w` | YES | Pure warp kernel | +| 6 | `root_quat_w` | YES | Pure warp kernel | +| 7 | `root_lin_vel_w` | YES | Pure warp kernel | +| 8 | `root_ang_vel_w` | YES | Pure warp kernel | +| 9 | `body_pose_w` | YES | Pure warp kernel | +| 10 | `body_projected_gravity_b` | YES | Pure warp kernel | +| 11 | `joint_pos` | YES | Pure warp kernel | +| 12 | `joint_pos_rel` | YES | Pure warp kernel with joint_mask | +| 13 | `joint_pos_limit_normalized` | YES | Pure warp kernel | +| 14 | `joint_vel` | YES | Pure warp kernel | +| 15 | `joint_vel_rel` | YES | Pure warp kernel with joint_mask | +| 16 | `joint_effort` | YES | Pure warp kernel | +| 17 | `last_action` | YES | Pure warp kernel | +| 18 | `generated_commands` | YES | `wp.from_torch` bridge (zero-copy) | +| 19 | `current_time_s` | YES | Pure warp kernel | +| 20 | `remaining_time_s` | YES | Pure warp kernel | +| 21 | `image` | **NO** | 4D tensor, per-type normalization. Deferred. | +| 22 | `image_features` | **NO** | PyTorch NN inference (ResNet/Theia). Not convertible. | + +**Coverage: 20/22 (91%)** + +### Rewards (22 stable terms) + +| # | Function/Class | Warp | Notes | +|---|---|---|---| +| 1 | `is_alive` | YES | Pure warp kernel | +| 2 | `is_terminated` | YES | Pure warp kernel | +| 3 | `is_terminated_term` | YES | Class-based, reads `_term_dones_wp` + `time_outs_wp` | +| 4 | `lin_vel_z_l2` | YES | Pure warp kernel | +| 5 | `ang_vel_xy_l2` | YES | Pure warp kernel | +| 6 | `flat_orientation_l2` | YES | Pure warp kernel | +| 7 | `base_height_l2` | YES | Pure warp kernel | +| 8 | `body_lin_acc_l2` | YES | Pure warp kernel | +| 9 | `joint_torques_l2` | YES | Pure warp kernel | +| 10 | `joint_vel_l1` | YES | Pure warp kernel | +| 11 | `joint_vel_l2` | YES | Pure warp kernel | +| 12 | `joint_acc_l2` | YES | Pure warp kernel | +| 13 | `joint_deviation_l1` | YES | Pure warp kernel | +| 14 | `joint_pos_limits` | YES | Pure warp kernel | +| 15 | `joint_vel_limits` | YES | Pure warp kernel | +| 16 | `applied_torque_limits` | YES | Pure warp kernel | +| 17 | `action_rate_l2` | YES | Pure warp kernel | +| 18 | `action_l2` | YES | Pure warp kernel | +| 19 | `undesired_contacts` | YES | `wp.from_torch` bridge for sensor data | +| 20 | `desired_contacts` | YES | `wp.from_torch` bridge for sensor data | +| 21 | `contact_forces` | YES | `wp.from_torch` bridge for sensor data | +| 22 | `track_lin_vel_xy_exp` | YES | `wp.from_torch` bridge for commands | +| 23 | `track_ang_vel_z_exp` | YES | `wp.from_torch` bridge for commands | + +**Coverage: 22/22 (100%)** (note: `track_*` counted as shared rewards) + +### Terminations (10 stable terms) + +| # | Function/Class | Warp | Notes | +|---|---|---|---| +| 1 | `time_out` | YES | Pure warp kernel | +| 2 | `command_resample` | YES | Pure warp kernel | +| 3 | `bad_orientation` | YES | Pure warp kernel | +| 4 | `root_height_below_minimum` | YES | Pure warp kernel | +| 5 | `joint_pos_out_of_limit` | YES | Pure warp kernel | +| 6 | `joint_pos_out_of_manual_limit` | YES | Pure warp kernel | +| 7 | `joint_vel_out_of_limit` | YES | Pure warp kernel | +| 8 | `joint_vel_out_of_manual_limit` | YES | Pure warp kernel | +| 9 | `joint_effort_out_of_limit` | YES | Pure warp kernel | +| 10 | `illegal_contact` | YES | `wp.from_torch` bridge for sensor data | + +**Coverage: 10/10 (100%)** + +### Events (20 stable terms) + +| # | Function/Class | Warp | Notes | +|---|---|---|---| +| 1 | `randomize_rigid_body_material` | YES | Class-based, warp kernel for mu sampling | +| 2 | `randomize_rigid_body_mass` | YES | Class-based, `_scale_inertia_kernel` | +| 3 | `randomize_rigid_body_com` | YES | Warp kernel | +| 4 | `randomize_actuator_gains` | YES | Class-based, writes directly to warp arrays | +| 5 | `randomize_joint_parameters` | YES | Class-based, warp kernels with clamp | +| 6 | `apply_external_force_torque` | YES | Warp kernel | +| 7 | `push_by_setting_velocity` | YES | Warp kernel | +| 8 | `reset_root_state_uniform` | YES | Warp kernel | +| 9 | `reset_root_state_with_random_orientation` | YES | Warp kernel | +| 10 | `reset_root_state_from_terrain` | YES | Warp kernel | +| 11 | `reset_joints_by_scale` | YES | Warp kernel | +| 12 | `reset_joints_by_offset` | YES | Warp kernel | +| 13 | `reset_scene_to_default` | YES | Warp kernel | +| 14 | `randomize_rigid_body_collider_offsets` | **NO** | Stub (`NotImplementedError`) in stable | +| 15 | `randomize_physics_scene_gravity` | **NO** | Class-based, per-env gravity. Low priority. | +| 16 | `randomize_fixed_tendon_parameters` | **NO** | Stub (`NotImplementedError`) in stable | +| 17 | `reset_nodal_state_uniform` | **NO** | Stub (`NotImplementedError`) in stable | +| 18 | `randomize_rigid_body_scale` | **NO** | USD `pxr` API, pre-sim only. Not convertible. | +| 19 | `randomize_visual_texture_material` | **NO** | Omni Replicator API. Not convertible. | +| 20 | `randomize_visual_color` | **NO** | Omni Replicator API. Not convertible. | + +**Coverage: 13/20 (65%)** -- remaining are stubs, USD/Replicator APIs, or low-priority + +### Actions (10 stable classes) + +| # | Class | Warp | Notes | +|---|---|---|---| +| 1 | `JointPositionActionCfg` | YES | Warp-first process_actions/apply_actions | +| 2 | `RelativeJointPositionActionCfg` | YES | | +| 3 | `JointVelocityActionCfg` | YES | | +| 4 | `JointEffortActionCfg` | YES | | +| 5 | `BinaryJointPositionActionCfg` | YES | | +| 6 | `BinaryJointVelocityActionCfg` | YES | | +| 7 | `JointPositionToLimitsActionCfg` | YES | | +| 8 | `EMAJointPositionToLimitsActionCfg` | YES | | +| 9 | `NonHolonomicActionCfg` | YES | | +| 10 | (IK-based actions) | N/A | Not used by current tasks | + +**Coverage: 10/10 (100%)** + +### Commands (6 stable classes) + +| # | Class | Warp | Notes | +|---|---|---|---| +| 1 | `NullCommand` / `NullCommandCfg` | NO | Bridged via `wp.from_torch` (zero-copy) | +| 2 | `UniformVelocityCommand` / Cfg | NO | Bridged via `wp.from_torch` | +| 3 | `NormalVelocityCommand` / Cfg | NO | Bridged via `wp.from_torch` | +| 4 | `UniformPoseCommand` / Cfg | NO | Bridged via `wp.from_torch` | +| 5 | `UniformPose2dCommand` / Cfg | NO | Bridged via `wp.from_torch` | +| 6 | `TerrainBasedPose2dCommand` / Cfg | NO | Bridged via `wp.from_torch` | + +**Coverage: 0/6 (0%)** -- **NOT a blocker** (see [Cross-Cutting Notes](#cross-cutting-notes)) + +### Curriculums (3 stable classes) + +| # | Class | Warp | Notes | +|---|---|---|---| +| 1 | `modify_reward_weight` | NO | Runs at reset, not per-step. Low priority. | +| 2 | `modify_env_param` | NO | Runs at reset, not per-step. Low priority. | +| 3 | `modify_term_cfg` | NO | Inherits from `modify_env_param`. Low priority. | + +**Coverage: 0/3 (0%)** -- Low priority (not in hot loop) + +### Overall Shared Library Coverage + +| Category | Stable | Warp | Coverage | +|---|---|---|---| +| **Actions** | 10 | 10 | **100%** | +| **Observations** | 22 | 20 | **91%** | +| **Rewards** | 22 | 22 | **100%** | +| **Terminations** | 10 | 10 | **100%** | +| **Events** | 20 | 13 | **65%** | +| **Commands** | 6 | 0 | **0%** (bridged) | +| **Curriculums** | 3 | 0 | **0%** (low priority) | +| **Total** | **93** | **75** | **~81%** | + +--- + +## Per-Task Gym ID Migration Table + +All stable manager-based tasks and their experimental `-Warp` counterpart status. + +### Classic Tasks + +| Stable Gym ID | Exp Gym ID | Status | Custom Terms | Blockers | +|---|---|---|---|---| +| `Isaac-Cartpole-v0` | `Isaac-Cartpole-Warp-v0` | **MIGRATED** | 1 reward | None | +| `Isaac-Humanoid-v0` | `Isaac-Humanoid-Warp-v0` | **MIGRATED** | 4 obs, 5 rewards | None | +| `Isaac-Ant-v0` | `Isaac-Ant-Warp-v0` | **MIGRATED** | Reuses humanoid mdp | None | + +### Locomotion Velocity -- Standard Robots (use base velocity MDP) + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-Velocity-Flat-Unitree-A1-v0` | `Isaac-Velocity-Flat-Unitree-A1-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Unitree-A1-Play-v0` | `Isaac-Velocity-Flat-Unitree-A1-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Unitree-Go1-v0` | `Isaac-Velocity-Flat-Unitree-Go1-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Unitree-Go1-Play-v0` | `Isaac-Velocity-Flat-Unitree-Go1-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Unitree-Go2-v0` | `Isaac-Velocity-Flat-Unitree-Go2-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Unitree-Go2-Play-v0` | `Isaac-Velocity-Flat-Unitree-Go2-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-B-v0` | `Isaac-Velocity-Flat-Anymal-B-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-B-Play-v0` | `Isaac-Velocity-Flat-Anymal-B-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-C-v0` | `Isaac-Velocity-Flat-Anymal-C-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-C-Play-v0` | `Isaac-Velocity-Flat-Anymal-C-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-D-v0` | `Isaac-Velocity-Flat-Anymal-D-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Anymal-D-Play-v0` | `Isaac-Velocity-Flat-Anymal-D-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Rough-Anymal-D-v0` | `Isaac-Velocity-Rough-Anymal-D-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Rough-Anymal-D-Play-v0` | `Isaac-Velocity-Rough-Anymal-D-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Cassie-v0` | `Isaac-Velocity-Flat-Cassie-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-Cassie-Play-v0` | `Isaac-Velocity-Flat-Cassie-Warp-Play-v0` | **MIGRATED** | None | + +### Locomotion Velocity -- Biped Robots (use base + biped rewards) + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-Velocity-Flat-G1-v0` | `Isaac-Velocity-Flat-G1-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-G1-Play-v0` | `Isaac-Velocity-Flat-G1-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-G1-v1` | `Isaac-Velocity-Flat-G1-Warp-v1` | **MIGRATED** | None (29-DOF) | +| `Isaac-Velocity-Flat-G1-Play-v1` | `Isaac-Velocity-Flat-G1-Warp-Play-v1` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-H1-v0` | `Isaac-Velocity-Flat-H1-Warp-v0` | **MIGRATED** | None | +| `Isaac-Velocity-Flat-H1-Play-v0` | `Isaac-Velocity-Flat-H1-Warp-Play-v0` | **MIGRATED** | None | + +### Locomotion Velocity -- Rough Terrain (biped) + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-Velocity-Rough-G1-v0` | `Isaac-Velocity-Rough-G1-Warp-v0` | **MIGRATED** | None (stable has this commented out) | +| `Isaac-Velocity-Rough-G1-Play-v0` | `Isaac-Velocity-Rough-G1-Warp-Play-v0` | **MIGRATED** | None (stable has this commented out) | +| `Isaac-Velocity-Rough-H1-v0` | `Isaac-Velocity-Rough-H1-Warp-v0` | **MIGRATED** | None (stable has this commented out) | +| `Isaac-Velocity-Rough-H1-Play-v0` | `Isaac-Velocity-Rough-H1-Warp-Play-v0` | **MIGRATED** | None (stable has this commented out) | + +### Locomotion Velocity -- Spot (custom MDP) + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-Velocity-Flat-Spot-v0` | -- | **NOT MIGRATED** | 14 custom reward fns + 1 event fn + `GaitReward` class need conversion | +| `Isaac-Velocity-Flat-Spot-Play-v0` | -- | **NOT MIGRATED** | Same as above | + +### Locomotion Velocity -- Distillation Variants + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Velocity-G1-Distillation-v1` | -- | **NOT MIGRATED** | Teacher-student pipeline not in scope | +| `Velocity-G1-Student-Finetune-v1` | -- | **NOT MIGRATED** | Teacher-student pipeline not in scope | + +### Manipulation -- Reach + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-Reach-Franka-v0` | `Isaac-Reach-Franka-Warp-v0` | **MIGRATED** | None | +| `Isaac-Reach-Franka-Play-v0` | `Isaac-Reach-Franka-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Reach-UR10-v0` | `Isaac-Reach-UR10-Warp-v0` | **MIGRATED** | None | +| `Isaac-Reach-UR10-Play-v0` | `Isaac-Reach-UR10-Warp-Play-v0` | **MIGRATED** | None | +| `Isaac-Reach-Franka-IK-Abs-v0` | -- | **NOT MIGRATED** | IK action not in scope | +| `Isaac-Reach-Franka-IK-Rel-v0` | -- | **NOT MIGRATED** | IK action not in scope | +| `Isaac-Reach-UR10-IK-Abs-v0` | -- | **NOT MIGRATED** | IK action not in scope | +| `Isaac-Reach-UR10-IK-Rel-v0` | -- | **NOT MIGRATED** | IK action not in scope | + +### Manipulation -- Dexsuite + +| Stable Gym ID | Exp Gym ID | Status | Blockers | +|---|---|---|---| +| `Isaac-DexsuiteKukaAllegroReorient-v0` | -- | **NOT MIGRATED** | Custom command class, 7 obs fns, 5 reward fns, 3 term fns, ADR curriculum | +| `Isaac-DexsuiteKukaAllegroReorient-Play-v0` | -- | **NOT MIGRATED** | Same as above | +| `Isaac-DexsuiteKukaAllegroLift-v0` | -- | **NOT MIGRATED** | Same + lift-specific overrides | +| `Isaac-DexsuiteKukaAllegroLift-Play-v0` | -- | **NOT MIGRATED** | Same as above | +| `Isaac-DexsuiteKukaAllegroReorientVision-v0` | -- | **NOT MIGRATED** | Same + `image`/`vision_camera` obs (not convertible) | +| `Isaac-DexsuiteKukaAllegroLiftVision-v0` | -- | **NOT MIGRATED** | Same as above | + +### Migration Summary + +| Category | Total Stable | Migrated | Not Migrated | % | +|---|---|---|---|---| +| Classic | 3 | 3 | 0 | **100%** | +| Velocity (flat, quadruped) | 16 | 16 | 0 | **100%** | +| Velocity (flat, biped) | 6 | 6 | 0 | **100%** | +| Velocity (rough) | 6 | 6 | 0 | **100%** | +| Velocity (Spot) | 2 | 0 | 2 | **0%** | +| Velocity (distillation) | 2 | 0 | 2 | **0%** | +| Reach (joint-space) | 4 | 4 | 0 | **100%** | +| Reach (IK) | 4 | 0 | 4 | **0%** | +| Dexsuite | 6 | 0 | 6 | **0%** | +| **Total** | **49** | **35** | **14** | **71%** | + +--- + +## Per-Task MDP Term Usage + +Shows which shared and custom MDP terms each task group uses. Terms from `isaaclab.envs.mdp` (shared) +are marked **S**. Task-local custom terms are marked **C**. + +### Cartpole + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Actions | `JointEffortActionCfg` | S | YES | +| Obs | `joint_pos_rel` | S | YES | +| Obs | `joint_vel_rel` | S | YES | +| Rewards | `is_alive` | S | YES | +| Rewards | `is_terminated` | S | YES | +| Rewards | `joint_pos_target_l2` | **C** | YES | +| Rewards | `joint_vel_l1` | S | YES | +| Terms | `time_out` | S | YES | +| Terms | `joint_pos_out_of_manual_limit` | S | YES | +| Events | `reset_joints_by_offset` | S | YES | + +### Humanoid / Ant + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Actions | `JointEffortActionCfg` | S | YES | +| Obs | `base_pos_z` | S | YES | +| Obs | `base_lin_vel` | S | YES | +| Obs | `base_ang_vel` | S | YES | +| Obs | `base_yaw_roll` | **C** | YES | +| Obs | `base_angle_to_target` | **C** | YES | +| Obs | `base_up_proj` | **C** | YES | +| Obs | `base_heading_proj` | **C** | YES | +| Obs | `joint_pos_limit_normalized` | S | YES | +| Obs | `joint_vel_rel` | S | YES | +| Obs | `last_action` | S | YES | +| Rewards | `progress_reward` | **C** (class) | YES | +| Rewards | `is_alive` | S | YES | +| Rewards | `upright_posture_bonus` | **C** | YES | +| Rewards | `move_to_target_bonus` | **C** | YES | +| Rewards | `action_l2` | S | YES | +| Rewards | `power_consumption` | **C** (class) | YES | +| Rewards | `joint_pos_limits_penalty_ratio` | **C** (class) | YES | +| Terms | `time_out` | S | YES | +| Terms | `root_height_below_minimum` | S | YES | +| Events | `reset_root_state_uniform` | S | YES | +| Events | `reset_joints_by_offset` | S | YES | + +### Velocity Locomotion (base config -- all non-Spot robots) + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Commands | `UniformVelocityCommandCfg` | S | bridged | +| Actions | `JointPositionActionCfg` | S | YES | +| Obs | `base_lin_vel` | S | YES | +| Obs | `base_ang_vel` | S | YES | +| Obs | `projected_gravity` | S | YES | +| Obs | `generated_commands` | S | YES | +| Obs | `joint_pos_rel` | S | YES | +| Obs | `joint_vel_rel` | S | YES | +| Obs | `last_action` | S | YES | +| Rewards | `track_lin_vel_xy_exp` | S | YES | +| Rewards | `track_ang_vel_z_exp` | S | YES | +| Rewards | `lin_vel_z_l2` | S | YES | +| Rewards | `ang_vel_xy_l2` | S | YES | +| Rewards | `joint_torques_l2` | S | YES | +| Rewards | `joint_acc_l2` | S | YES | +| Rewards | `action_rate_l2` | S | YES | +| Rewards | `feet_air_time` | **C** | YES | +| Rewards | `undesired_contacts` | S | YES | +| Rewards | `flat_orientation_l2` | S | YES | +| Rewards | `joint_pos_limits` | S | YES | +| Terms | `time_out` | S | YES | +| Terms | `illegal_contact` | S | YES | +| Terms | `terrain_out_of_bounds` | **C** | YES | +| Events | `randomize_rigid_body_com` | S | YES | +| Events | `apply_external_force_torque` | S | YES | +| Events | `reset_root_state_uniform` | S | YES | +| Events | `reset_joints_by_scale` | S | YES | +| Events | `push_by_setting_velocity` | S | YES | +| Curriculum | `terrain_levels_vel` | **C** | YES (torch, reset-only) | + +Biped robots (G1, G1-29, H1) additionally use: + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Rewards | `feet_air_time_positive_biped` | **C** | YES | +| Rewards | `feet_slide` | **C** | YES | +| Rewards | `track_lin_vel_xy_yaw_frame_exp` | **C** | YES | +| Rewards | `track_ang_vel_z_world_exp` | **C** | YES | +| Rewards | `joint_deviation_l1` | S | YES | +| Rewards | `is_terminated` | S | YES | + +### Velocity Locomotion -- Spot (custom MDP, NOT migrated) + +Additional terms beyond the base velocity config: + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Events | `reset_joints_around_default` | **C** | **NO** | +| Rewards | `air_time_reward` | **C** | **NO** | +| Rewards | `base_angular_velocity_reward` | **C** | **NO** | +| Rewards | `base_linear_velocity_reward` | **C** | **NO** | +| Rewards | `GaitReward` | **C** (class) | **NO** | +| Rewards | `foot_clearance_reward` | **C** | **NO** | +| Rewards | `action_smoothness_penalty` | **C** | **NO** | +| Rewards | `air_time_variance_penalty` | **C** | **NO** | +| Rewards | `base_motion_penalty` | **C** | **NO** | +| Rewards | `base_orientation_penalty` | **C** | **NO** | +| Rewards | `foot_slip_penalty` | **C** | **NO** | +| Rewards | `joint_acceleration_penalty` | **C** | **NO** | +| Rewards | `joint_position_penalty` | **C** | **NO** | +| Rewards | `joint_torques_penalty` | **C** | **NO** | +| Rewards | `joint_velocity_penalty` | **C** | **NO** | + +**Total Spot custom terms to convert: 15** (1 event + 14 rewards including 1 class) + +### Reach (Franka, UR10) + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Commands | `UniformPoseCommandCfg` | S | bridged | +| Actions | `JointPositionActionCfg` | S | YES | +| Obs | `joint_pos_rel` | S | YES | +| Obs | `joint_vel_rel` | S | YES | +| Obs | `generated_commands` | S | YES | +| Obs | `last_action` | S | YES | +| Rewards | `position_command_error` | **C** | YES | +| Rewards | `position_command_error_tanh` | **C** | YES | +| Rewards | `orientation_command_error` | **C** | YES | +| Rewards | `action_rate_l2` | S | YES | +| Rewards | `joint_vel_l2` | S | YES | +| Terms | `time_out` | S | YES | +| Events | `reset_root_state_uniform` | S | YES | +| Events | `reset_joints_by_scale` | S | YES | +| Curriculum | `modify_reward_weight` | S | forwarded from stable | + +### Dexsuite (NOT migrated) + +| Manager | Term | Source | Warp | +|---|---|---|---| +| Commands | `ObjectUniformPoseCommandCfg` | **C** | **NO** | +| Actions | `RelativeJointPositionActionCfg` | S | YES | +| Obs | `object_quat_b` | **C** | **NO** | +| Obs | `generated_commands` | S | YES | +| Obs | `last_action` | S | YES | +| Obs | `time_left` | **C** | **NO** | +| Obs | `joint_pos` | S | YES | +| Obs | `joint_vel` | S | YES | +| Obs | `body_state_b` | **C** | **NO** | +| Obs | `object_point_cloud_b` | **C** (class) | **NO** | +| Obs | `fingers_contact_force_b` | **C** | **NO** | +| Obs | `vision_camera` | **C** (class) | **NO** (not convertible) | +| Rewards | `action_l2_clamped` | **C** | **NO** | +| Rewards | `action_rate_l2_clamped` | **C** | **NO** | +| Rewards | `object_ee_distance` | **C** | **NO** | +| Rewards | `position_command_error_tanh` | S | YES | +| Rewards | `orientation_command_error_tanh` | S | YES | +| Rewards | `success_reward` | **C** | **NO** | +| Rewards | `is_terminated_term` | S | YES | +| Terms | `time_out` | S | YES | +| Terms | `out_of_bound` | **C** (class) | **NO** | +| Terms | `object_spinning_too_fast` | **C** | **NO** | +| Terms | `abnormal_robot_state` | **C** | **NO** | +| Events | `reset_root_state_uniform` | S | YES | +| Events | `reset_joints_by_offset` | S | YES | +| Events | `randomize_physics_scene_gravity` | S | **NO** | +| Curriculum | ADR | **C** (class) | **NO** | + +**Total Dexsuite custom terms to convert: 15** (1 command class, 6 obs, 4 rewards, 3 terms, 1 curriculum) +Plus `randomize_physics_scene_gravity` from shared events. +Plus `vision_camera` / `image` are not convertible (PyTorch NN). + +--- + +## Custom (Task-Local) MDP Terms + +All task-specific MDP functions and classes, grouped by task, with conversion status. + +### Cartpole -- `isaaclab_tasks.manager_based.classic.cartpole.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `joint_pos_target_l2` | reward fn | YES | Pure warp kernel `_joint_pos_target_l2_kernel` | + +### Humanoid -- `isaaclab_tasks.manager_based.classic.humanoid.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `base_yaw_roll` | obs fn | YES | Warp kernel `_base_yaw_roll_kernel` | +| `base_up_proj` | obs fn | YES | Warp kernel `_base_up_proj_kernel` | +| `base_heading_proj` | obs fn | YES | Warp kernel `_base_heading_proj_kernel` | +| `base_angle_to_target` | obs fn | YES | Warp kernel | +| `upright_posture_bonus` | reward fn | YES | Warp kernel `_upright_posture_bonus_kernel` | +| `move_to_target_bonus` | reward fn | YES | Warp kernel `_move_to_target_bonus_kernel` | +| `progress_reward` | reward class | YES | Warp kernels `_progress_reward_kernel`, `_progress_reward_reset_kernel` | +| `joint_pos_limits_penalty_ratio` | reward class | YES | Warp kernel `_joint_pos_limits_penalty_ratio_kernel`, gear ratios cached in `__init__` | +| `power_consumption` | reward class | YES | Warp kernel `_power_consumption_kernel`, gear ratios cached in `__init__` | + +### Velocity Locomotion -- `isaaclab_tasks.manager_based.locomotion.velocity.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `feet_air_time` | reward fn | YES | Warp kernel, sensor data cached via `wp.from_torch` | +| `feet_air_time_positive_biped` | reward fn | YES | Warp kernel, sensor data cached via `wp.from_torch` | +| `feet_slide` | reward fn | YES | Warp kernel, force history cached via `wp.from_torch` | +| `track_lin_vel_xy_yaw_frame_exp` | reward fn | YES | Warp kernel with yaw-frame rotation | +| `track_ang_vel_z_world_exp` | reward fn | YES | Warp kernel with exponential error | +| `stand_still_joint_deviation_l1` | reward fn | YES | Warp kernel with command gating | +| `terrain_out_of_bounds` | term fn | YES | Warp kernel, terrain config cached on first call | +| `terrain_levels_vel` | curriculum fn | YES | Torch-based (runs at reset, not per-step) | + +### Velocity Locomotion / Spot -- `isaaclab_tasks.manager_based.locomotion.velocity.config.spot.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `reset_joints_around_default` | event fn | **NO** | Torch-based | +| `air_time_reward` | reward fn | **NO** | Torch-based | +| `base_angular_velocity_reward` | reward fn | **NO** | Torch-based | +| `base_linear_velocity_reward` | reward fn | **NO** | Torch-based | +| `GaitReward` | reward class | **NO** | Torch-based, stateful (gait phase tracking) | +| `foot_clearance_reward` | reward fn | **NO** | Torch-based | +| `action_smoothness_penalty` | reward fn | **NO** | Torch-based | +| `air_time_variance_penalty` | reward fn | **NO** | Torch-based | +| `base_motion_penalty` | reward fn | **NO** | Torch-based | +| `base_orientation_penalty` | reward fn | **NO** | Torch-based | +| `foot_slip_penalty` | reward fn | **NO** | Torch-based | +| `joint_acceleration_penalty` | reward fn | **NO** | Torch-based | +| `joint_position_penalty` | reward fn | **NO** | Torch-based | +| `joint_torques_penalty` | reward fn | **NO** | Torch-based | +| `joint_velocity_penalty` | reward fn | **NO** | Torch-based | + +### Reach -- `isaaclab_tasks.manager_based.manipulation.reach.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `position_command_error` | reward fn | YES | Warp kernel `_position_command_error_kernel` with frame transform | +| `position_command_error_tanh` | reward fn | YES | Warp kernel `_position_command_error_tanh_kernel` | +| `orientation_command_error` | reward fn | YES | Warp kernel `_orientation_command_error_kernel` with quat math | + +### Dexsuite -- `isaaclab_tasks.manager_based.manipulation.dexsuite.mdp` + +| Term | Type | Warp | Implementation | +|---|---|---|---| +| `ObjectUniformPoseCommandCfg` | command class | **NO** | Torch-based | +| `object_pos_b` | obs fn | **NO** | Torch-based | +| `object_quat_b` | obs fn | **NO** | Torch-based | +| `body_state_b` | obs fn | **NO** | Torch-based | +| `object_point_cloud_b` | obs class | **NO** | Torch-based + USD point cloud | +| `fingers_contact_force_b` | obs fn | **NO** | Torch-based | +| `vision_camera` | obs class | **NO** | Not convertible (image pipeline) | +| `time_left` | obs fn | **NO** | Torch-based (simple) | +| `action_l2_clamped` | reward fn | **NO** | Torch-based | +| `action_rate_l2_clamped` | reward fn | **NO** | Torch-based | +| `object_ee_distance` | reward fn | **NO** | Torch-based | +| `success_reward` | reward fn | **NO** | Torch-based | +| `contacts` | reward fn | **NO** | Torch-based | +| `out_of_bound` | term class | **NO** | Torch-based | +| `abnormal_robot_state` | term fn | **NO** | Torch-based | +| `object_spinning_too_fast` | term fn | **NO** | Torch-based | +| ADR curriculum | curriculum class | **NO** | Torch-based | + +--- + +## Shared Terms Not Used by Any Migrated Task + +These shared MDP terms have warp overrides but are **not referenced** by any of the 31 currently +migrated gym IDs. They are available for future tasks. + +### Observations (not used by any migrated task) + +| Term | Warp | Potential Users | +|---|---|---| +| `root_pos_w` | YES | Navigation, locomanipulation | +| `root_quat_w` | YES | Navigation, locomanipulation | +| `root_lin_vel_w` | YES | Navigation | +| `root_ang_vel_w` | YES | Navigation | +| `body_pose_w` | YES | Manipulation (end-effector tracking) | +| `body_projected_gravity_b` | YES | Manipulation, dexterity | +| `joint_pos` | YES | Dexsuite (uses it, but not migrated) | +| `joint_vel` | YES | Dexsuite (uses it, but not migrated) | +| `joint_effort` | YES | Future tasks | +| `current_time_s` | YES | Future tasks | +| `remaining_time_s` | YES | Future tasks | + +### Rewards (not used by any migrated task) + +| Term | Warp | Potential Users | +|---|---|---| +| `base_height_l2` | YES | Locomotion with height tracking | +| `body_lin_acc_l2` | YES | Smooth motion tasks | +| `joint_vel_limits` | YES | Safety-critical tasks | +| `applied_torque_limits` | YES | Torque-limited robots | +| `desired_contacts` | YES | Gait-specific tasks | +| `contact_forces` | YES | Force control tasks | + +### Terminations (not used by any migrated task) + +| Term | Warp | Potential Users | +|---|---|---| +| `command_resample` | YES | Command-driven tasks with resampling | +| `bad_orientation` | YES | Tasks needing orientation limits | +| `joint_pos_out_of_limit` | YES | Safety-critical tasks | +| `joint_vel_out_of_limit` | YES | Safety-critical tasks | +| `joint_vel_out_of_manual_limit` | YES | Safety-critical tasks | +| `joint_effort_out_of_limit` | YES | Torque-limited tasks | + +### Events (not used by any migrated task) + +| Term | Warp | Potential Users | +|---|---|---| +| `randomize_rigid_body_material` | YES | Sim-to-real tasks | +| `randomize_rigid_body_mass` | YES | Sim-to-real tasks | +| `randomize_actuator_gains` | YES | Sim-to-real tasks | +| `randomize_joint_parameters` | YES | Sim-to-real tasks | +| `reset_root_state_with_random_orientation` | YES | Manipulation reset | +| `reset_root_state_from_terrain` | YES | Rough terrain locomotion | +| `reset_scene_to_default` | YES | General reset | + +--- + +## Cross-Cutting Notes + +### Command Manager -- NOT a Blocker + +The command manager classes themselves remain torch-based, but this does **not** block +downstream MDP terms that consume command data. The command manager's output tensors +(`get_command()`, `get_term().time_left`, etc.) have stable pointers that can be wrapped +with `wp.from_torch()` on the first call. Subsequent calls reuse the zero-copy view, +so the warp kernel always reads the latest command values with no conversion overhead. + +Pattern used in all command-dependent terms: +```python +def some_term(env, out, command_name: str, ...) -> None: + if not hasattr(some_term, "_cmd_wp"): + cmd_torch = env.command_manager.get_command(command_name) + some_term._cmd_wp = wp.from_torch(cmd_torch.contiguous()) + wp.launch(kernel=_some_kernel, ..., inputs=[some_term._cmd_wp, ...], ...) +``` + +### Integrator Setting + +All experimental configs use `integrator="implicit"` (standard Newton solver) instead of +the stable `"implicitfast"` variant. This is required for the MJWarp solver backend. + +### Graphability: No Torch in the Hot Path + +All per-step MDP terms must be CUDA-graph capturable. This means: + +- **No `torch.*` ops** in function bodies (only in `__init__` or `reset`) +- **No Python conditionals on changing values** in the per-step path +- Cross-manager torch tensors (commands, sensors) are cached as zero-copy warp views + via `wp.from_torch()` on first call using the `hasattr` pattern (see below) +- After warmup, every per-step call is a pure `wp.launch()` chain + +### Non-Capturable MDP Terms (`@warp_capturable(False)`) + +Some MDP terms access `ArticulationData` properties backed by `TimestampedWarpBuffer` +(lazy derived properties). These are incompatible with `wp.ScopedCapture` because the +timestamp guard (`if timestamp < sim_timestamp`) is a Python branch — the warmup call +updates the timestamp, causing the capture call to skip the kernel entirely. The graph +then replays with stale data. See `GRAPH_CAPTURE_MIGRATION.md` in the Newton +articulation package for the full Tier 1/2/3 property analysis. + +Affected terms are marked `@warp_capturable(False)`, which causes the owning manager to +fall back to mode=1 (warp not captured) automatically via `register_manager_capturability`. + +**Observations** (`isaaclab_experimental/envs/mdp/observations.py`): + +| Term | Accesses | Status | +|------|----------|--------| +| `base_lin_vel` | `root_lin_vel_b` → `root_com_vel_b` (Tier 2) | Applied | +| `base_ang_vel` | `root_ang_vel_b` → `root_com_vel_b` (Tier 2) | Applied | +| `projected_gravity` | `projected_gravity_b` (Tier 2) | Applied | +| `body_projected_gravity_b` | `projected_gravity_b` (Tier 2) | Applied | + +**Rewards — base** (`isaaclab_experimental/envs/mdp/rewards.py`): + +| Term | Accesses | Status | +|------|----------|--------| +| `lin_vel_z_l2` | `root_lin_vel_b` → `root_com_vel_b` (Tier 2) | Applied | +| `ang_vel_xy_l2` | `root_ang_vel_b` → `root_com_vel_b` (Tier 2) | Applied | +| `flat_orientation_l2` | `projected_gravity_b` (Tier 2) | Applied | +| `track_lin_vel_xy_exp` | `root_lin_vel_b` → `root_com_vel_b` (Tier 2) | Applied | +| `track_ang_vel_z_exp` | `root_ang_vel_b` → `root_com_vel_b` (Tier 2) | Applied | + +**Rewards — humanoid** (`isaaclab_tasks_experimental/.../humanoid/mdp/rewards.py`): + +| Term | Accesses | Status | +|------|----------|--------| +| `upright_posture_bonus` | `projected_gravity_b` (Tier 2) | Applied | + +**Rewards — safe** (no Tier 2 access, fully capturable): + +- `track_lin_vel_xy_yaw_frame_exp` → `root_quat_w`, `root_lin_vel_w` (Tier 3 from Tier 1) +- `track_ang_vel_z_world_exp` → `root_ang_vel_w` (Tier 3 from Tier 1) +- `feet_slide` → `body_lin_vel_w` → `body_com_lin_vel_w` (Tier 3 from Tier 1) +- `feet_air_time`, `feet_air_time_positive_biped` → contact sensor data +- `joint_torques_l2`, `joint_acc_l2`, `joint_vel_l2`, etc. → `joint_pos`, `joint_vel` (Tier 1) +- `is_alive`, `is_terminated`, `action_rate_l2`, `action_l2` → no articulation data + +**Pending fix:** Implement `materialize_derived()` in `ArticulationData.update()` to +eagerly compute Tier 2 properties before captured graphs replay. Once applied, all +`@warp_capturable(False)` annotations for Tier 2 access can be removed and these terms +become fully capturable. See `GRAPH_CAPTURE_MIGRATION.md` in the Newton articulation +package for the proposed implementation. + +### Resolved Cross-Cutting Blockers + +| Blocker | Resolution | +|---|---| +| Class-based term support | All 5 class-based terms converted (see Events table) | +| Command Manager dependency | `wp.from_torch` bridge (zero-copy) | +| ContactSensor dependency | `net_forces_w_history` wrapped via `wp.from_torch` on first call | +| `body_mask` pattern | body_ids cached as `wp.array(dtype=wp.int32)` on first call | + +### Remaining Gaps (shared events) + +| Term | Why Deferred | +|---|---| +| `randomize_rigid_body_collider_offsets` | Stub (`NotImplementedError`) in stable | +| `randomize_physics_scene_gravity` | Class-based, per-env gravity. Low priority. | +| `randomize_fixed_tendon_parameters` | Stub (`NotImplementedError`) in stable | +| `reset_nodal_state_uniform` | Stub (`NotImplementedError`) in stable | +| `randomize_rigid_body_scale` | USD `pxr` API, pre-sim only. Not convertible. | +| `randomize_visual_texture_material` | Omni Replicator API. Not convertible. | +| `randomize_visual_color` | Omni Replicator API. Not convertible. | + +### Not Convertible to Pure Warp + +| Term | Reason | +|---|---| +| `image` | 4D image tensor with per-type normalization, depth conversion. | +| `image_features` | PyTorch NN inference (ResNet/Theia); must remain hybrid. | +| `vision_camera` (dexsuite) | Image pipeline, same limitation as `image`. | +| `randomize_rigid_body_scale` | USD `pxr` API, no tensor math. | +| `randomize_visual_texture_material` | Omni Replicator API. | +| `randomize_visual_color` | Omni Replicator API. | + +--- + +## Migration Pattern + +### How to create an `_exp` task copy + +For each stable manager-based task, the experimental copy follows this structure: + +``` +isaaclab_tasks_experimental/manager_based/// +├── __init__.py # gym.register with -Warp suffix +├── _env_cfg.py # Copy of stable, change imports +└── mdp/ + ├── __init__.py # from isaaclab_experimental.envs.mdp import *; from .custom import * + └── .py # Warp-first versions of task-specific terms +``` + +No `agents/` directory -- reuse stable agent configs via import. + +### `__init__.py` registration pattern + +```python +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.. import agents + +gym.register( + id="Isaac--Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}._env_cfg:EnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:PPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + # ... all agent configs from stable + }, +) +``` + +Key rules: +- Use `f"{__name__}"` for `env_cfg_entry_point` (local experimental config) +- Use `f"{agents.__name__}"` for all agent configs (points to stable) +- Include ALL agent configs that stable has (rsl_rl, skrl, sb3, rl_games, symmetry) +- Entry point is always `isaaclab_experimental.envs:ManagerBasedRLEnvWarp` + +### env_cfg.py import changes + +```python +# Stable imports: +from isaaclab.managers import ... +import isaaclab_tasks.manager_based..mdp as mdp + +# Experimental imports: +from isaaclab_experimental.managers import ... +import isaaclab_tasks_experimental.manager_based..mdp as mdp +``` + +### MDP term signature change + +```python +# Stable (torch): returns tensor +def term(env, **params) -> torch.Tensor: + +# Experimental (warp): writes to pre-allocated output +def term(env, out: wp.array, **params) -> None: +``` + +### Kernel + function co-location + +Every warp kernel must be placed directly above the function that launches it: + +```python +@wp.kernel +def _my_reward_kernel(data: ..., out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = ... + +def my_reward(env, out, asset_cfg=SceneEntityCfg("robot")) -> None: + asset = env.scene[asset_cfg.name] + wp.launch(kernel=_my_reward_kernel, dim=env.num_envs, inputs=[..., out], device=env.device) +``` + +### Cross-manager reference caching (commands, sensors) + +Torch tensors from other managers (commands, contact sensors) are cached as zero-copy +warp views on first call. The `hasattr` guard only executes during warmup (before graph +capture). After warmup, only the `wp.launch` runs. + +```python +def some_reward(env, out, command_name: str, ...) -> None: + fn = some_reward + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch(kernel=_some_kernel, ..., inputs=[fn._cmd_wp, ...], ...) +``` + +### Observation dimension inference (`out_dim`) + +The observation manager infers output buffer dimensions from decorator metadata. +No `term_dim` parameter is needed in env_cfg `params`. + +Resolution order: +1. `out_dim` on `@generic_io_descriptor_warp` decorator (for body/command/action/time obs) +2. `axes` on decorator (for root-state obs: `len(axes)` gives dimension) +3. `asset_cfg.joint_ids` count (for joint-state obs) + +```python +# Root state: dimension derived from axes (no out_dim needed) +@generic_io_descriptor_warp(axes=["X", "Y", "Z"], observation_type="RootState", ...) +def base_lin_vel(env, out, asset_cfg=...): ... + +# Body state: out_dim required (per-body component count varies) +@generic_io_descriptor_warp(out_dim="body:7", observation_type="BodyState", ...) +def body_pose_w(env, out, asset_cfg=...): ... + +# Cross-manager: out_dim sentinel queries manager at init time +@generic_io_descriptor_warp(out_dim="command", observation_type="Command", ...) +def generated_commands(env, out, command_name: str): ... + +# Custom task obs: explicit int +@generic_io_descriptor_warp(out_dim=2, observation_type="RootState") +def base_yaw_roll(env, out, asset_cfg=...): ... +``` + +Supported `out_dim` values: `int`, `"joint"`, `"body:N"`, `"command"`, `"action"`. + +### Class-based term pattern + +```python +class my_reward(ManagerTermBase): + def __init__(self, env, cfg): + # Torch ops OK here (one-time setup) + # Cache persistent warp arrays + self._gear_wp = wp.from_torch(gear_tensor) + + def reset(self, env_mask=None): + # Warp kernel for reset + wp.launch(kernel=_reset_kernel, ...) + + def __call__(self, env, out, **params): + # Pure wp.launch only -- no torch ops + wp.launch(kernel=_compute_kernel, ..., inputs=[self._gear_wp, ...]) +``` + +### Joint subset: mask vs ids + +```python +# Stable: asset_cfg.joint_ids (int list) +# Experimental: asset_cfg.joint_mask (wp.array(dtype=wp.bool)) +``` + +### Action class changes + +```python +# process_actions: (actions: torch.Tensor) -> (actions: wp.array, action_offset: int) +# reset: (env_ids: Sequence[int]) -> (mask: wp.array(dtype=wp.bool)) +# joint targeting: joint_ids= -> joint_mask= +``` + +### Buffer management + +- Pre-allocate all output buffers in `__init__` (persistent pointers for graph capture) +- No dynamic tensor creation in per-step functions +- Per-joint constants stored as 1D arrays, indexed by `j` inside kernels diff --git a/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py new file mode 100644 index 00000000000..ecaa6ca0a11 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py @@ -0,0 +1,521 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright (c) 2022-2026, The Isaac Lab Project Developers. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for Warp-first action term classes. + +Tests all 10 experimental action classes: process_actions, apply_actions, reset. + +Usage:: + python -m pytest test_action_warp_parity.py -v +""" + +from __future__ import annotations + +import numpy as np +import torch + +import pytest +import warp as wp + +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +from isaaclab_experimental.envs.mdp.actions import ( + AbsBinaryJointPositionAction, + AbsBinaryJointPositionActionCfg, + BinaryJointPositionAction, + BinaryJointPositionActionCfg, + BinaryJointVelocityAction, + BinaryJointVelocityActionCfg, + EMAJointPositionToLimitsAction, + EMAJointPositionToLimitsActionCfg, + JointEffortAction, + JointEffortActionCfg, + JointPositionAction, + JointPositionActionCfg, + JointPositionToLimitsAction, + JointPositionToLimitsActionCfg, + JointVelocityAction, + JointVelocityActionCfg, + NonHolonomicAction, + NonHolonomicActionCfg, + RelativeJointPositionAction, + RelativeJointPositionActionCfg, +) + +NUM_ENVS = 32 +NUM_JOINTS = 6 +DEVICE = "cuda:0" +ATOL = 1e-5 +RTOL = 1e-5 +JOINT_NAMES = [f"joint_{i}" for i in range(NUM_JOINTS)] + + +# ============================================================================ +# Mock infrastructure +# ============================================================================ + + +class MockArticulationData: + def __init__(self, seed=42): + rng = np.random.RandomState(seed) + self.joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) + self.joint_vel = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) + self.default_joint_pos = wp.array( + np.tile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], (NUM_ENVS, 1)).astype(np.float32), device=DEVICE + ) + self.default_joint_vel = wp.array(np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32), device=DEVICE) + + limits_np = np.zeros((NUM_ENVS, NUM_JOINTS, 2), dtype=np.float32) + limits_np[:, :, 0] = -3.14 + limits_np[:, :, 1] = 3.14 + self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=DEVICE) + + # Body quaternion for NonHolonomicAction (identity = [0,0,0,1] in xyzw) + num_bodies = 3 + quat_np = np.zeros((NUM_ENVS, num_bodies, 4), dtype=np.float32) + quat_np[:, :, 3] = 1.0 # w=1 (identity) + self.body_quat_w = wp.array(quat_np, dtype=wp.quatf, device=DEVICE) + + self._num_joints = NUM_JOINTS + + def resolve_joint_mask(self, joint_ids=None): + mask = [False] * NUM_JOINTS + if joint_ids is None or isinstance(joint_ids, slice): + mask = [True] * NUM_JOINTS + else: + for j in joint_ids: + mask[j] = True + return wp.array(mask, dtype=wp.bool, device=DEVICE) + + +class MockArticulation: + def __init__(self, data: MockArticulationData): + self.data = data + self.num_joints = NUM_JOINTS + self.num_bodies = 3 + self.device = DEVICE + # Track what was last written for verification + self.last_pos_target = None + self.last_vel_target = None + self.last_effort_target = None + self.last_joint_mask = None + + def find_joints(self, names, preserve_order=False): + if isinstance(names, list) and names == [".*"]: + return None, JOINT_NAMES, list(range(NUM_JOINTS)) + # For specific joint names, resolve them + ids = [] + resolved_names = [] + for name in names if isinstance(names, list) else [names]: + for i, jn in enumerate(JOINT_NAMES): + if name in jn or name == jn or name == ".*": + if i not in ids: + ids.append(i) + resolved_names.append(jn) + if not ids: + ids = list(range(NUM_JOINTS)) + resolved_names = list(JOINT_NAMES) + return None, resolved_names, ids + + def find_bodies(self, name): + return None, [name], [0] + + def set_joint_position_target(self, target, joint_ids=None, joint_mask=None): + self.last_pos_target = target + self.last_joint_mask = joint_mask + + def set_joint_velocity_target(self, target, joint_ids=None, joint_mask=None): + self.last_vel_target = target + self.last_joint_mask = joint_mask + + def set_joint_effort_target(self, target, joint_ids=None, joint_mask=None): + self.last_effort_target = target + self.last_joint_mask = joint_mask + + +class MockScene: + def __init__(self, asset): + self._asset = asset + + def __getitem__(self, name): + return self._asset + + +class MockEnv: + def __init__(self, asset): + self.scene = MockScene(asset) + self.num_envs = NUM_ENVS + self.device = DEVICE + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture() +def art_data(): + return MockArticulationData() + + +@pytest.fixture() +def asset(art_data): + return MockArticulation(art_data) + + +@pytest.fixture() +def env(asset): + return MockEnv(asset) + + +@pytest.fixture() +def actions_wp(): + rng = np.random.RandomState(99) + return wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def assert_close(actual, expected, atol=ATOL, rtol=RTOL): + if isinstance(actual, wp.array): + actual = wp.to_torch(actual) + if isinstance(expected, wp.array): + expected = wp.to_torch(expected) + torch.testing.assert_close(actual.float(), expected.float(), atol=atol, rtol=rtol) + + +# ============================================================================ +# Joint action tests (JointPosition, JointVelocity, JointEffort, Relative) +# ============================================================================ + + +class TestJointActions: + """Test JointAction subclasses: process, apply, reset.""" + + def test_joint_effort_process_apply(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Processed = raw * scale(1.0) + offset(0.0) = raw + assert_close(term.processed_actions, actions_wp) + assert asset.last_effort_target is not None + + def test_joint_position_default_offset(self, env, asset, art_data, actions_wp): + cfg = JointPositionActionCfg(asset_name="robot", joint_names=[".*"], use_default_offset=True) + term = JointPositionAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Processed = raw * 1.0 + default_joint_pos[0] + defaults = wp.to_torch(art_data.default_joint_pos)[0] + raw = wp.to_torch(actions_wp) + expected = raw + defaults.unsqueeze(0) + assert_close(term.processed_actions, expected) + assert asset.last_pos_target is not None + + def test_joint_velocity_default_offset(self, env, asset, actions_wp): + cfg = JointVelocityActionCfg(asset_name="robot", joint_names=[".*"], use_default_offset=True) + term = JointVelocityAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Default vel is all zeros, so processed = raw + assert_close(term.processed_actions, actions_wp) + assert asset.last_vel_target is not None + + def test_relative_joint_position(self, env, asset, art_data, actions_wp): + cfg = RelativeJointPositionActionCfg(asset_name="robot", joint_names=[".*"], use_zero_offset=True) + term = RelativeJointPositionAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Applied = processed(=raw) + current joint_pos + raw = wp.to_torch(actions_wp) + current_pos = wp.to_torch(art_data.joint_pos) + expected = raw + current_pos + assert_close(asset.last_pos_target, expected) + + def test_joint_action_reset(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + # Process some actions + term.process_actions(actions_wp, action_offset=0) + assert wp.to_torch(term.raw_actions).abs().sum() > 0 + + # Reset all + term.reset(mask=None) + assert_close(term.raw_actions, wp.zeros_like(term.raw_actions)) + + def test_joint_action_reset_masked(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + raw_before = wp.to_torch(term.raw_actions).clone() + + # Reset only first half + mask_np = [i < NUM_ENVS // 2 for i in range(NUM_ENVS)] + mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) + term.reset(mask=mask) + + raw_after = wp.to_torch(term.raw_actions) + # First half zeroed + assert_close(raw_after[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) + # Second half unchanged + assert_close(raw_after[NUM_ENVS // 2 :], raw_before[NUM_ENVS // 2 :]) + + def test_joint_action_with_scale(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"], scale=2.5) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + expected = raw * 2.5 + assert_close(term.processed_actions, expected) + + +# ============================================================================ +# Binary joint action tests +# ============================================================================ + + +class TestBinaryJointActions: + """Test BinaryJointAction subclasses.""" + + def _make_binary_cfg(self, cls): + return cls( + asset_name="robot", + joint_names=[".*"], + open_command_expr={f"joint_{i}": 0.04 for i in range(NUM_JOINTS)}, + close_command_expr={f"joint_{i}": 0.0 for i in range(NUM_JOINTS)}, + ) + + def test_binary_position_open(self, env, asset): + cfg = self._make_binary_cfg(BinaryJointPositionActionCfg) + term = BinaryJointPositionAction(cfg, env) + + # Positive action → open + actions = wp.array(np.full((NUM_ENVS, NUM_JOINTS + 10), 1.0, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + term.apply_actions() + + processed = wp.to_torch(term.processed_actions) + expected_open = torch.full((NUM_ENVS, NUM_JOINTS), 0.04, device=DEVICE) + assert_close(processed, expected_open) + + def test_binary_position_close(self, env, asset): + cfg = self._make_binary_cfg(BinaryJointPositionActionCfg) + term = BinaryJointPositionAction(cfg, env) + + # Negative action → close + actions = wp.array(np.full((NUM_ENVS, NUM_JOINTS + 10), -1.0, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + + processed = wp.to_torch(term.processed_actions) + expected_close = torch.zeros(NUM_ENVS, NUM_JOINTS, device=DEVICE) + assert_close(processed, expected_close) + + def test_binary_velocity(self, env, asset): + cfg = self._make_binary_cfg(BinaryJointVelocityActionCfg) + term = BinaryJointVelocityAction(cfg, env) + + actions = wp.array(np.full((NUM_ENVS, 20), 1.0, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + term.apply_actions() + assert asset.last_vel_target is not None + + def test_abs_binary_threshold(self, env, asset): + cfg = AbsBinaryJointPositionActionCfg( + asset_name="robot", + joint_names=[".*"], + open_command_expr={f"joint_{i}": 0.04 for i in range(NUM_JOINTS)}, + close_command_expr={f"joint_{i}": 0.0 for i in range(NUM_JOINTS)}, + threshold=0.5, + positive_threshold=True, + ) + term = AbsBinaryJointPositionAction(cfg, env) + + # Action > threshold → open + actions = wp.array(np.full((NUM_ENVS, 20), 0.8, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + processed = wp.to_torch(term.processed_actions) + assert_close(processed, torch.full((NUM_ENVS, NUM_JOINTS), 0.04, device=DEVICE)) + + # Action < threshold → close + actions2 = wp.array(np.full((NUM_ENVS, 20), 0.2, dtype=np.float32), device=DEVICE) + term.process_actions(actions2, action_offset=0) + processed2 = wp.to_torch(term.processed_actions) + assert_close(processed2, torch.zeros(NUM_ENVS, NUM_JOINTS, device=DEVICE)) + + +# ============================================================================ +# Joint position to limits tests +# ============================================================================ + + +class TestJointPositionToLimitsActions: + """Test JointPositionToLimitsAction and EMA variant.""" + + def test_rescale_to_limits(self, env, asset): + cfg = JointPositionToLimitsActionCfg(asset_name="robot", joint_names=[".*"], rescale_to_limits=True, scale=1.0) + term = JointPositionToLimitsAction(cfg, env) + + # Input +1.0 → should map to upper limit (3.14) + actions = wp.array(np.full((NUM_ENVS, NUM_JOINTS), 1.0, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + + processed = wp.to_torch(term.processed_actions) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 3.14, device=DEVICE) + assert_close(processed, expected) + + def test_rescale_negative_one(self, env, asset): + cfg = JointPositionToLimitsActionCfg(asset_name="robot", joint_names=[".*"], rescale_to_limits=True, scale=1.0) + term = JointPositionToLimitsAction(cfg, env) + + # Input -1.0 → should map to lower limit (-3.14) + actions = wp.array(np.full((NUM_ENVS, NUM_JOINTS), -1.0, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + + processed = wp.to_torch(term.processed_actions) + expected = torch.full((NUM_ENVS, NUM_JOINTS), -3.14, device=DEVICE) + assert_close(processed, expected) + + def test_rescale_zero(self, env, asset): + cfg = JointPositionToLimitsActionCfg(asset_name="robot", joint_names=[".*"], rescale_to_limits=True, scale=1.0) + term = JointPositionToLimitsAction(cfg, env) + + # Input 0.0 → should map to midpoint (0.0 for symmetric limits) + actions = wp.array(np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + + processed = wp.to_torch(term.processed_actions) + expected = torch.zeros(NUM_ENVS, NUM_JOINTS, device=DEVICE) + assert_close(processed, expected) + + def test_ema_alpha_one(self, env, asset): + """alpha=1.0 means no smoothing — should behave like parent.""" + cfg = EMAJointPositionToLimitsActionCfg( + asset_name="robot", joint_names=[".*"], rescale_to_limits=True, scale=1.0, alpha=1.0 + ) + term = EMAJointPositionToLimitsAction(cfg, env) + term.reset(mask=None) + + actions = wp.array(np.full((NUM_ENVS, NUM_JOINTS), 0.5, dtype=np.float32), device=DEVICE) + term.process_actions(actions, action_offset=0) + + # With alpha=1.0, EMA = 1.0 * processed + 0.0 * prev = processed + # 0.5 rescaled: (0.5+1)/2 * 6.28 + (-3.14) = 0.75*6.28 - 3.14 = 4.71 - 3.14 = 1.57 + processed = wp.to_torch(term.processed_actions) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 1.57, device=DEVICE) + assert_close(processed, expected, atol=0.01, rtol=0.01) + + def test_ema_reset_to_current_pos(self, env, asset, art_data): + """After reset, prev_applied should match current joint positions.""" + cfg = EMAJointPositionToLimitsActionCfg( + asset_name="robot", joint_names=[".*"], rescale_to_limits=True, alpha=0.5 + ) + term = EMAJointPositionToLimitsAction(cfg, env) + term.reset(mask=None) + + prev = wp.to_torch(term._prev_applied_actions) + current_pos = wp.to_torch(art_data.joint_pos) + assert_close(prev, current_pos) + + +# ============================================================================ +# Non-holonomic action tests +# ============================================================================ + + +class TestNonHolonomicAction: + """Test NonHolonomicAction.""" + + def test_identity_orientation(self, env, asset): + """With identity quaternion (yaw=0), forward velocity maps to x only.""" + cfg = NonHolonomicActionCfg( + asset_name="robot", + body_name="base", + x_joint_name="joint_0", + y_joint_name="joint_1", + yaw_joint_name="joint_2", + scale=(1.0, 1.0), + offset=(0.0, 0.0), + ) + term = NonHolonomicAction(cfg, env) + + # Forward velocity = 1.0, yaw rate = 0.0 + actions = wp.zeros((NUM_ENVS, NUM_JOINTS), dtype=wp.float32, device=DEVICE) + # Set first 2 columns (action_dim=2) to [1.0, 0.0] + act_np = np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32) + act_np[:, 0] = 1.0 # forward vel + act_np[:, 1] = 0.0 # yaw rate + actions = wp.array(act_np, device=DEVICE) + + term.process_actions(actions, action_offset=0) + term.apply_actions() + + vel_cmd = wp.to_torch(term._joint_vel_command) + # With identity quat (yaw=0): vx = cos(0)*1.0 = 1.0, vy = sin(0)*1.0 = 0.0, omega = 0.0 + expected = torch.zeros(NUM_ENVS, 3, device=DEVICE) + expected[:, 0] = 1.0 + assert_close(vel_cmd, expected, atol=1e-4, rtol=1e-4) + + def test_pure_yaw(self, env, asset): + """Pure yaw rate input.""" + cfg = NonHolonomicActionCfg( + asset_name="robot", + body_name="base", + x_joint_name="joint_0", + y_joint_name="joint_1", + yaw_joint_name="joint_2", + ) + term = NonHolonomicAction(cfg, env) + + act_np = np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32) + act_np[:, 0] = 0.0 # no forward vel + act_np[:, 1] = 0.5 # yaw rate + actions = wp.array(act_np, device=DEVICE) + + term.process_actions(actions, action_offset=0) + term.apply_actions() + + vel_cmd = wp.to_torch(term._joint_vel_command) + # vx = vy = 0 (no forward vel), omega = 0.5 + expected = torch.zeros(NUM_ENVS, 3, device=DEVICE) + expected[:, 2] = 0.5 + assert_close(vel_cmd, expected, atol=1e-4, rtol=1e-4) + + def test_reset(self, env, asset): + cfg = NonHolonomicActionCfg( + asset_name="robot", + body_name="base", + x_joint_name="joint_0", + y_joint_name="joint_1", + yaw_joint_name="joint_2", + ) + term = NonHolonomicAction(cfg, env) + + act_np = np.ones((NUM_ENVS, NUM_JOINTS), dtype=np.float32) + term.process_actions(wp.array(act_np, device=DEVICE), action_offset=0) + assert wp.to_torch(term.raw_actions).abs().sum() > 0 + + term.reset(mask=None) + assert_close(term.raw_actions, wp.zeros((NUM_ENVS, 2), dtype=wp.float32, device=DEVICE)) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py new file mode 100644 index 00000000000..d54f14bac43 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py @@ -0,0 +1,1378 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests: stable (torch) MDP terms vs experimental (Warp-first) implementations. + +Verifies that every newly ported Warp-first MDP function produces **identical** results +to the stable torch-based implementation under three execution modes: + + 1. **Stable baseline** — the torch implementation in ``isaaclab.envs.mdp`` + 2. **Warp uncaptured** — the Warp kernel launched normally + 3. **Warp captured** — the Warp kernel recorded in a CUDA graph and replayed + +Usage:: + + python -m pytest test_mdp_warp_parity.py -v +""" + +from __future__ import annotations + +import numpy as np +import torch + +import isaaclab_experimental.envs.mdp.events as warp_evt + +# --------------------------------------------------------------------------- +# Experimental (Warp-first) implementations +# --------------------------------------------------------------------------- +import isaaclab_experimental.envs.mdp.observations as warp_obs +import isaaclab_experimental.envs.mdp.rewards as warp_rew +import isaaclab_experimental.envs.mdp.terminations as warp_term +import pytest +import warp as wp + +# --------------------------------------------------------------------------- +# Stable (torch) implementations +# --------------------------------------------------------------------------- +import isaaclab.envs.mdp.observations as stable_obs +import isaaclab.envs.mdp.rewards as stable_rew +import isaaclab.envs.mdp.terminations as stable_term + +# --------------------------------------------------------------------------- +# Test constants +# --------------------------------------------------------------------------- +NUM_ENVS = 64 +NUM_JOINTS = 12 +NUM_ACTIONS = 6 +DEVICE = "cuda:0" + +# Tolerance for float32 comparison (torch vs warp may differ by FMA / instruction order) +ATOL = 1e-5 +RTOL = 1e-5 + + +# ============================================================================ +# Mock objects +# ============================================================================ + + +class MockArticulationData: + """Mock articulation data backed by Warp arrays (same storage Newton uses).""" + + def __init__(self, num_envs: int, num_joints: int, device: str, seed: int = 42): + rng = np.random.RandomState(seed) + + # --- Joint state (float32 2D) --- + self.joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32), device=device) + self.joint_vel = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 2.0, device=device) + self.joint_acc = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.5, device=device) + self.default_joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.01, device=device) + self.default_joint_vel = wp.array(np.zeros((num_envs, num_joints), dtype=np.float32), device=device) + self.applied_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + self.computed_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + + # --- Soft joint position limits (vec2f 2D) --- + limits_np = np.zeros((num_envs, num_joints, 2), dtype=np.float32) + limits_np[:, :, 0] = -3.14 # lower + limits_np[:, :, 1] = 3.14 # upper + self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) + + # --- Soft joint velocity limits (float32 2D) --- + self.soft_joint_vel_limits = wp.array(np.full((num_envs, num_joints), 10.0, dtype=np.float32), device=device) + + # --- Root state --- + root_pos_np = rng.randn(num_envs, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 # positive heights + self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) + + self.root_lin_vel_b = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) + self.root_ang_vel_b = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) + + # Gravity projection (unit-ish vectors pointing mostly down) + gravity_np = np.zeros((num_envs, 3), dtype=np.float32) + gravity_np[:, 2] = -1.0 + gravity_np += rng.randn(num_envs, 3).astype(np.float32) * 0.1 + gravity_np /= np.linalg.norm(gravity_np, axis=1, keepdims=True) + self.projected_gravity_b = wp.array(gravity_np, dtype=wp.vec3f, device=device) + + # --- Additional root state for new observations --- + # Quaternion (random unit quaternions) + quat_np = rng.randn(num_envs, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) + + # World-frame velocities + self.root_lin_vel_w = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) + self.root_ang_vel_w = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) + + # --- Event-specific data --- + # Spatial velocity (6-component: lin + ang) + self.root_vel_w = wp.array(rng.randn(num_envs, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) + + # Default root pose (transformf = position vec3f + quaternion quatf) + default_pose_np = np.zeros((num_envs, 7), dtype=np.float32) + default_pose_np[:, 0:3] = rng.randn(num_envs, 3).astype(np.float32) * 0.1 # small position offsets + default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] # identity quaternion (xyzw) + self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) + + # Default root velocity (spatial_vectorf) + self.default_root_vel = wp.array( + np.zeros((num_envs, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device + ) + + +class MockArticulation: + def __init__(self, data: MockArticulationData): + self.data = data + self.num_bodies = 1 + self.device = DEVICE + + # Stub write APIs for events (no-ops — we verify scratch buffer contents instead) + def write_root_velocity_to_sim(self, root_velocity, env_ids=None, env_mask=None): + pass + + def write_root_pose_to_sim(self, root_pose, env_ids=None, env_mask=None): + pass + + def set_external_force_and_torque(self, forces, torques, body_ids=None, env_ids=None, env_mask=None): + pass + + +class MockScene: + def __init__(self, assets: dict, env_origins: torch.Tensor): + self._assets = assets + self.env_origins = env_origins + + def __getitem__(self, name: str): + return self._assets[name] + + +class MockActionManagerWarp: + """Returns warp arrays (for experimental functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = action_wp + self._prev_action = prev_action_wp + + @property + def action(self) -> wp.array: + return self._action + + @property + def prev_action(self) -> wp.array: + return self._prev_action + + +class MockActionManagerTorch: + """Returns torch tensors (for stable functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = wp.to_torch(action_wp) + self._prev_action = wp.to_torch(prev_action_wp) + + @property + def action(self) -> torch.Tensor: + return self._action + + @property + def prev_action(self) -> torch.Tensor: + return self._prev_action + + +class MockSceneEntityCfg: + """Unified cfg that works for both stable (joint_ids) and experimental (joint_mask / joint_ids_wp).""" + + def __init__(self, name: str, joint_ids: list[int], num_joints: int, device: str): + self.name = name + self.joint_ids = joint_ids + + # Experimental extras + mask = [False] * num_joints + for idx in joint_ids: + mask[idx] = True + self.joint_mask = wp.array(mask, dtype=wp.bool, device=device) + self.joint_ids_wp = wp.array(joint_ids, dtype=wp.int32, device=device) + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_function_caches(): + """Clear first-call caches on warp MDP functions so each test starts fresh. + + Functions like ``current_time_s`` and ``root_pos_w`` cache warp views on + themselves (``hasattr`` pattern). Without clearing, a cached view from a + prior test's fixture would be stale when a new test creates different tensors. + """ + yield + for fn in ( + warp_obs.root_pos_w, + warp_obs.current_time_s, + warp_obs.remaining_time_s, + warp_evt.push_by_setting_velocity, + warp_evt.apply_external_force_torque, + warp_evt.reset_root_state_uniform, + ): + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockArticulationData(NUM_ENVS, NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def env_origins(): + rng = np.random.RandomState(77) + # Newton stores env_origins as a warp vec3f array (stable root_pos_w calls wp.to_torch on it). + origins_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def scene(art_data, env_origins): + return MockScene({"robot": MockArticulation(art_data)}, env_origins) + + +@pytest.fixture() +def action_wp(): + rng = np.random.RandomState(99) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b # (action, prev_action) + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf): + """Env with warp action manager (for experimental functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + # RNG state for events (seeded deterministically) + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf): + """Env with torch action manager (for stable functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerTorch(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def all_joints_cfg(): + return MockSceneEntityCfg("robot", list(range(NUM_JOINTS)), NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def subset_cfg(): + return MockSceneEntityCfg("robot", [0, 2, 5, 8], NUM_JOINTS, DEVICE) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _run_warp_obs(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function and return the result as a torch tensor.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out) + + +def _run_warp_obs_captured(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function under CUDA graph capture and return the result.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + # Warm-up (triggers any first-call lazy init) + func(env, out, **kwargs) + # Capture + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + # Replay + wp.capture_launch(capture.graph) + return wp.to_torch(out) + + +def _run_warp_rew(func, env, device=DEVICE, **kwargs): + """Run a warp reward function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out) + + +def _run_warp_rew_captured(func, env, device=DEVICE, **kwargs): + """Run a warp reward function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out) + + +def _run_warp_term(func, env, device=DEVICE, **kwargs): + """Run a warp termination function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) + return wp.to_torch(out) + + +def _run_warp_term_captured(func, env, device=DEVICE, **kwargs): + """Run a warp termination function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out) + + +def assert_close(actual: torch.Tensor, expected: torch.Tensor, atol: float = ATOL, rtol: float = RTOL): + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +def assert_equal(actual: torch.Tensor, expected: torch.Tensor): + assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" + + +# ============================================================================ +# Observation parity tests +# ============================================================================ + + +class TestObservationParity: + """Verify experimental observation Warp kernels match stable torch implementations.""" + + # -- Root state observations ------------------------------------------------ + + def test_base_pos_z(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_pos_z(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_base_lin_vel(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_lin_vel(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_base_ang_vel(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_ang_vel(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_projected_gravity(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.projected_gravity(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint observations (all joints) ---------------------------------------- + + def test_joint_pos_all(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_vel_all(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint observations (subset) ------------------------------------------- + + def test_joint_pos_subset(self, warp_env, stable_env, subset_cfg): + cfg = subset_cfg + n_selected = len(cfg.joint_ids) + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_vel_subset(self, warp_env, stable_env, subset_cfg): + cfg = subset_cfg + n_selected = len(cfg.joint_ids) + expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Normalized joint position ---------------------------------------------- + + def test_joint_pos_limit_normalized(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_pos_limit_normalized(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured( + warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Action observation ----------------------------------------------------- + + def test_last_action(self, warp_env, stable_env, action_wp): + # Stable last_action returns env.action_manager.action (torch tensor) + expected = stable_obs.last_action(stable_env) + actual = _run_warp_obs(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) + actual_cap = _run_warp_obs_captured(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Additional root state observations ------------------------------------- + + def test_root_pos_w(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.root_pos_w(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.root_pos_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.root_pos_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_root_quat_w(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.root_quat_w(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_root_quat_w_unique(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.root_quat_w(stable_env, make_quat_unique=True, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), make_quat_unique=True, asset_cfg=cfg) + actual_cap = _run_warp_obs_captured( + warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), make_quat_unique=True, asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_root_lin_vel_w(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.root_lin_vel_w(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.root_lin_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.root_lin_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_root_ang_vel_w(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.root_ang_vel_w(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.root_ang_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.root_ang_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_effort(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_effort(stable_env, asset_cfg=cfg) + actual = _run_warp_obs(warp_obs.joint_effort, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = _run_warp_obs_captured(warp_obs.joint_effort, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Time observations ------------------------------------------------------ + + def test_current_time_s(self, warp_env, stable_env): + expected = stable_obs.current_time_s(stable_env) + actual = _run_warp_obs(warp_obs.current_time_s, warp_env, (NUM_ENVS, 1)) + actual_cap = _run_warp_obs_captured(warp_obs.current_time_s, warp_env, (NUM_ENVS, 1)) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_remaining_time_s(self, warp_env, stable_env): + expected = stable_obs.remaining_time_s(stable_env) + actual = _run_warp_obs(warp_obs.remaining_time_s, warp_env, (NUM_ENVS, 1)) + actual_cap = _run_warp_obs_captured(warp_obs.remaining_time_s, warp_env, (NUM_ENVS, 1)) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# Reward parity tests +# ============================================================================ + + +class TestRewardParity: + """Verify experimental reward Warp kernels match stable torch implementations.""" + + # -- Root penalties --------------------------------------------------------- + + def test_lin_vel_z_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.lin_vel_z_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_ang_vel_xy_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.ang_vel_xy_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_flat_orientation_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.flat_orientation_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint L2 penalties (masked) -------------------------------------------- + + def test_joint_vel_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_vel_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_acc_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_acc_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_torques_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_torques_l2(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Action penalties ------------------------------------------------------- + + def test_action_l2(self, warp_env, stable_env): + expected = stable_rew.action_l2(stable_env) + actual = _run_warp_rew(warp_rew.action_l2, warp_env) + actual_cap = _run_warp_rew_captured(warp_rew.action_l2, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_action_rate_l2(self, warp_env, stable_env): + expected = stable_rew.action_rate_l2(stable_env) + actual = _run_warp_rew(warp_rew.action_rate_l2, warp_env) + actual_cap = _run_warp_rew_captured(warp_rew.action_rate_l2, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Limit penalties -------------------------------------------------------- + + def test_joint_pos_limits(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_pos_limits(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_vel_limits(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_vel_limits(stable_env, soft_ratio=0.9, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_vel_limits, warp_env, soft_ratio=0.9, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_vel_limits, warp_env, soft_ratio=0.9, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_applied_torque_limits(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.applied_torque_limits(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.applied_torque_limits, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.applied_torque_limits, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Additional penalties --------------------------------------------------- + + def test_joint_deviation_l1(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_deviation_l1(stable_env, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_base_height_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + target = 0.5 + expected = stable_rew.base_height_l2(stable_env, target_height=target, asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.base_height_l2, warp_env, target_height=target, asset_cfg=cfg) + actual_cap = _run_warp_rew_captured(warp_rew.base_height_l2, warp_env, target_height=target, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# Termination parity tests +# ============================================================================ + + +class TestTerminationParity: + """Verify experimental termination Warp kernels match stable torch implementations.""" + + def test_root_height_below_minimum(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + min_h = 0.5 + expected = stable_term.root_height_below_minimum(stable_env, minimum_height=min_h, asset_cfg=cfg) + actual = _run_warp_term(warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg) + actual_cap = _run_warp_term_captured( + warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_bad_orientation(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + limit = 0.5 # ~29 degrees + expected = stable_term.bad_orientation(stable_env, limit_angle=limit, asset_cfg=cfg) + actual = _run_warp_term(warp_term.bad_orientation, warp_env, limit_angle=limit, asset_cfg=cfg) + actual_cap = _run_warp_term_captured(warp_term.bad_orientation, warp_env, limit_angle=limit, asset_cfg=cfg) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_joint_pos_out_of_limit(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_term.joint_pos_out_of_limit(stable_env, asset_cfg=cfg) + actual = _run_warp_term(warp_term.joint_pos_out_of_limit, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_term_captured(warp_term.joint_pos_out_of_limit, warp_env, asset_cfg=cfg) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_joint_vel_out_of_limit(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_term.joint_vel_out_of_limit(stable_env, asset_cfg=cfg) + actual = _run_warp_term(warp_term.joint_vel_out_of_limit, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_term_captured(warp_term.joint_vel_out_of_limit, warp_env, asset_cfg=cfg) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + # -- Additional joint terminations ------------------------------------------ + + def test_joint_vel_out_of_manual_limit(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + max_vel = 5.0 + expected = stable_term.joint_vel_out_of_manual_limit(stable_env, max_velocity=max_vel, asset_cfg=cfg) + actual = _run_warp_term(warp_term.joint_vel_out_of_manual_limit, warp_env, max_velocity=max_vel, asset_cfg=cfg) + actual_cap = _run_warp_term_captured( + warp_term.joint_vel_out_of_manual_limit, warp_env, max_velocity=max_vel, asset_cfg=cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_joint_effort_out_of_limit(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_term.joint_effort_out_of_limit(stable_env, asset_cfg=cfg) + actual = _run_warp_term(warp_term.joint_effort_out_of_limit, warp_env, asset_cfg=cfg) + actual_cap = _run_warp_term_captured(warp_term.joint_effort_out_of_limit, warp_env, asset_cfg=cfg) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + +# ============================================================================ +# Capture-then-mutate-then-replay tests +# +# Verify that a captured CUDA graph produces correct results when the +# underlying buffer *data* changes between capture and replay (simulating +# a new simulation step). +# ============================================================================ + + +def _copy_np_to_wp(dest: wp.array, src_np: np.ndarray): + """In-place overwrite of a warp array's contents from numpy (preserves pointer).""" + tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) + wp.copy(dest, tmp) + + +def _mutate_art_data(art_data: MockArticulationData, warp_env, rng_seed: int = 200): + """Mutate every data array in-place so captured graphs see fresh values.""" + rng = np.random.RandomState(rng_seed) + + _copy_np_to_wp(art_data.joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 1.5) + _copy_np_to_wp(art_data.joint_vel, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 3.0) + _copy_np_to_wp(art_data.joint_acc, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.8) + _copy_np_to_wp(art_data.default_joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.02) + _copy_np_to_wp(art_data.applied_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) + _copy_np_to_wp(art_data.computed_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) + + root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 + _copy_np_to_wp(art_data.root_pos_w, root_pos_np) + _copy_np_to_wp(art_data.root_lin_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) + _copy_np_to_wp(art_data.root_ang_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) + _copy_np_to_wp(art_data.root_lin_vel_w, rng.randn(NUM_ENVS, 3).astype(np.float32)) + _copy_np_to_wp(art_data.root_ang_vel_w, rng.randn(NUM_ENVS, 3).astype(np.float32)) + + gravity_np = np.zeros((NUM_ENVS, 3), dtype=np.float32) + gravity_np[:, 2] = -1.0 + gravity_np += rng.randn(NUM_ENVS, 3).astype(np.float32) * 0.15 + gravity_np /= np.linalg.norm(gravity_np, axis=1, keepdims=True) + _copy_np_to_wp(art_data.projected_gravity_b, gravity_np) + + quat_np = rng.randn(NUM_ENVS, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + _copy_np_to_wp(art_data.root_quat_w, quat_np) + + # Actions (in-place via warp copy — torch views auto-update) + _copy_np_to_wp(warp_env.action_manager._action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) + _copy_np_to_wp(warp_env.action_manager._prev_action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) + + # Episode length (in-place torch update — warp zero-copy view auto-updates) + warp_env.episode_length_buf[:] = torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + wp.synchronize() + + +class TestCapturedDataMutation: + """Capture a graph, mutate buffer data in-place, replay — results must match stable on the *new* data. + + This verifies every migrated MDP function is truly capture-safe: the CUDA graph + reads from the same buffer pointers but picks up whatever data is there at replay time. + """ + + # -- helpers --------------------------------------------------------------- + + def _capture_mutate_check_obs(self, warp_fn, stable_fn, warp_env, stable_env, art_data, shape, **kwargs): + out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) # warm-up + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_close(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + def _capture_mutate_check_rew(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_close(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + def _capture_mutate_check_term(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_equal(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + # -- observations ----------------------------------------------------------- + + def test_base_pos_z(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_pos_z, + stable_obs.base_pos_z, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 1), + asset_cfg=all_joints_cfg, + ) + + def test_base_lin_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_lin_vel, + stable_obs.base_lin_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_base_ang_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_ang_vel, + stable_obs.base_ang_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_projected_gravity(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.projected_gravity, + stable_obs.projected_gravity, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_pos, + stable_obs.joint_pos, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_vel, + stable_obs.joint_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos_limit_normalized(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_pos_limit_normalized, + stable_obs.joint_pos_limit_normalized, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_last_action(self, warp_env, stable_env, art_data): + self._capture_mutate_check_obs( + warp_obs.last_action, + stable_obs.last_action, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_ACTIONS), + ) + + def test_root_pos_w(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.root_pos_w, + stable_obs.root_pos_w, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_root_quat_w(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.root_quat_w, + stable_obs.root_quat_w, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 4), + asset_cfg=all_joints_cfg, + ) + + def test_root_quat_w_unique(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.root_quat_w, + stable_obs.root_quat_w, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 4), + make_quat_unique=True, + asset_cfg=all_joints_cfg, + ) + + def test_root_lin_vel_w(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.root_lin_vel_w, + stable_obs.root_lin_vel_w, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_root_ang_vel_w(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.root_ang_vel_w, + stable_obs.root_ang_vel_w, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_joint_effort(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_effort, + stable_obs.joint_effort, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_current_time_s(self, warp_env, stable_env, art_data): + self._capture_mutate_check_obs( + warp_obs.current_time_s, + stable_obs.current_time_s, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 1), + ) + + def test_remaining_time_s(self, warp_env, stable_env, art_data): + self._capture_mutate_check_obs( + warp_obs.remaining_time_s, + stable_obs.remaining_time_s, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 1), + ) + + # -- rewards ---------------------------------------------------------------- + + def test_lin_vel_z_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.lin_vel_z_l2, + stable_rew.lin_vel_z_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_ang_vel_xy_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.ang_vel_xy_l2, + stable_rew.ang_vel_xy_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_flat_orientation_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.flat_orientation_l2, + stable_rew.flat_orientation_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_vel_l2, + stable_rew.joint_vel_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_acc_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_acc_l2, + stable_rew.joint_acc_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_torques_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_torques_l2, + stable_rew.joint_torques_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_action_l2(self, warp_env, stable_env, art_data): + self._capture_mutate_check_rew( + warp_rew.action_l2, + stable_rew.action_l2, + warp_env, + stable_env, + art_data, + ) + + def test_action_rate_l2(self, warp_env, stable_env, art_data): + self._capture_mutate_check_rew( + warp_rew.action_rate_l2, + stable_rew.action_rate_l2, + warp_env, + stable_env, + art_data, + ) + + def test_joint_pos_limits(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_pos_limits, + stable_rew.joint_pos_limits, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel_limits(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_vel_limits, + stable_rew.joint_vel_limits, + warp_env, + stable_env, + art_data, + soft_ratio=0.9, + asset_cfg=all_joints_cfg, + ) + + def test_applied_torque_limits(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.applied_torque_limits, + stable_rew.applied_torque_limits, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_deviation_l1(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_deviation_l1, + stable_rew.joint_deviation_l1, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_base_height_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.base_height_l2, + stable_rew.base_height_l2, + warp_env, + stable_env, + art_data, + target_height=0.5, + asset_cfg=all_joints_cfg, + ) + + # -- terminations ----------------------------------------------------------- + + def test_root_height_below_minimum(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.root_height_below_minimum, + stable_term.root_height_below_minimum, + warp_env, + stable_env, + art_data, + minimum_height=0.5, + asset_cfg=all_joints_cfg, + ) + + def test_bad_orientation(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.bad_orientation, + stable_term.bad_orientation, + warp_env, + stable_env, + art_data, + limit_angle=0.5, + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.joint_pos_out_of_limit, + stable_term.joint_pos_out_of_limit, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.joint_vel_out_of_limit, + stable_term.joint_vel_out_of_limit, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel_out_of_manual_limit(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.joint_vel_out_of_manual_limit, + stable_term.joint_vel_out_of_manual_limit, + warp_env, + stable_env, + art_data, + max_velocity=5.0, + asset_cfg=all_joints_cfg, + ) + + def test_joint_effort_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.joint_effort_out_of_limit, + stable_term.joint_effort_out_of_limit, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + +# ============================================================================ +# Event tests +# +# Events use warp RNG (wp.randf) so exact parity with stable (torch RNG) is +# not possible. Instead we test: +# 1. Uncaptured run produces structurally correct output +# 2. Captured replay does not crash +# 3. Capture-then-mutate-then-replay: the graph picks up new input data +# (tested with zero-width ranges to eliminate RNG dependency) +# ============================================================================ + + +class TestEventCapturedDataMutation: + """Verify event functions are capture-safe and react to mutated input data.""" + + # -- reset_joints_by_offset ------------------------------------------------- + + def test_reset_joints_by_offset(self, warp_env, art_data, all_joints_cfg): + """With zero-width offset, result == defaults. Mutate defaults → result tracks.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Warm-up + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + + # Capture + with wp.ScopedCapture() as cap: + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + + # Mutate defaults in-place + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.5, dtype=np.float32) + _copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + # Replay + wp.capture_launch(cap.graph) + wp.synchronize() + + # With zero offset, joint_pos should equal new defaults (clamped to limits [-3.14, 3.14]) + result = wp.to_torch(art_data.joint_pos) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.5, device=DEVICE) + assert_close(result, expected) + + # -- reset_joints_by_scale -------------------------------------------------- + + def test_reset_joints_by_scale(self, warp_env, art_data, all_joints_cfg): + """With scale=1.0, result == defaults. Mutate defaults → result tracks.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + warp_evt.reset_joints_by_scale( + warp_env, mask, position_range=(1.0, 1.0), velocity_range=(1.0, 1.0), asset_cfg=cfg + ) + with wp.ScopedCapture() as cap: + warp_evt.reset_joints_by_scale( + warp_env, mask, position_range=(1.0, 1.0), velocity_range=(1.0, 1.0), asset_cfg=cfg + ) + + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.25, dtype=np.float32) + _copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + wp.capture_launch(cap.graph) + wp.synchronize() + + result = wp.to_torch(art_data.joint_pos) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.25, device=DEVICE) + assert_close(result, expected) + + # -- push_by_setting_velocity ----------------------------------------------- + + def test_push_by_setting_velocity(self, warp_env, art_data, all_joints_cfg): + """With zero-width velocity range, scratch == root_vel_w. Mutate root_vel_w → scratch tracks.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + zero_range = { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + } + + warp_evt.push_by_setting_velocity(warp_env, mask, velocity_range=zero_range) + with wp.ScopedCapture() as cap: + warp_evt.push_by_setting_velocity(warp_env, mask, velocity_range=zero_range) + + # Mutate root_vel_w + new_vel = np.tile([1.0, 2.0, 3.0, 0.1, 0.2, 0.3], (NUM_ENVS, 1)).astype(np.float32) + _copy_np_to_wp(art_data.root_vel_w, new_vel) + + wp.capture_launch(cap.graph) + wp.synchronize() + + scratch = wp.to_torch(warp_evt.push_by_setting_velocity._scratch_vel) + expected = torch.tensor([1.0, 2.0, 3.0, 0.1, 0.2, 0.3], device=DEVICE).expand(NUM_ENVS, -1) + assert_close(scratch, expected) + + # -- apply_external_force_torque -------------------------------------------- + + def test_apply_external_force_torque(self, warp_env, art_data, all_joints_cfg): + """With zero-width ranges, forces/torques are zero. Non-zero ranges produce non-zero output.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Zero-range: forces and torques should be zero + warp_evt.apply_external_force_torque(warp_env, mask, force_range=(0.0, 0.0), torque_range=(0.0, 0.0)) + with wp.ScopedCapture() as cap: + warp_evt.apply_external_force_torque(warp_env, mask, force_range=(0.0, 0.0), torque_range=(0.0, 0.0)) + wp.capture_launch(cap.graph) + wp.synchronize() + + forces = wp.to_torch(warp_evt.apply_external_force_torque._scratch_forces) + torques = wp.to_torch(warp_evt.apply_external_force_torque._scratch_torques) + assert_close(forces, torch.zeros_like(forces)) + assert_close(torques, torch.zeros_like(torques)) + + # -- reset_root_state_uniform ----------------------------------------------- + + def test_reset_root_state_uniform(self, warp_env, art_data, all_joints_cfg, env_origins): + """With zero-width ranges, pose = default + env_origin, vel = default. Mutate defaults → tracks.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + zero_pose = { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + } + zero_vel = dict(zero_pose) + + warp_evt.reset_root_state_uniform(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) + with wp.ScopedCapture() as cap: + warp_evt.reset_root_state_uniform(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) + + # Mutate default_root_pose: set all positions to (1, 2, 3), identity quat + new_pose = np.zeros((NUM_ENVS, 7), dtype=np.float32) + new_pose[:, 0:3] = [1.0, 2.0, 3.0] + new_pose[:, 3:7] = [0.0, 0.0, 0.0, 1.0] # identity (xyzw) + _copy_np_to_wp(art_data.default_root_pose, new_pose) + + wp.capture_launch(cap.graph) + wp.synchronize() + + scratch_pose = wp.to_torch(warp_evt.reset_root_state_uniform._scratch_pose) + origins_t = wp.to_torch(env_origins) + + # position = default(1,2,3) + env_origin + 0 + expected_pos = torch.tensor([1.0, 2.0, 3.0], device=DEVICE).unsqueeze(0) + origins_t + assert_close(scratch_pose[:, :3], expected_pos) + + # quaternion = identity * identity_delta = identity = (0,0,0,1) in xyzw + expected_quat = torch.tensor([0.0, 0.0, 0.0, 1.0], device=DEVICE).expand(NUM_ENVS, -1) + assert_close(scratch_pose[:, 3:7], expected_quat) + + # -- env_mask selectivity --------------------------------------------------- + + def test_reset_joints_mask_selectivity(self, warp_env, art_data, all_joints_cfg): + """Only masked envs are modified; unmasked envs retain their state.""" + cfg = all_joints_cfg + # Mask: only first half of envs + mask_np = np.array([i < NUM_ENVS // 2 for i in range(NUM_ENVS)]) + mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) + + # Set joint_pos to a known value + sentinel = np.full((NUM_ENVS, NUM_JOINTS), 999.0, dtype=np.float32) + _copy_np_to_wp(art_data.joint_pos, sentinel) + + # Set defaults to 0 + _copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) + + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + wp.synchronize() + + result = wp.to_torch(art_data.joint_pos) + # Masked envs: reset to 0 (defaults + 0 offset) + assert_close(result[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) + # Unmasked envs: still 999.0 + assert_close(result[NUM_ENVS // 2 :], torch.full((NUM_ENVS // 2, NUM_JOINTS), 999.0, device=DEVICE)) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py new file mode 100644 index 00000000000..1aadc590f04 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py @@ -0,0 +1,920 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright (c) 2022-2026, The Isaac Lab Project Developers. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for newly migrated Warp-first MDP terms. + +Tests: body observations, command-dependent rewards, contact sensor rewards/terminations, +and new event functions. + +Usage:: + python -m pytest test_mdp_warp_parity_new_terms.py -v +""" + +from __future__ import annotations + +import numpy as np +import torch + +import pytest +import warp as wp + +# Skip entire module if no CUDA device available +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +import isaaclab_experimental.envs.mdp.events as warp_evt +import isaaclab_experimental.envs.mdp.observations as warp_obs +import isaaclab_experimental.envs.mdp.rewards as warp_rew +import isaaclab_experimental.envs.mdp.terminations as warp_term + +import isaaclab.envs.mdp.observations as stable_obs +import isaaclab.envs.mdp.rewards as stable_rew +import isaaclab.envs.mdp.terminations as stable_term + +# --------------------------------------------------------------------------- +NUM_ENVS = 64 +NUM_JOINTS = 12 +NUM_BODIES = 4 +NUM_ACTIONS = 6 +NUM_HISTORY = 3 +CMD_DIM = 3 +DEVICE = "cuda:0" +ATOL = 1e-5 +RTOL = 1e-5 +BODY_IDS = [0, 2] # subset of bodies to test + + +# ============================================================================ +# Mock infrastructure +# ============================================================================ + + +def _make_rng(seed=42): + return np.random.RandomState(seed) + + +class MockMultiBodyArticulationData: + """Mock articulation data with multi-body arrays for body-level observations.""" + + def __init__(self, device=DEVICE, seed=42): + rng = _make_rng(seed) + + # --- Joint state --- + self.joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=device) + self.joint_vel = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 2.0, device=device) + self.default_joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.01, device=device) + self.default_joint_vel = wp.array(np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32), device=device) + self.joint_acc = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.5, device=device) + self.applied_torque = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 10.0, device=device) + self.computed_torque = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 10.0, device=device) + + # --- Root state --- + root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 + self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) + self.root_lin_vel_b = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) + self.root_ang_vel_b = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) + + # --- Soft limits --- + limits_np = np.zeros((NUM_ENVS, NUM_JOINTS, 2), dtype=np.float32) + limits_np[:, :, 0] = -3.14 + limits_np[:, :, 1] = 3.14 + self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) + self.soft_joint_vel_limits = wp.array(np.full((NUM_ENVS, NUM_JOINTS), 10.0, dtype=np.float32), device=device) + + # --- Body-level data (2D vec3f / transformf) --- + # projected_gravity_b: (num_envs, num_bodies) vec3f + grav_np = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) + grav_np[:, :, 2] = -1.0 + norms = np.linalg.norm(grav_np, axis=2, keepdims=True) + grav_np /= norms + self.projected_gravity_b = wp.array(grav_np, dtype=wp.vec3f, device=device) + + # body_pose_w: (num_envs, num_bodies) transformf — pos + identity quat + pose_np = np.zeros((NUM_ENVS, NUM_BODIES, 7), dtype=np.float32) + pose_np[:, :, :3] = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) + pose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.body_pose_w = wp.array(pose_np, dtype=wp.transformf, device=device) + + # body_lin_acc_w: (num_envs, num_bodies) vec3f + self.body_lin_acc_w = wp.array( + rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32), dtype=wp.vec3f, device=device + ) + + # body_com_pos_b: (num_envs, num_bodies) vec3f + self.body_com_pos_b = wp.array( + rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) * 0.01, dtype=wp.vec3f, device=device + ) + + # Event-specific + self.root_vel_w = wp.array(rng.randn(NUM_ENVS, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) + default_pose_np = np.zeros((NUM_ENVS, 7), dtype=np.float32) + default_pose_np[:, 0:3] = rng.randn(NUM_ENVS, 3).astype(np.float32) * 0.1 + default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) + self.default_root_vel = wp.array( + np.zeros((NUM_ENVS, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device + ) + + quat_np = rng.randn(NUM_ENVS, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) + self.root_lin_vel_w = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) + self.root_ang_vel_w = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) + + def resolve_joint_mask(self, joint_ids=None): + mask = [False] * NUM_JOINTS + if joint_ids is None or isinstance(joint_ids, slice): + mask = [True] * NUM_JOINTS + else: + for j in joint_ids: + mask[j] = True + return wp.array(mask, dtype=wp.bool, device=DEVICE) + + +class MockMultiBodyArticulation: + def __init__(self, data: MockMultiBodyArticulationData): + self.data = data + self.num_bodies = NUM_BODIES + self.num_joints = NUM_JOINTS + self.device = DEVICE + + def write_root_velocity_to_sim(self, *a, **kw): + pass + + def write_root_pose_to_sim(self, *a, **kw): + pass + + def write_joint_state_to_sim(self, *a, **kw): + pass + + def set_external_force_and_torque(self, *a, **kw): + pass + + def find_joints(self, names, preserve_order=False): + return None, [f"j{i}" for i in range(NUM_JOINTS)], list(range(NUM_JOINTS)) + + +class MockContactSensorData: + def __init__(self, device=DEVICE, seed=77): + rng = _make_rng(seed) + self.net_forces_w_history = torch.tensor( + rng.randn(NUM_ENVS, NUM_HISTORY, NUM_BODIES, 3).astype(np.float32), device=device + ) + + +class MockContactSensor: + def __init__(self, data: MockContactSensorData): + self.data = data + self.num_bodies = NUM_BODIES + + +class MockCommandTerm: + def __init__(self, device=DEVICE, seed=88): + rng = _make_rng(seed) + self.time_left = torch.tensor(rng.rand(NUM_ENVS).astype(np.float32) * 0.05, device=device) + self.command_counter = torch.tensor(rng.randint(0, 3, (NUM_ENVS,)), dtype=torch.float32, device=device) + + +class MockCommandManager: + def __init__(self, command_tensor: torch.Tensor, cmd_term: MockCommandTerm): + self._cmd = command_tensor + self._term = cmd_term + + def get_command(self, name: str) -> torch.Tensor: + return self._cmd + + def get_term(self, name: str): + return self._term + + +class MockBodyCfg: + """SceneEntityCfg-like object for body-level terms.""" + + def __init__(self, name="robot", body_ids=None): + self.name = name + self.body_ids = body_ids if body_ids is not None else BODY_IDS + + +class MockSensorCfg: + """SceneEntityCfg-like object for contact sensor terms.""" + + def __init__(self, name="contact_sensor", body_ids=None): + self.name = name + self.body_ids = body_ids if body_ids is not None else BODY_IDS + + +class MockScene: + def __init__(self, assets: dict, env_origins, sensors=None): + self._assets = assets + self.env_origins = env_origins + self.sensors = sensors or {} + self.articulations = {k: v for k, v in assets.items()} + self.rigid_objects = {} + self.num_envs = NUM_ENVS + + def __getitem__(self, name: str): + return self._assets[name] + + +class MockActionManagerWarp: + def __init__(self, action_wp, prev_action_wp): + self._action = action_wp + self._prev_action = prev_action_wp + + @property + def action(self): + return self._action + + @property + def prev_action(self): + return self._prev_action + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_caches(): + yield + # Clear function-level caches from all new warp functions + fns_to_clear = [ + warp_obs.body_projected_gravity_b, + warp_obs.body_pose_w, + warp_obs.generated_commands, + warp_rew.body_lin_acc_l2, + warp_rew.track_lin_vel_xy_exp, + warp_rew.track_ang_vel_z_exp, + warp_rew.undesired_contacts, + warp_rew.desired_contacts, + warp_rew.contact_forces, + warp_term.command_resample, + warp_term.illegal_contact, + warp_evt.reset_root_state_with_random_orientation, + warp_evt.reset_scene_to_default, + warp_evt.randomize_rigid_body_com, + ] + for fn in fns_to_clear: + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockMultiBodyArticulationData() + + +@pytest.fixture() +def env_origins(): + origins_np = _make_rng(77).randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def contact_data(): + return MockContactSensorData() + + +@pytest.fixture() +def cmd_tensor(): + rng = _make_rng(99) + return torch.tensor(rng.randn(NUM_ENVS, CMD_DIM).astype(np.float32), device=DEVICE) + + +@pytest.fixture() +def cmd_term(): + return MockCommandTerm() + + +@pytest.fixture() +def scene(art_data, env_origins, contact_data): + art = MockMultiBodyArticulation(art_data) + sensor = MockContactSensor(contact_data) + return MockScene({"robot": art}, env_origins, sensors={"contact_sensor": sensor}) + + +@pytest.fixture() +def action_wp(): + rng = _make_rng(55) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf, cmd_tensor, cmd_term): + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf, cmd_tensor, cmd_term): + class _Env: + pass + + env = _Env() + env.scene = scene + # stable functions access action_manager.action as torch + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + # stable termination_manager needed for time_out + env.termination_manager = type("_TM", (), {"terminated": torch.zeros(NUM_ENVS, dtype=torch.bool, device=DEVICE)})() + return env + + +@pytest.fixture() +def body_cfg(): + return MockBodyCfg("robot", BODY_IDS) + + +@pytest.fixture() +def sensor_cfg(): + return MockSensorCfg("contact_sensor", BODY_IDS) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _run_warp_obs(func, env, shape, **kwargs): + out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def _run_warp_obs_captured(func, env, shape, **kwargs): + out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) + func(env, out, **kwargs) + with wp.ScopedCapture() as cap: + func(env, out, **kwargs) + wp.capture_launch(cap.graph) + return wp.to_torch(out).clone() + + +def _run_warp_rew(func, env, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def _run_warp_rew_captured(func, env, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + func(env, out, **kwargs) + with wp.ScopedCapture() as cap: + func(env, out, **kwargs) + wp.capture_launch(cap.graph) + return wp.to_torch(out).clone() + + +def _run_warp_term(func, env, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def _run_warp_term_captured(func, env, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + func(env, out, **kwargs) + with wp.ScopedCapture() as cap: + func(env, out, **kwargs) + wp.capture_launch(cap.graph) + return wp.to_torch(out).clone() + + +def assert_close(actual, expected, atol=ATOL, rtol=RTOL): + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +def assert_equal(actual, expected): + assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" + + +# ============================================================================ +# Body observation parity tests +# ============================================================================ + + +class TestBodyObservationParity: + """Verify body-level observation Warp kernels match stable torch implementations.""" + + def test_body_projected_gravity_b(self, warp_env, stable_env, body_cfg): + n_sel = len(body_cfg.body_ids) + expected = stable_obs.body_projected_gravity_b(stable_env, asset_cfg=body_cfg) + actual = _run_warp_obs(warp_obs.body_projected_gravity_b, warp_env, (NUM_ENVS, n_sel * 3), asset_cfg=body_cfg) + actual_cap = _run_warp_obs_captured( + warp_obs.body_projected_gravity_b, warp_env, (NUM_ENVS, n_sel * 3), asset_cfg=body_cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_body_pose_w(self, warp_env, stable_env, body_cfg): + n_sel = len(body_cfg.body_ids) + # Stable body_pose_w calls env.scene.env_origins.unsqueeze(1) — needs torch tensor. + # Temporarily swap env_origins to torch for the stable call. + orig_origins = stable_env.scene.env_origins + stable_env.scene.env_origins = wp.to_torch(orig_origins) + expected = stable_obs.body_pose_w(stable_env, asset_cfg=body_cfg) + stable_env.scene.env_origins = orig_origins # restore + actual = _run_warp_obs(warp_obs.body_pose_w, warp_env, (NUM_ENVS, n_sel * 7), asset_cfg=body_cfg) + actual_cap = _run_warp_obs_captured(warp_obs.body_pose_w, warp_env, (NUM_ENVS, n_sel * 7), asset_cfg=body_cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_generated_commands(self, warp_env, stable_env): + expected = stable_obs.generated_commands(stable_env, command_name="vel") + actual = _run_warp_obs(warp_obs.generated_commands, warp_env, (NUM_ENVS, CMD_DIM), command_name="vel") + actual_cap = _run_warp_obs_captured( + warp_obs.generated_commands, warp_env, (NUM_ENVS, CMD_DIM), command_name="vel" + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# New reward parity tests +# ============================================================================ + + +class TestNewRewardParity: + """Verify newly migrated reward Warp kernels match stable torch implementations.""" + + def test_body_lin_acc_l2(self, warp_env, stable_env, body_cfg): + expected = stable_rew.body_lin_acc_l2(stable_env, asset_cfg=body_cfg) + actual = _run_warp_rew(warp_rew.body_lin_acc_l2, warp_env, asset_cfg=body_cfg) + actual_cap = _run_warp_rew_captured(warp_rew.body_lin_acc_l2, warp_env, asset_cfg=body_cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_track_lin_vel_xy_exp(self, warp_env, stable_env, body_cfg): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) # needed for stable + std = 0.25 + expected = stable_rew.track_lin_vel_xy_exp(stable_env, std=std, command_name="vel", asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.track_lin_vel_xy_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) + actual_cap = _run_warp_rew_captured( + warp_rew.track_lin_vel_xy_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_track_ang_vel_z_exp(self, warp_env, stable_env, body_cfg): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + std = 0.25 + expected = stable_rew.track_ang_vel_z_exp(stable_env, std=std, command_name="vel", asset_cfg=cfg) + actual = _run_warp_rew(warp_rew.track_ang_vel_z_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) + actual_cap = _run_warp_rew_captured( + warp_rew.track_ang_vel_z_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_undesired_contacts(self, warp_env, stable_env, sensor_cfg): + threshold = 0.5 + # Stable returns int64 (torch.sum of bools); warp returns float32 — cast for comparison. + expected = stable_rew.undesired_contacts(stable_env, threshold=threshold, sensor_cfg=sensor_cfg).float() + actual = _run_warp_rew(warp_rew.undesired_contacts, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) + actual_cap = _run_warp_rew_captured( + warp_rew.undesired_contacts, warp_env, threshold=threshold, sensor_cfg=sensor_cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_desired_contacts(self, warp_env, stable_env, sensor_cfg): + threshold = 0.5 + expected = stable_rew.desired_contacts(stable_env, sensor_cfg=sensor_cfg, threshold=threshold) + actual = _run_warp_rew(warp_rew.desired_contacts, warp_env, sensor_cfg=sensor_cfg, threshold=threshold) + actual_cap = _run_warp_rew_captured( + warp_rew.desired_contacts, warp_env, sensor_cfg=sensor_cfg, threshold=threshold + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_contact_forces(self, warp_env, stable_env, sensor_cfg): + threshold = 0.5 + expected = stable_rew.contact_forces(stable_env, threshold=threshold, sensor_cfg=sensor_cfg) + actual = _run_warp_rew(warp_rew.contact_forces, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) + actual_cap = _run_warp_rew_captured( + warp_rew.contact_forces, warp_env, threshold=threshold, sensor_cfg=sensor_cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# New termination parity tests +# ============================================================================ + + +class TestNewTerminationParity: + """Verify newly migrated termination Warp kernels match stable torch implementations.""" + + def test_time_out(self, warp_env, stable_env): + expected = stable_term.time_out(stable_env) + actual = _run_warp_term(warp_term.time_out, warp_env) + actual_cap = _run_warp_term_captured(warp_term.time_out, warp_env) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_illegal_contact(self, warp_env, stable_env, sensor_cfg): + threshold = 0.5 + expected = stable_term.illegal_contact(stable_env, threshold=threshold, sensor_cfg=sensor_cfg) + actual = _run_warp_term(warp_term.illegal_contact, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) + actual_cap = _run_warp_term_captured( + warp_term.illegal_contact, warp_env, threshold=threshold, sensor_cfg=sensor_cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + +# ============================================================================ +# New event capture-safety tests +# ============================================================================ + + +def _copy_np_to_wp(dest: wp.array, src_np: np.ndarray): + tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) + wp.copy(dest, tmp) + + +class TestNewEventCaptureSafety: + """Verify new event functions are capture-safe.""" + + def test_reset_root_state_with_random_orientation(self, warp_env, art_data, env_origins): + """With zero-width position ranges, positions = default + env_origin.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + zero_pose = {"x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0)} + zero_vel = { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + } + + warp_evt.reset_root_state_with_random_orientation(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) + with wp.ScopedCapture() as cap: + warp_evt.reset_root_state_with_random_orientation( + warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel + ) + + # Mutate defaults + new_pose = np.zeros((NUM_ENVS, 7), dtype=np.float32) + new_pose[:, 0:3] = [1.0, 2.0, 3.0] + new_pose[:, 3:7] = [0.0, 0.0, 0.0, 1.0] + _copy_np_to_wp(art_data.default_root_pose, new_pose) + + wp.capture_launch(cap.graph) + wp.synchronize() + + fn = warp_evt.reset_root_state_with_random_orientation + scratch_pose = wp.to_torch(fn._scratch_pose) + origins_t = wp.to_torch(env_origins) + + # Positions: default(1,2,3) + env_origin + 0 + expected_pos = torch.tensor([1.0, 2.0, 3.0], device=DEVICE).unsqueeze(0) + origins_t + assert_close(scratch_pose[:, :3], expected_pos) + + # Quaternions: should be unit quaternions (random SO(3)) + qnorm = scratch_pose[:, 3:7].norm(dim=1) + assert_close(qnorm, torch.ones(NUM_ENVS, device=DEVICE), atol=1e-4, rtol=1e-4) + + def test_reset_scene_to_default(self, warp_env, art_data, env_origins): + """With all envs masked, joints should be reset to defaults.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Set defaults to known values + _copy_np_to_wp(art_data.default_joint_pos, np.full((NUM_ENVS, NUM_JOINTS), 0.42, dtype=np.float32)) + _copy_np_to_wp(art_data.default_joint_vel, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) + + warp_evt.reset_scene_to_default(warp_env, mask) + wp.synchronize() + + result_pos = wp.to_torch(art_data.joint_pos) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.42, device=DEVICE) + assert_close(result_pos, expected) + + def test_reset_scene_to_default_mask_selectivity(self, warp_env, art_data, env_origins): + """Only masked envs are reset.""" + mask_np = np.array([i < NUM_ENVS // 2 for i in range(NUM_ENVS)]) + mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) + + # Set joint_pos to sentinel + _copy_np_to_wp(art_data.joint_pos, np.full((NUM_ENVS, NUM_JOINTS), 999.0, dtype=np.float32)) + # Set defaults to 0 + _copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) + + warp_evt.reset_scene_to_default(warp_env, mask) + wp.synchronize() + + result = wp.to_torch(art_data.joint_pos) + # Masked: reset to 0 + assert_close(result[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) + # Unmasked: still 999 + assert_close(result[NUM_ENVS // 2 :], torch.full((NUM_ENVS // 2, NUM_JOINTS), 999.0, device=DEVICE)) + + def test_randomize_rigid_body_com(self, warp_env, art_data): + """With zero-width range, CoM should not change. With nonzero range, CoM should differ.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + body_cfg = MockBodyCfg("robot", list(range(NUM_BODIES))) + + # Snapshot original CoM + original_com = wp.to_torch(art_data.body_com_pos_b).clone() + + # Zero range: no change + warp_evt.randomize_rigid_body_com( + warp_env, mask, com_range={"x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0)}, asset_cfg=body_cfg + ) + wp.synchronize() + assert_close(wp.to_torch(art_data.body_com_pos_b), original_com) + + def test_reset_root_state_from_terrain(self, warp_env, art_data, env_origins): + """With zero-width orientation and velocity ranges, verify positions come from terrain patches.""" + # Create mock terrain + rng = _make_rng(123) + num_levels, num_types, num_patches = 2, 2, 5 + flat_patches_np = rng.randn(num_levels, num_types, num_patches, 3).astype(np.float32) + flat_patches_torch = torch.tensor(flat_patches_np, device=DEVICE) + + terrain_levels = torch.zeros(NUM_ENVS, dtype=torch.int32, device=DEVICE) + terrain_types = torch.zeros(NUM_ENVS, dtype=torch.int32, device=DEVICE) + + # Attach terrain mock to scene + warp_env.scene.terrain = type( + "_T", + (), + { + "flat_patches": {"init_pos": flat_patches_torch}, + "terrain_levels": terrain_levels, + "terrain_types": terrain_types, + }, + )() + + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + zero_pose = {"roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (0.0, 0.0)} + zero_vel = { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + } + + warp_evt.reset_root_state_from_terrain(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) + wp.synchronize() + + fn = warp_evt.reset_root_state_from_terrain + scratch_pose = wp.to_torch(fn._scratch_pose) + + # All envs use level=0, type=0 so positions must come from flat_patches[0, 0, *, :] + valid_positions = flat_patches_torch[0, 0] # (num_patches, 3) + default_pos = wp.to_torch(art_data.default_root_pose)[:, :3] + + # Each env's position should be one of the valid patches + default offset + for i in range(min(8, NUM_ENVS)): # spot check first 8 + pos = scratch_pose[i, :3] + diffs = (valid_positions + default_pos[i]) - pos + min_dist = diffs.norm(dim=1).min() + assert min_dist < 1e-4, f"env {i}: position {pos} not near any valid patch" + + def test_command_resample(self, warp_env, cmd_term): + """Parity check for command_resample termination.""" + # Set up deterministic data: half the envs should trigger + cmd_term.time_left[:] = 0.01 # all below step_dt=0.02 + cmd_term.command_counter[: NUM_ENVS // 2] = 1.0 # match num_resamples=1 + cmd_term.command_counter[NUM_ENVS // 2 :] = 0.0 # no match + + expected = torch.logical_and( + cmd_term.time_left <= warp_env.step_dt, + cmd_term.command_counter == 1.0, + ) + + actual = _run_warp_term(warp_term.command_resample, warp_env, command_name="vel", num_resamples=1) + assert_equal(actual, expected) + + +# ============================================================================ +# Capture-mutate-replay tests for new terms +# ============================================================================ + + +def _mutate_body_data(art_data: MockMultiBodyArticulationData, rng_seed=200): + """Mutate body-level and root-level data in-place so captured graphs see fresh values.""" + rng = _make_rng(rng_seed) + + # Root state + root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 + _copy_np_to_wp(art_data.root_pos_w, root_pos_np) + _copy_np_to_wp(art_data.root_lin_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) + _copy_np_to_wp(art_data.root_ang_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) + + # Body data + grav_np = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) + grav_np[:, :, 2] = -1.0 + grav_np /= np.linalg.norm(grav_np, axis=2, keepdims=True) + _copy_np_to_wp(art_data.projected_gravity_b, grav_np) + + _copy_np_to_wp(art_data.body_lin_acc_w, rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32)) + + pose_np = np.zeros((NUM_ENVS, NUM_BODIES, 7), dtype=np.float32) + pose_np[:, :, :3] = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) + pose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] + _copy_np_to_wp(art_data.body_pose_w, pose_np) + + wp.synchronize() + + +class TestCapturedDataMutationNewTerms: + """Capture graph, mutate buffer data, replay — verify results match stable on new data. + + This validates the dynamic dependency update check (test requirement b). + """ + + def _capture_mutate_check_obs(self, warp_fn, stable_fn, warp_env, stable_env, art_data, shape, **kwargs): + out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) # warm-up + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_body_data(art_data) + wp.capture_launch(cap.graph) + expected = stable_fn(stable_env, **kwargs) + assert_close(wp.to_torch(out).clone(), expected) + + def _capture_mutate_check_rew(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_body_data(art_data) + wp.capture_launch(cap.graph) + expected = stable_fn(stable_env, **kwargs) + assert_close(wp.to_torch(out).clone(), expected) + + def _capture_mutate_check_term(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_body_data(art_data) + wp.capture_launch(cap.graph) + expected = stable_fn(stable_env, **kwargs) + assert_equal(wp.to_torch(out).clone(), expected) + + # -- body observations ------------------------------------------------- + + def test_body_projected_gravity_b(self, warp_env, stable_env, art_data, body_cfg): + n_sel = len(body_cfg.body_ids) + self._capture_mutate_check_obs( + warp_obs.body_projected_gravity_b, + stable_obs.body_projected_gravity_b, + warp_env, + stable_env, + art_data, + (NUM_ENVS, n_sel * 3), + asset_cfg=body_cfg, + ) + + def test_body_pose_w(self, warp_env, stable_env, art_data, body_cfg): + n_sel = len(body_cfg.body_ids) + + # Stable needs torch env_origins for unsqueeze + def stable_body_pose_w_fixed(env, **kw): + orig = env.scene.env_origins + env.scene.env_origins = wp.to_torch(orig) + result = stable_obs.body_pose_w(env, **kw) + env.scene.env_origins = orig + return result + + out = wp.zeros((NUM_ENVS, n_sel * 7), dtype=wp.float32, device=DEVICE) + warp_obs.body_pose_w(warp_env, out, asset_cfg=body_cfg) + with wp.ScopedCapture() as cap: + warp_obs.body_pose_w(warp_env, out, asset_cfg=body_cfg) + _mutate_body_data(art_data) + wp.capture_launch(cap.graph) + expected = stable_body_pose_w_fixed(stable_env, asset_cfg=body_cfg) + assert_close(wp.to_torch(out).clone(), expected) + + def test_generated_commands(self, warp_env, stable_env, art_data, cmd_tensor): + """Mutate command tensor, replay captured graph, verify new commands are read.""" + out = wp.zeros((NUM_ENVS, CMD_DIM), dtype=wp.float32, device=DEVICE) + warp_obs.generated_commands(warp_env, out, command_name="vel") + with wp.ScopedCapture() as cap: + warp_obs.generated_commands(warp_env, out, command_name="vel") + # Mutate the command tensor in-place (zero-copy view picks it up) + cmd_tensor[:] = torch.randn_like(cmd_tensor) + wp.capture_launch(cap.graph) + expected = stable_obs.generated_commands(stable_env, command_name="vel") + assert_close(wp.to_torch(out).clone(), expected) + + # -- rewards ----------------------------------------------------------- + + def test_body_lin_acc_l2(self, warp_env, stable_env, art_data, body_cfg): + self._capture_mutate_check_rew( + warp_rew.body_lin_acc_l2, + stable_rew.body_lin_acc_l2, + warp_env, + stable_env, + art_data, + asset_cfg=body_cfg, + ) + + def test_track_lin_vel_xy_exp(self, warp_env, stable_env, art_data): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + self._capture_mutate_check_rew( + warp_rew.track_lin_vel_xy_exp, + stable_rew.track_lin_vel_xy_exp, + warp_env, + stable_env, + art_data, + std=0.25, + command_name="vel", + asset_cfg=cfg, + ) + + def test_track_ang_vel_z_exp(self, warp_env, stable_env, art_data): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + self._capture_mutate_check_rew( + warp_rew.track_ang_vel_z_exp, + stable_rew.track_ang_vel_z_exp, + warp_env, + stable_env, + art_data, + std=0.25, + command_name="vel", + asset_cfg=cfg, + ) + + def test_contact_forces(self, warp_env, stable_env, art_data, contact_data, sensor_cfg): + """Mutate contact force history, verify captured graph picks up changes.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + warp_rew.contact_forces(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) + with wp.ScopedCapture() as cap: + warp_rew.contact_forces(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) + # Mutate contact sensor data in-place + contact_data.net_forces_w_history[:] = torch.randn_like(contact_data.net_forces_w_history) * 3.0 + wp.capture_launch(cap.graph) + expected = stable_rew.contact_forces(stable_env, threshold=0.5, sensor_cfg=sensor_cfg) + assert_close(wp.to_torch(out).clone(), expected) + + # -- terminations ------------------------------------------------------ + + def test_time_out(self, warp_env, stable_env, art_data): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_term.time_out(warp_env, out) + with wp.ScopedCapture() as cap: + warp_term.time_out(warp_env, out) + # Mutate episode length in-place + warp_env.episode_length_buf[:] = torch.randint(0, 600, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + wp.capture_launch(cap.graph) + expected = stable_term.time_out(stable_env) + assert_equal(wp.to_torch(out).clone(), expected) + + def test_illegal_contact(self, warp_env, stable_env, art_data, contact_data, sensor_cfg): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_term.illegal_contact(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) + with wp.ScopedCapture() as cap: + warp_term.illegal_contact(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) + contact_data.net_forces_w_history[:] = torch.randn_like(contact_data.net_forces_w_history) * 5.0 + wp.capture_launch(cap.graph) + expected = stable_term.illegal_contact(stable_env, threshold=0.5, sensor_cfg=sensor_cfg) + assert_equal(wp.to_torch(out).clone(), expected) diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md b/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md new file mode 100644 index 00000000000..24bbea777ee --- /dev/null +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md @@ -0,0 +1,200 @@ +# Articulation Data: CUDA Graph Capture Migration + +## Problem + +`ArticulationData` uses lazy `TimestampedWarpBuffer` properties that are incompatible with +`wp.ScopedCapture` / CUDA graph capture. When MDP terms access these properties inside a +captured scope, the compute kernel is skipped (timestamp already fresh from warmup) and +never recorded in the graph. Subsequent replays read stale data. + +## Property Tiers + +### Tier 1 — Sim-bind raw buffers (graph-safe) + +Direct physics solver outputs. Updated in-place each `sim.step()`. Stable pointers. + +| Property | Type | Source | +|----------|------|--------| +| `_sim_bind_root_link_pose_w` | `wp.transformf` | Solver root pose | +| `_sim_bind_root_com_vel_w` | `wp.spatial_vectorf` | Solver root COM velocity | +| `_sim_bind_body_link_pose_w` | `wp.transformf (2D)` | Solver body poses | +| `_sim_bind_body_com_vel_w` | `wp.spatial_vectorf (2D)` | Solver body COM velocities | +| `_sim_bind_joint_pos` | `wp.float32 (2D)` | Solver joint positions | +| `_sim_bind_joint_vel` | `wp.float32 (2D)` | Solver joint velocities | + +### Tier 2 — Derived properties (graph-hostile) + +Computed from Tier 1 via `wp.launch`, guarded by `TimestampedWarpBuffer` timestamp check. +The timestamp guard is a Python `if` that prevents the kernel from being captured. + +| Property | Computation | Inputs | +|----------|-------------|--------| +| `projected_gravity_b` | Rotate gravity into body frame | `root_link_pose_w` | +| `heading_w` | Extract yaw from quaternion | `root_link_pose_w` | +| `root_link_vel_w` | Project COM vel to link frame | `root_com_vel_w`, `root_link_pose_w` | +| `root_link_vel_b` | Project link vel to body frame | `root_link_vel_w`, `root_link_pose_w` | +| `root_com_vel_b` | Project COM vel to body frame | `root_com_vel_w`, `root_link_pose_w` | +| `root_com_pose_w` | Apply COM offset to link pose | `root_link_pose_w`, `body_com_pos_b` | +| `root_com_acc_w` | Finite difference of COM vel | `root_com_vel_w`, previous vel | +| `body_link_vel_w` | Project body COM vel to link frame | `body_com_vel_w`, `body_link_pose_w` | +| `body_com_pose_w` | Apply COM offset to body poses | `body_link_pose_w`, `body_com_pos_b` | +| `body_com_acc_w` | Finite difference of body vel | `body_com_vel_w`, previous vel | +| `joint_acc` | Finite difference of joint vel | `joint_vel`, previous vel | + +### Tier 3 — Sliced properties (mostly graph-safe) + +Extract a single component from a compound type. If data is contiguous, a strided +`wp.array` view is created once (zero-cost). If not contiguous, a `wp.launch(split_...)` +runs each access — which IS captured correctly since `is_contiguous` is a fixed flag. + +**Exception:** Tier 3 properties that chain through Tier 2 are NOT graph-safe: + +| Property | Chains through | Graph-safe? | +|----------|---------------|-------------| +| `root_link_pos_w` | Tier 1 (`_sim_bind_root_link_pose_w`) | Yes | +| `root_link_quat_w` | Tier 1 (`_sim_bind_root_link_pose_w`) | Yes | +| `root_com_lin_vel_w` | Tier 1 (`_sim_bind_root_com_vel_w`) | Yes | +| `root_com_ang_vel_w` | Tier 1 (`_sim_bind_root_com_vel_w`) | Yes | +| `root_link_lin_vel_b` | **Tier 2** (`root_link_vel_b`) | **No** | +| `root_link_ang_vel_b` | **Tier 2** (`root_link_vel_b`) | **No** | +| `root_com_lin_vel_b` | **Tier 2** (`root_com_vel_b`) | **No** | +| `root_com_ang_vel_b` | **Tier 2** (`root_com_vel_b`) | **No** | + +## Why Lazy? + +The laziness exists for two reasons: + +1. **Avoid unnecessary computation.** An env using only `joint_pos` should not pay for + `projected_gravity_b`. Most envs only use a small subset of derived properties. + +2. **Deduplicate within a step.** If multiple MDP terms access `projected_gravity_b` in the + same step, the timestamp guard ensures the kernel runs only once. Without it, the same + transform would be recomputed per access. + +`update()` (called each `scene.update(dt)`) only eagerly pre-computes `joint_acc` and +`body_com_acc_w` because these need the previous-step velocity snapshot for finite differencing. +Everything else stays lazy. + +## Capture Failure Mechanism + +``` +_wp_capture_or_launch: + 1. WARMUP (eager): + - MDP term accesses asset.data.projected_gravity_b + - TimestampedWarpBuffer: timestamp(-1) < sim_timestamp(T) → True + - wp.launch(project_vec_from_pose_single, ...) runs + - timestamp set to T + + 2. CAPTURE (wp.ScopedCapture): + - MDP term accesses asset.data.projected_gravity_b + - TimestampedWarpBuffer: timestamp(T) < sim_timestamp(T) → False + - wp.launch SKIPPED — kernel NOT recorded in graph + - MDP term's own wp.launch recorded, pointing to projected_gravity_b.data + + 3. REPLAY (all subsequent steps): + - Only MDP term's kernel replays + - Reads from projected_gravity_b.data — NEVER recomputed + - Data is stale from warmup +``` + +## Proposed Fix: `materialize_derived()` + +Add a method to `ArticulationData` that unconditionally launches all Tier 2 kernels +and updates timestamps. Call from `scene.update()` which runs outside capture scopes. + +```python +# ArticulationData +def materialize_derived(self) -> None: + """Eagerly compute all Tier 2 derived properties. + + Call before any captured graph that reads derived data. + Safe to call every step — cost is the same as accessing each property once. + """ + # Root-level derived + _ = self.projected_gravity_b # forces timestamp check → launches if stale + _ = self.heading_w + _ = self.root_link_vel_w + _ = self.root_link_vel_b + _ = self.root_com_vel_b + _ = self.root_com_pose_w + # Body-level derived + _ = self.body_link_vel_w + _ = self.body_com_pose_w +``` + +Integration point — `scene.update()` or `ArticulationData.update()`: + +```python +def update(self, dt: float): + self._sim_timestamp += dt + # Existing: finite-difference quantities (need previous-step snapshot) + self.joint_acc + self.body_com_acc_w + # NEW: eagerly materialize all derived properties for graph capture + self.materialize_derived() +``` + +**Trade-off:** This removes the lazy optimization — every derived property computes +every step, even if unused. For capture-mode envs this is the correct trade-off (the +kernel cost is negligible vs graph replay savings). For non-capture envs, the extra +kernels add overhead for unused properties. + +**Better approach — opt-in materialization:** + +Only materialize properties that the env actually uses. The `ManagerCallSwitch` knows +which managers are in capture mode. The env can call `materialize_derived()` only when +capture mode is active: + +```python +# In ManagerBasedRLEnvWarp, after scene.update(): +if any_manager_in_capture_mode: + for articulation in self.scene.articulations.values(): + articulation.data.materialize_derived() +``` + +Or more selectively, track which properties were accessed during warmup and only +materialize those on subsequent steps. + +## Alternative: Use Compound Types in MDP Kernels + +Instead of fixing the data class, modify MDP terms to use Tier 1 compound types directly +(`root_link_pose_w` as `wp.transformf`, `root_com_vel_w` as `wp.spatial_vectorf`) and +extract components inside warp kernels: + +```python +@wp.kernel +def _projected_gravity_kernel( + pose_w: wp.array(dtype=wp.transformf), + gravity: wp.vec3f, + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + q = wp.transform_get_rotation(pose_w[i]) + g_b = wp.quat_rotate_inv(q, gravity) + out[i, 0] = g_b[0] + out[i, 1] = g_b[1] + out[i, 2] = g_b[2] +``` + +**Pros:** No changes to articulation data class. Eliminates all Tier 2/3 overhead. +**Cons:** Every MDP term must be rewritten. Duplicates split logic across terms. + +## Affected MDP Terms + +See "Non-Capturable MDP Terms" section in +`isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md` for the full list of +MDP terms marked `@warp_capturable(False)` due to Tier 2 access, and the pending fix +(`materialize_derived()`) that would make them capturable again. + +## Recommendation + +Short-term: Mark affected MDP terms `@warp_capturable(False)` so they fall back to +mode=1 automatically. No incorrect results, modest perf regression for those terms. + +Medium-term: Add `materialize_derived()` to `ArticulationData` and call it from +`scene.update()` when capture mode is active. Minimal changes, preserves lazy +optimization for non-capture users. Once applied, all `@warp_capturable(False)` +annotations for Tier 2 access can be removed and these terms become fully capturable. + +Long-term: Migrate MDP kernels to use compound Tier 1 types directly. Best performance, +no derived property overhead at all. diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py index 4781f141af4..79c13e2aa8f 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py @@ -3,4 +3,10 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Classic experimental task registrations (manager-based).""" +"""Classic environments for control. + +These environments are based on the MuJoCo environments provided by OpenAI. + +Reference: + https://github.com/openai/gym/tree/master/gym/envs/mujoco +""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py new file mode 100644 index 00000000000..5f123abaa75 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Ant locomotion environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.ant import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Ant-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.ant_env_cfg:AntEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AntPPORunnerCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py new file mode 100644 index 00000000000..a6d8e9e1d7e --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py @@ -0,0 +1,196 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# Ant reuses humanoid's experimental MDP (mirrors stable pattern). +import isaaclab_tasks_experimental.manager_based.classic.humanoid.mdp as mdp +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets.robots.ant import ANT_CFG # isort: skip + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with an ant robot.""" + + # terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="plane", + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="average", + restitution_combine_mode="average", + static_friction=1.0, + dynamic_friction=1.0, + restitution=0.0, + ), + debug_vis=False, + ) + + # robot + robot = ANT_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg(asset_name="robot", joint_names=[".*"], scale=7.5) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for the policy.""" + + base_height = ObsTerm(func=mdp.base_pos_z) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel) + base_yaw_roll = ObsTerm(func=mdp.base_yaw_roll) + base_angle_to_target = ObsTerm(func=mdp.base_angle_to_target, params={"target_pos": (1000.0, 0.0, 0.0)}) + base_up_proj = ObsTerm(func=mdp.base_up_proj) + base_heading_proj = ObsTerm(func=mdp.base_heading_proj, params={"target_pos": (1000.0, 0.0, 0.0)}) + joint_pos_norm = ObsTerm(func=mdp.joint_pos_limit_normalized) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel, scale=0.2) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "position_range": (-0.2, 0.2), + "velocity_range": (-0.1, 0.1), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Reward for moving forward + progress = RewTerm(func=mdp.progress_reward, weight=1.0, params={"target_pos": (1000.0, 0.0, 0.0)}) + # (2) Stay alive bonus + alive = RewTerm(func=mdp.is_alive, weight=0.5) + # (3) Reward for non-upright posture + upright = RewTerm(func=mdp.upright_posture_bonus, weight=0.1, params={"threshold": 0.93}) + # (4) Reward for moving in the right direction + move_to_target = RewTerm( + func=mdp.move_to_target_bonus, weight=0.5, params={"threshold": 0.8, "target_pos": (1000.0, 0.0, 0.0)} + ) + # (5) Penalty for large action commands + action_l2 = RewTerm(func=mdp.action_l2, weight=-0.005) + # (6) Penalty for energy consumption + energy = RewTerm(func=mdp.power_consumption, weight=-0.05, params={"gear_ratio": {".*": 15.0}}) + # (7) Penalty for reaching close to joint limits + joint_pos_limits = RewTerm( + func=mdp.joint_pos_limits_penalty_ratio, weight=-0.1, params={"threshold": 0.99, "gear_ratio": {".*": 15.0}} + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Terminate if the episode length is exceeded + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Terminate if the robot falls + torso_height = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.31}) + + +## +# Environment configuration +## + + +@configclass +class AntEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the MuJoCo-style Ant walking environment.""" + + # Simulation settings + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=38, + nconmax=15, + ls_iterations=10, + cone="pyramidal", + ls_parallel=True, + impratio=1, + ), + num_substeps=1, + debug_mode=False, + ) + ) + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=5.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 16.0 + # simulation settings + self.sim.dt = 1 / 120.0 + self.sim.render_interval = self.decimation + # default friction material + self.sim.physics_material.static_friction = 1.0 + self.sim.physics_material.dynamic_friction = 1.0 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py index 17a4c5c03cd..4e332426494 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py @@ -9,21 +9,22 @@ import gymnasium as gym +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.cartpole import agents + +## +# Register Gym environments. +## + gym.register( id="Isaac-Cartpole-Warp-v0", entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", disable_env_checker=True, kwargs={ - # Use experimental Cartpole cfg (allows isolated modifications). - "env_cfg_entry_point": ( - "isaaclab_tasks_experimental.manager_based.classic.cartpole.cartpole_env_cfg:CartpoleEnvCfg" - ), - # Point agent configs to the existing task package. - "rl_games_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:rl_games_ppo_cfg.yaml", - "rsl_rl_cfg_entry_point": ( - "isaaclab_tasks.manager_based.classic.cartpole.agents.rsl_rl_ppo_cfg:CartpolePPORunnerCfg" - ), - "skrl_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:skrl_ppo_cfg.yaml", - "sb3_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:sb3_ppo_cfg.yaml", + "env_cfg_entry_point": f"{__name__}.cartpole_env_cfg:CartpoleEnvCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartpolePPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", }, ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py new file mode 100644 index 00000000000..c08e5156b92 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Humanoid locomotion environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.humanoid import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Humanoid-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.humanoid_env_cfg:HumanoidEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:HumanoidPPORunnerCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py new file mode 100644 index 00000000000..40841677070 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py @@ -0,0 +1,231 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.classic.humanoid.mdp as mdp +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass + +from isaaclab_assets.robots.humanoid import HUMANOID_CFG # isort:skip + + +## +# Scene definition +## + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with a humanoid robot.""" + + # terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="plane", + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg(static_friction=1.0, dynamic_friction=1.0, restitution=0.0), + debug_vis=False, + ) + + # robot + robot = HUMANOID_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg( + asset_name="robot", + joint_names=[".*"], + scale={ + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + }, + ) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for the policy.""" + + base_height = ObsTerm(func=mdp.base_pos_z) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel, scale=0.25) + base_yaw_roll = ObsTerm(func=mdp.base_yaw_roll) + base_angle_to_target = ObsTerm(func=mdp.base_angle_to_target, params={"target_pos": (1000.0, 0.0, 0.0)}) + base_up_proj = ObsTerm(func=mdp.base_up_proj) + base_heading_proj = ObsTerm(func=mdp.base_heading_proj, params={"target_pos": (1000.0, 0.0, 0.0)}) + joint_pos_norm = ObsTerm(func=mdp.joint_pos_limit_normalized) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel, scale=0.1) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "position_range": (-0.2, 0.2), + "velocity_range": (-0.1, 0.1), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Reward for moving forward + progress = RewTerm(func=mdp.progress_reward, weight=1.0, params={"target_pos": (1000.0, 0.0, 0.0)}) + # (2) Stay alive bonus + alive = RewTerm(func=mdp.is_alive, weight=2.0) + # (3) Reward for non-upright posture + upright = RewTerm(func=mdp.upright_posture_bonus, weight=0.1, params={"threshold": 0.93}) + # (4) Reward for moving in the right direction + move_to_target = RewTerm( + func=mdp.move_to_target_bonus, weight=0.5, params={"threshold": 0.8, "target_pos": (1000.0, 0.0, 0.0)} + ) + # (5) Penalty for large action commands + action_l2 = RewTerm(func=mdp.action_l2, weight=-0.01) + # (6) Penalty for energy consumption + energy = RewTerm( + func=mdp.power_consumption, + weight=-0.005, + params={ + "gear_ratio": { + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + } + }, + ) + # (7) Penalty for reaching close to joint limits + joint_pos_limits = RewTerm( + func=mdp.joint_pos_limits_penalty_ratio, + weight=-0.25, + params={ + "threshold": 0.98, + "gear_ratio": { + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + }, + }, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Terminate if the episode length is exceeded + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Terminate if the robot falls + torso_height = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.8}) + + +@configclass +class HumanoidEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the MuJoCo-style Humanoid walking environment.""" + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=5.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 16.0 + # simulation settings + self.sim: SimulationCfg = SimulationCfg( + dt=1 / 120.0, + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=100, + nconmax=25, + ls_iterations=15, + cone="pyramidal", + ls_parallel=True, + update_data_interval=2, + impratio=1, + ), + num_substeps=2, + debug_mode=False, + ), + ) + # self.sim.dt = 1 / 120.0 + self.sim.render_interval = self.decimation + # default friction material + self.sim.physics_material.static_friction = 1.0 + self.sim.physics_material.dynamic_friction = 1.0 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py new file mode 100644 index 00000000000..df0802edf05 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the humanoid environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .observations import * # noqa: F401, F403 +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py new file mode 100644 index 00000000000..59906f4b5ef --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first observation terms for the humanoid task. + +All observation functions follow the ``func(env, out, **params) -> None`` signature. +Dimensions are declared via ``out_dim`` on the ``@generic_io_descriptor_warp`` decorator. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.envs.utils.io_descriptors import generic_io_descriptor_warp +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +# Reviewed(jichuanh): file reviewed +@wp.kernel +def _base_yaw_roll_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + out: wp.array(dtype=wp.float32, ndim=2), +): + """Extract yaw and roll angles from root quaternion (x, y, z, w layout).""" + i = wp.tid() + q = root_quat_w[i] + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + # roll = atan2(2*(qw*qx + qy*qz), 1 - 2*(qx^2 + qy^2)) + sin_roll = 2.0 * (qw * qx + qy * qz) + cos_roll = 1.0 - 2.0 * (qx * qx + qy * qy) + roll = wp.atan2(sin_roll, cos_roll) + # yaw = atan2(2*(qw*qz + qx*qy), 1 - 2*(qy^2 + qz^2)) + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw = wp.atan2(sin_yaw, cos_yaw) + out[i, 0] = yaw + out[i, 1] = roll + + +@generic_io_descriptor_warp(out_dim=2, observation_type="RootState") +def base_yaw_roll(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Yaw and roll of the base in the simulation world frame. Shape: (num_envs, 2).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_yaw_roll_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, out], + device=env.device, + ) + + +@wp.kernel +def _base_up_proj_kernel( + projected_gravity_b: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + """Project base up vector onto world up: -gravity_b[2].""" + i = wp.tid() + out[i, 0] = -projected_gravity_b[i][2] + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_up_proj(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Projection of the base up vector onto the world up vector. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_up_proj_kernel, + dim=env.num_envs, + inputs=[asset.data.projected_gravity_b, out], + device=env.device, + ) + + +@wp.kernel +def _base_heading_proj_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + target_z: float, + out: wp.array(dtype=wp.float32, ndim=2), +): + """Dot product between robot forward and direction to target.""" + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # compute direction to target (zeroed z) + dx = target_x - pos[0] + dy = target_y - pos[1] + dist = wp.sqrt(dx * dx + dy * dy) + # avoid division by zero + inv_dist = wp.where(dist > 1.0e-6, 1.0 / dist, 0.0) + to_target_x = dx * inv_dist + to_target_y = dy * inv_dist + # compute forward vector via quaternion rotation of (1,0,0) + fwd = wp.quat_rotate(q, wp.vec3f(1.0, 0.0, 0.0)) + # dot product (xy only) + heading_proj = fwd[0] * to_target_x + fwd[1] * to_target_y + out[i, 0] = heading_proj + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_heading_proj( + env: ManagerBasedEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Dot product between the base forward direction and direction to target. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_heading_proj_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], target_pos[2], out], + device=env.device, + ) + + +@wp.kernel +def _base_angle_to_target_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + out: wp.array(dtype=wp.float32, ndim=2), +): + """Angle between base forward and vector to target, normalized to [-pi, pi].""" + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # angle to target in world frame + dx = target_x - pos[0] + dy = target_y - pos[1] + walk_target_angle = wp.atan2(dy, dx) + # extract yaw from quaternion + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw = wp.atan2(sin_yaw, cos_yaw) + # normalize to [-pi, pi] + angle = walk_target_angle - yaw + out[i, 0] = wp.atan2(wp.sin(angle), wp.cos(angle)) + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_angle_to_target( + env: ManagerBasedEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Angle between the base forward vector and the vector to the target. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_angle_to_target_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py new file mode 100644 index 00000000000..de6e24be978 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py @@ -0,0 +1,309 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward terms for the humanoid task. + +All reward functions follow the ``func(env, out, **params) -> None`` signature +where ``out`` is a pre-allocated Warp array of shape ``(num_envs,)`` with float32 dtype. +""" + +from __future__ import annotations + +import torch +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers.manager_base import ManagerTermBase +from isaaclab_experimental.utils.warp import warp_capturable + +import isaaclab.utils.string as string_utils +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab_experimental.managers.manager_term_cfg import RewardTermCfg + + from isaaclab.envs import ManagerBasedRLEnv + + +# --------------------------------------------------------------------------- +# Function-based reward terms +# --------------------------------------------------------------------------- + +# Reviewed(jichuanh): file roughly reviewed + + +@wp.kernel +def _upright_posture_bonus_kernel( + projected_gravity_b: wp.array(dtype=wp.vec3f), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + up_proj = -projected_gravity_b[i][2] + out[i] = wp.where(up_proj > threshold, 1.0, 0.0) + + +@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) +def upright_posture_bonus( + env: ManagerBasedRLEnv, out, threshold: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward for maintaining an upright posture. Writes 1.0 if up_proj > threshold, else 0.0.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_upright_posture_bonus_kernel, + dim=env.num_envs, + inputs=[asset.data.projected_gravity_b, threshold, out], + device=env.device, + ) + + +@wp.kernel +def _move_to_target_bonus_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # direction to target + dx = target_x - pos[0] + dy = target_y - pos[1] + dist = wp.sqrt(dx * dx + dy * dy) + inv_dist = wp.where(dist > 1.0e-6, 1.0 / dist, 0.0) + to_target_x = dx * inv_dist + to_target_y = dy * inv_dist + # forward vector + fwd = wp.quat_rotate(q, wp.vec3f(1.0, 0.0, 0.0)) + heading_proj = fwd[0] * to_target_x + fwd[1] * to_target_y + out[i] = wp.where(heading_proj > threshold, 1.0, heading_proj / threshold) + + +def move_to_target_bonus( + env: ManagerBasedRLEnv, + out, + threshold: float, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward for heading towards the target.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_move_to_target_bonus_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], threshold, out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# Class-based reward terms +# --------------------------------------------------------------------------- + + +@wp.kernel +def _progress_reward_reset_kernel( + env_mask: wp.array(dtype=wp.bool), + root_pos_w: wp.array(dtype=wp.vec3f), + target_x: float, + target_y: float, + target_z: float, + inv_step_dt: float, + potentials: wp.array(dtype=wp.float32), +): + i = wp.tid() + if env_mask[i]: + pos = root_pos_w[i] + dx = target_x - pos[0] + dy = target_y - pos[1] + dz = target_z - pos[2] + dist = wp.sqrt(dx * dx + dy * dy + dz * dz) + potentials[i] = -dist * inv_step_dt + + +@wp.kernel +def _progress_reward_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + target_x: float, + target_y: float, + inv_step_dt: float, + potentials: wp.array(dtype=wp.float32), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + pos = root_pos_w[i] + dx = target_x - pos[0] + dy = target_y - pos[1] + # z component is zeroed (xy distance only, matching stable) + dist = wp.sqrt(dx * dx + dy * dy) + prev = potentials[i] + pot = -dist * inv_step_dt + potentials[i] = pot + out[i] = pot - prev + + +class progress_reward(ManagerTermBase): + """Reward for making progress towards the target (potential-based).""" + + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): + super().__init__(cfg, env) + self.potentials = wp.zeros(env.num_envs, dtype=wp.float32, device=env.device) + self._target_pos = cfg.params["target_pos"] + + def reset(self, env_mask: wp.array | None = None) -> None: + if env_mask is None: + self.potentials.zero_() + return + asset: Articulation = self._env.scene["robot"] + inv_dt = 1.0 / self._env.step_dt + wp.launch( + kernel=_progress_reward_reset_kernel, + dim=self.num_envs, + inputs=[ + env_mask, + asset.data.root_pos_w, + self._target_pos[0], + self._target_pos[1], + self._target_pos[2], + inv_dt, + self.potentials, + ], + device=self.device, + ) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + inv_dt = 1.0 / env.step_dt + wp.launch( + kernel=_progress_reward_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, target_pos[0], target_pos[1], inv_dt, self.potentials, out], + device=env.device, + ) + + +@wp.kernel +def _joint_pos_limits_penalty_ratio_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_limits: wp.array(dtype=wp.vec2f, ndim=2), + gear_ratio_scaled: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + inv_range: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = joint_pos.shape[1] + s = float(0.0) + for j in range(n_joints): + lim = soft_limits[i, j] + lower = lim.x + upper = lim.y + mid = (lower + upper) * 0.5 + half_range = (upper - lower) * 0.5 + scaled = float(0.0) + if half_range > 0.0: + scaled = (joint_pos[i, j] - mid) / half_range + abs_scaled = wp.abs(scaled) + if abs_scaled > threshold: + violation = (abs_scaled - threshold) * inv_range + s += violation * gear_ratio_scaled[i, j] + out[i] = s + + +class joint_pos_limits_penalty_ratio(ManagerTermBase): + """Penalty for violating joint position limits weighted by the gear ratio.""" + + def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg): + asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot")) + asset: Articulation = env.scene[asset_cfg.name] + + # resolve the gear ratio for each joint (torch in __init__ is fine) + gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) + index_list, _, value_list = string_utils.resolve_matching_names_values( + cfg.params["gear_ratio"], asset.joint_names + ) + gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device) + gear_ratio_scaled = gear_ratio / torch.max(gear_ratio) + self._gear_ratio_scaled_wp = wp.from_torch(gear_ratio_scaled) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + threshold: float, + gear_ratio: dict[str, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_joint_pos_limits_penalty_ratio_kernel, + dim=env.num_envs, + inputs=[ + asset.data.joint_pos, + asset.data.soft_joint_pos_limits, + self._gear_ratio_scaled_wp, + threshold, + 1.0 / (1.0 - threshold), + out, + ], + device=env.device, + ) + + +@wp.kernel +def _power_consumption_kernel( + action: wp.array(dtype=wp.float32, ndim=2), + joint_vel: wp.array(dtype=wp.float32, ndim=2), + gear_ratio_scaled: wp.array(dtype=wp.float32, ndim=2), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = action.shape[1] + s = float(0.0) + for j in range(n_joints): + s += wp.abs(action[i, j] * joint_vel[i, j] * gear_ratio_scaled[i, j]) + out[i] = s + + +class power_consumption(ManagerTermBase): + """Penalty for the power consumed by the actions to the environment.""" + + def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg): + asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot")) + asset: Articulation = env.scene[asset_cfg.name] + + # resolve the gear ratio for each joint (torch in __init__ is fine) + gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) + index_list, _, value_list = string_utils.resolve_matching_names_values( + cfg.params["gear_ratio"], asset.joint_names + ) + gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device) + gear_ratio_scaled = gear_ratio / torch.max(gear_ratio) + self._gear_ratio_scaled_wp = wp.from_torch(gear_ratio_scaled) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + gear_ratio: dict[str, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_power_consumption_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, asset.data.joint_vel, self._gear_ratio_scaled_wp, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py new file mode 100644 index 00000000000..0660d38f065 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Locomotion experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py new file mode 100644 index 00000000000..0857176a3fc --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Velocity locomotion experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py new file mode 100644 index 00000000000..26f3257daef --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Configurations for velocity-based locomotion environments.""" + +# We leave this file empty since we don't want to expose any configs in this package directly. +# We still need this file to import the "config" module in the parent package. diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py new file mode 100644 index 00000000000..6c79524e853 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.a1 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-A1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-A1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py new file mode 100644 index 00000000000..0705892c827 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeA1RoughEnvCfg + + +@configclass +class UnitreeA1FlatEnvCfg(UnitreeA1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=45, + nconmax=30, + ls_iterations=30, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no height scan + # self.scene.height_scanner = None + # self.observations.policy.height_scan = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class UnitreeA1FlatEnvCfg_PLAY(UnitreeA1FlatEnvCfg): + def __post_init__(self) -> None: + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py new file mode 100644 index 00000000000..ace7b9cf8a6 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + TerminationsCfg, +) + +from isaaclab.utils import configclass + +from isaaclab_assets.robots.unitree import UNITREE_A1_CFG # isort: skip + +# reviewed(jichuanh): file roughly reviewed + + +class TerminationsCfg_A1(TerminationsCfg): + base_too_low = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.2}) + + +@configclass +class UnitreeA1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + terminations: TerminationsCfg_A1 = TerminationsCfg_A1() + + def __post_init__(self): + # post init of parent + super().__post_init__() + + self.scene.robot = UNITREE_A1_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + + # reduce action scale + self.actions.joint_pos.scale = 0.25 + + # event + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass.params["mass_distribution_params"] = (-1.0, 3.0) + # self.events.add_base_mass.params["asset_cfg"].body_names = "trunk" + self.events.base_external_force_torque.params["asset_cfg"].body_names = "trunk" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + + # rewards + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts.params["sensor_cfg"].body_names = ".*thigh" + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "trunk" + + +@configclass +class UnitreeA1RoughEnvCfg_PLAY(UnitreeA1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py new file mode 100644 index 00000000000..cbbf5290e82 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_b import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Anymal-B-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalBFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerWithSymmetryCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-B-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalBFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerWithSymmetryCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py new file mode 100644 index 00000000000..3dbe3af6144 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalBRoughEnvCfg + + +@configclass +class AnymalBFlatEnvCfg(AnymalBRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=50, + nconmax=15, + ls_iterations=15, + cone="elliptic", + impratio=100, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no height scan + # self.scene.height_scanner = None + # self.observations.policy.height_scan = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class AnymalBFlatEnvCfg_PLAY(AnymalBFlatEnvCfg): + def __post_init__(self) -> None: + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py new file mode 100644 index 00000000000..3829a6999ba --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab.utils import configclass + +from isaaclab_assets import ANYMAL_B_CFG # isort: skip + + +@configclass +class AnymalBRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = ANYMAL_B_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalBRoughEnvCfg_PLAY(AnymalBRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py new file mode 100644 index 00000000000..318b13cc470 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_c import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Anymal-C-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalCFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerWithSymmetryCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_flat_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-C-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalCFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerWithSymmetryCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_flat_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py new file mode 100644 index 00000000000..185225219b0 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalCRoughEnvCfg + + +@configclass +class AnymalCFlatEnvCfg(AnymalCRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=50, + nconmax=15, + ls_iterations=40, + cone="elliptic", + impratio=100, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class AnymalCFlatEnvCfg_PLAY(AnymalCFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py new file mode 100644 index 00000000000..1814aba423a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets.robots.anymal import ANYMAL_C_CFG # isort: skip + + +@configclass +class AnymalCRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # switch robot to anymal-c + self.scene.robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalCRoughEnvCfg_PLAY(AnymalCRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py new file mode 100644 index 00000000000..e5e75d19dc2 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_d import agents + +## +# Register Gym environments. +## + +# Rough env disabled: requires isaaclab_physx which is not yet available on dev/newton. +# The package exists on upstream/develop (commit 308400f1d35) but has not been merged. +# Re-enable once dev/newton picks up isaaclab_physx. +# gym.register( +# id="Isaac-Velocity-Rough-Anymal-D-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:AnymalDRoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDRoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Velocity-Rough-Anymal-D-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:AnymalDRoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDRoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-D-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-D-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py new file mode 100644 index 00000000000..2db47cad0ec --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalDRoughEnvCfg + + +@configclass +class AnymalDFlatEnvCfg(AnymalDRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=60, + nconmax=25, + ls_iterations=40, + cone="elliptic", + impratio=100.0, + ls_parallel=True, + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class AnymalDFlatEnvCfg_PLAY(AnymalDFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py new file mode 100644 index 00000000000..b5521534dfa --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets.robots.anymal import ANYMAL_D_CFG # isort: skip + + +@configclass +class AnymalDRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = ANYMAL_D_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalDRoughEnvCfg_PLAY(AnymalDRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py new file mode 100644 index 00000000000..4d9d4a77883 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.cassie import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Cassie-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:CassieFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CassieFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Cassie-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:CassieFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CassieFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py new file mode 100644 index 00000000000..bffe0eb8ea8 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import CassieRoughEnvCfg + + +@configclass +class CassieFlatEnvCfg(CassieRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=52, + nconmax=15, + ls_iterations=10, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 5.0 + self.rewards.joint_deviation_hip.params["asset_cfg"].joint_names = ["hip_rotation_.*"] + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class CassieFlatEnvCfg_PLAY(CassieFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py new file mode 100644 index 00000000000..aba4b82aebb --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +from isaaclab.utils import configclass + +from isaaclab_assets.robots.cassie import CASSIE_CFG # isort: skip + + +@configclass +class CassieRewardsCfg(RewardsCfg): + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=2.5, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*toe"), + "command_name": "base_velocity", + "threshold": 0.3, + }, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["hip_abduction_.*", "hip_rotation_.*"])}, + ) + joint_deviation_toes = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["toe_joint_.*"])}, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="toe_joint_.*")}, + ) + + +@configclass +class CassieRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: CassieRewardsCfg = CassieRewardsCfg() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = CASSIE_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.actions.joint_pos.scale = 0.5 + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = [".*pelvis"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.terminations.base_contact.params["sensor_cfg"].body_names = [".*pelvis"] + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -5.0e-6 + self.rewards.track_lin_vel_xy_exp.weight = 2.0 + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.action_rate_l2.weight *= 1.5 + self.rewards.dof_acc_l2.weight *= 1.5 + + +@configclass +class CassieRoughEnvCfg_PLAY(CassieRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py new file mode 100644 index 00000000000..83bdf047a48 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.g1 import agents + +## +# Register Gym environments. +## + +# gym.register( +# id="Isaac-Velocity-Rough-G1-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:G1RoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + + +# gym.register( +# id="Isaac-Velocity-Rough-G1-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:G1RoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py new file mode 100644 index 00000000000..16bfaf1e9b7 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from isaaclab_experimental.managers import SceneEntityCfg + +from .rough_env_cfg import G1RoughEnvCfg + + +@configclass +class G1FlatEnvCfg(G1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=95, + nconmax=10, + ls_iterations=10, + ls_parallel=True, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.lin_vel_z_l2.weight = -0.2 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.0e-7 + self.rewards.feet_air_time.weight = 0.75 + self.rewards.feet_air_time.params["threshold"] = 0.4 + self.rewards.dof_torques_l2.weight = -2.0e-6 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.5, 0.5) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + +class G1FlatEnvCfg_PLAY(G1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py new file mode 100644 index 00000000000..740f70a279a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets import G1_MINIMAL_CFG # isort: skip + + +@configclass +class G1Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.1, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])}, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_shoulder_pitch_joint", + ".*_shoulder_roll_joint", + ".*_shoulder_yaw_joint", + ".*_elbow_pitch_joint", + ".*_elbow_roll_joint", + ], + ) + }, + ) + joint_deviation_fingers = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.05, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_five_joint", + ".*_three_joint", + ".*_six_joint", + ".*_four_joint", + ".*_zero_joint", + ".*_one_joint", + ".*_two_joint", + ], + ) + }, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso_joint")}, + ) + + +@configclass +class G1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: G1Rewards = G1Rewards() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = G1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + + # Rewards + self.rewards.lin_vel_z_l2.weight = 0.0 + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.rewards.dof_torques_l2.weight = -1.5e-7 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"] + ) + + # Commands + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + # terminations + self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link" + + +@configclass +class G1RoughEnvCfg_PLAY(G1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.episode_length_s = 40.0 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + self.commands.base_velocity.ranges.lin_vel_x = (1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.commands.base_velocity.ranges.heading = (0.0, 0.0) + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py new file mode 100644 index 00000000000..a0ae516387a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.g1_29_dofs import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-v1", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1_29_DOFs_FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1_29_DOFs_FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-Play-v1", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1_29_DOFs_FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1_29_DOFs_FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py new file mode 100644 index 00000000000..7d293ef7e88 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from isaaclab_experimental.managers import SceneEntityCfg + +from .rough_env_cfg import G1_29_DOFs_RoughEnvCfg + + +@configclass +class G1_29_DOFs_FlatEnvCfg(G1_29_DOFs_RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=210, + nconmax=35, + ls_iterations=10, + ls_parallel=True, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.lin_vel_z_l2.weight = -0.2 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.0e-7 + self.rewards.feet_air_time.weight = 0.75 + self.rewards.feet_air_time.params["threshold"] = 0.4 + self.rewards.dof_torques_l2.weight = -2.0e-6 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (-1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.5, 0.5) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + +class G1_29_DOFs_FlatEnvCfg_PLAY(G1_29_DOFs_FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py new file mode 100644 index 00000000000..901ed99b128 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets import G1_29_DOF_CFG # isort: skip + + +@configclass +class G1_29_DOFs_Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.1, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])}, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_shoulder_pitch_joint", + ".*_shoulder_roll_joint", + ".*_shoulder_yaw_joint", + ".*.*_elbow_joint", + ".*_wrist_.*_joint", + ], + ) + }, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="waist_.*_joint")}, + ) + + +@configclass +class G1_29_DOFs_RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: G1_29_DOFs_Rewards = G1_29_DOFs_Rewards() + observed_joint_names: list[str] = ["waist.*", ".*_hip.*", ".*_knee.*", ".*_ankle.*"] + + def __post_init__(self): + super().__post_init__() + self.scene.robot = G1_29_DOF_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.lin_vel_z_l2.weight = 0.0 + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.rewards.dof_torques_l2.weight = -1.5e-7 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link" + self.actions.joint_pos.joint_names = self.observed_joint_names + self.observations.policy.joint_pos.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=self.observed_joint_names + ) + self.observations.policy.joint_vel.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=self.observed_joint_names + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py new file mode 100644 index 00000000000..038f5574072 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.go1 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py new file mode 100644 index 00000000000..11c209daf63 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeGo1RoughEnvCfg + + +@configclass +class UnitreeGo1FlatEnvCfg(UnitreeGo1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=60, + nconmax=25, + ls_iterations=30, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class UnitreeGo1FlatEnvCfg_PLAY(UnitreeGo1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py new file mode 100644 index 00000000000..a76998703e3 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab.utils import configclass + +from isaaclab_assets.robots.unitree import UNITREE_GO1_CFG # isort: skip + + +@configclass +class UnitreeGo1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = UNITREE_GO1_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + self.actions.joint_pos.scale = 0.25 + self.events.push_robot = None + self.events.base_external_force_torque.params["asset_cfg"].body_names = "trunk" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "trunk" + + +@configclass +class UnitreeGo1RoughEnvCfg_PLAY(UnitreeGo1RoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py new file mode 100644 index 00000000000..7e124029c68 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.go2 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go2-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo2FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo2FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go2-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo2FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo2FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py new file mode 100644 index 00000000000..317fb4720d7 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeGo2RoughEnvCfg + + +@configclass +class UnitreeGo2FlatEnvCfg(UnitreeGo2RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=65, + nconmax=35, + ls_iterations=20, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class UnitreeGo2FlatEnvCfg_PLAY(UnitreeGo2FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py new file mode 100644 index 00000000000..f188d54021b --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab.utils import configclass + +from isaaclab_assets.robots.unitree import UNITREE_GO2_CFG # isort: skip + + +@configclass +class UnitreeGo2RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = UNITREE_GO2_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + self.actions.joint_pos.scale = 0.25 + self.events.push_robot = None + self.events.base_external_force_torque.params["asset_cfg"].body_names = "base" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "base" + + +@configclass +class UnitreeGo2RoughEnvCfg_PLAY(UnitreeGo2RoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py new file mode 100644 index 00000000000..95a1e8f29e3 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.h1 import agents + +## +# Register Gym environments. +## + +# gym.register( +# id="Isaac-Velocity-Rough-H1-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:H1RoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Velocity-Rough-H1-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:H1RoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-H1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:H1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-H1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:H1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py new file mode 100644 index 00000000000..65e2a3a9a92 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import H1RoughEnvCfg + + +@configclass +class H1FlatEnvCfg(H1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=65, + nconmax=15, + ls_iterations=10, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.feet_air_time.weight = 1.0 + self.rewards.feet_air_time.params["threshold"] = 0.6 + + +class H1FlatEnvCfg_PLAY(H1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py new file mode 100644 index 00000000000..87a0d9a2a9a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets import H1_MINIMAL_CFG # isort: skip + + +@configclass +class H1Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + lin_vel_z_l2 = None + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=1.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*ankle_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.25, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*ankle_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*ankle_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, weight=-1.0, params={"asset_cfg": SceneEntityCfg("robot", joint_names=".*_ankle")} + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw", ".*_hip_roll"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_shoulder_.*", ".*_elbow"])}, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, weight=-0.1, params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso")} + ) + + +@configclass +class H1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: H1Rewards = H1Rewards() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = H1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = [".*torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.dof_torques_l2.weight = 0.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.terminations.base_contact.params["sensor_cfg"].body_names = ".*torso_link" + + +@configclass +class H1RoughEnvCfg_PLAY(H1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.episode_length_s = 40.0 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + self.commands.base_velocity.ranges.lin_vel_x = (1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.commands.base_velocity.ranges.heading = (0.0, 0.0) + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py new file mode 100644 index 00000000000..cdc532db425 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the locomotion environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .curriculums import * # noqa: F401, F403 +from .rewards import * # noqa: F401, F403 +from .terminations import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py new file mode 100644 index 00000000000..06a40a40491 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Curriculum functions for the velocity locomotion environment. + +Curriculum terms are not warp-managed (they run at reset time, not per-step), +so they remain torch-based. +""" + +from __future__ import annotations + +import torch +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg +from isaaclab.terrains import TerrainImporter + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +def terrain_levels_vel( + env: ManagerBasedRLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> torch.Tensor: + """Curriculum based on the distance the robot walked when commanded to move at a desired velocity.""" + asset: Articulation = env.scene[asset_cfg.name] + terrain: TerrainImporter = env.scene.terrain + command = env.command_manager.get_command("base_velocity") + distance = torch.norm(wp.to_torch(asset.data.root_pos_w)[env_ids, :2] - env.scene.env_origins[env_ids, :2], dim=1) + move_up = distance > terrain.cfg.terrain_generator.size[0] / 2 + move_down = distance < torch.norm(command[env_ids, :2], dim=1) * env.max_episode_length_s * 0.5 + move_down *= ~move_up + terrain.update_env_origins(env_ids, move_up, move_down) + return torch.mean(terrain.terrain_levels.float()) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py new file mode 100644 index 00000000000..e19ce6de96d --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py @@ -0,0 +1,309 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward functions for the velocity locomotion environment. + +All functions follow the ``func(env, out, **params) -> None`` signature. +Cross-manager torch tensors (contact sensor, commands) are cached as zero-copy +warp views on first call via ``wp.from_torch``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.managers import SceneEntityCfg +from isaaclab.sensors import ContactSensor + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + +# Review(jichuanh): Needs revisit. + +# --------------------------------------------------------------------------- +# feet_air_time +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_air_time_kernel( + last_air_time: wp.array(dtype=wp.float32, ndim=2), + first_contact: wp.array(dtype=wp.float32, ndim=2), + body_ids: wp.array(dtype=wp.int32), + cmd_xy: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + s += (last_air_time[i, b] - threshold) * first_contact[i, b] + # gate by command magnitude + cx = cmd_xy[i, 0] + cy = cmd_xy[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm > 0.1, s, 0.0) + + +def feet_air_time(env: ManagerBasedRLEnv, out, command_name: str, sensor_cfg: SceneEntityCfg, threshold: float) -> None: + """Reward long steps taken by the feet using L2-kernel.""" + fn = feet_air_time + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + # Cache command bridge (persistent pointer) + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + # Newton contact sensor returns persistent wp.arrays — use directly, no wp.from_torch needed + first_contact = contact_sensor.compute_first_contact(env.step_dt) + wp.launch( + kernel=_feet_air_time_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.last_air_time, first_contact, sensor_cfg.body_ids_wp, fn._cmd_wp, threshold, out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# feet_air_time_positive_biped +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_air_time_positive_biped_kernel( + air_time: wp.array(dtype=wp.float32, ndim=2), + contact_time: wp.array(dtype=wp.float32, ndim=2), + body_ids: wp.array(dtype=wp.int32), + cmd_xy: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_feet = body_ids.shape[0] + # count feet in contact and find single-stance min mode time + n_contact = int(0) + for k in range(n_feet): + b = body_ids[k] + if contact_time[i, b] > 0.0: + n_contact += 1 + single_stance = n_contact == 1 + min_val = threshold # clamp upper bound + for k in range(n_feet): + b = body_ids[k] + in_contact = contact_time[i, b] > 0.0 + mode_time = wp.where(in_contact, contact_time[i, b], air_time[i, b]) + val = wp.where(single_stance, mode_time, 0.0) + min_val = wp.min(min_val, val) + # gate by command magnitude + cx = cmd_xy[i, 0] + cy = cmd_xy[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm > 0.1, min_val, 0.0) + + +def feet_air_time_positive_biped(env, out, command_name: str, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Reward long steps taken by the feet for bipeds.""" + fn = feet_air_time_positive_biped + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_feet_air_time_positive_biped_kernel, + dim=env.num_envs, + inputs=[ + contact_sensor.data.current_air_time, + contact_sensor.data.current_contact_time, + sensor_cfg.body_ids_wp, + fn._cmd_wp, + threshold, + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# feet_slide +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_slide_kernel( + body_lin_vel_w: wp.array(dtype=wp.vec3f, ndim=2), + net_forces_w: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + n_history: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + # check if in contact: max force norm over history > 1.0 + max_force = float(0.0) + for h in range(n_history): + f = net_forces_w[i, h, b] + f_norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + max_force = wp.max(max_force, f_norm) + in_contact = wp.where(max_force > 1.0, 1.0, 0.0) + # planar velocity norm + vx = body_lin_vel_w[i, b][0] + vy = body_lin_vel_w[i, b][1] + vel_norm = wp.sqrt(vx * vx + vy * vy) + s += vel_norm * in_contact + out[i] = s + + +def feet_slide(env, out, sensor_cfg: SceneEntityCfg, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize feet sliding.""" + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + asset = env.scene[asset_cfg.name] + wp.launch( + kernel=_feet_slide_kernel, + dim=env.num_envs, + inputs=[ + asset.data.body_lin_vel_w, + contact_sensor.data.net_forces_w_history, + sensor_cfg.body_ids_wp, + contact_sensor.data.net_forces_w_history.shape[1], + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# track_lin_vel_xy_yaw_frame_exp +# --------------------------------------------------------------------------- + + +@wp.kernel +def _track_lin_vel_xy_yaw_frame_exp_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + root_lin_vel_w: wp.array(dtype=wp.vec3f), + cmd: wp.array(dtype=wp.float32, ndim=2), + inv_std_sq: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + q = root_quat_w[i] + # extract yaw-only quaternion + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw_half = wp.atan2(sin_yaw, cos_yaw) * 0.5 + yaw_q = wp.quatf(0.0, 0.0, wp.sin(yaw_half), wp.cos(yaw_half)) + # rotate world velocity into yaw frame (inverse = conjugate for unit quat) + vel_w = root_lin_vel_w[i] + vel_yaw = wp.quat_rotate(wp.quat_inverse(yaw_q), vel_w) + # error + ex = cmd[i, 0] - vel_yaw[0] + ey = cmd[i, 1] - vel_yaw[1] + err_sq = ex * ex + ey * ey + out[i] = wp.exp(-err_sq * inv_std_sq) + + +def track_lin_vel_xy_yaw_frame_exp( + env, out, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward tracking of linear velocity commands (xy axes) in the gravity aligned robot frame.""" + fn = track_lin_vel_xy_yaw_frame_exp + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_track_lin_vel_xy_yaw_frame_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, asset.data.root_lin_vel_w, fn._cmd_wp, 1.0 / (std * std), out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# track_ang_vel_z_world_exp +# --------------------------------------------------------------------------- + + +@wp.kernel +def _track_ang_vel_z_world_exp_kernel( + root_ang_vel_w: wp.array(dtype=wp.vec3f), + cmd: wp.array(dtype=wp.float32, ndim=2), + inv_std_sq: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + err = cmd[i, 2] - root_ang_vel_w[i][2] + out[i] = wp.exp(-(err * err) * inv_std_sq) + + +def track_ang_vel_z_world_exp( + env, out, command_name: str, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward tracking of angular velocity commands (yaw) in world frame.""" + fn = track_ang_vel_z_world_exp + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_track_ang_vel_z_world_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_ang_vel_w, fn._cmd_wp, 1.0 / (std * std), out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# stand_still_joint_deviation_l1 +# --------------------------------------------------------------------------- + + +@wp.kernel +def _stand_still_joint_deviation_l1_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + default_joint_pos: wp.array(dtype=wp.float32, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + command_threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = joint_pos.shape[1] + dev = float(0.0) + for j in range(n_joints): + dev += wp.abs(joint_pos[i, j] - default_joint_pos[i, j]) + # gate: only penalize when command is small + cx = cmd[i, 0] + cy = cmd[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm < command_threshold, dev, 0.0) + + +def stand_still_joint_deviation_l1( + env, out, command_name: str, command_threshold: float = 0.06, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Penalize offsets from the default joint positions when the command is very small.""" + fn = stand_still_joint_deviation_l1 + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_stand_still_joint_deviation_l1_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.default_joint_pos, fn._cmd_wp, command_threshold, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py new file mode 100644 index 00000000000..3d81bb0c2df --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first termination functions for the velocity locomotion environment.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +# Review(jichuanh): Needs revisit. +@wp.kernel +def _terrain_out_of_bounds_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + half_width: float, + half_height: float, + distance_buffer: float, + out: wp.array(dtype=wp.bool), +): + i = wp.tid() + px = wp.abs(root_pos_w[i][0]) + py = wp.abs(root_pos_w[i][1]) + out[i] = px > half_width - distance_buffer or py > half_height - distance_buffer + + +def terrain_out_of_bounds( + env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), distance_buffer: float = 3.0 +) -> None: + """Terminate when the actor moves too close to the edge of the terrain.""" + fn = terrain_out_of_bounds + if not hasattr(fn, "_terrain_resolved"): + fn._terrain_resolved = True + terrain_type = env.scene.cfg.terrain.terrain_type + if terrain_type == "plane": + fn._is_plane = True + elif terrain_type == "generator": + fn._is_plane = False + terrain_gen_cfg = env.scene.terrain.cfg.terrain_generator + grid_width, grid_length = terrain_gen_cfg.size + n_rows, n_cols = terrain_gen_cfg.num_rows, terrain_gen_cfg.num_cols + border_width = terrain_gen_cfg.border_width + fn._half_width = 0.5 * (n_rows * grid_width + 2 * border_width) + fn._half_height = 0.5 * (n_cols * grid_length + 2 * border_width) + else: + raise ValueError("Received unsupported terrain type, must be either 'plane' or 'generator'.") + + if fn._is_plane: + out.zero_() + return + + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_terrain_out_of_bounds_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, fn._half_width, fn._half_height, distance_buffer, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py new file mode 100644 index 00000000000..a8e912d2575 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py @@ -0,0 +1,291 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math +from dataclasses import MISSING + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import CurriculumTermCfg as CurrTerm +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sensors import ContactSensorCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass +from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR, ISAACLAB_NUCLEUS_DIR +from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise + +## +# Pre-defined configs +## +from isaaclab.terrains.config.rough import ROUGH_TERRAINS_CFG # isort: skip + + +## +# Scene definition +## + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with a legged robot.""" + + # ground terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="generator", + terrain_generator=ROUGH_TERRAINS_CFG, + max_init_terrain_level=5, + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="multiply", + restitution_combine_mode="multiply", + static_friction=1.0, + dynamic_friction=1.0, + ), + visual_material=sim_utils.MdlFileCfg( + mdl_path=f"{ISAACLAB_NUCLEUS_DIR}/Materials/TilesMarbleSpiderWhiteBrickBondHoned/TilesMarbleSpiderWhiteBrickBondHoned.mdl", + project_uvw=True, + texture_scale=(0.25, 0.25), + ), + debug_vis=False, + ) + # robots + robot: ArticulationCfg = MISSING + # sensors + contact_forces = ContactSensorCfg( + prim_path="{ENV_REGEX_NS}/Robot/.*", + filter_shape_paths_expr=None, + history_length=3, + track_air_time=True, + ) + # lights + sky_light = AssetBaseCfg( + prim_path="/World/skyLight", + spawn=sim_utils.DomeLightCfg( + intensity=750.0, + texture_file=f"{ISAAC_NUCLEUS_DIR}/Materials/Textures/Skies/PolyHaven/kloofendal_43d_clear_puresky_4k.hdr", + ), + ) + + +## +# MDP settings +## + + +@configclass +class CommandsCfg: + """Command specifications for the MDP.""" + + base_velocity = mdp.UniformVelocityCommandCfg( + asset_name="robot", + resampling_time_range=(1.0e9, 1.0e9), + rel_standing_envs=0.02, + rel_heading_envs=1.0, + heading_command=True, + heading_control_stiffness=0.5, + debug_vis=True, + ranges=mdp.UniformVelocityCommandCfg.Ranges( + lin_vel_x=(-1.0, 1.0), lin_vel_y=(-1.0, 1.0), ang_vel_z=(-1.0, 1.0), heading=(-math.pi, math.pi) + ), + ) + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True) + + +@configclass +class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1)) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2)) + projected_gravity = ObsTerm( + func=mdp.projected_gravity, + noise=Unoise(n_min=-0.05, n_max=0.05), + ) + velocity_commands = ObsTerm(func=mdp.generated_commands, params={"command_name": "base_velocity"}) + joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5)) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = True + self.concatenate_terms = True + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + base_com = EventTerm( + func=mdp.randomize_rigid_body_com, + mode="startup", + params={ + "asset_cfg": SceneEntityCfg("robot", body_names="base"), + "com_range": {"x": (-0.05, 0.05), "y": (-0.05, 0.05), "z": (-0.01, 0.01)}, + }, + ) + + # reset + base_external_force_torque = EventTerm( + func=mdp.apply_external_force_torque, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", body_names="base"), + "force_range": (0.0, 0.0), + "torque_range": (-0.0, 0.0), + }, + ) + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={ + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (-0.1, 0.1), + "y": (-0.1, 0.1), + "z": (-0.1, 0.1), + "roll": (-0.1, 0.1), + "pitch": (-0.1, 0.1), + "yaw": (-0.1, 0.1), + }, + }, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_scale, + mode="reset", + params={ + "position_range": (0.5, 1.5), + "velocity_range": (0.0, 0.0), + }, + ) + + # interval + push_robot = EventTerm( + func=mdp.push_by_setting_velocity, + mode="interval", + interval_range_s=(10.0, 15.0), + params={"velocity_range": {"x": (-0.1, 0.1), "y": (-0.1, 0.1)}}, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # -- task + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"command_name": "base_velocity", "std": math.sqrt(0.25)} + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_exp, weight=0.5, params={"command_name": "base_velocity", "std": math.sqrt(0.25)} + ) + # -- penalties + lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0) + ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05) + dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5) + dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7) + action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01) + feet_air_time = RewTerm( + func=mdp.feet_air_time, + weight=0.125, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"), + "command_name": "base_velocity", + "threshold": 0.5, + }, + ) + undesired_contacts = RewTerm( + func=mdp.undesired_contacts, + weight=-1.0, + params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*THIGH"), "threshold": 1.0}, + ) + # -- optional penalties + flat_orientation_l2 = RewTerm(func=mdp.flat_orientation_l2, weight=0.0) + dof_pos_limits = RewTerm(func=mdp.joint_pos_limits, weight=0.0) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + time_out = DoneTerm(func=mdp.time_out, time_out=True) + base_contact = DoneTerm( + func=mdp.illegal_contact, + params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names="base"), "threshold": 1.0}, + ) + + +@configclass +class CurriculumCfg: + """Curriculum terms for the MDP.""" + + terrain_levels = CurrTerm(func=mdp.terrain_levels_vel) + + +## +# Environment configuration +## + + +@configclass +class LocomotionVelocityRoughEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the locomotion velocity-tracking environment.""" + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=2.5, replicate_physics=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + commands: CommandsCfg = CommandsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + curriculum: CurriculumCfg = CurriculumCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 4 + self.episode_length_s = 20.0 + # simulation settings + self.sim.dt = 1.0 / 200.0 + self.sim.render_interval = self.decimation + # update sensor update periods + if self.scene.contact_forces is not None: + self.scene.contact_forces.update_period = self.sim.dt + # check if terrain levels curriculum is enabled + if getattr(self.curriculum, "terrain_levels", None) is not None: + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.curriculum = True + else: + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.curriculum = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py new file mode 100644 index 00000000000..6cd56351b6e --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from .reach import * diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py new file mode 100644 index 00000000000..fe34199f232 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Reach experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py new file mode 100644 index 00000000000..460a3056908 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py new file mode 100644 index 00000000000..b08612ccc74 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.manipulation.reach.config.franka import agents + +## +# Register Gym environments. +## + +## +# Joint Position Control +## + +gym.register( + id="Isaac-Reach-Franka-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:FrankaReachEnvCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:FrankaReachPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Reach-Franka-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:FrankaReachEnvCfg_PLAY", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:FrankaReachPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py new file mode 100644 index 00000000000..e762d78a664 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +from isaaclab.sim import SimulationCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp +from isaaclab_tasks_experimental.manager_based.manipulation.reach.reach_env_cfg import ReachEnvCfg + +## +# Pre-defined configs +## +from isaaclab_assets import FRANKA_PANDA_CFG # isort: skip + + +## +# Environment configuration +## + + +@configclass +class FrankaReachEnvCfg(ReachEnvCfg): + sim: SimulationCfg = SimulationCfg( + dt=1 / 120, + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=20, + nconmax=20, + ls_iterations=20, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + save_to_mjcf="FrankaReachEnv.xml", + ), + num_substeps=1, + debug_mode=True, + ), + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # switch robot to franka + self.scene.robot = FRANKA_PANDA_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + # override rewards + self.rewards.end_effector_position_tracking.params["asset_cfg"].body_names = ["panda_hand"] + self.rewards.end_effector_position_tracking_fine_grained.params["asset_cfg"].body_names = ["panda_hand"] + self.rewards.end_effector_orientation_tracking.params["asset_cfg"].body_names = ["panda_hand"] + + # override actions + self.actions.arm_action = mdp.JointPositionActionCfg( + asset_name="robot", joint_names=["panda_joint.*"], scale=0.5, use_default_offset=True + ) + # override command generator body + # end-effector is along z-direction + self.commands.ee_pose.body_name = "panda_hand" + self.commands.ee_pose.ranges.pitch = (math.pi, math.pi) + + +@configclass +class FrankaReachEnvCfg_PLAY(FrankaReachEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py new file mode 100644 index 00000000000..85908c15805 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# UR10 env disabled: USD asset has composition errors (broken asset file). +# Fails on both torch baseline and warp with: +# RuntimeError: USD stage has composition errors while loading provided stage +# Re-enable once the UR10 USD asset is fixed. + +# import gymnasium as gym +# from isaaclab_tasks.manager_based.manipulation.reach.config.ur_10 import agents + +# gym.register( +# id="Isaac-Reach-UR10-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:UR10ReachEnvCfg", +# "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UR10ReachPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Reach-UR10-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:UR10ReachEnvCfg_PLAY", +# "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UR10ReachPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", +# }, +# ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py new file mode 100644 index 00000000000..c4f3b5b12a8 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg +from isaaclab_tasks_experimental.manager_based.manipulation.reach.reach_env_cfg import ReachEnvCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets import UR10_CFG # isort: skip + + +## +# Environment configuration +## + + +@configclass +class UR10ReachEnvCfg(ReachEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=20, + nconmax=20, + ls_iterations=20, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicitfast", + save_to_mjcf="UR10ReachEnv.xml", + ), + num_substeps=1, + debug_mode=True, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # switch robot to ur10 + self.scene.robot = UR10_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + # override events + self.events.reset_robot_joints.params["position_range"] = (0.75, 1.25) + # override rewards + self.rewards.end_effector_position_tracking.params["asset_cfg"].body_names = ["ee_link"] + self.rewards.end_effector_position_tracking_fine_grained.params["asset_cfg"].body_names = ["ee_link"] + self.rewards.end_effector_orientation_tracking.params["asset_cfg"].body_names = ["ee_link"] + # override actions + self.actions.arm_action = mdp.JointPositionActionCfg( + asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True + ) + # override command generator body + # end-effector is along x-direction + self.commands.ee_pose.body_name = "ee_link" + self.commands.ee_pose.ranges.pitch = (math.pi / 2, math.pi / 2) + + +@configclass +class UR10ReachEnvCfg_PLAY(UR10ReachEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py new file mode 100644 index 00000000000..b0845d6735b --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the reach environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py new file mode 100644 index 00000000000..ea15d91f831 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward terms for the reach task. + +All functions follow the ``func(env, out, **params) -> None`` signature. +Command tensors are cached as zero-copy warp views on first call. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +# Review(jichuanh): Needs revisit. +# --------------------------------------------------------------------------- +# position_command_error +# --------------------------------------------------------------------------- + + +@wp.kernel +def _position_command_error_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + body_pos_w: wp.array(dtype=wp.vec3f, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + # desired position in body frame -> world frame + des_b = wp.vec3f(cmd[i, 0], cmd[i, 1], cmd[i, 2]) + des_w = root_pos_w[i] + wp.quat_rotate(root_quat_w[i], des_b) + # current end-effector position + cur_w = body_pos_w[i, body_idx] + dx = cur_w[0] - des_w[0] + dy = cur_w[1] - des_w[1] + dz = cur_w[2] - des_w[2] + out[i] = wp.sqrt(dx * dx + dy * dy + dz * dz) + + +def position_command_error(env: ManagerBasedRLEnv, out, command_name: str, asset_cfg: SceneEntityCfg) -> None: + """Penalize tracking of the position error using L2-norm.""" + fn = position_command_error + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_position_command_error_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_pos_w, + asset.data.root_quat_w, + asset.data.body_pos_w, + fn._cmd_wp, + asset_cfg.body_ids[0], + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# position_command_error_tanh +# --------------------------------------------------------------------------- + + +@wp.kernel +def _position_command_error_tanh_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + body_pos_w: wp.array(dtype=wp.vec3f, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + inv_std: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + des_b = wp.vec3f(cmd[i, 0], cmd[i, 1], cmd[i, 2]) + des_w = root_pos_w[i] + wp.quat_rotate(root_quat_w[i], des_b) + cur_w = body_pos_w[i, body_idx] + dx = cur_w[0] - des_w[0] + dy = cur_w[1] - des_w[1] + dz = cur_w[2] - des_w[2] + dist = wp.sqrt(dx * dx + dy * dy + dz * dz) + out[i] = 1.0 - wp.tanh(dist * inv_std) + + +def position_command_error_tanh( + env: ManagerBasedRLEnv, out, std: float, command_name: str, asset_cfg: SceneEntityCfg +) -> None: + """Reward tracking of the position using the tanh kernel.""" + fn = position_command_error_tanh + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_position_command_error_tanh_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_pos_w, + asset.data.root_quat_w, + asset.data.body_pos_w, + fn._cmd_wp, + asset_cfg.body_ids[0], + 1.0 / std, + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# orientation_command_error +# --------------------------------------------------------------------------- + + +@wp.kernel +def _orientation_command_error_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + body_quat_w: wp.array(dtype=wp.quatf, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + # desired quat in body frame -> world frame: q_des_w = q_root * q_des_b + des_b = wp.quatf(cmd[i, 3], cmd[i, 4], cmd[i, 5], cmd[i, 6]) + des_w = wp.quat_inverse(root_quat_w[i]) * des_b # TODO: verify if mul order matches stable + des_w = root_quat_w[i] * des_b + # current ee orientation + cur_w = body_quat_w[i, body_idx] + # shortest-path error: angle of q_err = cur^-1 * des + q_err = wp.quat_inverse(cur_w) * des_w + # error magnitude = 2 * acos(|w|) (w component of the error quaternion) + qw = wp.abs(q_err[3]) + qw = wp.clamp(qw, 0.0, 1.0) + out[i] = 2.0 * wp.acos(qw) + + +def orientation_command_error(env: ManagerBasedRLEnv, out, command_name: str, asset_cfg: SceneEntityCfg) -> None: + """Penalize tracking orientation error using shortest path.""" + fn = orientation_command_error + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_orientation_command_error_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, asset.data.body_quat_w, fn._cmd_wp, asset_cfg.body_ids[0], out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py new file mode 100644 index 00000000000..57436618342 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from dataclasses import MISSING + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import ActionTermCfg as ActionTerm +from isaaclab.managers import CurriculumTermCfg as CurrTerm +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.utils import configclass +from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise + +## +# Scene definition +## + + +@configclass +class ReachSceneCfg(InteractiveSceneCfg): + """Configuration for the scene with a robotic arm.""" + + # world + ground = AssetBaseCfg( + prim_path="/World/ground", + spawn=sim_utils.GroundPlaneCfg(), + init_state=AssetBaseCfg.InitialStateCfg(pos=(0.0, 0.0, -1.05)), + ) + + # robots + robot: ArticulationCfg = MISSING + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DomeLightCfg(color=(0.75, 0.75, 0.75), intensity=2500.0), + ) + + +## +# MDP settings +## + + +@configclass +class CommandsCfg: + """Command terms for the MDP.""" + + ee_pose = mdp.UniformPoseCommandCfg( + asset_name="robot", + body_name=MISSING, + resampling_time_range=(4.0, 4.0), + debug_vis=True, + ranges=mdp.UniformPoseCommandCfg.Ranges( + pos_x=(0.35, 0.65), + pos_y=(-0.2, 0.2), + pos_z=(0.15, 0.5), + roll=(0.0, 0.0), + pitch=MISSING, # depends on end-effector axis + yaw=(-3.14, 3.14), + ), + ) + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + arm_action: ActionTerm = MISSING + gripper_action: ActionTerm | None = None + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + pose_command = ObsTerm(func=mdp.generated_commands, params={"command_name": "ee_pose"}) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = True + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_scale, + mode="reset", + params={ + "position_range": (0.5, 1.5), + "velocity_range": (0.0, 0.0), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # task terms + end_effector_position_tracking = RewTerm( + func=mdp.position_command_error, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"}, + ) + end_effector_position_tracking_fine_grained = RewTerm( + func=mdp.position_command_error_tanh, + weight=0.1, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "std": 0.1, "command_name": "ee_pose"}, + ) + end_effector_orientation_tracking = RewTerm( + func=mdp.orientation_command_error, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"}, + ) + + # action penalty + action_rate = RewTerm(func=mdp.action_rate_l2, weight=-0.0001) + joint_vel = RewTerm( + func=mdp.joint_vel_l2, + weight=-0.0001, + params={"asset_cfg": SceneEntityCfg("robot")}, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + time_out = DoneTerm(func=mdp.time_out, time_out=True) + + +@configclass +class CurriculumCfg: + """Curriculum terms for the MDP.""" + + action_rate = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "action_rate", "weight": -0.005, "num_steps": 4500} + ) + + joint_vel = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "joint_vel", "weight": -0.001, "num_steps": 4500} + ) + + +## +# Environment configuration +## + + +@configclass +class ReachEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the reach end-effector pose tracking environment.""" + + # Scene settings + scene: ReachSceneCfg = ReachSceneCfg(num_envs=4096, env_spacing=2.5) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + commands: CommandsCfg = CommandsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + curriculum: CurriculumCfg = CurriculumCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.sim.render_interval = self.decimation + self.episode_length_s = 12.0 + self.viewer.eye = (3.5, 3.5, 3.5) + # simulation settings + self.sim.dt = 1.0 / 60.0 From b17d5c2230d9e960574e0c0a0d4081c1487a80a4 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Wed, 25 Feb 2026 01:14:08 -0800 Subject: [PATCH 4/5] Make MDP kernels graph-capturable and consolidate test infrastructure - Rewrite obs/reward kernels to consume Tier 1 compound types directly, bypassing lazy Tier 2 properties that break CUDA graph capture - Update GRAPH_CAPTURE_MIGRATION.md and WARP_MIGRATION_GAP_ANALYSIS.md --- .../envs/manager_based_env_warp.py | 30 +- .../envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md | 15 +- .../envs/mdp/observations.py | 113 +-- .../isaaclab_experimental/envs/mdp/rewards.py | 82 +- .../envs/mdp/terminations.py | 4 - .../test/envs/mdp/parity_helpers.py | 425 +++++++++++ .../test/envs/mdp/test_action_warp_parity.py | 105 +-- .../test/envs/mdp/test_mdp_warp_parity.py | 722 +++--------------- .../mdp/test_mdp_warp_parity_new_terms.py | 591 ++------------ .../articulation/GRAPH_CAPTURE_MIGRATION.md | 125 +-- .../isaaclab_newton/kernels/state_kernels.py | 18 + .../classic/humanoid/mdp/observations.py | 14 +- .../classic/humanoid/mdp/rewards.py | 17 +- 13 files changed, 811 insertions(+), 1450 deletions(-) create mode 100644 source/isaaclab_experimental/test/envs/mdp/parity_helpers.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py index 89006dd8386..5cfd3303298 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -88,7 +88,6 @@ class ManagerCallSwitch: def __init__(self, cfg_source: dict | str | None = None, max_modes: dict[str, int] | None = None): self._wp_graphs: dict[str, Any] = {} - self._non_capturable_managers: set[str] = set() self._cfg = self._load_cfg(cfg_source) self._max_modes = self._validate_max_modes(max_modes) print("[INFO] ManagerCallSwitch configuration:") @@ -104,10 +103,18 @@ def invalidate_graphs(self) -> None: self._wp_graphs.clear() def register_manager_capturability(self, manager_name: str, all_terms_capturable: bool) -> None: - """Register whether a manager's terms are all CUDA-graph-capturable.""" + """Register whether a manager's terms are all CUDA-graph-capturable. + + Called during manager _prepare_terms(). If any term is non-capturable and the + manager is currently configured for mode=2 (WARP_CAPTURED), downgrades it to + mode=1 (WARP_NOT_CAPTURED) in-place in _cfg. This is safe because mode=1 and + mode=2 use the same experimental manager class (only mode=0 switches to stable). + """ if not all_terms_capturable: - self._non_capturable_managers.add(manager_name) - logger.warning(f"{manager_name} has non-capturable terms — mode=2 requests will fall back to mode=1.") + current = self.get_mode_for_manager(manager_name) + if current == ManagerCallMode.WARP_CAPTURED: + self._cfg[manager_name] = int(ManagerCallMode.WARP_NOT_CAPTURED) + logger.warning(f"{manager_name} has non-capturable terms — downgraded from mode=2 to mode=1.") def call_stage( self, @@ -122,8 +129,6 @@ def call_stage( mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override) if mode == ManagerCallMode.STABLE: return self._run_calls(stable_calls) - if mode == ManagerCallMode.WARP_CAPTURED and manager_name in self._non_capturable_managers: - mode = ManagerCallMode.WARP_NOT_CAPTURED if mode == ManagerCallMode.WARP_NOT_CAPTURED: return self._run_calls(warp_calls) self._wp_capture_or_launch(stage=stage, calls=warp_calls) @@ -135,6 +140,19 @@ def _manager_name_from_stage(self, stage: str) -> str: return stage.split("_", 1)[0] def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode: + """Get the effective execution mode for a manager. + + The mode is resolved from _cfg which accumulates all overrides: + 1. Default mode from DEFAULT_CONFIG (set at init). + 2. Per-manager overrides from manager_call_config (set at init). + 3. max_mode cap from manager_call_max_mode (applied at query time). + 4. Non-capturable downgrade: mode=2 → mode=1 when a manager has terms marked + @warp_capturable(False). Written into _cfg by register_manager_capturability() + during manager init, so this is reflected here automatically. + + The result is always the effective runtime mode — safe to use for both + manager class resolution and execution dispatch. + """ default_key = next(iter(self.DEFAULT_CONFIG)) mode_value = self._cfg.get(manager_name, self._cfg[default_key]) cap = self._max_modes.get(manager_name) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md index 07e6e446b7c..a0ab787cc80 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md @@ -688,7 +688,7 @@ fall back to mode=1 (warp not captured) automatically via `register_manager_capt | `base_lin_vel` | `root_lin_vel_b` → `root_com_vel_b` (Tier 2) | Applied | | `base_ang_vel` | `root_ang_vel_b` → `root_com_vel_b` (Tier 2) | Applied | | `projected_gravity` | `projected_gravity_b` (Tier 2) | Applied | -| `body_projected_gravity_b` | `projected_gravity_b` (Tier 2) | Applied | +| `body_projected_gravity_b` | `projected_gravity_b` (Tier 2) | Pending (body-level, not yet in experimental module) | **Rewards — base** (`isaaclab_experimental/envs/mdp/rewards.py`): @@ -715,11 +715,14 @@ fall back to mode=1 (warp not captured) automatically via `register_manager_capt - `joint_torques_l2`, `joint_acc_l2`, `joint_vel_l2`, etc. → `joint_pos`, `joint_vel` (Tier 1) - `is_alive`, `is_terminated`, `action_rate_l2`, `action_l2` → no articulation data -**Pending fix:** Implement `materialize_derived()` in `ArticulationData.update()` to -eagerly compute Tier 2 properties before captured graphs replay. Once applied, all -`@warp_capturable(False)` annotations for Tier 2 access can be removed and these terms -become fully capturable. See `GRAPH_CAPTURE_MIGRATION.md` in the Newton articulation -package for the proposed implementation. +**Applied fix (Phase 1):** Affected MDP kernels were rewritten to consume Tier 1 compound +types directly (`root_link_pose_w` as `wp.transformf`, `root_com_vel_w` as +`wp.spatial_vectorf`) and perform the body-frame rotation inline, eliminating the Tier 2 +dependency entirely. The `@warp_capturable(False)` annotations have been removed and these +terms are now fully capturable. Shared `@wp.func` helpers (`body_lin_vel_from_root`, +`body_ang_vel_from_root`, `rotate_vec_to_body_frame`) live in +`isaaclab_newton.kernels.state_kernels`. See `GRAPH_CAPTURE_MIGRATION.md` in the Newton +articulation package for Phase 2 plans (making Tier 2 lazy update itself graph-safe). ### Resolved Cross-Cutting Blockers diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py index 49f46a70587..9802eb86fbd 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py @@ -23,7 +23,6 @@ import warp as wp from isaaclab_experimental.envs.utils.io_descriptors import ( generic_io_descriptor_warp, - record_body_names, record_dtype, record_joint_names, record_joint_pos_offsets, @@ -31,7 +30,11 @@ record_shape, ) from isaaclab_experimental.managers import SceneEntityCfg -from isaaclab_experimental.utils.warp import warp_capturable +from isaaclab_newton.kernels.state_kernels import ( + body_ang_vel_from_root, + body_lin_vel_from_root, + rotate_vec_to_body_frame, +) from isaaclab.assets import Articulation @@ -81,7 +84,6 @@ def _base_pos_z_kernel( out[env_id, 0] = root_pos_w[env_id][2] -# Reviewed(jichuanh): good @generic_io_descriptor_warp( units="m", axes=["Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] ) @@ -96,8 +98,27 @@ def base_pos_z(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntit ) -# Reviewed(jichuanh): good -@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) +# Inline Tier 1 access: these observations derive body-frame quantities directly from +# root_link_pose_w (transformf) and root_com_vel_w (spatial_vectorf), avoiding the lazy +# TimestampedWarpBuffer properties which are not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe in the future, these can +# revert to reading the pre-computed .data buffers (simpler, avoids redundant rotations). + + +@wp.kernel +def _base_lin_vel_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + v = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i]) + out[i, 0] = v[0] + out[i, 1] = v[1] + out[i, 2] = v[2] + + @generic_io_descriptor_warp( units="m/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] ) @@ -105,15 +126,26 @@ def base_lin_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEnt """Root linear velocity in the asset's root frame.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( - kernel=_vec3_to_out3_kernel, + kernel=_base_lin_vel_kernel, dim=env.num_envs, - inputs=[asset.data.root_lin_vel_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], device=env.device, ) -# Reviewed(jichuanh): good -@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) +@wp.kernel +def _base_ang_vel_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + v = body_ang_vel_from_root(root_pose_w[i], root_vel_w[i]) + out[i, 0] = v[0] + out[i, 1] = v[1] + out[i, 2] = v[2] + + @generic_io_descriptor_warp( units="rad/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] ) @@ -121,15 +153,26 @@ def base_ang_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEnt """Root angular velocity in the asset's root frame.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( - kernel=_vec3_to_out3_kernel, + kernel=_base_ang_vel_kernel, dim=env.num_envs, - inputs=[asset.data.root_ang_vel_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], device=env.device, ) -# Reviewed(jichuanh): good -@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) +@wp.kernel +def _projected_gravity_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.vec3f, + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + g = rotate_vec_to_body_frame(gravity_w, root_pose_w[i]) + out[i, 0] = g[0] + out[i, 1] = g[1] + out[i, 2] = g[2] + + @generic_io_descriptor_warp( units="m/s^2", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] ) @@ -137,9 +180,9 @@ def projected_gravity(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = Sce """Gravity projection on the asset's root frame.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( - kernel=_vec3_to_out3_kernel, + kernel=_projected_gravity_kernel, dim=env.num_envs, - inputs=[asset.data.projected_gravity_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], device=env.device, ) @@ -170,18 +213,17 @@ def joint_pos(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntity @wp.kernel -def _joint_pos_rel_gather_kernel( - joint_pos: wp.array(dtype=wp.float32, ndim=2), - default_joint_pos: wp.array(dtype=wp.float32, ndim=2), +def _joint_rel_gather_kernel( + values: wp.array(dtype=wp.float32, ndim=2), + defaults: wp.array(dtype=wp.float32, ndim=2), joint_ids: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.float32, ndim=2), ): env_id, k = wp.tid() j = joint_ids[k] - out[env_id, k] = joint_pos[env_id, j] - default_joint_pos[env_id, j] + out[env_id, k] = values[env_id, j] - defaults[env_id, j] -# Reviewed(jichuanh): good @generic_io_descriptor_warp( observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_pos_offsets], @@ -199,15 +241,13 @@ def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn "Pass `asset_cfg` via term cfg params so it is resolved at manager init." ) wp.launch( - kernel=_joint_pos_rel_gather_kernel, + kernel=_joint_rel_gather_kernel, dim=(env.num_envs, out.shape[1]), inputs=[asset.data.joint_pos, asset.data.default_joint_pos, joint_ids_wp, out], device=env.device, ) -# Reviewed(jichuanh): logic is different from stable version. Even upper and lower are flipped, stable -# logic should work, fix this. @wp.kernel def _joint_pos_limit_normalized_kernel( joint_pos: wp.array(dtype=wp.float32, ndim=2), @@ -221,12 +261,7 @@ def _joint_pos_limit_normalized_kernel( lim = soft_joint_pos_limits[env_id, j] lower = lim.x upper = lim.y - mid = (lower + upper) * 0.5 - half_range = (upper - lower) * 0.5 - if half_range > 0.0: - out[env_id, k] = (pos - mid) / half_range - else: - out[env_id, k] = 0.0 + out[env_id, k] = 2.0 * (pos - (lower + upper) * 0.5) / (upper - lower) @generic_io_descriptor_warp(observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape]) @@ -247,7 +282,6 @@ def joint_pos_limit_normalized(env: ManagerBasedEnv, out, asset_cfg: SceneEntity ) -# Reviewed(jichuanh): good @generic_io_descriptor_warp( observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape], units="rad/s" ) @@ -268,19 +302,6 @@ def joint_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntity ) -# Reviewed(jichuanh): kernel impl seems duplicate, rel_gather kernel could be shared. -@wp.kernel -def _joint_vel_rel_gather_kernel( - joint_vel: wp.array(dtype=wp.float32, ndim=2), - default_joint_vel: wp.array(dtype=wp.float32, ndim=2), - joint_ids: wp.array(dtype=wp.int32), - out: wp.array(dtype=wp.float32, ndim=2), -): - env_id, k = wp.tid() - j = joint_ids[k] - out[env_id, k] = joint_vel[env_id, j] - default_joint_vel[env_id, j] - - @generic_io_descriptor_warp( observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_vel_offsets], @@ -298,7 +319,7 @@ def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn "Pass `asset_cfg` via term cfg params so it is resolved at manager init." ) wp.launch( - kernel=_joint_vel_rel_gather_kernel, + kernel=_joint_rel_gather_kernel, dim=(env.num_envs, out.shape[1]), inputs=[asset.data.joint_vel, asset.data.default_joint_vel, joint_ids_wp, out], device=env.device, @@ -308,7 +329,6 @@ def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn """ Actions. """ -# Reviewed(jichuanh): good @generic_io_descriptor_warp(out_dim="action", dtype=torch.float32, observation_type="Action", on_inspect=[record_shape]) @@ -326,7 +346,6 @@ def last_action(env: ManagerBasedEnv, out, action_name: str | None = None) -> No """ -# Reviewed(jichuanh): good @generic_io_descriptor_warp( out_dim="command", dtype=torch.float32, observation_type="Command", on_inspect=[record_shape] ) @@ -347,5 +366,3 @@ def generated_commands(env: ManagerBasedEnv, out, command_name: str) -> None: fn._cmd_wp = wp.from_torch(cmd) fn._cmd_name = command_name wp.copy(out, fn._cmd_wp) - - diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py index cea77832c31..1e444fd22d6 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py @@ -19,7 +19,11 @@ import warp as wp from isaaclab_experimental.managers import SceneEntityCfg -from isaaclab_experimental.utils.warp import warp_capturable +from isaaclab_newton.kernels.state_kernels import ( + body_ang_vel_from_root, + body_lin_vel_from_root, + rotate_vec_to_body_frame, +) from isaaclab.assets import Articulation @@ -65,63 +69,76 @@ def is_terminated(env: ManagerBasedRLEnv, out) -> None: """ -# Reviewed(jichuanh): opportunity to share kernel should be explored, e.g. a square_index kernel with -# pre-allocated warp-ids array could be used. +# Inline Tier 1 access: these rewards derive body-frame quantities directly from +# root_link_pose_w (transformf) and root_com_vel_w (spatial_vectorf), avoiding the lazy +# TimestampedWarpBuffer properties which are not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe in the future, these can +# revert to reading the pre-computed .data buffers (simpler, avoids redundant rotations). + + @wp.kernel -def _lin_vel_z_l2_kernel(root_lin_vel_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): +def _lin_vel_z_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32), +): i = wp.tid() - vz = root_lin_vel_b[i][2] + vz = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i])[2] out[i] = vz * vz -@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) def lin_vel_z_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: """Penalize z-axis base linear velocity using L2 squared kernel.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( kernel=_lin_vel_z_l2_kernel, dim=env.num_envs, - inputs=[asset.data.root_lin_vel_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], device=env.device, ) -# Reviewed(jichuanh): same as previous @wp.kernel -def _ang_vel_xy_l2_kernel(root_ang_vel_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): +def _ang_vel_xy_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32), +): i = wp.tid() - v = root_ang_vel_b[i] + v = body_ang_vel_from_root(root_pose_w[i], root_vel_w[i]) out[i] = v[0] * v[0] + v[1] * v[1] -@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) def ang_vel_xy_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: """Penalize xy-axis base angular velocity using L2 squared kernel.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( kernel=_ang_vel_xy_l2_kernel, dim=env.num_envs, - inputs=[asset.data.root_ang_vel_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], device=env.device, ) -# Reviewed(jichuanh): same as previous @wp.kernel -def _flat_orientation_l2_kernel(projected_gravity_b: wp.array(dtype=wp.vec3f), out: wp.array(dtype=wp.float32)): +def _flat_orientation_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.vec3f, + out: wp.array(dtype=wp.float32), +): i = wp.tid() - g = projected_gravity_b[i] + g = rotate_vec_to_body_frame(gravity_w, root_pose_w[i]) out[i] = g[0] * g[0] + g[1] * g[1] -@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) def flat_orientation_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: """Penalize non-flat base orientation using L2 squared kernel.""" asset: Articulation = env.scene[asset_cfg.name] wp.launch( kernel=_flat_orientation_l2_kernel, dim=env.num_envs, - inputs=[asset.data.projected_gravity_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], device=env.device, ) @@ -322,7 +339,6 @@ def action_l2(env: ManagerBasedRLEnv, out) -> None: """ -# Reviewed(jichuanh): good @wp.kernel def _undesired_contacts_kernel( forces: wp.array(dtype=wp.vec3f, ndim=3), @@ -367,21 +383,20 @@ def undesired_contacts(env: ManagerBasedRLEnv, out, threshold: float, sensor_cfg @wp.kernel def _track_lin_vel_xy_exp_kernel( - root_lin_vel_b: wp.array(dtype=wp.vec3f), + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), command: wp.array(dtype=wp.float32, ndim=2), std_sq_inv: float, out: wp.array(dtype=wp.float32), ): i = wp.tid() - v = root_lin_vel_b[i] + v = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i]) dx = command[i, 0] - v[0] dy = command[i, 1] - v[1] error = dx * dx + dy * dy out[i] = wp.exp(-error * std_sq_inv) -# Reviewed(jichuanh): Review if there's any gap to make term provide warp type by default. -@warp_capturable(False) # accesses root_lin_vel_b → lazy TimestampedWarpBuffer (Tier 2) def track_lin_vel_xy_exp( env: ManagerBasedRLEnv, out, @@ -407,25 +422,31 @@ def track_lin_vel_xy_exp( wp.launch( kernel=_track_lin_vel_xy_exp_kernel, dim=env.num_envs, - inputs=[asset.data.root_lin_vel_b, track_lin_vel_xy_exp._cmd_wp, 1.0 / (std * std), out], + inputs=[ + asset.data.root_link_pose_w, + asset.data.root_com_vel_w, + track_lin_vel_xy_exp._cmd_wp, + 1.0 / (std * std), + out, + ], device=env.device, ) @wp.kernel def _track_ang_vel_z_exp_kernel( - root_ang_vel_b: wp.array(dtype=wp.vec3f), + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), command: wp.array(dtype=wp.float32, ndim=2), cmd_col: int, std_sq_inv: float, out: wp.array(dtype=wp.float32), ): i = wp.tid() - dz = command[i, cmd_col] - root_ang_vel_b[i][2] + dz = command[i, cmd_col] - body_ang_vel_from_root(root_pose_w[i], root_vel_w[i])[2] out[i] = wp.exp(-dz * dz * std_sq_inv) -@warp_capturable(False) # accesses root_ang_vel_b → lazy TimestampedWarpBuffer (Tier 2) def track_ang_vel_z_exp( env: ManagerBasedRLEnv, out, @@ -450,6 +471,13 @@ def track_ang_vel_z_exp( wp.launch( kernel=_track_ang_vel_z_exp_kernel, dim=env.num_envs, - inputs=[asset.data.root_ang_vel_b, track_ang_vel_z_exp._cmd_wp, 2, 1.0 / (std * std), out], + inputs=[ + asset.data.root_link_pose_w, + asset.data.root_com_vel_w, + track_ang_vel_z_exp._cmd_wp, + 2, + 1.0 / (std * std), + out, + ], device=env.device, ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py index 721bbc26711..2fa0179c3f8 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py @@ -31,7 +31,6 @@ """ -# Reviewed(jichuanh) @wp.kernel def _time_out_kernel(episode_length: wp.array(dtype=wp.int64), max_episode_length: int, out: wp.array(dtype=wp.bool)): i = wp.tid() @@ -56,7 +55,6 @@ def time_out(env: ManagerBasedRLEnv, out) -> None: """ -# Reviewed(jichuanh): good. @wp.kernel def _root_height_below_min_kernel( root_pos_w: wp.array(dtype=wp.vec3f), @@ -85,7 +83,6 @@ def root_height_below_minimum( """ -# Reviewed(jichuanh): good @wp.kernel def _joint_pos_out_of_manual_limit_kernel( joint_pos: wp.array(dtype=wp.float32, ndim=2), @@ -122,7 +119,6 @@ def joint_pos_out_of_manual_limit( """ -# Reviewed(jichuanh): good @wp.kernel def _illegal_contact_kernel( forces: wp.array(dtype=wp.vec3f, ndim=3), diff --git a/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py b/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py new file mode 100644 index 00000000000..cceb364b6db --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py @@ -0,0 +1,425 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Shared test utilities for MDP warp-vs-stable parity tests. + +Contains constants, assertion helpers, warp kernel runners, mock classes, and +numpy math utilities used by both ``test_mdp_warp_parity.py`` and +``test_mdp_warp_parity_new_terms.py``. +""" + +from __future__ import annotations + +import numpy as np +import torch + +import warp as wp + +# --------------------------------------------------------------------------- +# Constants (shared across all MDP parity test files) +# --------------------------------------------------------------------------- +NUM_ENVS = 64 +NUM_JOINTS = 12 +NUM_ACTIONS = 6 +DEVICE = "cuda:0" +ATOL = 1e-5 +RTOL = 1e-5 + +# Gravity direction constant (normalized, same as ArticulationData.GRAVITY_VEC_W) +GRAVITY_DIR_NP = np.array([[0.0, 0.0, -1.0]], dtype=np.float32) + + +# --------------------------------------------------------------------------- +# Numpy math utilities +# --------------------------------------------------------------------------- + + +def quat_rotate_inv_np(q_xyzw: np.ndarray, v: np.ndarray) -> np.ndarray: + """Apply inverse quaternion rotation to vectors (numpy, batch). + + Equivalent to ``wp.quat_rotate_inv`` — rotates *v* by the conjugate of *q*. + + Args: + q_xyzw: (N, 4) quaternion array in [x, y, z, w] order (warp convention). + v: (N, 3) vector array. + + Returns: + (N, 3) rotated vectors in float32. + """ + qv = -q_xyzw[..., :3] # conjugate xyz + qw = q_xyzw[..., 3:4] + t = 2.0 * np.cross(qv, v) + return (v + qw * t + np.cross(qv, t)).astype(np.float32) + + +# --------------------------------------------------------------------------- +# Warp / numpy utilities +# --------------------------------------------------------------------------- + + +def copy_np_to_wp(dest: wp.array, src_np: np.ndarray): + """In-place overwrite of a warp array's contents from numpy (preserves pointer).""" + tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) + wp.copy(dest, tmp) + + +# --------------------------------------------------------------------------- +# Test runner helpers +# --------------------------------------------------------------------------- + + +def run_warp_obs(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function and return the result as a torch tensor.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_obs_captured(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function under CUDA graph capture and return the result.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +def run_warp_rew(func, env, device=DEVICE, **kwargs): + """Run a warp reward function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_rew_captured(func, env, device=DEVICE, **kwargs): + """Run a warp reward function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +def run_warp_term(func, env, device=DEVICE, **kwargs): + """Run a warp termination function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_term_captured(func, env, device=DEVICE, **kwargs): + """Run a warp termination function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +# --------------------------------------------------------------------------- +# Assertion helpers +# --------------------------------------------------------------------------- + + +def assert_close(actual: torch.Tensor, expected: torch.Tensor, atol: float = ATOL, rtol: float = RTOL): + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +def assert_equal(actual: torch.Tensor, expected: torch.Tensor): + assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" + + +# --------------------------------------------------------------------------- +# Mock classes (shared across parity test files) +# --------------------------------------------------------------------------- + + +class MockArticulationData: + """Mock articulation data backed by Warp arrays (same storage Newton uses). + + Args: + num_envs: Number of environments. + num_joints: Number of joints. + device: Warp device string. + seed: Random seed for reproducibility. + num_bodies: Number of bodies. When > 0, generates body-level arrays + (body_pose_w, body_lin_acc_w, body_com_pos_b) and multi-body + projected_gravity_b. When 0, projected_gravity_b is root-level + (derived from root quaternion). + """ + + def __init__(self, num_envs=NUM_ENVS, num_joints=NUM_JOINTS, device=DEVICE, seed=42, num_bodies=0): + rng = np.random.RandomState(seed) + + # --- Joint state (float32 2D) --- + self.joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32), device=device) + self.joint_vel = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 2.0, device=device) + self.joint_acc = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.5, device=device) + self.default_joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.01, device=device) + self.default_joint_vel = wp.array(np.zeros((num_envs, num_joints), dtype=np.float32), device=device) + self.applied_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + self.computed_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + + # --- Soft joint limits --- + limits_np = np.zeros((num_envs, num_joints, 2), dtype=np.float32) + limits_np[:, :, 0] = -3.14 + limits_np[:, :, 1] = 3.14 + self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) + self.soft_joint_vel_limits = wp.array(np.full((num_envs, num_joints), 10.0, dtype=np.float32), device=device) + + # --- Root state --- + root_pos_np = rng.randn(num_envs, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 # positive heights + self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) + + # Unit quaternions + quat_np = rng.randn(num_envs, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) + + # Tier 1 compound: root_link_pose_w (transformf = pos + quat) + pose_np = np.zeros((num_envs, 7), dtype=np.float32) + pose_np[:, :3] = root_pos_np + pose_np[:, 3:] = quat_np + self.root_link_pose_w = wp.array(pose_np, dtype=wp.transformf, device=device) + + # World-frame velocities + lin_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + ang_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + self.root_lin_vel_w = wp.array(lin_vel_w_np, dtype=wp.vec3f, device=device) + self.root_ang_vel_w = wp.array(ang_vel_w_np, dtype=wp.vec3f, device=device) + + # Tier 1 compound: root_com_vel_w (spatial_vectorf: top=linear, bottom=angular) + vel_np = np.zeros((num_envs, 6), dtype=np.float32) + vel_np[:, :3] = lin_vel_w_np + vel_np[:, 3:] = ang_vel_w_np + self.root_com_vel_w = wp.array(vel_np, dtype=wp.spatial_vectorf, device=device) + + # Gravity direction constant + self.GRAVITY_VEC_W = wp.vec3f(0.0, 0.0, -1.0) + + # Derived body-frame quantities (consistent with Tier 1 compounds) + self.root_lin_vel_b = wp.array(quat_rotate_inv_np(quat_np, lin_vel_w_np), dtype=wp.vec3f, device=device) + self.root_ang_vel_b = wp.array(quat_rotate_inv_np(quat_np, ang_vel_w_np), dtype=wp.vec3f, device=device) + + # --- projected_gravity_b and body-level data --- + if num_bodies > 0: + # Multi-body projected_gravity_b: (num_envs, num_bodies) vec3f + grav_np = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + grav_np[:, :, 2] = -1.0 + grav_np /= np.linalg.norm(grav_np, axis=2, keepdims=True) + self.projected_gravity_b = wp.array(grav_np, dtype=wp.vec3f, device=device) + + # body_pose_w: (num_envs, num_bodies) transformf + bpose_np = np.zeros((num_envs, num_bodies, 7), dtype=np.float32) + bpose_np[:, :, :3] = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + bpose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.body_pose_w = wp.array(bpose_np, dtype=wp.transformf, device=device) + + # body_lin_acc_w: (num_envs, num_bodies) vec3f + self.body_lin_acc_w = wp.array( + rng.randn(num_envs, num_bodies, 3).astype(np.float32), dtype=wp.vec3f, device=device + ) + + # body_com_pos_b: (num_envs, num_bodies) vec3f + self.body_com_pos_b = wp.array( + rng.randn(num_envs, num_bodies, 3).astype(np.float32) * 0.01, dtype=wp.vec3f, device=device + ) + else: + # Root-level projected_gravity_b: (num_envs,) vec3f — derived from root quat + self.projected_gravity_b = wp.array( + quat_rotate_inv_np(quat_np, np.tile(GRAVITY_DIR_NP, (num_envs, 1))), + dtype=wp.vec3f, + device=device, + ) + + # --- Event-specific data --- + self.root_vel_w = wp.array(rng.randn(num_envs, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) + + default_pose_np = np.zeros((num_envs, 7), dtype=np.float32) + default_pose_np[:, 0:3] = rng.randn(num_envs, 3).astype(np.float32) * 0.1 + default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) + + self.default_root_vel = wp.array( + np.zeros((num_envs, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device + ) + + def resolve_joint_mask(self, joint_ids=None): + n = self.joint_pos.shape[1] + mask = [False] * n + if joint_ids is None or isinstance(joint_ids, slice): + mask = [True] * n + else: + for j in joint_ids: + mask[j] = True + return wp.array(mask, dtype=wp.bool, device=str(self.joint_pos.device)) + + +class MockArticulation: + """Mock articulation asset with simulation write stubs. + + Provides both no-op write stubs (for event tests) and tracking write stubs + (for action tests). The ``last_*_target`` attributes record the most recent + values passed to ``set_joint_*_target``, enabling verification in action tests. + """ + + def __init__(self, data: MockArticulationData, num_bodies: int = 1, num_joints: int = NUM_JOINTS): + self.data = data + self.num_bodies = num_bodies + self.num_joints = num_joints + self.device = DEVICE + self._joint_names = [f"joint_{i}" for i in range(num_joints)] + # Tracking attributes for action tests + self.last_pos_target = None + self.last_vel_target = None + self.last_effort_target = None + self.last_joint_mask = None + + # -- Simulation write stubs (no-op, for event tests) -------------------- + + def write_root_velocity_to_sim(self, *a, **kw): + pass + + def write_root_pose_to_sim(self, *a, **kw): + pass + + def write_joint_state_to_sim(self, *a, **kw): + pass + + def set_external_force_and_torque(self, *a, **kw): + pass + + # -- Action write stubs (tracking, for action tests) -------------------- + + def set_joint_position_target(self, target, joint_ids=None, joint_mask=None): + self.last_pos_target = target + self.last_joint_mask = joint_mask + + def set_joint_velocity_target(self, target, joint_ids=None, joint_mask=None): + self.last_vel_target = target + self.last_joint_mask = joint_mask + + def set_joint_effort_target(self, target, joint_ids=None, joint_mask=None): + self.last_effort_target = target + self.last_joint_mask = joint_mask + + # -- Query stubs -------------------------------------------------------- + + def find_joints(self, names, preserve_order=False): + if isinstance(names, list) and names == [".*"]: + return None, list(self._joint_names), list(range(self.num_joints)) + ids = [] + resolved = [] + for name in names if isinstance(names, list) else [names]: + for i, jn in enumerate(self._joint_names): + if (name in jn or name == jn or name == ".*") and i not in ids: + ids.append(i) + resolved.append(jn) + if not ids: + ids = list(range(self.num_joints)) + resolved = list(self._joint_names) + return None, resolved, ids + + def find_bodies(self, name): + return None, [name], [0] + + +class MockScene: + """Mock scene with asset lookup, env origins, and optional sensors.""" + + def __init__(self, assets: dict, env_origins, sensors=None): + self._assets = assets + self.env_origins = env_origins + self.sensors = sensors or {} + self.articulations = dict(assets) + self.rigid_objects = {} + self.num_envs = NUM_ENVS + + def __getitem__(self, name: str): + return self._assets[name] + + +# --------------------------------------------------------------------------- +# Root-state mutation helper +# --------------------------------------------------------------------------- + + +def mutate_root_state(rng: np.random.RandomState, art_data: MockArticulationData, num_envs: int = NUM_ENVS): + """Mutate root-level state arrays in-place (preserves buffer pointers). + + Updates root_pos_w, root_quat_w, root_link_pose_w, root_com_vel_w, + root_lin_vel_w, root_ang_vel_w, root_lin_vel_b, root_ang_vel_b, and + (when 1D) projected_gravity_b — all consistently derived from a fresh + random quaternion and world-frame velocities. + """ + root_pos_np = rng.randn(num_envs, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 + copy_np_to_wp(art_data.root_pos_w, root_pos_np) + + quat_np = rng.randn(num_envs, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + copy_np_to_wp(art_data.root_quat_w, quat_np) + + pose_np = np.zeros((num_envs, 7), dtype=np.float32) + pose_np[:, :3] = root_pos_np + pose_np[:, 3:] = quat_np + copy_np_to_wp(art_data.root_link_pose_w, pose_np) + + lin_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + ang_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + copy_np_to_wp(art_data.root_lin_vel_w, lin_vel_w_np) + copy_np_to_wp(art_data.root_ang_vel_w, ang_vel_w_np) + + vel_np = np.zeros((num_envs, 6), dtype=np.float32) + vel_np[:, :3] = lin_vel_w_np + vel_np[:, 3:] = ang_vel_w_np + copy_np_to_wp(art_data.root_com_vel_w, vel_np) + + copy_np_to_wp(art_data.root_lin_vel_b, quat_rotate_inv_np(quat_np, lin_vel_w_np)) + copy_np_to_wp(art_data.root_ang_vel_b, quat_rotate_inv_np(quat_np, ang_vel_w_np)) + + # Root-level projected_gravity_b (1D) is derived from quat. + # Multi-body (2D) is mutated separately by callers. + if art_data.projected_gravity_b.ndim == 1: + copy_np_to_wp( + art_data.projected_gravity_b, + quat_rotate_inv_np(quat_np, np.tile(GRAVITY_DIR_NP, (num_envs, 1))), + ) + + +class MockActionManagerWarp: + """Returns warp arrays (for experimental functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = action_wp + self._prev_action = prev_action_wp + + @property + def action(self) -> wp.array: + return self._action + + @property + def prev_action(self) -> wp.array: + return self._prev_action + + +class MockActionManagerTorch: + """Returns torch tensors (for stable functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = wp.to_torch(action_wp) + self._prev_action = wp.to_torch(prev_action_wp) + + @property + def action(self) -> torch.Tensor: + return self._action + + @property + def prev_action(self) -> torch.Tensor: + return self._prev_action diff --git a/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py index ecaa6ca0a11..9af90f6ec71 100644 --- a/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py +++ b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py @@ -48,9 +48,11 @@ RelativeJointPositionAction, RelativeJointPositionActionCfg, ) +from parity_helpers import MockArticulation, MockArticulationData, MockScene, copy_np_to_wp NUM_ENVS = 32 NUM_JOINTS = 6 +NUM_BODIES = 3 DEVICE = "cuda:0" ATOL = 1e-5 RTOL = 1e-5 @@ -62,95 +64,9 @@ # ============================================================================ -class MockArticulationData: - def __init__(self, seed=42): - rng = np.random.RandomState(seed) - self.joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) - self.joint_vel = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) - self.default_joint_pos = wp.array( - np.tile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], (NUM_ENVS, 1)).astype(np.float32), device=DEVICE - ) - self.default_joint_vel = wp.array(np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32), device=DEVICE) - - limits_np = np.zeros((NUM_ENVS, NUM_JOINTS, 2), dtype=np.float32) - limits_np[:, :, 0] = -3.14 - limits_np[:, :, 1] = 3.14 - self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=DEVICE) - - # Body quaternion for NonHolonomicAction (identity = [0,0,0,1] in xyzw) - num_bodies = 3 - quat_np = np.zeros((NUM_ENVS, num_bodies, 4), dtype=np.float32) - quat_np[:, :, 3] = 1.0 # w=1 (identity) - self.body_quat_w = wp.array(quat_np, dtype=wp.quatf, device=DEVICE) - - self._num_joints = NUM_JOINTS - - def resolve_joint_mask(self, joint_ids=None): - mask = [False] * NUM_JOINTS - if joint_ids is None or isinstance(joint_ids, slice): - mask = [True] * NUM_JOINTS - else: - for j in joint_ids: - mask[j] = True - return wp.array(mask, dtype=wp.bool, device=DEVICE) - - -class MockArticulation: - def __init__(self, data: MockArticulationData): - self.data = data - self.num_joints = NUM_JOINTS - self.num_bodies = 3 - self.device = DEVICE - # Track what was last written for verification - self.last_pos_target = None - self.last_vel_target = None - self.last_effort_target = None - self.last_joint_mask = None - - def find_joints(self, names, preserve_order=False): - if isinstance(names, list) and names == [".*"]: - return None, JOINT_NAMES, list(range(NUM_JOINTS)) - # For specific joint names, resolve them - ids = [] - resolved_names = [] - for name in names if isinstance(names, list) else [names]: - for i, jn in enumerate(JOINT_NAMES): - if name in jn or name == jn or name == ".*": - if i not in ids: - ids.append(i) - resolved_names.append(jn) - if not ids: - ids = list(range(NUM_JOINTS)) - resolved_names = list(JOINT_NAMES) - return None, resolved_names, ids - - def find_bodies(self, name): - return None, [name], [0] - - def set_joint_position_target(self, target, joint_ids=None, joint_mask=None): - self.last_pos_target = target - self.last_joint_mask = joint_mask - - def set_joint_velocity_target(self, target, joint_ids=None, joint_mask=None): - self.last_vel_target = target - self.last_joint_mask = joint_mask - - def set_joint_effort_target(self, target, joint_ids=None, joint_mask=None): - self.last_effort_target = target - self.last_joint_mask = joint_mask - - -class MockScene: - def __init__(self, asset): - self._asset = asset - - def __getitem__(self, name): - return self._asset - - class MockEnv: def __init__(self, asset): - self.scene = MockScene(asset) + self.scene = MockScene({"robot": asset}, env_origins=None) self.num_envs = NUM_ENVS self.device = DEVICE @@ -162,12 +78,23 @@ def __init__(self, asset): @pytest.fixture() def art_data(): - return MockArticulationData() + data = MockArticulationData(num_envs=NUM_ENVS, num_joints=NUM_JOINTS, num_bodies=NUM_BODIES) + # Override defaults with specific per-joint values for action tests + copy_np_to_wp( + data.default_joint_pos, + np.tile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], (NUM_ENVS, 1)).astype(np.float32), + ) + # Body quaternion for NonHolonomicAction (identity = [0,0,0,1] in xyzw) + quat_np = np.zeros((NUM_ENVS, NUM_BODIES, 4), dtype=np.float32) + quat_np[:, :, 3] = 1.0 + data.body_quat_w = wp.array(quat_np, dtype=wp.quatf, device=DEVICE) + data._num_joints = NUM_JOINTS + return data @pytest.fixture() def asset(art_data): - return MockArticulation(art_data) + return MockArticulation(art_data, num_bodies=NUM_BODIES, num_joints=NUM_JOINTS) @pytest.fixture() diff --git a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py index d54f14bac43..dfdfcd762e4 100644 --- a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py +++ b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity.py @@ -33,6 +33,31 @@ import pytest import warp as wp +# --------------------------------------------------------------------------- +# Shared utilities (from parity_helpers.py) +# --------------------------------------------------------------------------- +from parity_helpers import ( + DEVICE, + NUM_ACTIONS, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerTorch, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockScene, + assert_close, + assert_equal, + copy_np_to_wp, + mutate_root_state, + run_warp_obs, + run_warp_obs_captured, + run_warp_rew, + run_warp_rew_captured, + run_warp_term, + run_warp_term_captured, +) + # --------------------------------------------------------------------------- # Stable (torch) implementations # --------------------------------------------------------------------------- @@ -40,147 +65,11 @@ import isaaclab.envs.mdp.rewards as stable_rew import isaaclab.envs.mdp.terminations as stable_term -# --------------------------------------------------------------------------- -# Test constants -# --------------------------------------------------------------------------- -NUM_ENVS = 64 -NUM_JOINTS = 12 -NUM_ACTIONS = 6 -DEVICE = "cuda:0" - -# Tolerance for float32 comparison (torch vs warp may differ by FMA / instruction order) -ATOL = 1e-5 -RTOL = 1e-5 - - # ============================================================================ -# Mock objects +# File-specific mock objects # ============================================================================ -class MockArticulationData: - """Mock articulation data backed by Warp arrays (same storage Newton uses).""" - - def __init__(self, num_envs: int, num_joints: int, device: str, seed: int = 42): - rng = np.random.RandomState(seed) - - # --- Joint state (float32 2D) --- - self.joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32), device=device) - self.joint_vel = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 2.0, device=device) - self.joint_acc = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.5, device=device) - self.default_joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.01, device=device) - self.default_joint_vel = wp.array(np.zeros((num_envs, num_joints), dtype=np.float32), device=device) - self.applied_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) - self.computed_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) - - # --- Soft joint position limits (vec2f 2D) --- - limits_np = np.zeros((num_envs, num_joints, 2), dtype=np.float32) - limits_np[:, :, 0] = -3.14 # lower - limits_np[:, :, 1] = 3.14 # upper - self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) - - # --- Soft joint velocity limits (float32 2D) --- - self.soft_joint_vel_limits = wp.array(np.full((num_envs, num_joints), 10.0, dtype=np.float32), device=device) - - # --- Root state --- - root_pos_np = rng.randn(num_envs, 3).astype(np.float32) - root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 # positive heights - self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) - - self.root_lin_vel_b = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) - self.root_ang_vel_b = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) - - # Gravity projection (unit-ish vectors pointing mostly down) - gravity_np = np.zeros((num_envs, 3), dtype=np.float32) - gravity_np[:, 2] = -1.0 - gravity_np += rng.randn(num_envs, 3).astype(np.float32) * 0.1 - gravity_np /= np.linalg.norm(gravity_np, axis=1, keepdims=True) - self.projected_gravity_b = wp.array(gravity_np, dtype=wp.vec3f, device=device) - - # --- Additional root state for new observations --- - # Quaternion (random unit quaternions) - quat_np = rng.randn(num_envs, 4).astype(np.float32) - quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) - self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) - - # World-frame velocities - self.root_lin_vel_w = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) - self.root_ang_vel_w = wp.array(rng.randn(num_envs, 3).astype(np.float32), dtype=wp.vec3f, device=device) - - # --- Event-specific data --- - # Spatial velocity (6-component: lin + ang) - self.root_vel_w = wp.array(rng.randn(num_envs, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) - - # Default root pose (transformf = position vec3f + quaternion quatf) - default_pose_np = np.zeros((num_envs, 7), dtype=np.float32) - default_pose_np[:, 0:3] = rng.randn(num_envs, 3).astype(np.float32) * 0.1 # small position offsets - default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] # identity quaternion (xyzw) - self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) - - # Default root velocity (spatial_vectorf) - self.default_root_vel = wp.array( - np.zeros((num_envs, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device - ) - - -class MockArticulation: - def __init__(self, data: MockArticulationData): - self.data = data - self.num_bodies = 1 - self.device = DEVICE - - # Stub write APIs for events (no-ops — we verify scratch buffer contents instead) - def write_root_velocity_to_sim(self, root_velocity, env_ids=None, env_mask=None): - pass - - def write_root_pose_to_sim(self, root_pose, env_ids=None, env_mask=None): - pass - - def set_external_force_and_torque(self, forces, torques, body_ids=None, env_ids=None, env_mask=None): - pass - - -class MockScene: - def __init__(self, assets: dict, env_origins: torch.Tensor): - self._assets = assets - self.env_origins = env_origins - - def __getitem__(self, name: str): - return self._assets[name] - - -class MockActionManagerWarp: - """Returns warp arrays (for experimental functions).""" - - def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): - self._action = action_wp - self._prev_action = prev_action_wp - - @property - def action(self) -> wp.array: - return self._action - - @property - def prev_action(self) -> wp.array: - return self._prev_action - - -class MockActionManagerTorch: - """Returns torch tensors (for stable functions).""" - - def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): - self._action = wp.to_torch(action_wp) - self._prev_action = wp.to_torch(prev_action_wp) - - @property - def action(self) -> torch.Tensor: - return self._action - - @property - def prev_action(self) -> torch.Tensor: - return self._prev_action - - class MockSceneEntityCfg: """Unified cfg that works for both stable (joint_ids) and experimental (joint_mask / joint_ids_wp).""" @@ -205,15 +94,11 @@ def __init__(self, name: str, joint_ids: list[int], num_joints: int, device: str def _clear_function_caches(): """Clear first-call caches on warp MDP functions so each test starts fresh. - Functions like ``current_time_s`` and ``root_pos_w`` cache warp views on - themselves (``hasattr`` pattern). Without clearing, a cached view from a - prior test's fixture would be stale when a new test creates different tensors. + Functions that cache warp views via the ``hasattr`` pattern need clearing + between tests to avoid stale references from prior fixtures. """ yield for fn in ( - warp_obs.root_pos_w, - warp_obs.current_time_s, - warp_obs.remaining_time_s, warp_evt.push_by_setting_velocity, warp_evt.apply_external_force_torque, warp_evt.reset_root_state_uniform, @@ -303,73 +188,6 @@ def subset_cfg(): return MockSceneEntityCfg("robot", [0, 2, 5, 8], NUM_JOINTS, DEVICE) -# ============================================================================ -# Helpers -# ============================================================================ - - -def _run_warp_obs(func, env, shape, device=DEVICE, **kwargs): - """Run a warp observation function and return the result as a torch tensor.""" - out = wp.zeros(shape, dtype=wp.float32, device=device) - func(env, out, **kwargs) - return wp.to_torch(out) - - -def _run_warp_obs_captured(func, env, shape, device=DEVICE, **kwargs): - """Run a warp observation function under CUDA graph capture and return the result.""" - out = wp.zeros(shape, dtype=wp.float32, device=device) - # Warm-up (triggers any first-call lazy init) - func(env, out, **kwargs) - # Capture - with wp.ScopedCapture() as capture: - func(env, out, **kwargs) - # Replay - wp.capture_launch(capture.graph) - return wp.to_torch(out) - - -def _run_warp_rew(func, env, device=DEVICE, **kwargs): - """Run a warp reward function and return the result as a torch tensor.""" - out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) - func(env, out, **kwargs) - return wp.to_torch(out) - - -def _run_warp_rew_captured(func, env, device=DEVICE, **kwargs): - """Run a warp reward function under CUDA graph capture.""" - out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) - func(env, out, **kwargs) # warm-up - with wp.ScopedCapture() as capture: - func(env, out, **kwargs) - wp.capture_launch(capture.graph) - return wp.to_torch(out) - - -def _run_warp_term(func, env, device=DEVICE, **kwargs): - """Run a warp termination function and return the result as a torch tensor.""" - out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) - func(env, out, **kwargs) - return wp.to_torch(out) - - -def _run_warp_term_captured(func, env, device=DEVICE, **kwargs): - """Run a warp termination function under CUDA graph capture.""" - out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) - func(env, out, **kwargs) # warm-up - with wp.ScopedCapture() as capture: - func(env, out, **kwargs) - wp.capture_launch(capture.graph) - return wp.to_torch(out) - - -def assert_close(actual: torch.Tensor, expected: torch.Tensor, atol: float = ATOL, rtol: float = RTOL): - torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) - - -def assert_equal(actual: torch.Tensor, expected: torch.Tensor): - assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" - - # ============================================================================ # Observation parity tests # ============================================================================ @@ -383,32 +201,32 @@ class TestObservationParity: def test_base_pos_z(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.base_pos_z(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_base_lin_vel(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.base_lin_vel(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_base_ang_vel(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.base_ang_vel(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_projected_gravity(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.projected_gravity(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -417,16 +235,16 @@ def test_projected_gravity(self, warp_env, stable_env, all_joints_cfg): def test_joint_pos_all(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_joint_vel_all(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -436,8 +254,8 @@ def test_joint_pos_subset(self, warp_env, stable_env, subset_cfg): cfg = subset_cfg n_selected = len(cfg.joint_ids) expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -445,8 +263,8 @@ def test_joint_vel_subset(self, warp_env, stable_env, subset_cfg): cfg = subset_cfg n_selected = len(cfg.joint_ids) expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -455,8 +273,8 @@ def test_joint_vel_subset(self, warp_env, stable_env, subset_cfg): def test_joint_pos_limit_normalized(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_obs.joint_pos_limit_normalized(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured( + actual = run_warp_obs(warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured( warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg ) assert_close(actual, expected) @@ -467,76 +285,8 @@ def test_joint_pos_limit_normalized(self, warp_env, stable_env, all_joints_cfg): def test_last_action(self, warp_env, stable_env, action_wp): # Stable last_action returns env.action_manager.action (torch tensor) expected = stable_obs.last_action(stable_env) - actual = _run_warp_obs(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) - actual_cap = _run_warp_obs_captured(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - # -- Additional root state observations ------------------------------------- - - def test_root_pos_w(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.root_pos_w(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.root_pos_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.root_pos_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_root_quat_w(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.root_quat_w(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_root_quat_w_unique(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.root_quat_w(stable_env, make_quat_unique=True, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), make_quat_unique=True, asset_cfg=cfg) - actual_cap = _run_warp_obs_captured( - warp_obs.root_quat_w, warp_env, (NUM_ENVS, 4), make_quat_unique=True, asset_cfg=cfg - ) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_root_lin_vel_w(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.root_lin_vel_w(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.root_lin_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.root_lin_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_root_ang_vel_w(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.root_ang_vel_w(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.root_ang_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.root_ang_vel_w, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_joint_effort(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_obs.joint_effort(stable_env, asset_cfg=cfg) - actual = _run_warp_obs(warp_obs.joint_effort, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) - actual_cap = _run_warp_obs_captured(warp_obs.joint_effort, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - # -- Time observations ------------------------------------------------------ - - def test_current_time_s(self, warp_env, stable_env): - expected = stable_obs.current_time_s(stable_env) - actual = _run_warp_obs(warp_obs.current_time_s, warp_env, (NUM_ENVS, 1)) - actual_cap = _run_warp_obs_captured(warp_obs.current_time_s, warp_env, (NUM_ENVS, 1)) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_remaining_time_s(self, warp_env, stable_env): - expected = stable_obs.remaining_time_s(stable_env) - actual = _run_warp_obs(warp_obs.remaining_time_s, warp_env, (NUM_ENVS, 1)) - actual_cap = _run_warp_obs_captured(warp_obs.remaining_time_s, warp_env, (NUM_ENVS, 1)) + actual = run_warp_obs(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) + actual_cap = run_warp_obs_captured(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -554,24 +304,24 @@ class TestRewardParity: def test_lin_vel_z_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.lin_vel_z_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_ang_vel_xy_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.ang_vel_xy_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_flat_orientation_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.flat_orientation_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -580,24 +330,24 @@ def test_flat_orientation_l2(self, warp_env, stable_env, all_joints_cfg): def test_joint_vel_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.joint_vel_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_joint_acc_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.joint_acc_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) def test_joint_torques_l2(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.joint_torques_l2(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -605,15 +355,15 @@ def test_joint_torques_l2(self, warp_env, stable_env, all_joints_cfg): def test_action_l2(self, warp_env, stable_env): expected = stable_rew.action_l2(stable_env) - actual = _run_warp_rew(warp_rew.action_l2, warp_env) - actual_cap = _run_warp_rew_captured(warp_rew.action_l2, warp_env) + actual = run_warp_rew(warp_rew.action_l2, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.action_l2, warp_env) assert_close(actual, expected) assert_close(actual_cap, expected) def test_action_rate_l2(self, warp_env, stable_env): expected = stable_rew.action_rate_l2(stable_env) - actual = _run_warp_rew(warp_rew.action_rate_l2, warp_env) - actual_cap = _run_warp_rew_captured(warp_rew.action_rate_l2, warp_env) + actual = run_warp_rew(warp_rew.action_rate_l2, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.action_rate_l2, warp_env) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -622,24 +372,8 @@ def test_action_rate_l2(self, warp_env, stable_env): def test_joint_pos_limits(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.joint_pos_limits(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_joint_vel_limits(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_rew.joint_vel_limits(stable_env, soft_ratio=0.9, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_vel_limits, warp_env, soft_ratio=0.9, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_vel_limits, warp_env, soft_ratio=0.9, asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_applied_torque_limits(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_rew.applied_torque_limits(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.applied_torque_limits, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.applied_torque_limits, warp_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -648,17 +382,8 @@ def test_applied_torque_limits(self, warp_env, stable_env, all_joints_cfg): def test_joint_deviation_l1(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg expected = stable_rew.joint_deviation_l1(stable_env, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_base_height_l2(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - target = 0.5 - expected = stable_rew.base_height_l2(stable_env, target_height=target, asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.base_height_l2, warp_env, target_height=target, asset_cfg=cfg) - actual_cap = _run_warp_rew_captured(warp_rew.base_height_l2, warp_env, target_height=target, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) assert_close(actual, expected) assert_close(actual_cap, expected) @@ -675,59 +400,13 @@ def test_root_height_below_minimum(self, warp_env, stable_env, all_joints_cfg): cfg = all_joints_cfg min_h = 0.5 expected = stable_term.root_height_below_minimum(stable_env, minimum_height=min_h, asset_cfg=cfg) - actual = _run_warp_term(warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg) - actual_cap = _run_warp_term_captured( + actual = run_warp_term(warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg) + actual_cap = run_warp_term_captured( warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg ) assert_equal(actual, expected) assert_equal(actual_cap, expected) - def test_bad_orientation(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - limit = 0.5 # ~29 degrees - expected = stable_term.bad_orientation(stable_env, limit_angle=limit, asset_cfg=cfg) - actual = _run_warp_term(warp_term.bad_orientation, warp_env, limit_angle=limit, asset_cfg=cfg) - actual_cap = _run_warp_term_captured(warp_term.bad_orientation, warp_env, limit_angle=limit, asset_cfg=cfg) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - - def test_joint_pos_out_of_limit(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_term.joint_pos_out_of_limit(stable_env, asset_cfg=cfg) - actual = _run_warp_term(warp_term.joint_pos_out_of_limit, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_term_captured(warp_term.joint_pos_out_of_limit, warp_env, asset_cfg=cfg) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - - def test_joint_vel_out_of_limit(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_term.joint_vel_out_of_limit(stable_env, asset_cfg=cfg) - actual = _run_warp_term(warp_term.joint_vel_out_of_limit, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_term_captured(warp_term.joint_vel_out_of_limit, warp_env, asset_cfg=cfg) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - - # -- Additional joint terminations ------------------------------------------ - - def test_joint_vel_out_of_manual_limit(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - max_vel = 5.0 - expected = stable_term.joint_vel_out_of_manual_limit(stable_env, max_velocity=max_vel, asset_cfg=cfg) - actual = _run_warp_term(warp_term.joint_vel_out_of_manual_limit, warp_env, max_velocity=max_vel, asset_cfg=cfg) - actual_cap = _run_warp_term_captured( - warp_term.joint_vel_out_of_manual_limit, warp_env, max_velocity=max_vel, asset_cfg=cfg - ) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - - def test_joint_effort_out_of_limit(self, warp_env, stable_env, all_joints_cfg): - cfg = all_joints_cfg - expected = stable_term.joint_effort_out_of_limit(stable_env, asset_cfg=cfg) - actual = _run_warp_term(warp_term.joint_effort_out_of_limit, warp_env, asset_cfg=cfg) - actual_cap = _run_warp_term_captured(warp_term.joint_effort_out_of_limit, warp_env, asset_cfg=cfg) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - # ============================================================================ # Capture-then-mutate-then-replay tests @@ -738,44 +417,23 @@ def test_joint_effort_out_of_limit(self, warp_env, stable_env, all_joints_cfg): # ============================================================================ -def _copy_np_to_wp(dest: wp.array, src_np: np.ndarray): - """In-place overwrite of a warp array's contents from numpy (preserves pointer).""" - tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) - wp.copy(dest, tmp) - - def _mutate_art_data(art_data: MockArticulationData, warp_env, rng_seed: int = 200): """Mutate every data array in-place so captured graphs see fresh values.""" rng = np.random.RandomState(rng_seed) - _copy_np_to_wp(art_data.joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 1.5) - _copy_np_to_wp(art_data.joint_vel, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 3.0) - _copy_np_to_wp(art_data.joint_acc, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.8) - _copy_np_to_wp(art_data.default_joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.02) - _copy_np_to_wp(art_data.applied_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) - _copy_np_to_wp(art_data.computed_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) - - root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) - root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 - _copy_np_to_wp(art_data.root_pos_w, root_pos_np) - _copy_np_to_wp(art_data.root_lin_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) - _copy_np_to_wp(art_data.root_ang_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) - _copy_np_to_wp(art_data.root_lin_vel_w, rng.randn(NUM_ENVS, 3).astype(np.float32)) - _copy_np_to_wp(art_data.root_ang_vel_w, rng.randn(NUM_ENVS, 3).astype(np.float32)) - - gravity_np = np.zeros((NUM_ENVS, 3), dtype=np.float32) - gravity_np[:, 2] = -1.0 - gravity_np += rng.randn(NUM_ENVS, 3).astype(np.float32) * 0.15 - gravity_np /= np.linalg.norm(gravity_np, axis=1, keepdims=True) - _copy_np_to_wp(art_data.projected_gravity_b, gravity_np) - - quat_np = rng.randn(NUM_ENVS, 4).astype(np.float32) - quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) - _copy_np_to_wp(art_data.root_quat_w, quat_np) + copy_np_to_wp(art_data.joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 1.5) + copy_np_to_wp(art_data.joint_vel, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 3.0) + copy_np_to_wp(art_data.joint_acc, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.8) + copy_np_to_wp(art_data.default_joint_pos, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.02) + copy_np_to_wp(art_data.applied_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) + copy_np_to_wp(art_data.computed_torque, rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 12.0) + + # Root state + Tier 1 compounds + derived body-frame (including projected_gravity_b) + mutate_root_state(rng, art_data) # Actions (in-place via warp copy — torch views auto-update) - _copy_np_to_wp(warp_env.action_manager._action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) - _copy_np_to_wp(warp_env.action_manager._prev_action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) + copy_np_to_wp(warp_env.action_manager._action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) + copy_np_to_wp(warp_env.action_manager._prev_action, rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32)) # Episode length (in-place torch update — warp zero-copy view auto-updates) warp_env.episode_length_buf[:] = torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) @@ -908,93 +566,6 @@ def test_last_action(self, warp_env, stable_env, art_data): (NUM_ENVS, NUM_ACTIONS), ) - def test_root_pos_w(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.root_pos_w, - stable_obs.root_pos_w, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 3), - asset_cfg=all_joints_cfg, - ) - - def test_root_quat_w(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.root_quat_w, - stable_obs.root_quat_w, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 4), - asset_cfg=all_joints_cfg, - ) - - def test_root_quat_w_unique(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.root_quat_w, - stable_obs.root_quat_w, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 4), - make_quat_unique=True, - asset_cfg=all_joints_cfg, - ) - - def test_root_lin_vel_w(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.root_lin_vel_w, - stable_obs.root_lin_vel_w, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 3), - asset_cfg=all_joints_cfg, - ) - - def test_root_ang_vel_w(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.root_ang_vel_w, - stable_obs.root_ang_vel_w, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 3), - asset_cfg=all_joints_cfg, - ) - - def test_joint_effort(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_obs( - warp_obs.joint_effort, - stable_obs.joint_effort, - warp_env, - stable_env, - art_data, - (NUM_ENVS, NUM_JOINTS), - asset_cfg=all_joints_cfg, - ) - - def test_current_time_s(self, warp_env, stable_env, art_data): - self._capture_mutate_check_obs( - warp_obs.current_time_s, - stable_obs.current_time_s, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 1), - ) - - def test_remaining_time_s(self, warp_env, stable_env, art_data): - self._capture_mutate_check_obs( - warp_obs.remaining_time_s, - stable_obs.remaining_time_s, - warp_env, - stable_env, - art_data, - (NUM_ENVS, 1), - ) - # -- rewards ---------------------------------------------------------------- def test_lin_vel_z_l2(self, warp_env, stable_env, art_data, all_joints_cfg): @@ -1085,27 +656,6 @@ def test_joint_pos_limits(self, warp_env, stable_env, art_data, all_joints_cfg): asset_cfg=all_joints_cfg, ) - def test_joint_vel_limits(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_rew( - warp_rew.joint_vel_limits, - stable_rew.joint_vel_limits, - warp_env, - stable_env, - art_data, - soft_ratio=0.9, - asset_cfg=all_joints_cfg, - ) - - def test_applied_torque_limits(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_rew( - warp_rew.applied_torque_limits, - stable_rew.applied_torque_limits, - warp_env, - stable_env, - art_data, - asset_cfg=all_joints_cfg, - ) - def test_joint_deviation_l1(self, warp_env, stable_env, art_data, all_joints_cfg): self._capture_mutate_check_rew( warp_rew.joint_deviation_l1, @@ -1116,17 +666,6 @@ def test_joint_deviation_l1(self, warp_env, stable_env, art_data, all_joints_cfg asset_cfg=all_joints_cfg, ) - def test_base_height_l2(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_rew( - warp_rew.base_height_l2, - stable_rew.base_height_l2, - warp_env, - stable_env, - art_data, - target_height=0.5, - asset_cfg=all_joints_cfg, - ) - # -- terminations ----------------------------------------------------------- def test_root_height_below_minimum(self, warp_env, stable_env, art_data, all_joints_cfg): @@ -1140,58 +679,6 @@ def test_root_height_below_minimum(self, warp_env, stable_env, art_data, all_joi asset_cfg=all_joints_cfg, ) - def test_bad_orientation(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_term( - warp_term.bad_orientation, - stable_term.bad_orientation, - warp_env, - stable_env, - art_data, - limit_angle=0.5, - asset_cfg=all_joints_cfg, - ) - - def test_joint_pos_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_term( - warp_term.joint_pos_out_of_limit, - stable_term.joint_pos_out_of_limit, - warp_env, - stable_env, - art_data, - asset_cfg=all_joints_cfg, - ) - - def test_joint_vel_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_term( - warp_term.joint_vel_out_of_limit, - stable_term.joint_vel_out_of_limit, - warp_env, - stable_env, - art_data, - asset_cfg=all_joints_cfg, - ) - - def test_joint_vel_out_of_manual_limit(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_term( - warp_term.joint_vel_out_of_manual_limit, - stable_term.joint_vel_out_of_manual_limit, - warp_env, - stable_env, - art_data, - max_velocity=5.0, - asset_cfg=all_joints_cfg, - ) - - def test_joint_effort_out_of_limit(self, warp_env, stable_env, art_data, all_joints_cfg): - self._capture_mutate_check_term( - warp_term.joint_effort_out_of_limit, - stable_term.joint_effort_out_of_limit, - warp_env, - stable_env, - art_data, - asset_cfg=all_joints_cfg, - ) - # ============================================================================ # Event tests @@ -1228,7 +715,7 @@ def test_reset_joints_by_offset(self, warp_env, art_data, all_joints_cfg): # Mutate defaults in-place new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.5, dtype=np.float32) - _copy_np_to_wp(art_data.default_joint_pos, new_defaults) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) # Replay wp.capture_launch(cap.graph) @@ -1255,7 +742,7 @@ def test_reset_joints_by_scale(self, warp_env, art_data, all_joints_cfg): ) new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.25, dtype=np.float32) - _copy_np_to_wp(art_data.default_joint_pos, new_defaults) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) wp.capture_launch(cap.graph) wp.synchronize() @@ -1284,7 +771,7 @@ def test_push_by_setting_velocity(self, warp_env, art_data, all_joints_cfg): # Mutate root_vel_w new_vel = np.tile([1.0, 2.0, 3.0, 0.1, 0.2, 0.3], (NUM_ENVS, 1)).astype(np.float32) - _copy_np_to_wp(art_data.root_vel_w, new_vel) + copy_np_to_wp(art_data.root_vel_w, new_vel) wp.capture_launch(cap.graph) wp.synchronize() @@ -1313,43 +800,6 @@ def test_apply_external_force_torque(self, warp_env, art_data, all_joints_cfg): # -- reset_root_state_uniform ----------------------------------------------- - def test_reset_root_state_uniform(self, warp_env, art_data, all_joints_cfg, env_origins): - """With zero-width ranges, pose = default + env_origin, vel = default. Mutate defaults → tracks.""" - mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) - zero_pose = { - "x": (0.0, 0.0), - "y": (0.0, 0.0), - "z": (0.0, 0.0), - "roll": (0.0, 0.0), - "pitch": (0.0, 0.0), - "yaw": (0.0, 0.0), - } - zero_vel = dict(zero_pose) - - warp_evt.reset_root_state_uniform(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) - with wp.ScopedCapture() as cap: - warp_evt.reset_root_state_uniform(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) - - # Mutate default_root_pose: set all positions to (1, 2, 3), identity quat - new_pose = np.zeros((NUM_ENVS, 7), dtype=np.float32) - new_pose[:, 0:3] = [1.0, 2.0, 3.0] - new_pose[:, 3:7] = [0.0, 0.0, 0.0, 1.0] # identity (xyzw) - _copy_np_to_wp(art_data.default_root_pose, new_pose) - - wp.capture_launch(cap.graph) - wp.synchronize() - - scratch_pose = wp.to_torch(warp_evt.reset_root_state_uniform._scratch_pose) - origins_t = wp.to_torch(env_origins) - - # position = default(1,2,3) + env_origin + 0 - expected_pos = torch.tensor([1.0, 2.0, 3.0], device=DEVICE).unsqueeze(0) + origins_t - assert_close(scratch_pose[:, :3], expected_pos) - - # quaternion = identity * identity_delta = identity = (0,0,0,1) in xyzw - expected_quat = torch.tensor([0.0, 0.0, 0.0, 1.0], device=DEVICE).expand(NUM_ENVS, -1) - assert_close(scratch_pose[:, 3:7], expected_quat) - # -- env_mask selectivity --------------------------------------------------- def test_reset_joints_mask_selectivity(self, warp_env, art_data, all_joints_cfg): @@ -1361,10 +811,10 @@ def test_reset_joints_mask_selectivity(self, warp_env, art_data, all_joints_cfg) # Set joint_pos to a known value sentinel = np.full((NUM_ENVS, NUM_JOINTS), 999.0, dtype=np.float32) - _copy_np_to_wp(art_data.joint_pos, sentinel) + copy_np_to_wp(art_data.joint_pos, sentinel) # Set defaults to 0 - _copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) + copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) warp_evt.reset_joints_by_offset( warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg diff --git a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py index 1aadc590f04..ace60e8ce5d 100644 --- a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py +++ b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py @@ -33,137 +33,49 @@ import isaaclab_experimental.envs.mdp.rewards as warp_rew import isaaclab_experimental.envs.mdp.terminations as warp_term +# --------------------------------------------------------------------------- +# Shared utilities (from parity_helpers.py) +# --------------------------------------------------------------------------- +from parity_helpers import ( + DEVICE, + NUM_ACTIONS, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockScene, + assert_close, + assert_equal, + copy_np_to_wp, + mutate_root_state, + run_warp_obs, + run_warp_obs_captured, + run_warp_rew, + run_warp_rew_captured, + run_warp_term, + run_warp_term_captured, +) + import isaaclab.envs.mdp.observations as stable_obs import isaaclab.envs.mdp.rewards as stable_rew import isaaclab.envs.mdp.terminations as stable_term -# --------------------------------------------------------------------------- -NUM_ENVS = 64 -NUM_JOINTS = 12 +# File-specific constants NUM_BODIES = 4 -NUM_ACTIONS = 6 NUM_HISTORY = 3 CMD_DIM = 3 -DEVICE = "cuda:0" -ATOL = 1e-5 -RTOL = 1e-5 BODY_IDS = [0, 2] # subset of bodies to test # ============================================================================ -# Mock infrastructure +# File-specific mock infrastructure # ============================================================================ -def _make_rng(seed=42): - return np.random.RandomState(seed) - - -class MockMultiBodyArticulationData: - """Mock articulation data with multi-body arrays for body-level observations.""" - - def __init__(self, device=DEVICE, seed=42): - rng = _make_rng(seed) - - # --- Joint state --- - self.joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=device) - self.joint_vel = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 2.0, device=device) - self.default_joint_pos = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.01, device=device) - self.default_joint_vel = wp.array(np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32), device=device) - self.joint_acc = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 0.5, device=device) - self.applied_torque = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 10.0, device=device) - self.computed_torque = wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32) * 10.0, device=device) - - # --- Root state --- - root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) - root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 - self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) - self.root_lin_vel_b = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) - self.root_ang_vel_b = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) - - # --- Soft limits --- - limits_np = np.zeros((NUM_ENVS, NUM_JOINTS, 2), dtype=np.float32) - limits_np[:, :, 0] = -3.14 - limits_np[:, :, 1] = 3.14 - self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) - self.soft_joint_vel_limits = wp.array(np.full((NUM_ENVS, NUM_JOINTS), 10.0, dtype=np.float32), device=device) - - # --- Body-level data (2D vec3f / transformf) --- - # projected_gravity_b: (num_envs, num_bodies) vec3f - grav_np = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) - grav_np[:, :, 2] = -1.0 - norms = np.linalg.norm(grav_np, axis=2, keepdims=True) - grav_np /= norms - self.projected_gravity_b = wp.array(grav_np, dtype=wp.vec3f, device=device) - - # body_pose_w: (num_envs, num_bodies) transformf — pos + identity quat - pose_np = np.zeros((NUM_ENVS, NUM_BODIES, 7), dtype=np.float32) - pose_np[:, :, :3] = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) - pose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] - self.body_pose_w = wp.array(pose_np, dtype=wp.transformf, device=device) - - # body_lin_acc_w: (num_envs, num_bodies) vec3f - self.body_lin_acc_w = wp.array( - rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32), dtype=wp.vec3f, device=device - ) - - # body_com_pos_b: (num_envs, num_bodies) vec3f - self.body_com_pos_b = wp.array( - rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) * 0.01, dtype=wp.vec3f, device=device - ) - - # Event-specific - self.root_vel_w = wp.array(rng.randn(NUM_ENVS, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) - default_pose_np = np.zeros((NUM_ENVS, 7), dtype=np.float32) - default_pose_np[:, 0:3] = rng.randn(NUM_ENVS, 3).astype(np.float32) * 0.1 - default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] - self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) - self.default_root_vel = wp.array( - np.zeros((NUM_ENVS, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device - ) - - quat_np = rng.randn(NUM_ENVS, 4).astype(np.float32) - quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) - self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) - self.root_lin_vel_w = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) - self.root_ang_vel_w = wp.array(rng.randn(NUM_ENVS, 3).astype(np.float32), dtype=wp.vec3f, device=device) - - def resolve_joint_mask(self, joint_ids=None): - mask = [False] * NUM_JOINTS - if joint_ids is None or isinstance(joint_ids, slice): - mask = [True] * NUM_JOINTS - else: - for j in joint_ids: - mask[j] = True - return wp.array(mask, dtype=wp.bool, device=DEVICE) - - -class MockMultiBodyArticulation: - def __init__(self, data: MockMultiBodyArticulationData): - self.data = data - self.num_bodies = NUM_BODIES - self.num_joints = NUM_JOINTS - self.device = DEVICE - - def write_root_velocity_to_sim(self, *a, **kw): - pass - - def write_root_pose_to_sim(self, *a, **kw): - pass - - def write_joint_state_to_sim(self, *a, **kw): - pass - - def set_external_force_and_torque(self, *a, **kw): - pass - - def find_joints(self, names, preserve_order=False): - return None, [f"j{i}" for i in range(NUM_JOINTS)], list(range(NUM_JOINTS)) - - class MockContactSensorData: def __init__(self, device=DEVICE, seed=77): - rng = _make_rng(seed) + rng = np.random.RandomState(seed) self.net_forces_w_history = torch.tensor( rng.randn(NUM_ENVS, NUM_HISTORY, NUM_BODIES, 3).astype(np.float32), device=device ) @@ -177,7 +89,7 @@ def __init__(self, data: MockContactSensorData): class MockCommandTerm: def __init__(self, device=DEVICE, seed=88): - rng = _make_rng(seed) + rng = np.random.RandomState(seed) self.time_left = torch.tensor(rng.rand(NUM_ENVS).astype(np.float32) * 0.05, device=device) self.command_counter = torch.tensor(rng.randint(0, 3, (NUM_ENVS,)), dtype=torch.float32, device=device) @@ -210,33 +122,6 @@ def __init__(self, name="contact_sensor", body_ids=None): self.body_ids = body_ids if body_ids is not None else BODY_IDS -class MockScene: - def __init__(self, assets: dict, env_origins, sensors=None): - self._assets = assets - self.env_origins = env_origins - self.sensors = sensors or {} - self.articulations = {k: v for k, v in assets.items()} - self.rigid_objects = {} - self.num_envs = NUM_ENVS - - def __getitem__(self, name: str): - return self._assets[name] - - -class MockActionManagerWarp: - def __init__(self, action_wp, prev_action_wp): - self._action = action_wp - self._prev_action = prev_action_wp - - @property - def action(self): - return self._action - - @property - def prev_action(self): - return self._prev_action - - # ============================================================================ # Fixtures # ============================================================================ @@ -247,19 +132,11 @@ def _clear_caches(): yield # Clear function-level caches from all new warp functions fns_to_clear = [ - warp_obs.body_projected_gravity_b, - warp_obs.body_pose_w, warp_obs.generated_commands, - warp_rew.body_lin_acc_l2, warp_rew.track_lin_vel_xy_exp, warp_rew.track_ang_vel_z_exp, warp_rew.undesired_contacts, - warp_rew.desired_contacts, - warp_rew.contact_forces, - warp_term.command_resample, warp_term.illegal_contact, - warp_evt.reset_root_state_with_random_orientation, - warp_evt.reset_scene_to_default, warp_evt.randomize_rigid_body_com, ] for fn in fns_to_clear: @@ -270,12 +147,12 @@ def _clear_caches(): @pytest.fixture() def art_data(): - return MockMultiBodyArticulationData() + return MockArticulationData(num_bodies=NUM_BODIES) @pytest.fixture() def env_origins(): - origins_np = _make_rng(77).randn(NUM_ENVS, 3).astype(np.float32) + origins_np = np.random.RandomState(77).randn(NUM_ENVS, 3).astype(np.float32) return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) @@ -286,7 +163,7 @@ def contact_data(): @pytest.fixture() def cmd_tensor(): - rng = _make_rng(99) + rng = np.random.RandomState(99) return torch.tensor(rng.randn(NUM_ENVS, CMD_DIM).astype(np.float32), device=DEVICE) @@ -297,14 +174,14 @@ def cmd_term(): @pytest.fixture() def scene(art_data, env_origins, contact_data): - art = MockMultiBodyArticulation(art_data) + art = MockArticulation(art_data, num_bodies=NUM_BODIES) sensor = MockContactSensor(contact_data) return MockScene({"robot": art}, env_origins, sensors={"contact_sensor": sensor}) @pytest.fixture() def action_wp(): - rng = _make_rng(55) + rng = np.random.RandomState(55) a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) return a, b @@ -371,94 +248,18 @@ def sensor_cfg(): # ============================================================================ -def _run_warp_obs(func, env, shape, **kwargs): - out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) - func(env, out, **kwargs) - return wp.to_torch(out).clone() - - -def _run_warp_obs_captured(func, env, shape, **kwargs): - out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) - func(env, out, **kwargs) - with wp.ScopedCapture() as cap: - func(env, out, **kwargs) - wp.capture_launch(cap.graph) - return wp.to_torch(out).clone() - - -def _run_warp_rew(func, env, **kwargs): - out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) - func(env, out, **kwargs) - return wp.to_torch(out).clone() - - -def _run_warp_rew_captured(func, env, **kwargs): - out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) - func(env, out, **kwargs) - with wp.ScopedCapture() as cap: - func(env, out, **kwargs) - wp.capture_launch(cap.graph) - return wp.to_torch(out).clone() - - -def _run_warp_term(func, env, **kwargs): - out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) - func(env, out, **kwargs) - return wp.to_torch(out).clone() - - -def _run_warp_term_captured(func, env, **kwargs): - out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) - func(env, out, **kwargs) - with wp.ScopedCapture() as cap: - func(env, out, **kwargs) - wp.capture_launch(cap.graph) - return wp.to_torch(out).clone() - - -def assert_close(actual, expected, atol=ATOL, rtol=RTOL): - torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) - - -def assert_equal(actual, expected): - assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" - - # ============================================================================ # Body observation parity tests # ============================================================================ -class TestBodyObservationParity: - """Verify body-level observation Warp kernels match stable torch implementations.""" - - def test_body_projected_gravity_b(self, warp_env, stable_env, body_cfg): - n_sel = len(body_cfg.body_ids) - expected = stable_obs.body_projected_gravity_b(stable_env, asset_cfg=body_cfg) - actual = _run_warp_obs(warp_obs.body_projected_gravity_b, warp_env, (NUM_ENVS, n_sel * 3), asset_cfg=body_cfg) - actual_cap = _run_warp_obs_captured( - warp_obs.body_projected_gravity_b, warp_env, (NUM_ENVS, n_sel * 3), asset_cfg=body_cfg - ) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_body_pose_w(self, warp_env, stable_env, body_cfg): - n_sel = len(body_cfg.body_ids) - # Stable body_pose_w calls env.scene.env_origins.unsqueeze(1) — needs torch tensor. - # Temporarily swap env_origins to torch for the stable call. - orig_origins = stable_env.scene.env_origins - stable_env.scene.env_origins = wp.to_torch(orig_origins) - expected = stable_obs.body_pose_w(stable_env, asset_cfg=body_cfg) - stable_env.scene.env_origins = orig_origins # restore - actual = _run_warp_obs(warp_obs.body_pose_w, warp_env, (NUM_ENVS, n_sel * 7), asset_cfg=body_cfg) - actual_cap = _run_warp_obs_captured(warp_obs.body_pose_w, warp_env, (NUM_ENVS, n_sel * 7), asset_cfg=body_cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) +class TestObservationParity: + """Verify observation Warp kernels match stable torch implementations.""" def test_generated_commands(self, warp_env, stable_env): expected = stable_obs.generated_commands(stable_env, command_name="vel") - actual = _run_warp_obs(warp_obs.generated_commands, warp_env, (NUM_ENVS, CMD_DIM), command_name="vel") - actual_cap = _run_warp_obs_captured( + actual = run_warp_obs(warp_obs.generated_commands, warp_env, (NUM_ENVS, CMD_DIM), command_name="vel") + actual_cap = run_warp_obs_captured( warp_obs.generated_commands, warp_env, (NUM_ENVS, CMD_DIM), command_name="vel" ) assert_close(actual, expected) @@ -473,20 +274,13 @@ def test_generated_commands(self, warp_env, stable_env): class TestNewRewardParity: """Verify newly migrated reward Warp kernels match stable torch implementations.""" - def test_body_lin_acc_l2(self, warp_env, stable_env, body_cfg): - expected = stable_rew.body_lin_acc_l2(stable_env, asset_cfg=body_cfg) - actual = _run_warp_rew(warp_rew.body_lin_acc_l2, warp_env, asset_cfg=body_cfg) - actual_cap = _run_warp_rew_captured(warp_rew.body_lin_acc_l2, warp_env, asset_cfg=body_cfg) - assert_close(actual, expected) - assert_close(actual_cap, expected) - def test_track_lin_vel_xy_exp(self, warp_env, stable_env, body_cfg): cfg = MockBodyCfg("robot") cfg.joint_ids = list(range(NUM_JOINTS)) # needed for stable std = 0.25 expected = stable_rew.track_lin_vel_xy_exp(stable_env, std=std, command_name="vel", asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.track_lin_vel_xy_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) - actual_cap = _run_warp_rew_captured( + actual = run_warp_rew(warp_rew.track_lin_vel_xy_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) + actual_cap = run_warp_rew_captured( warp_rew.track_lin_vel_xy_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg ) assert_close(actual, expected) @@ -497,267 +291,54 @@ def test_track_ang_vel_z_exp(self, warp_env, stable_env, body_cfg): cfg.joint_ids = list(range(NUM_JOINTS)) std = 0.25 expected = stable_rew.track_ang_vel_z_exp(stable_env, std=std, command_name="vel", asset_cfg=cfg) - actual = _run_warp_rew(warp_rew.track_ang_vel_z_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) - actual_cap = _run_warp_rew_captured( + actual = run_warp_rew(warp_rew.track_ang_vel_z_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg) + actual_cap = run_warp_rew_captured( warp_rew.track_ang_vel_z_exp, warp_env, std=std, command_name="vel", asset_cfg=cfg ) assert_close(actual, expected) assert_close(actual_cap, expected) - def test_undesired_contacts(self, warp_env, stable_env, sensor_cfg): - threshold = 0.5 - # Stable returns int64 (torch.sum of bools); warp returns float32 — cast for comparison. - expected = stable_rew.undesired_contacts(stable_env, threshold=threshold, sensor_cfg=sensor_cfg).float() - actual = _run_warp_rew(warp_rew.undesired_contacts, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) - actual_cap = _run_warp_rew_captured( - warp_rew.undesired_contacts, warp_env, threshold=threshold, sensor_cfg=sensor_cfg - ) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_desired_contacts(self, warp_env, stable_env, sensor_cfg): - threshold = 0.5 - expected = stable_rew.desired_contacts(stable_env, sensor_cfg=sensor_cfg, threshold=threshold) - actual = _run_warp_rew(warp_rew.desired_contacts, warp_env, sensor_cfg=sensor_cfg, threshold=threshold) - actual_cap = _run_warp_rew_captured( - warp_rew.desired_contacts, warp_env, sensor_cfg=sensor_cfg, threshold=threshold - ) - assert_close(actual, expected) - assert_close(actual_cap, expected) - - def test_contact_forces(self, warp_env, stable_env, sensor_cfg): - threshold = 0.5 - expected = stable_rew.contact_forces(stable_env, threshold=threshold, sensor_cfg=sensor_cfg) - actual = _run_warp_rew(warp_rew.contact_forces, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) - actual_cap = _run_warp_rew_captured( - warp_rew.contact_forces, warp_env, threshold=threshold, sensor_cfg=sensor_cfg - ) - assert_close(actual, expected) - assert_close(actual_cap, expected) - # ============================================================================ -# New termination parity tests +# Termination parity tests # ============================================================================ -class TestNewTerminationParity: - """Verify newly migrated termination Warp kernels match stable torch implementations.""" +class TestTerminationParity: + """Verify termination Warp kernels match stable torch implementations.""" def test_time_out(self, warp_env, stable_env): expected = stable_term.time_out(stable_env) - actual = _run_warp_term(warp_term.time_out, warp_env) - actual_cap = _run_warp_term_captured(warp_term.time_out, warp_env) - assert_equal(actual, expected) - assert_equal(actual_cap, expected) - - def test_illegal_contact(self, warp_env, stable_env, sensor_cfg): - threshold = 0.5 - expected = stable_term.illegal_contact(stable_env, threshold=threshold, sensor_cfg=sensor_cfg) - actual = _run_warp_term(warp_term.illegal_contact, warp_env, threshold=threshold, sensor_cfg=sensor_cfg) - actual_cap = _run_warp_term_captured( - warp_term.illegal_contact, warp_env, threshold=threshold, sensor_cfg=sensor_cfg - ) + actual = run_warp_term(warp_term.time_out, warp_env) + actual_cap = run_warp_term_captured(warp_term.time_out, warp_env) assert_equal(actual, expected) assert_equal(actual_cap, expected) -# ============================================================================ -# New event capture-safety tests -# ============================================================================ - - -def _copy_np_to_wp(dest: wp.array, src_np: np.ndarray): - tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) - wp.copy(dest, tmp) - - -class TestNewEventCaptureSafety: - """Verify new event functions are capture-safe.""" - - def test_reset_root_state_with_random_orientation(self, warp_env, art_data, env_origins): - """With zero-width position ranges, positions = default + env_origin.""" - mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) - zero_pose = {"x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0)} - zero_vel = { - "x": (0.0, 0.0), - "y": (0.0, 0.0), - "z": (0.0, 0.0), - "roll": (0.0, 0.0), - "pitch": (0.0, 0.0), - "yaw": (0.0, 0.0), - } - - warp_evt.reset_root_state_with_random_orientation(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) - with wp.ScopedCapture() as cap: - warp_evt.reset_root_state_with_random_orientation( - warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel - ) - - # Mutate defaults - new_pose = np.zeros((NUM_ENVS, 7), dtype=np.float32) - new_pose[:, 0:3] = [1.0, 2.0, 3.0] - new_pose[:, 3:7] = [0.0, 0.0, 0.0, 1.0] - _copy_np_to_wp(art_data.default_root_pose, new_pose) - - wp.capture_launch(cap.graph) - wp.synchronize() - - fn = warp_evt.reset_root_state_with_random_orientation - scratch_pose = wp.to_torch(fn._scratch_pose) - origins_t = wp.to_torch(env_origins) - - # Positions: default(1,2,3) + env_origin + 0 - expected_pos = torch.tensor([1.0, 2.0, 3.0], device=DEVICE).unsqueeze(0) + origins_t - assert_close(scratch_pose[:, :3], expected_pos) - - # Quaternions: should be unit quaternions (random SO(3)) - qnorm = scratch_pose[:, 3:7].norm(dim=1) - assert_close(qnorm, torch.ones(NUM_ENVS, device=DEVICE), atol=1e-4, rtol=1e-4) - - def test_reset_scene_to_default(self, warp_env, art_data, env_origins): - """With all envs masked, joints should be reset to defaults.""" - mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) - - # Set defaults to known values - _copy_np_to_wp(art_data.default_joint_pos, np.full((NUM_ENVS, NUM_JOINTS), 0.42, dtype=np.float32)) - _copy_np_to_wp(art_data.default_joint_vel, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) - - warp_evt.reset_scene_to_default(warp_env, mask) - wp.synchronize() - - result_pos = wp.to_torch(art_data.joint_pos) - expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.42, device=DEVICE) - assert_close(result_pos, expected) - - def test_reset_scene_to_default_mask_selectivity(self, warp_env, art_data, env_origins): - """Only masked envs are reset.""" - mask_np = np.array([i < NUM_ENVS // 2 for i in range(NUM_ENVS)]) - mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) - - # Set joint_pos to sentinel - _copy_np_to_wp(art_data.joint_pos, np.full((NUM_ENVS, NUM_JOINTS), 999.0, dtype=np.float32)) - # Set defaults to 0 - _copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) - - warp_evt.reset_scene_to_default(warp_env, mask) - wp.synchronize() - - result = wp.to_torch(art_data.joint_pos) - # Masked: reset to 0 - assert_close(result[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) - # Unmasked: still 999 - assert_close(result[NUM_ENVS // 2 :], torch.full((NUM_ENVS // 2, NUM_JOINTS), 999.0, device=DEVICE)) - - def test_randomize_rigid_body_com(self, warp_env, art_data): - """With zero-width range, CoM should not change. With nonzero range, CoM should differ.""" - mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) - body_cfg = MockBodyCfg("robot", list(range(NUM_BODIES))) - - # Snapshot original CoM - original_com = wp.to_torch(art_data.body_com_pos_b).clone() - - # Zero range: no change - warp_evt.randomize_rigid_body_com( - warp_env, mask, com_range={"x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0)}, asset_cfg=body_cfg - ) - wp.synchronize() - assert_close(wp.to_torch(art_data.body_com_pos_b), original_com) - - def test_reset_root_state_from_terrain(self, warp_env, art_data, env_origins): - """With zero-width orientation and velocity ranges, verify positions come from terrain patches.""" - # Create mock terrain - rng = _make_rng(123) - num_levels, num_types, num_patches = 2, 2, 5 - flat_patches_np = rng.randn(num_levels, num_types, num_patches, 3).astype(np.float32) - flat_patches_torch = torch.tensor(flat_patches_np, device=DEVICE) - - terrain_levels = torch.zeros(NUM_ENVS, dtype=torch.int32, device=DEVICE) - terrain_types = torch.zeros(NUM_ENVS, dtype=torch.int32, device=DEVICE) - - # Attach terrain mock to scene - warp_env.scene.terrain = type( - "_T", - (), - { - "flat_patches": {"init_pos": flat_patches_torch}, - "terrain_levels": terrain_levels, - "terrain_types": terrain_types, - }, - )() - - mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) - zero_pose = {"roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (0.0, 0.0)} - zero_vel = { - "x": (0.0, 0.0), - "y": (0.0, 0.0), - "z": (0.0, 0.0), - "roll": (0.0, 0.0), - "pitch": (0.0, 0.0), - "yaw": (0.0, 0.0), - } - - warp_evt.reset_root_state_from_terrain(warp_env, mask, pose_range=zero_pose, velocity_range=zero_vel) - wp.synchronize() - - fn = warp_evt.reset_root_state_from_terrain - scratch_pose = wp.to_torch(fn._scratch_pose) - - # All envs use level=0, type=0 so positions must come from flat_patches[0, 0, *, :] - valid_positions = flat_patches_torch[0, 0] # (num_patches, 3) - default_pos = wp.to_torch(art_data.default_root_pose)[:, :3] - - # Each env's position should be one of the valid patches + default offset - for i in range(min(8, NUM_ENVS)): # spot check first 8 - pos = scratch_pose[i, :3] - diffs = (valid_positions + default_pos[i]) - pos - min_dist = diffs.norm(dim=1).min() - assert min_dist < 1e-4, f"env {i}: position {pos} not near any valid patch" - - def test_command_resample(self, warp_env, cmd_term): - """Parity check for command_resample termination.""" - # Set up deterministic data: half the envs should trigger - cmd_term.time_left[:] = 0.01 # all below step_dt=0.02 - cmd_term.command_counter[: NUM_ENVS // 2] = 1.0 # match num_resamples=1 - cmd_term.command_counter[NUM_ENVS // 2 :] = 0.0 # no match - - expected = torch.logical_and( - cmd_term.time_left <= warp_env.step_dt, - cmd_term.command_counter == 1.0, - ) - - actual = _run_warp_term(warp_term.command_resample, warp_env, command_name="vel", num_resamples=1) - assert_equal(actual, expected) - - # ============================================================================ # Capture-mutate-replay tests for new terms # ============================================================================ -def _mutate_body_data(art_data: MockMultiBodyArticulationData, rng_seed=200): +def _mutate_body_data(art_data: MockArticulationData, rng_seed=200): """Mutate body-level and root-level data in-place so captured graphs see fresh values.""" - rng = _make_rng(rng_seed) + rng = np.random.RandomState(rng_seed) - # Root state - root_pos_np = rng.randn(NUM_ENVS, 3).astype(np.float32) - root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 - _copy_np_to_wp(art_data.root_pos_w, root_pos_np) - _copy_np_to_wp(art_data.root_lin_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) - _copy_np_to_wp(art_data.root_ang_vel_b, rng.randn(NUM_ENVS, 3).astype(np.float32)) + # Root state + Tier 1 compounds + derived body-frame velocities + mutate_root_state(rng, art_data) # Body data grav_np = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) grav_np[:, :, 2] = -1.0 grav_np /= np.linalg.norm(grav_np, axis=2, keepdims=True) - _copy_np_to_wp(art_data.projected_gravity_b, grav_np) + copy_np_to_wp(art_data.projected_gravity_b, grav_np) - _copy_np_to_wp(art_data.body_lin_acc_w, rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32)) + copy_np_to_wp(art_data.body_lin_acc_w, rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32)) pose_np = np.zeros((NUM_ENVS, NUM_BODIES, 7), dtype=np.float32) pose_np[:, :, :3] = rng.randn(NUM_ENVS, NUM_BODIES, 3).astype(np.float32) pose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] - _copy_np_to_wp(art_data.body_pose_w, pose_np) + copy_np_to_wp(art_data.body_pose_w, pose_np) wp.synchronize() @@ -798,39 +379,7 @@ def _capture_mutate_check_term(self, warp_fn, stable_fn, warp_env, stable_env, a expected = stable_fn(stable_env, **kwargs) assert_equal(wp.to_torch(out).clone(), expected) - # -- body observations ------------------------------------------------- - - def test_body_projected_gravity_b(self, warp_env, stable_env, art_data, body_cfg): - n_sel = len(body_cfg.body_ids) - self._capture_mutate_check_obs( - warp_obs.body_projected_gravity_b, - stable_obs.body_projected_gravity_b, - warp_env, - stable_env, - art_data, - (NUM_ENVS, n_sel * 3), - asset_cfg=body_cfg, - ) - - def test_body_pose_w(self, warp_env, stable_env, art_data, body_cfg): - n_sel = len(body_cfg.body_ids) - - # Stable needs torch env_origins for unsqueeze - def stable_body_pose_w_fixed(env, **kw): - orig = env.scene.env_origins - env.scene.env_origins = wp.to_torch(orig) - result = stable_obs.body_pose_w(env, **kw) - env.scene.env_origins = orig - return result - - out = wp.zeros((NUM_ENVS, n_sel * 7), dtype=wp.float32, device=DEVICE) - warp_obs.body_pose_w(warp_env, out, asset_cfg=body_cfg) - with wp.ScopedCapture() as cap: - warp_obs.body_pose_w(warp_env, out, asset_cfg=body_cfg) - _mutate_body_data(art_data) - wp.capture_launch(cap.graph) - expected = stable_body_pose_w_fixed(stable_env, asset_cfg=body_cfg) - assert_close(wp.to_torch(out).clone(), expected) + # -- observations ---------------------------------------------------------- def test_generated_commands(self, warp_env, stable_env, art_data, cmd_tensor): """Mutate command tensor, replay captured graph, verify new commands are read.""" @@ -846,16 +395,6 @@ def test_generated_commands(self, warp_env, stable_env, art_data, cmd_tensor): # -- rewards ----------------------------------------------------------- - def test_body_lin_acc_l2(self, warp_env, stable_env, art_data, body_cfg): - self._capture_mutate_check_rew( - warp_rew.body_lin_acc_l2, - stable_rew.body_lin_acc_l2, - warp_env, - stable_env, - art_data, - asset_cfg=body_cfg, - ) - def test_track_lin_vel_xy_exp(self, warp_env, stable_env, art_data): cfg = MockBodyCfg("robot") cfg.joint_ids = list(range(NUM_JOINTS)) @@ -884,18 +423,6 @@ def test_track_ang_vel_z_exp(self, warp_env, stable_env, art_data): asset_cfg=cfg, ) - def test_contact_forces(self, warp_env, stable_env, art_data, contact_data, sensor_cfg): - """Mutate contact force history, verify captured graph picks up changes.""" - out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) - warp_rew.contact_forces(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) - with wp.ScopedCapture() as cap: - warp_rew.contact_forces(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) - # Mutate contact sensor data in-place - contact_data.net_forces_w_history[:] = torch.randn_like(contact_data.net_forces_w_history) * 3.0 - wp.capture_launch(cap.graph) - expected = stable_rew.contact_forces(stable_env, threshold=0.5, sensor_cfg=sensor_cfg) - assert_close(wp.to_torch(out).clone(), expected) - # -- terminations ------------------------------------------------------ def test_time_out(self, warp_env, stable_env, art_data): @@ -908,13 +435,3 @@ def test_time_out(self, warp_env, stable_env, art_data): wp.capture_launch(cap.graph) expected = stable_term.time_out(stable_env) assert_equal(wp.to_torch(out).clone(), expected) - - def test_illegal_contact(self, warp_env, stable_env, art_data, contact_data, sensor_cfg): - out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) - warp_term.illegal_contact(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) - with wp.ScopedCapture() as cap: - warp_term.illegal_contact(warp_env, out, threshold=0.5, sensor_cfg=sensor_cfg) - contact_data.net_forces_w_history[:] = torch.randn_like(contact_data.net_forces_w_history) * 5.0 - wp.capture_launch(cap.graph) - expected = stable_term.illegal_contact(stable_env, threshold=0.5, sensor_cfg=sensor_cfg) - assert_equal(wp.to_torch(out).clone(), expected) diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md b/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md index 24bbea777ee..f6e5760ea67 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/GRAPH_CAPTURE_MIGRATION.md @@ -97,104 +97,55 @@ _wp_capture_or_launch: - Data is stale from warmup ``` -## Proposed Fix: `materialize_derived()` - -Add a method to `ArticulationData` that unconditionally launches all Tier 2 kernels -and updates timestamps. Call from `scene.update()` which runs outside capture scopes. - -```python -# ArticulationData -def materialize_derived(self) -> None: - """Eagerly compute all Tier 2 derived properties. - - Call before any captured graph that reads derived data. - Safe to call every step — cost is the same as accessing each property once. - """ - # Root-level derived - _ = self.projected_gravity_b # forces timestamp check → launches if stale - _ = self.heading_w - _ = self.root_link_vel_w - _ = self.root_link_vel_b - _ = self.root_com_vel_b - _ = self.root_com_pose_w - # Body-level derived - _ = self.body_link_vel_w - _ = self.body_com_pose_w -``` +## Key Insight: Tier 2 Kernels ARE Capturable -Integration point — `scene.update()` or `ArticulationData.update()`: +The preparation kernels (`project_vec_from_pose_single`, `project_velocities_to_frame`, +`compute_heading`) are plain `@wp.kernel` with no Python conditionals. They are fully +capturable. The ONLY problem is the Python `if timestamp < sim_timestamp` guard. -```python -def update(self, dt: float): - self._sim_timestamp += dt - # Existing: finite-difference quantities (need previous-step snapshot) - self.joint_acc - self.body_com_acc_w - # NEW: eagerly materialize all derived properties for graph capture - self.materialize_derived() -``` +`scene.update()` runs outside any `wp.ScopedCapture` scope. Kernels launched there +execute eagerly every step. MDP terms then read from pre-computed `.data` buffers +(stable pointers), which is capturable. -**Trade-off:** This removes the lazy optimization — every derived property computes -every step, even if unused. For capture-mode envs this is the correct trade-off (the -kernel cost is negligible vs graph replay savings). For non-capture envs, the extra -kernels add overhead for unused properties. +## Affected MDP Terms -**Better approach — opt-in materialization:** +See "Non-Capturable MDP Terms" in `isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md`. -Only materialize properties that the env actually uses. The `ManagerCallSwitch` knows -which managers are in capture mode. The env can call `materialize_derived()` only when -capture mode is active: +Per-step Tier 2 access counts for tested envs: -```python -# In ManagerBasedRLEnvWarp, after scene.update(): -if any_manager_in_capture_mode: - for articulation in self.scene.articulations.values(): - articulation.data.materialize_derived() -``` +| Env | `root_com_vel_b` | `projected_gravity_b` | `body_link_vel_w` | Total | +|-----|---:|---:|---:|---:| +| Cartpole | 0 | 0 | 0 | 0 | +| Reach-Franka | 0 | 0 | 0 | 0 | +| Humanoid/Ant | 2 | 2 | 0 | 4 | +| Quadruped velocity | 6 | 2 | 0 | 8 | +| Biped velocity (G1/H1) | 4 | 2 | 1 | 7 | -Or more selectively, track which properties were accessed during warmup and only -materialize those on subsequent steps. - -## Alternative: Use Compound Types in MDP Kernels - -Instead of fixing the data class, modify MDP terms to use Tier 1 compound types directly -(`root_link_pose_w` as `wp.transformf`, `root_com_vel_w` as `wp.spatial_vectorf`) and -extract components inside warp kernels: - -```python -@wp.kernel -def _projected_gravity_kernel( - pose_w: wp.array(dtype=wp.transformf), - gravity: wp.vec3f, - out: wp.array(dtype=wp.float32, ndim=2), -): - i = wp.tid() - q = wp.transform_get_rotation(pose_w[i]) - g_b = wp.quat_rotate_inv(q, gravity) - out[i, 0] = g_b[0] - out[i, 1] = g_b[1] - out[i, 2] = g_b[2] -``` +## Fix Plan -**Pros:** No changes to articulation data class. Eliminates all Tier 2/3 overhead. -**Cons:** Every MDP term must be rewritten. Duplicates split logic across terms. +### Phase 1: Inline Tier 1 access in MDP kernels (applied) -## Affected MDP Terms +Rewrite affected MDP kernels to consume Tier 1 compound types directly +(`root_link_pose_w` as `wp.transformf`, `root_com_vel_w` as `wp.spatial_vectorf`) +and perform the frame rotation inline. Remove `@warp_capturable(False)`. -See "Non-Capturable MDP Terms" section in -`isaaclab_experimental/envs/mdp/WARP_MIGRATION_GAP_ANALYSIS.md` for the full list of -MDP terms marked `@warp_capturable(False)` due to Tier 2 access, and the pending fix -(`materialize_derived()`) that would make them capturable again. +This is viable because the affected MDP terms do minimal work on top of the +derived property — observations are pure format copies (`vec3f → float32[3]`), +rewards extract a component and do a simple op (square, exp, threshold). Folding +the rotation into the same kernel adds negligible cost and eliminates the Tier 2 +dependency entirely. -## Recommendation +No changes to `ArticulationData`. All managers become fully capturable. -Short-term: Mark affected MDP terms `@warp_capturable(False)` so they fall back to -mode=1 automatically. No incorrect results, modest perf regression for those terms. +### Phase 2: Fix lazy update for graph capture (future) -Medium-term: Add `materialize_derived()` to `ArticulationData` and call it from -`scene.update()` when capture mode is active. Minimal changes, preserves lazy -optimization for non-capture users. Once applied, all `@warp_capturable(False)` -annotations for Tier 2 access can be removed and these terms become fully capturable. +If `ArticulationData` Tier 2 properties are made graph-safe (e.g. via unconditional +materialization in `update()` or selective on-demand computation), MDP terms can +revert to the simpler pattern of reading pre-computed `.data` buffers. This would +be preferable when many MDP terms access the same derived property per step, as it +avoids redundant inline rotations. -Long-term: Migrate MDP kernels to use compound Tier 1 types directly. Best performance, -no derived property overhead at all. +The Tier 2 kernels themselves are fully capturable `@wp.kernel` — only the Python +timestamp guard needs removal. A fused kernel in `update()` computing all (or +selectively needed) derived properties in one launch would make Tier 2 graph-safe +with minimal overhead. diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py index 04002b0cfc9..8bcaf6ee92f 100644 --- a/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py +++ b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py @@ -474,6 +474,24 @@ def project_velocities_to_frame( resulting_velocity[index] = project_velocity_to_frame(velocity[index], pose[index]) +@wp.func +def rotate_vec_to_body_frame(vec_w: wp.vec3f, pose_w: wp.transformf) -> wp.vec3f: + """Rotate a world-frame vector into the body frame defined by pose_w.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), vec_w) + + +@wp.func +def body_lin_vel_from_root(pose_w: wp.transformf, vel_w: wp.spatial_vectorf) -> wp.vec3f: + """Extract body-frame linear velocity from root pose and spatial velocity.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), wp.spatial_top(vel_w)) + + +@wp.func +def body_ang_vel_from_root(pose_w: wp.transformf, vel_w: wp.spatial_vectorf) -> wp.vec3f: + """Extract body-frame angular velocity from root pose and spatial velocity.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), wp.spatial_bottom(vel_w)) + + """ Heading utility kernels """ diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py index 59906f4b5ef..a79e9fdce53 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py @@ -16,6 +16,7 @@ import warp as wp from isaaclab_experimental.envs.utils.io_descriptors import generic_io_descriptor_warp from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_newton.kernels.state_kernels import rotate_vec_to_body_frame from isaaclab.assets import Articulation @@ -23,7 +24,6 @@ from isaaclab.envs import ManagerBasedEnv -# Reviewed(jichuanh): file reviewed @wp.kernel def _base_yaw_roll_kernel( root_quat_w: wp.array(dtype=wp.quatf), @@ -60,14 +60,20 @@ def base_yaw_roll(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn ) +# Inline Tier 1 access: derives projected gravity directly from root_link_pose_w, +# avoiding the lazy TimestampedWarpBuffer which is not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. + + @wp.kernel def _base_up_proj_kernel( - projected_gravity_b: wp.array(dtype=wp.vec3f), + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.vec3f, out: wp.array(dtype=wp.float32, ndim=2), ): """Project base up vector onto world up: -gravity_b[2].""" i = wp.tid() - out[i, 0] = -projected_gravity_b[i][2] + out[i, 0] = -rotate_vec_to_body_frame(gravity_w, root_pose_w[i])[2] @generic_io_descriptor_warp(out_dim=1, observation_type="RootState") @@ -77,7 +83,7 @@ def base_up_proj(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEnt wp.launch( kernel=_base_up_proj_kernel, dim=env.num_envs, - inputs=[asset.data.projected_gravity_b, out], + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], device=env.device, ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py index de6e24be978..033c17a5179 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py @@ -17,7 +17,7 @@ import warp as wp from isaaclab_experimental.managers import SceneEntityCfg from isaaclab_experimental.managers.manager_base import ManagerTermBase -from isaaclab_experimental.utils.warp import warp_capturable +from isaaclab_newton.kernels.state_kernels import rotate_vec_to_body_frame import isaaclab.utils.string as string_utils from isaaclab.assets import Articulation @@ -32,21 +32,26 @@ # Function-based reward terms # --------------------------------------------------------------------------- -# Reviewed(jichuanh): file roughly reviewed + +# Inline Tier 1 access: derives projected gravity directly from root_link_pose_w, +# avoiding the lazy TimestampedWarpBuffer which is not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe, this can revert to +# reading asset.data.projected_gravity_b directly. @wp.kernel def _upright_posture_bonus_kernel( - projected_gravity_b: wp.array(dtype=wp.vec3f), + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.vec3f, threshold: float, out: wp.array(dtype=wp.float32), ): i = wp.tid() - up_proj = -projected_gravity_b[i][2] + up_proj = -rotate_vec_to_body_frame(gravity_w, root_pose_w[i])[2] out[i] = wp.where(up_proj > threshold, 1.0, 0.0) -@warp_capturable(False) # accesses projected_gravity_b → lazy TimestampedWarpBuffer (Tier 2) def upright_posture_bonus( env: ManagerBasedRLEnv, out, threshold: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> None: @@ -55,7 +60,7 @@ def upright_posture_bonus( wp.launch( kernel=_upright_posture_bonus_kernel, dim=env.num_envs, - inputs=[asset.data.projected_gravity_b, threshold, out], + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, threshold, out], device=env.device, ) From c908c662abcb3e78ca190623cf9b8f25938a53ce Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Fri, 27 Feb 2026 00:43:02 -0800 Subject: [PATCH 5/5] Cleanups --- .../envs/manager_based_env_warp.py | 6 +++--- .../isaaclab_experimental/envs/mdp/events.py | 11 ----------- .../envs/utils/io_descriptors.py | 1 - .../isaaclab_experimental/managers/event_manager.py | 2 +- .../managers/scene_entity_cfg.py | 1 - .../test/envs/mdp/test_action_warp_parity.py | 4 ---- .../test/envs/mdp/test_mdp_warp_parity_new_terms.py | 4 ---- .../locomotion/velocity/config/a1/rough_env_cfg.py | 2 -- .../manager_based/locomotion/velocity/mdp/rewards.py | 2 -- .../locomotion/velocity/mdp/terminations.py | 1 - .../manager_based/manipulation/reach/mdp/rewards.py | 2 -- 11 files changed, 4 insertions(+), 32 deletions(-) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py index 5cfd3303298..935358d4189 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -90,13 +90,13 @@ def __init__(self, cfg_source: dict | str | None = None, max_modes: dict[str, in self._wp_graphs: dict[str, Any] = {} self._cfg = self._load_cfg(cfg_source) self._max_modes = self._validate_max_modes(max_modes) - print("[INFO] ManagerCallSwitch configuration:") - print(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}") + logger.info("ManagerCallSwitch configuration:") + logger.info(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}") for manager_name in self.MANAGER_NAMES: mode = int(self.get_mode_for_manager(manager_name)) cap = self._max_modes.get(manager_name) cap_str = f" (cap={cap})" if cap is not None else "" - print(f" - {manager_name}: {mode}{cap_str}") + logger.info(f" - {manager_name}: {mode}{cap_str}") def invalidate_graphs(self) -> None: """Invalidate cached capture graphs.""" diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py index ad60100ea44..160b5f5aa5a 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py @@ -24,21 +24,12 @@ from __future__ import annotations -import logging -from typing import TYPE_CHECKING - import warp as wp from isaaclab_experimental.managers import SceneEntityCfg from isaaclab_experimental.utils.warp import warp_capturable from isaaclab.assets import Articulation -if TYPE_CHECKING: - from isaaclab.envs import ManagerBasedEnv - -logger = logging.getLogger(__name__) - - # --------------------------------------------------------------------------- # Randomize rigid body center of mass # --------------------------------------------------------------------------- @@ -563,5 +554,3 @@ def reset_joints_by_offset( ], device=env.device, ) - - diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py index 0b26a2205f1..4b730dae4e4 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py @@ -65,7 +65,6 @@ def _make_descriptor(**kwargs: Any) -> GenericObservationIODescriptor: return desc -# TODO(jichuanh): The exact usage is unclear and this need revisit # Decorator factory for Warp-first IO descriptors. def generic_io_descriptor_warp( _func: Callable[Concatenate[ManagerBasedEnv, P], R] | None = None, diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py index a560f664255..7dbc6e8e88a 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py @@ -323,7 +323,7 @@ def apply( self._apply_reset(env_mask_wp, global_env_step_count) return - # other modes keep the stable convention (env_ids forwarded) + # other modes (startup, prestartup, custom) — env_mask forwarded for term_cfg in self._mode_term_cfgs[mode]: term_cfg.func(self._env, env_mask_wp, **term_cfg.params) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py index 554f667bd02..01589fceac5 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -25,7 +25,6 @@ class SceneEntityCfg(_SceneEntityCfg): - `joint_mask` is intended for Warp kernels only. """ - # TODO(jichuanh): review the necessity of these two attributes. joint_mask: wp.array | None = None joint_ids_wp: wp.array | None = ( None # Needed for subset-sized outputs/gathers (len(selected)); mask can't map k→joint/order. diff --git a/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py index 9af90f6ec71..bc419825038 100644 --- a/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py +++ b/source/isaaclab_experimental/test/envs/mdp/test_action_warp_parity.py @@ -3,10 +3,6 @@ # # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2022-2026, The Isaac Lab Project Developers. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - """Parity tests for Warp-first action term classes. Tests all 10 experimental action classes: process_actions, apply_actions, reset. diff --git a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py index ace60e8ce5d..1ae47cefd83 100644 --- a/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py +++ b/source/isaaclab_experimental/test/envs/mdp/test_mdp_warp_parity_new_terms.py @@ -3,10 +3,6 @@ # # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2022-2026, The Isaac Lab Project Developers. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - """Parity tests for newly migrated Warp-first MDP terms. Tests: body observations, command-dependent rewards, contact sensor rewards/terminations, diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py index ace7b9cf8a6..79d83b2b129 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py @@ -14,8 +14,6 @@ from isaaclab_assets.robots.unitree import UNITREE_A1_CFG # isort: skip -# reviewed(jichuanh): file roughly reviewed - class TerminationsCfg_A1(TerminationsCfg): base_too_low = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.2}) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py index e19ce6de96d..31baccee1e7 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py @@ -22,8 +22,6 @@ if TYPE_CHECKING: from isaaclab.envs import ManagerBasedRLEnv -# Review(jichuanh): Needs revisit. - # --------------------------------------------------------------------------- # feet_air_time # --------------------------------------------------------------------------- diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py index 3d81bb0c2df..6dd29b49a46 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py @@ -18,7 +18,6 @@ from isaaclab.envs import ManagerBasedRLEnv -# Review(jichuanh): Needs revisit. @wp.kernel def _terrain_out_of_bounds_kernel( root_pos_w: wp.array(dtype=wp.vec3f), diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py index ea15d91f831..811163ec973 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py @@ -22,7 +22,6 @@ from isaaclab.envs import ManagerBasedRLEnv -# Review(jichuanh): Needs revisit. # --------------------------------------------------------------------------- # position_command_error # --------------------------------------------------------------------------- @@ -140,7 +139,6 @@ def _orientation_command_error_kernel( i = wp.tid() # desired quat in body frame -> world frame: q_des_w = q_root * q_des_b des_b = wp.quatf(cmd[i, 3], cmd[i, 4], cmd[i, 5], cmd[i, 6]) - des_w = wp.quat_inverse(root_quat_w[i]) * des_b # TODO: verify if mul order matches stable des_w = root_quat_w[i] * des_b # current ee orientation cur_w = body_quat_w[i, body_idx]