Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
default=False,
help="Enable Isaac Lab timers (use --no-timer to disable).",
)
parser.add_argument(
"--step-timer",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable granular timer sections in environment step().",
)
parser.add_argument(
"--reset-timer",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable granular timer sections in environment reset().",
)
parser.add_argument(
"--manager_call_config",
type=str,
Expand Down Expand Up @@ -91,12 +103,18 @@
import torch
from datetime import datetime

import isaaclab_experimental.envs.manager_based_rl_env_warp as manager_based_rl_env_warp
from rsl_rl.runners import DistillationRunner, OnPolicyRunner

import isaaclab.envs.manager_based_rl_env as manager_based_rl_env
from isaaclab.utils.timer import Timer

Timer.enable = args_cli.timer
Timer.enable_display_output = args_cli.timer
manager_based_rl_env.TIMER_ENABLED_STEP = args_cli.step_timer
manager_based_rl_env.TIMER_ENABLED_RESET_IDX = args_cli.reset_timer
manager_based_rl_env_warp.TIMER_ENABLED_STEP = args_cli.step_timer
manager_based_rl_env_warp.TIMER_ENABLED_RESET_IDX = args_cli.reset_timer

import isaaclab_tasks_experimental # noqa: F401

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def reset(self, env_ids: Sequence[int] | None = None, env_mask: wp.array(dtype=w
env_mask: The masks of the environments to reset. Defaults to None: all the environments are reset.
"""
# reset the timers and counters
super().reset(env_ids)
super().reset(env_ids, env_mask)

@abstractmethod
def find_bodies(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,50 @@ class ManagerCallSwitch:
"TerminationManager",
"RewardManager",
"CurriculumManager",
"Scene",
)

def __init__(self, cfg_source: str | None = None):
def __init__(self, cfg_source: dict | str | None = None, max_modes: dict[str, int] | None = None):
self._wp_graphs: dict[str, Any] = {}
self._cfg = self._load_cfg(cfg_source)
print("[INFO] ManagerCallSwitch configuration:")
print(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}")
self._max_modes = self._validate_max_modes(max_modes)
logger.info("ManagerCallSwitch configuration:")
logger.info(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}")
for manager_name in self.MANAGER_NAMES:
print(f" - {manager_name}: {int(self.get_mode_for_manager(manager_name))}")
mode = int(self.get_mode_for_manager(manager_name))
cap = self._max_modes.get(manager_name)
cap_str = f" (cap={cap})" if cap is not None else ""
logger.info(f" - {manager_name}: {mode}{cap_str}")

def invalidate_graphs(self) -> None:
"""Invalidate cached capture graphs."""
self._wp_graphs.clear()

def register_manager_capturability(self, manager_name: str, all_terms_capturable: bool) -> None:
"""Register whether a manager's terms are all CUDA-graph-capturable.

Called during manager _prepare_terms(). If any term is non-capturable and the
manager is currently configured for mode=2 (WARP_CAPTURED), downgrades it to
mode=1 (WARP_NOT_CAPTURED) in-place in _cfg. This is safe because mode=1 and
mode=2 use the same experimental manager class (only mode=0 switches to stable).
"""
if not all_terms_capturable:
current = self.get_mode_for_manager(manager_name)
if current == ManagerCallMode.WARP_CAPTURED:
self._cfg[manager_name] = int(ManagerCallMode.WARP_NOT_CAPTURED)
logger.warning(f"{manager_name} has non-capturable terms — downgraded from mode=2 to mode=1.")

def call_stage(
self,
*,
stage: str,
stable_calls: Sequence[dict[str, Any]],
warp_calls: Sequence[dict[str, Any]],
mode_override: ManagerCallMode | int | None = None,
) -> Any:
"""Run the stage according to configured mode."""
"""Run the stage according to configured mode (or explicit override)."""
manager_name = self._manager_name_from_stage(stage)
mode = self.get_mode_for_manager(manager_name)
mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override)
if mode == ManagerCallMode.STABLE:
return self._run_calls(stable_calls)
if mode == ManagerCallMode.WARP_NOT_CAPTURED:
Expand All @@ -120,16 +140,29 @@ def _manager_name_from_stage(self, stage: str) -> str:
return stage.split("_", 1)[0]

def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode:
"""Get the effective execution mode for a manager.

The mode is resolved from _cfg which accumulates all overrides:
1. Default mode from DEFAULT_CONFIG (set at init).
2. Per-manager overrides from manager_call_config (set at init).
3. max_mode cap from manager_call_max_mode (applied at query time).
4. Non-capturable downgrade: mode=2 → mode=1 when a manager has terms marked
@warp_capturable(False). Written into _cfg by register_manager_capturability()
during manager init, so this is reflected here automatically.

The result is always the effective runtime mode — safe to use for both
manager class resolution and execution dispatch.
"""
default_key = next(iter(self.DEFAULT_CONFIG))
mode_value = self._cfg.get(manager_name, self._cfg[default_key])
cap = self._max_modes.get(manager_name)
if cap is not None:
mode_value = min(mode_value, cap)
return ManagerCallMode(mode_value)

def resolve_manager_class(self, manager_name: str) -> type:
module_name = (
"isaaclab.managers"
if self.get_mode_for_manager(manager_name) == ManagerCallMode.STABLE
else "isaaclab_experimental.managers"
)
def resolve_manager_class(self, manager_name: str, mode_override: ManagerCallMode | int | None = None) -> type:
mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override)
module_name = "isaaclab.managers" if mode == ManagerCallMode.STABLE else "isaaclab_experimental.managers"
module = importlib.import_module(module_name)
if not hasattr(module, manager_name):
raise AttributeError(f"Manager '{manager_name}' not found in module '{module_name}'.")
Expand All @@ -138,35 +171,41 @@ def resolve_manager_class(self, manager_name: str) -> type:
def _run_calls(self, calls: Sequence[dict[str, Any]]) -> Any:
result = None
for spec in calls:
fn = spec["fn"]
fn_args = spec.get("args", ())
fn_kwargs = spec.get("kwargs", {})
result = fn(*fn_args, **fn_kwargs)
result = self._run_call(call=spec)
return result

def _run_call(self, call: dict[str, Any]) -> Any:
fn = call["fn"]
fn_args = call.get("args", ())
fn_kwargs = call.get("kwargs", {})
return fn(*fn_args, **fn_kwargs)

def _wp_capture_or_launch(self, stage: str, calls: Sequence[dict[str, Any]]) -> None:
"""Capture Warp CUDA graph for a stage on first call, then replay."""
graph = self._wp_graphs.get(stage)
if graph is None:
with wp.ScopedCapture() as capture:
for spec in calls:
fn = spec["fn"]
fn_args = spec.get("args", ())
fn_kwargs = spec.get("kwargs", {})
fn(*fn_args, **fn_kwargs)
graph = capture.graph
self._wp_graphs[stage] = graph
wp.capture_launch(graph)

def _load_cfg(self, cfg_source: str | None) -> dict[str, int]:
if cfg_source is not None and not isinstance(cfg_source, str):
raise TypeError(f"cfg_source must be a string or None, got: {type(cfg_source)}")
if cfg_source is None or cfg_source.strip() == "":
if graph is not None:
wp.capture_launch(graph)
return
# Warmup: run eagerly to trigger first-call allocations (hasattr guards, wp.zeros, etc.)
self._run_calls(calls)
# Capture: allocations already done, only wp.launch calls are recorded
with wp.ScopedCapture() as capture:
self._run_calls(calls)
self._wp_graphs[stage] = capture.graph

def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]:
if cfg_source is None:
return dict(self.DEFAULT_CONFIG)

parsed = json.loads(cfg_source)
if not isinstance(parsed, dict):
raise TypeError("manager_call_config must decode to a dict.")
if isinstance(cfg_source, dict):
parsed = cfg_source
elif isinstance(cfg_source, str):
if cfg_source.strip() == "":
return dict(self.DEFAULT_CONFIG)
parsed = json.loads(cfg_source)
if not isinstance(parsed, dict):
raise TypeError("manager_call_config must decode to a dict.")
else:
raise TypeError(f"cfg_source must be a dict, string, or None, got: {type(cfg_source)}")

cfg = dict(parsed)
if self.DEFAULT_KEY not in cfg:
Expand All @@ -186,6 +225,20 @@ def _load_cfg(self, cfg_source: str | None) -> dict[str, int]:
) from exc
return cfg

def _validate_max_modes(self, max_modes: dict[str, int] | None) -> dict[str, int]:
if max_modes is None:
return {}
if not isinstance(max_modes, dict):
raise TypeError(f"max_modes must be a dict or None, got: {type(max_modes)}")
for name, cap in max_modes.items():
if not isinstance(cap, int):
raise TypeError(f"manager_call_max_mode value for '{name}' must be int (0/1/2), got: {type(cap)}")
try:
ManagerCallMode(cap)
except ValueError as exc:
raise ValueError(f"Invalid manager_call_max_mode value for '{name}': {cap}. Expected 0/1/2.") from exc
return dict(max_modes)


class ManagerBasedEnvWarp:
"""The base environment for the manager-based workflow (experimental fork).
Expand All @@ -211,16 +264,9 @@ def __init__(self, cfg: ManagerBasedEnvCfg):
# initialize internal variables
self._is_closed = False
# temporary debug runtime config for manager source/call switching.
cfg_source: str | None = getattr(self.cfg, "manager_call_config", None)
# if cfg_source is None:
# try:
# import __main__

# args_cli = getattr(__main__, "args_cli", None)
# cfg_source = getattr(args_cli, "manager_call_config", None)
# except Exception:
# cfg_source = None
self._manager_call_switch = ManagerCallSwitch(cfg_source)
cfg_source: dict | str | None = getattr(self.cfg, "manager_call_config", None)
max_modes: dict[str, int] | None = getattr(self.cfg, "manager_call_max_mode", None)
self._manager_call_switch = ManagerCallSwitch(cfg_source, max_modes=max_modes)
self._apply_manager_term_cfg_profile()

# set the seed for the environment
Expand Down Expand Up @@ -409,6 +455,17 @@ def device(self):
"""The device on which the environment is running."""
return self.sim.device

@property
def env_origins_wp(self) -> wp.array:
"""Scene env origins as a warp ``vec3f`` array. Cached on first access."""
if not hasattr(self, "_env_origins_wp"):
origins = self.scene.env_origins
if isinstance(origins, wp.array):
self._env_origins_wp = origins
else:
self._env_origins_wp = wp.from_torch(origins, dtype=wp.vec3f)
return self._env_origins_wp

def resolve_env_mask(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def load_managers(self):
# note: this order is important since observation manager needs to know the command and action managers
# and the reward manager needs to know the termination manager
# -- command manager
self.command_manager = self._manager_call_switch.resolve_manager_class("CommandManager")(
self.cfg.commands, self
)
# TODO(jichuanh): switch to experimental command manager once command-term isolation is complete.
self.command_manager = self._manager_call_switch.resolve_manager_class(
"CommandManager", mode_override=ManagerCallMode.STABLE
)(self.cfg.commands, self)
print("[INFO] Command Manager: ", self.command_manager)

# call the parent class to load the managers for observations and actions.
Expand Down Expand Up @@ -379,6 +380,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:
stage="CommandManager_compute",
stable_calls=[{"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}],
warp_calls=[{"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}],
mode_override=ManagerCallMode.STABLE,
)

# -- step interval events
Expand Down Expand Up @@ -669,10 +671,12 @@ def _reset_idx(
enable=TIMER_ENABLED_RESET_IDX,
format="us",
):
command_mode = ManagerCallMode.STABLE
command_info = self._manager_call_switch.call_stage(
stage="CommandManager_reset",
stable_calls=[{"fn": self.command_manager.reset, "kwargs": {"env_ids": env_ids}}],
warp_calls=[{"fn": self.command_manager.reset, "kwargs": {"env_mask": env_mask}}],
mode_override=command_mode,
)
event_info = self._manager_call_switch.call_stage(
stage="EventManager_reset",
Expand All @@ -684,7 +688,7 @@ def _reset_idx(
stable_calls=[{"fn": self.termination_manager.reset, "kwargs": {"env_ids": env_ids}}],
warp_calls=[{"fn": self.termination_manager.reset, "kwargs": {"env_mask": env_mask}}],
)
if self._manager_call_switch.get_mode_for_manager("CommandManager") == ManagerCallMode.WARP_CAPTURED:
if command_mode == ManagerCallMode.WARP_CAPTURED:
command_info = self.command_manager.reset_extras
if self._manager_call_switch.get_mode_for_manager("EventManager") == ManagerCallMode.WARP_CAPTURED:
event_info = {}
Expand Down
Loading