From b40e6dc5eae191eb2892af08f2e8c506d7a5077e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Sousa-Pinto?= Date: Tue, 12 May 2026 23:50:21 -0700 Subject: [PATCH 1/3] Add Primal-Dual LIPA solver to mpx Wires primal-dual-lipa (https://github.com/joaospinto/primal-dual-lipa) into mpx as a third solver_mode option alongside primal_dual and fddp. --- mpx/config/config_aliengo_trot_two_step.py | 15 ++ mpx/config/config_barrel_roll.py | 17 +- mpx/config/config_h1_jump_forward.py | 15 ++ mpx/examples/acrobot.py | 2 +- mpx/examples/offline_task.py | 65 ++++--- mpx/utils/lipa_solver.py | 212 +++++++++++++++++++++ mpx/utils/mpc_wrapper.py | 54 ++++-- pyproject.toml | 3 +- 8 files changed, 341 insertions(+), 42 deletions(-) create mode 100644 mpx/utils/lipa_solver.py diff --git a/mpx/config/config_aliengo_trot_two_step.py b/mpx/config/config_aliengo_trot_two_step.py index 3bb9a37..8009101 100644 --- a/mpx/config/config_aliengo_trot_two_step.py +++ b/mpx/config/config_aliengo_trot_two_step.py @@ -64,3 +64,18 @@ solver_mode = "fddp" max_torque = base.max_torque min_torque = base.min_torque + +def _lipa_settings(): + from primal_dual_lipa.types import SolverSettings + return SolverSettings( + max_iterations=2000, + η0=1e9, + η_update_factor=1.0, + µ_update_factor=0.9, + cost_improvement_threshold=1e-3, + primal_violation_threshold=1e-5, + use_parallel_lqr=True, + num_parallel_line_search_steps=8, + ) + +lipa_settings = _lipa_settings() diff --git a/mpx/config/config_barrel_roll.py b/mpx/config/config_barrel_roll.py index 05f6b12..ca6005d 100644 --- a/mpx/config/config_barrel_roll.py +++ b/mpx/config/config_barrel_roll.py @@ -105,4 +105,19 @@ def dynamics(model, mjx_model, contact_id, body_id): # dynamics = mpc_dyn_model.quadruped_wb_dynamics_learned_contact_model # dynamics = mpc_dyn_model.quadruped_wb_dynamics_explicit_contact max_torque = 40 -min_torque = -40 \ No newline at end of file +min_torque = -40 + +def _lipa_settings(): + from primal_dual_lipa.types import SolverSettings + return SolverSettings( + max_iterations=2000, + η0=1e9, + η_update_factor=1.1, + µ_update_factor=0.9, + cost_improvement_threshold=1e-3, + primal_violation_threshold=1e-5, + use_parallel_lqr=True, + num_parallel_line_search_steps=8, + ) + +lipa_settings = _lipa_settings() diff --git a/mpx/config/config_h1_jump_forward.py b/mpx/config/config_h1_jump_forward.py index a13822b..ca0fcec 100644 --- a/mpx/config/config_h1_jump_forward.py +++ b/mpx/config/config_h1_jump_forward.py @@ -58,3 +58,18 @@ solver_mode = "fddp" max_torque = base.max_torque min_torque = base.min_torque + +def _lipa_settings(): + from primal_dual_lipa.types import SolverSettings + return SolverSettings( + max_iterations=2000, + η0=1e9, + η_update_factor=1.0, + µ_update_factor=0.9, + cost_improvement_threshold=1e-3, + primal_violation_threshold=1e-5, + use_parallel_lqr=True, + num_parallel_line_search_steps=8, + ) + +lipa_settings = _lipa_settings() diff --git a/mpx/examples/acrobot.py b/mpx/examples/acrobot.py index 7d71ae3..15423a6 100644 --- a/mpx/examples/acrobot.py +++ b/mpx/examples/acrobot.py @@ -228,6 +228,6 @@ def step_controller(viewer=None): parser = argparse.ArgumentParser() parser.add_argument("--headless", action="store_true") parser.add_argument("--steps", type=int, default=500) - parser.add_argument("--solver", choices=("primal_dual", "fddp"), default="primal_dual") + parser.add_argument("--solver", choices=("primal_dual", "fddp", "lipa"), default="primal_dual") args = parser.parse_args() main(headless=args.headless, steps=args.steps, solver_mode=args.solver) diff --git a/mpx/examples/offline_task.py b/mpx/examples/offline_task.py index c72fe0c..a80992c 100644 --- a/mpx/examples/offline_task.py +++ b/mpx/examples/offline_task.py @@ -52,7 +52,7 @@ }, } -SOLVERS = ("primal_dual", "fddp") +SOLVERS = ("primal_dual", "fddp", "lipa") def _clone_config(module_name, solver_mode): @@ -88,29 +88,46 @@ def _solve_wrapper_task(config, max_iter, verbose): def _solve_direct_task(config, max_iter, verbose): - _, solve = base_mpc_wrapper.build_solver_step( - config, - config.cost, - config.dynamics, - config.hessian_approx, - limited_memory=False, - ) - solve = jax.jit(solve) - X, U, V, history, stats = offline_solver.run_offline_solve( - solve, - config.cost, - config.dynamics, - config.solver_mode, - config.reference, - config.parameter, - config.W, - config.x0, - config.initial_X0, - config.initial_U0, - config.initial_V0, - max_iter=max_iter, - verbose=verbose, - ) + if getattr(config, "solver_mode", None) == "lipa": + from mpx.utils.lipa_solver import run_lipa_offline + + X, U, V, history, stats = run_lipa_offline( + config.cost, + config.dynamics, + config.reference, + config.parameter, + config.W, + config.x0, + config.initial_X0, + config.initial_U0, + config.initial_V0, + settings=getattr(config, "lipa_settings", None), + verbose=verbose, + ) + else: + _, solve = base_mpc_wrapper.build_solver_step( + config, + config.cost, + config.dynamics, + config.hessian_approx, + limited_memory=False, + ) + solve = jax.jit(solve) + X, U, V, history, stats = offline_solver.run_offline_solve( + solve, + config.cost, + config.dynamics, + config.solver_mode, + config.reference, + config.parameter, + config.W, + config.x0, + config.initial_X0, + config.initial_U0, + config.initial_V0, + max_iter=max_iter, + verbose=verbose, + ) return { "config": config, "X": X, diff --git a/mpx/utils/lipa_solver.py b/mpx/utils/lipa_solver.py new file mode 100644 index 0000000..fc2e800 --- /dev/null +++ b/mpx/utils/lipa_solver.py @@ -0,0 +1,212 @@ +"""Adapter that exposes the Primal-Dual LIPA solver via the mpx solver API. + +mpx solvers all share the signature + solve(reference, parameter, W, x0, X0, U0, V0) -> (X, U, V) +with V having shape (N+1, n). LIPA expects a different problem statement +(`Variables` pytree, cost/dynamics with a (x, u, theta, t) signature, no +externalised W/reference/parameter). This module bridges the two. + +Note on offline use vs mpx's other solvers: mpx's primal_dual / fddp do +*one* SQP/iLQR step per call and rely on `run_offline_solve`'s outer loop +to converge. LIPA is a complete NLP solver — its main loop schedules µ +(IPM barrier) and η (per-constraint AL penalty) internally. Calling it +many times restarts those parameters at every call (see +`primal_dual_lipa.optimizers.solve` lines 78-81), wasting iterations and +producing misleading benchmark numbers. So for offline mode use +`run_lipa_offline`, which calls LIPA exactly once and reports its +internal iteration count and wall time. +""" + +from functools import partial +from timeit import default_timer as timer + +import jax +import jax.numpy as jnp +import numpy as np + +from primal_dual_lipa.optimizers import solve as lipa_solve +from primal_dual_lipa.types import SolverSettings, Variables + + +def _wrap_cost(cost): + def lipa_cost(W, reference, x, u, theta, t): + del theta + return cost(W, reference, x, u, t) + + return lipa_cost + + +def _wrap_dynamics(dynamics): + def lipa_dynamics(parameter, x, u, theta, t): + del theta + return dynamics(x, u, t, parameter=parameter) + + return lipa_dynamics + + +@partial(jax.jit, static_argnums=(0, 1)) +def _lipa_solve_with_stats(cost, dynamics, settings, reference, parameter, W, x0, X_in, U_in, V_in): + """Single LIPA call that returns the final variables plus solver stats.""" + + lipa_cost = partial(_wrap_cost(cost), W, reference) + lipa_dynamics = partial(_wrap_dynamics(dynamics), parameter) + + T = U_in.shape[0] + vars_in = Variables( + X=X_in, + U=U_in, + S=jnp.zeros((T + 1, 0), dtype=X_in.dtype), + Y_dyn=V_in, + Y_eq=jnp.zeros((T + 1, 0), dtype=X_in.dtype), + Z=jnp.zeros((T + 1, 0), dtype=X_in.dtype), + Theta=jnp.empty(0, dtype=X_in.dtype), + ) + + vars_out, iterations, no_errors = lipa_solve( + vars_in=vars_in, + x0=x0, + cost=lipa_cost, + dynamics=lipa_dynamics, + settings=settings, + ) + return vars_out.X, vars_out.U, vars_out.Y_dyn, iterations, no_errors + + +def _default_settings(): + """Pick conservative defaults for an unseen problem. + + The goal here is robustness, not peak performance. Aggressive + settings belong as per-config `lipa_settings` overrides. + """ + + on_gpu = any(d.platform == "gpu" for d in jax.devices()) + common = dict( + max_iterations=2000, + η0=1e3, + η_update_factor=1.0, + µ_update_factor=0.9, + cost_improvement_threshold=1e-3, + primal_violation_threshold=1e-5, + ) + if on_gpu: + return SolverSettings( + use_parallel_lqr=True, + num_parallel_line_search_steps=8, + **common, + ) + return SolverSettings(**common) + + +def build_lipa_solve(cost, dynamics, settings=None): + """Return a `solve(reference, parameter, W, x0, X0, U0, V0) -> (X, U, V)`. + + Used by online MPC (e.g. `MPCWrapper.run`). For offline benchmarks, + prefer `run_lipa_offline`, which is a single-call path that surfaces + LIPA's internal iteration count and avoids resetting µ/η repeatedly. + + Defaults differ by backend (parallel LQR + parallel line search on GPU). + Override via `config.lipa_settings`. + """ + + if settings is None: + settings = _default_settings() + + def solve(reference, parameter, W, x0, X0, U0, V0): + X, U, V, _iters, _no_errors = _lipa_solve_with_stats( + cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + ) + return X, U, V + + return solve + + +def run_lipa_offline( + cost, + dynamics, + reference, + parameter, + W, + x0, + X0, + U0, + V0, + *, + settings=None, + warmup=True, + verbose=True, +): + """Solve a single OCP with LIPA and return stats matching `run_offline_solve`. + + Unlike `run_offline_solve`, which loops one-step solvers until cost + plateaus, this calls LIPA exactly once. Reported `n_iterations` is + LIPA's internal IPM iteration count. + """ + + from mpx.jax_ocp_solvers.jax_ocp_solvers import optimizers as ocp_opt + + if settings is None: + settings = _default_settings() + + offline_cost = partial(cost, W, reference) + offline_dynamics = partial(dynamics, parameter=parameter) + model_evaluator = jax.jit( + partial(ocp_opt.model_evaluator_helper, offline_cost, offline_dynamics, x0) + ) + + g0, c0 = model_evaluator(X0, U0) + initial_objective = float(g0) + initial_l2_cost = float(np.sqrt(np.sum(np.asarray(g0) * np.asarray(g0)))) + initial_dynamics_violation = float(np.sum(np.asarray(c0) * np.asarray(c0))) + + if verbose: + print("{:<10} {:<20} {:<20} {:<20}".format("Iter", "Cost", "Constraint", "Time [ms]")) + print("{:<10d} {:<20.5f} {:<20.5f} {:<20}".format(0, initial_l2_cost, initial_dynamics_violation, "-")) + + if warmup: + Xw, _, _, _, _ = _lipa_solve_with_stats( + cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + ) + Xw.block_until_ready() + + start = timer() + X, U, V, iterations, no_errors = _lipa_solve_with_stats( + cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + ) + X.block_until_ready() + stop = timer() + iteration_time_ms = 1e3 * (stop - start) + + g, c = model_evaluator(X, U) + final_objective = float(g) + final_l2_cost = float(np.sqrt(np.sum(np.asarray(g) * np.asarray(g)))) + final_dynamics_violation = float(np.sum(np.asarray(c) * np.asarray(c))) + n_iters = int(iterations) + converged = bool(no_errors) + + if verbose: + print( + "{:<10d} {:<20.5f} {:<20.5f} {:<20.5f}".format( + 1, final_l2_cost, final_dynamics_violation, iteration_time_ms + ) + ) + print(f" LIPA internal iterations: {n_iters}, no_errors: {converged}") + + history = [X0, X] + stats = { + "n_iterations": n_iters, + "converged": converged, + "warmup_discarded": warmup, + "objective_history": [initial_objective, final_objective], + "l2_cost_history": [initial_l2_cost, final_l2_cost], + "dynamics_violation_history": [initial_dynamics_violation, final_dynamics_violation], + "metric_iteration_history": [0, 1], + "iteration_time_ms_history": [iteration_time_ms], + "initial_objective": initial_objective, + "initial_l2_cost": initial_l2_cost, + "initial_dynamics_violation": initial_dynamics_violation, + "average_iteration_time_ms": iteration_time_ms, + "final_objective": final_objective, + "final_l2_cost": final_l2_cost, + "final_dynamics_violation": final_dynamics_violation, + } + return X, U, V, history, stats diff --git a/mpx/utils/mpc_wrapper.py b/mpx/utils/mpc_wrapper.py index 4ca4b59..43e0810 100644 --- a/mpx/utils/mpc_wrapper.py +++ b/mpx/utils/mpc_wrapper.py @@ -7,6 +7,7 @@ from mujoco.mjx._src.dataclasses import PyTreeNode from mpx.jax_ocp_solvers.jax_ocp_solvers import optimizers +from mpx.utils.lipa_solver import build_lipa_solve, run_lipa_offline import mpx.utils.offline_solver as offline_solver import mpx.utils.mpc_utils as mpc_utils @@ -54,6 +55,11 @@ def solve(reference, parameter, W, x0, X0, U0, V0): return solver_mode, solve + if solver_mode == "lipa": + lipa_settings = getattr(config, "lipa_settings", None) + solve = build_lipa_solve(cost, dynamics, settings=lipa_settings) + return solver_mode, solve + raise ValueError(f"Unsupported MPC solver_mode: {solver_mode}") @@ -326,21 +332,39 @@ def runOffline(self, qpos, qvel, *, return_stats=False, verbose=True, max_iter=1 U0 = self.initial_U0 V0 = self.initial_V0 - X0, U0, _, output, stats = offline_solver.run_offline_solve( - self._solve, - self.cost, - self.dynamics, - self.config.solver_mode, - reference, - parameter, - W, - x0, - X0, - U0, - V0, - max_iter=max_iter, - verbose=verbose, - ) + if self.solver_mode == "lipa": + # LIPA is a complete NLP solver; one call converges. Looping it + # restarts the IPM µ/η each time, which inflates "iterations" + # and wall time without improving the solution. + X0, U0, _, output, stats = run_lipa_offline( + self.cost, + self.dynamics, + reference, + parameter, + W, + x0, + X0, + U0, + V0, + settings=getattr(self.config, "lipa_settings", None), + verbose=verbose, + ) + else: + X0, U0, _, output, stats = offline_solver.run_offline_solve( + self._solve, + self.cost, + self.dynamics, + self.config.solver_mode, + reference, + parameter, + W, + x0, + X0, + U0, + V0, + max_iter=max_iter, + verbose=verbose, + ) if return_stats: return X0, U0, reference, output, stats diff --git a/pyproject.toml b/pyproject.toml index 1bd39d6..b6a1ac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ dependencies = [ "jax[cuda12]", "mujoco", "mujoco-mjx", - "trajax @ git+https://github.com/google/trajax" + "trajax @ git+https://github.com/google/trajax", + "primal-dual-lipa @ git+https://github.com/joaospinto/primal-dual-lipa" ] [project.urls] From 6ef875b1ca91f1d404322cc13e4bdd3c5934d701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Sousa-Pinto?= Date: Tue, 12 May 2026 23:50:21 -0700 Subject: [PATCH 2/3] Add support for constrained LIPA solves of offline MPC problems. --- mpx/config/config_aliengo_trot_two_step.py | 4 + mpx/config/config_barrel_roll.py | 6 ++ mpx/config/config_h1_jump_forward.py | 21 ++++ mpx/examples/mjx_h1.py | 46 +++++++- mpx/examples/mjx_h1_kinodynamic.py | 46 +++++++- mpx/examples/mjx_quad.py | 56 +++++++++- mpx/examples/mjx_talos.py | 46 +++++++- mpx/examples/offline_task.py | 58 ++++++++-- mpx/examples/srbd_quad.py | 50 ++++++++- mpx/utils/lipa_solver.py | 116 +++++++++++++++++--- mpx/utils/mpc_wrapper.py | 61 ++++++++++- mpx/utils/objectives.py | 119 ++++++++++++++++----- mpx/utils/sim.py | 62 +++++++++++ 13 files changed, 620 insertions(+), 71 deletions(-) diff --git a/mpx/config/config_aliengo_trot_two_step.py b/mpx/config/config_aliengo_trot_two_step.py index 8009101..2874f63 100644 --- a/mpx/config/config_aliengo_trot_two_step.py +++ b/mpx/config/config_aliengo_trot_two_step.py @@ -48,6 +48,8 @@ initial_state = base.initial_state cost = partial(mpc_objectives.quadruped_wb_obj, True, n_joints, n_contact, N) +cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, True, n_joints, n_contact, N) +inequalities = partial(mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 44.0, 10.0) hessian_approx = base.hessian_approx dynamics = base.dynamics @@ -65,6 +67,8 @@ max_torque = base.max_torque min_torque = base.min_torque +lipa_enforce_inequalities = True + def _lipa_settings(): from primal_dual_lipa.types import SolverSettings return SolverSettings( diff --git a/mpx/config/config_barrel_roll.py b/mpx/config/config_barrel_roll.py index ca6005d..9d793e6 100644 --- a/mpx/config/config_barrel_roll.py +++ b/mpx/config/config_barrel_roll.py @@ -88,6 +88,7 @@ ) cost = partial(mpc_objectives.quadruped_wb_obj, False, n_joints, n_contact, N) +cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, False, n_joints, n_contact, N) hessian_approx = None def dynamics(model, mjx_model, contact_id, body_id): @@ -107,6 +108,11 @@ def dynamics(model, mjx_model, contact_id, body_id): max_torque = 40 min_torque = -40 +inequalities = partial( + mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 50.0, 20.0 +) +lipa_enforce_inequalities = True + def _lipa_settings(): from primal_dual_lipa.types import SolverSettings return SolverSettings( diff --git a/mpx/config/config_h1_jump_forward.py b/mpx/config/config_h1_jump_forward.py index ca0fcec..f2e2277 100644 --- a/mpx/config/config_h1_jump_forward.py +++ b/mpx/config/config_h1_jump_forward.py @@ -41,6 +41,8 @@ torque_limits = base.torque_limits cost = partial(mpc_objectives.h1_kinodynamic_obj, n_joints, n_contact, N) +cost_smooth = partial(mpc_objectives.h1_kinodynamic_smooth_cost, n_joints, n_contact, N) +inequalities = partial(mpc_objectives.h1_kinodynamic_inequalities, n_joints, n_contact, 0.7) hessian_approx = base.hessian_approx dynamics = base.dynamics MPCWrapper = base.MPCWrapper @@ -59,6 +61,8 @@ max_torque = base.max_torque min_torque = base.min_torque +lipa_enforce_inequalities = True + def _lipa_settings(): from primal_dual_lipa.types import SolverSettings return SolverSettings( @@ -68,8 +72,25 @@ def _lipa_settings(): µ_update_factor=0.9, cost_improvement_threshold=1e-3, primal_violation_threshold=1e-5, + num_iterative_refinement_steps=2, use_parallel_lqr=True, num_parallel_line_search_steps=8, ) lipa_settings = _lipa_settings() + +def _lipa_settings_enforce(): + from primal_dual_lipa.types import SolverSettings + return SolverSettings( + max_iterations=500, + η0=1e5, + η_update_factor=2.0, + µ_update_factor=0.9, + cost_improvement_threshold=1e-3, + primal_violation_threshold=1e-5, + num_iterative_refinement_steps=2, + use_parallel_lqr=True, + num_parallel_line_search_steps=8, + ) + +lipa_settings_enforce = _lipa_settings_enforce() diff --git a/mpx/examples/mjx_h1.py b/mpx/examples/mjx_h1.py index 7969fb7..df6944f 100644 --- a/mpx/examples/mjx_h1.py +++ b/mpx/examples/mjx_h1.py @@ -7,6 +7,10 @@ sys.path.append(os.path.abspath(os.path.join(dir_path, ".."))) os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=") +if "--video" in sys.argv: + os.environ.setdefault("MUJOCO_GL", "egl") + os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + import jax import jax.numpy as jnp import mujoco @@ -36,7 +40,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact): return solve_mpc -def main(steps=500): +def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False): model = mujoco.MjModel.from_xml_path( dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml" ) @@ -45,7 +49,7 @@ def main(steps=500): model.opt.timestep = 1 / sim_frequency mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True) - command_handle = sim_utils.KeyboardVelocityCommand() + command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz) solve_mpc = _build_solve_fn(mpc) reset_mpc = jax.jit(mpc.reset) @@ -102,6 +106,27 @@ def step_controller(): mujoco.mj_step(model, data) counter += 1 + if headless or video is not None: + recorder = None + capture_period = max(1, int(round(sim_frequency / fps))) + if video is not None: + os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True) + recorder = sim_utils.VideoRecorder(model, video, fps=fps) + p_start = np.asarray(data.qpos[:3]).copy() + try: + for i in range(steps): + step_controller() + if recorder is not None and i % capture_period == 0: + recorder.capture(data) + finally: + if recorder is not None: + recorder.close() + print(f"Wrote video: {video}") + p_end = np.asarray(data.qpos[:3]) + delta = p_end - p_start + print(f"Base position: start={p_start} end={p_end} delta={delta}") + return + with mujoco.viewer.launch_passive( model, data, @@ -119,5 +144,20 @@ def step_controller(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=500) + parser.add_argument("--headless", action="store_true") + parser.add_argument("--video", type=str, default=None, + help="Write an mp4 of the run to this path (forces headless).") + parser.add_argument("--vx", type=float, default=0.0) + parser.add_argument("--vy", type=float, default=0.0) + parser.add_argument("--wz", type=float, default=0.0) + parser.add_argument("--fps", type=int, default=30) args = parser.parse_args() - main(steps=args.steps) + main( + steps=args.steps, + video=args.video, + vx=args.vx, + vy=args.vy, + wz=args.wz, + fps=args.fps, + headless=args.headless, + ) diff --git a/mpx/examples/mjx_h1_kinodynamic.py b/mpx/examples/mjx_h1_kinodynamic.py index 043a29e..39d2893 100644 --- a/mpx/examples/mjx_h1_kinodynamic.py +++ b/mpx/examples/mjx_h1_kinodynamic.py @@ -7,6 +7,10 @@ sys.path.append(os.path.abspath(os.path.join(dir_path, ".."))) os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=") +if "--video" in sys.argv: + os.environ.setdefault("MUJOCO_GL", "egl") + os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + import jax import jax.numpy as jnp import mujoco @@ -35,7 +39,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact): return solve_mpc -def main(steps=500): +def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False): model = mujoco.MjModel.from_xml_path( dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml" ) @@ -44,7 +48,7 @@ def main(steps=500): model.opt.timestep = 1 / sim_frequency mpc = config.MPCWrapper(config, limited_memory=True) - command_handle = sim_utils.KeyboardVelocityCommand() + command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz) solve_mpc = _build_solve_fn(mpc) reset_mpc = jax.jit(mpc.reset) @@ -100,6 +104,27 @@ def step_controller(): mujoco.mj_step(model, data) counter += 1 + if headless or video is not None: + recorder = None + capture_period = max(1, int(round(sim_frequency / fps))) + if video is not None: + os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True) + recorder = sim_utils.VideoRecorder(model, video, fps=fps) + p_start = np.asarray(data.qpos[:3]).copy() + try: + for i in range(steps): + step_controller() + if recorder is not None and i % capture_period == 0: + recorder.capture(data) + finally: + if recorder is not None: + recorder.close() + print(f"Wrote video: {video}") + p_end = np.asarray(data.qpos[:3]) + delta = p_end - p_start + print(f"Base position: start={p_start} end={p_end} delta={delta}") + return + with mujoco.viewer.launch_passive( model, data, @@ -117,5 +142,20 @@ def step_controller(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=500) + parser.add_argument("--headless", action="store_true") + parser.add_argument("--video", type=str, default=None, + help="Write an mp4 of the run to this path (forces headless).") + parser.add_argument("--vx", type=float, default=0.0) + parser.add_argument("--vy", type=float, default=0.0) + parser.add_argument("--wz", type=float, default=0.0) + parser.add_argument("--fps", type=int, default=30) args = parser.parse_args() - main(steps=args.steps) + main( + steps=args.steps, + video=args.video, + vx=args.vx, + vy=args.vy, + wz=args.wz, + fps=args.fps, + headless=args.headless, + ) diff --git a/mpx/examples/mjx_quad.py b/mpx/examples/mjx_quad.py index c223900..880ca75 100644 --- a/mpx/examples/mjx_quad.py +++ b/mpx/examples/mjx_quad.py @@ -8,6 +8,12 @@ sys.path.append(os.path.abspath(os.path.join(dir_path, ".."))) os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=") +# Headless video recording uses `mujoco.Renderer`, which requires an OpenGL +# backend to be configured before the first `import mujoco` in the process. +if "--video" in sys.argv: + os.environ.setdefault("MUJOCO_GL", "egl") + os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + import jax import jax.numpy as jnp import mujoco @@ -37,7 +43,16 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact): return solve_mpc -def main(headless=False, steps=500, scene="flat"): +def main( + headless=False, + steps=500, + scene="flat", + video=None, + vx=0.0, + vy=0.0, + wz=0.0, + fps=30, +): model = mujoco.MjModel.from_xml_path( dir_path + f"/../data/aliengo/scene_{scene}.xml" ) @@ -47,7 +62,9 @@ def main(headless=False, steps=500, scene="flat"): contact_ids = sim_utils.geom_ids(model, config.contact_frame) mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True) - command_handle = sim_utils.KeyboardVelocityCommand() + # Headless+video: scripted velocity (no keyboard); viewer mode keeps the + # interactive arrow-key handle. + command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz) solve_mpc = _build_solve_fn(mpc) reset_mpc = jax.jit(mpc.reset) @@ -112,9 +129,25 @@ def step_controller(): mujoco.mj_step(model, data) counter += 1 - if headless: - for _ in range(steps): - step_controller() + if headless or video is not None: + recorder = None + capture_period = max(1, int(round(sim_frequency / fps))) + if video is not None: + os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True) + recorder = sim_utils.VideoRecorder(model, video, fps=fps) + p_start = np.asarray(data.qpos[:3]).copy() + try: + for i in range(steps): + step_controller() + if recorder is not None and i % capture_period == 0: + recorder.capture(data) + finally: + if recorder is not None: + recorder.close() + print(f"Wrote video: {video}") + p_end = np.asarray(data.qpos[:3]) + delta = p_end - p_start + print(f"Base position: start={p_start} end={p_end} delta={delta}") return with mujoco.viewer.launch_passive( @@ -141,9 +174,22 @@ def step_controller(): parser.add_argument("--steps", type=int, default=500) parser.add_argument("--scene", type=str, default="flat") parser.add_argument("--headless", action="store_true") + parser.add_argument("--video", type=str, default=None, + help="Write an mp4 of the run to this path (forces headless).") + parser.add_argument("--vx", type=float, default=0.0, + help="Forward velocity command (m/s) for headless/video runs.") + parser.add_argument("--vy", type=float, default=0.0) + parser.add_argument("--wz", type=float, default=0.0, + help="Yaw-rate command (rad/s).") + parser.add_argument("--fps", type=int, default=30) args = parser.parse_args() main( headless=args.headless, steps=args.steps, scene=args.scene, + video=args.video, + vx=args.vx, + vy=args.vy, + wz=args.wz, + fps=args.fps, ) diff --git a/mpx/examples/mjx_talos.py b/mpx/examples/mjx_talos.py index d14ea2c..5c95e78 100644 --- a/mpx/examples/mjx_talos.py +++ b/mpx/examples/mjx_talos.py @@ -7,6 +7,10 @@ sys.path.append(os.path.abspath(os.path.join(dir_path, ".."))) os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=") +if "--video" in sys.argv: + os.environ.setdefault("MUJOCO_GL", "egl") + os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + import jax import jax.numpy as jnp import mujoco @@ -36,7 +40,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact): return solve_mpc -def main(steps=500): +def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False): model = mujoco.MjModel.from_xml_path( dir_path + "/../data/pal_talos/talos_motor_rough.xml" ) @@ -45,7 +49,7 @@ def main(steps=500): model.opt.timestep = 1 / sim_frequency mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True) - command_handle = sim_utils.KeyboardVelocityCommand() + command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz) solve_mpc = _build_solve_fn(mpc) reset_mpc = jax.jit(mpc.reset) @@ -102,6 +106,27 @@ def step_controller(): mujoco.mj_step(model, data) counter += 1 + if headless or video is not None: + recorder = None + capture_period = max(1, int(round(sim_frequency / fps))) + if video is not None: + os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True) + recorder = sim_utils.VideoRecorder(model, video, fps=fps) + p_start = np.asarray(data.qpos[:3]).copy() + try: + for i in range(steps): + step_controller() + if recorder is not None and i % capture_period == 0: + recorder.capture(data) + finally: + if recorder is not None: + recorder.close() + print(f"Wrote video: {video}") + p_end = np.asarray(data.qpos[:3]) + delta = p_end - p_start + print(f"Base position: start={p_start} end={p_end} delta={delta}") + return + with mujoco.viewer.launch_passive( model, data, @@ -119,5 +144,20 @@ def step_controller(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=500) + parser.add_argument("--headless", action="store_true") + parser.add_argument("--video", type=str, default=None, + help="Write an mp4 of the run to this path (forces headless).") + parser.add_argument("--vx", type=float, default=0.0) + parser.add_argument("--vy", type=float, default=0.0) + parser.add_argument("--wz", type=float, default=0.0) + parser.add_argument("--fps", type=int, default=30) args = parser.parse_args() - main(steps=args.steps) + main( + steps=args.steps, + video=args.video, + vx=args.vx, + vy=args.vy, + wz=args.wz, + fps=args.fps, + headless=args.headless, + ) diff --git a/mpx/examples/offline_task.py b/mpx/examples/offline_task.py index a80992c..77909e3 100644 --- a/mpx/examples/offline_task.py +++ b/mpx/examples/offline_task.py @@ -55,12 +55,14 @@ SOLVERS = ("primal_dual", "fddp", "lipa") -def _clone_config(module_name, solver_mode): +def _clone_config(module_name, solver_mode, lipa_enforce_inequalities=None): module = importlib.import_module(module_name) attrs = {name: getattr(module, name) for name in dir(module) if not name.startswith("__")} config = SimpleNamespace(**attrs) if solver_mode is not None: config.solver_mode = solver_mode + if lipa_enforce_inequalities is not None: + config.lipa_enforce_inequalities = lipa_enforce_inequalities return config @@ -90,9 +92,17 @@ def _solve_wrapper_task(config, max_iter, verbose): def _solve_direct_task(config, max_iter, verbose): if getattr(config, "solver_mode", None) == "lipa": from mpx.utils.lipa_solver import run_lipa_offline - + from mpx.utils.mpc_wrapper import lipa_pick_cost_and_inequalities + + ( + lipa_cost, + lipa_inequalities, + lipa_settings, + lipa_warmup_cost, + lipa_warmup_settings, + ) = lipa_pick_cost_and_inequalities(config, config.cost) X, U, V, history, stats = run_lipa_offline( - config.cost, + lipa_cost, config.dynamics, config.reference, config.parameter, @@ -101,7 +111,10 @@ def _solve_direct_task(config, max_iter, verbose): config.initial_X0, config.initial_U0, config.initial_V0, - settings=getattr(config, "lipa_settings", None), + settings=lipa_settings, + inequalities=lipa_inequalities, + warmup_cost=lipa_warmup_cost, + warmup_settings=lipa_warmup_settings, verbose=verbose, ) else: @@ -140,9 +153,11 @@ def _solve_direct_task(config, max_iter, verbose): } -def solve_task(task_name, solver_mode=None, max_iter=100, verbose=True): +def solve_task( + task_name, solver_mode=None, max_iter=100, verbose=True, lipa_enforce_inequalities=None +): task = TASKS[task_name] - config = _clone_config(task["config"], solver_mode) + config = _clone_config(task["config"], solver_mode, lipa_enforce_inequalities) benchmark_mode = task["benchmark_mode"] if benchmark_mode == "direct": result = _solve_direct_task(config, max_iter=max_iter, verbose=verbose) @@ -269,16 +284,27 @@ def _play_mujoco_trajectory(result, headless=False, loop=True, ghost_stride=1): time.sleep(config.dt) -def run_task(task_name, solver_mode=None, headless=False, max_iter=100, verbose=True, loop=True): +def run_task( + task_name, + solver_mode=None, + headless=False, + max_iter=100, + verbose=True, + loop=True, + lipa_enforce_inequalities=None, +): result = solve_task( task_name, solver_mode=solver_mode, max_iter=max_iter, verbose=verbose, + lipa_enforce_inequalities=lipa_enforce_inequalities, ) stats = result["stats"] + enforce = getattr(result["config"], "lipa_enforce_inequalities", False) + enforce_tag = " | enforce-ineq" if (result["config"].solver_mode == "lipa" and enforce) else "" print( - f"{task_name} | {result['config'].solver_mode} | " + f"{task_name} | {result['config'].solver_mode}{enforce_tag} | " f"iterations {stats['n_iterations']} | " f"avg iter time {stats['average_iteration_time_ms']:.3f} ms" ) @@ -301,6 +327,21 @@ def build_parser(default_task=None): parser.add_argument("--max-iter", type=int, default=100) parser.add_argument("--quiet", action="store_true") parser.add_argument("--no-loop", action="store_true") + enforce_group = parser.add_mutually_exclusive_group() + enforce_group.add_argument( + "--lipa-enforce-inequalities", + dest="lipa_enforce_inequalities", + action="store_true", + default=None, + help="(LIPA only) Enforce config inequalities as true constraints; overrides config attr.", + ) + enforce_group.add_argument( + "--no-lipa-enforce-inequalities", + dest="lipa_enforce_inequalities", + action="store_false", + default=None, + help="(LIPA only) Disable enforcement; revert to soft-penalty cost shared with FDDP/PD.", + ) return parser @@ -316,6 +357,7 @@ def main(default_task=None): max_iter=args.max_iter, verbose=not args.quiet, loop=not args.no_loop, + lipa_enforce_inequalities=args.lipa_enforce_inequalities, ) diff --git a/mpx/examples/srbd_quad.py b/mpx/examples/srbd_quad.py index 7547cd9..fbd1625 100644 --- a/mpx/examples/srbd_quad.py +++ b/mpx/examples/srbd_quad.py @@ -8,6 +8,10 @@ sys.path.append(os.path.abspath(os.path.join(dir_path, ".."))) os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=") +if "--video" in sys.argv: + os.environ.setdefault("MUJOCO_GL", "egl") + os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + import jax import jax.numpy as jnp import mujoco @@ -44,7 +48,16 @@ def _srbd_state(qpos, qvel): ) -def main(headless=False, steps=500, scene="flat"): +def main( + headless=False, + steps=500, + scene="flat", + video=None, + vx=0.0, + vy=0.0, + wz=0.0, + fps=30, +): model = mujoco.MjModel.from_xml_path( dir_path + f"/../data/aliengo/scene_{scene}.xml" ) @@ -53,7 +66,7 @@ def main(headless=False, steps=500, scene="flat"): model.opt.timestep = 1.0 / sim_frequency contact_ids = sim_utils.geom_ids(model, config.contact_frame) - command_handle = sim_utils.KeyboardVelocityCommand(vx=0.0, vy=0.0, wz=0.0) + command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz) mpc = mpc_wrapper_srbd.BatchedMPCControllerWrapper(config, n_env=1) _reset_to_initial_state(model, data) @@ -104,9 +117,25 @@ def step_controller(): mujoco.mj_step(model, data) counter += 1 - if headless: - for _ in range(steps): - step_controller() + if headless or video is not None: + recorder = None + capture_period = max(1, int(round(sim_frequency / fps))) + if video is not None: + os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True) + recorder = sim_utils.VideoRecorder(model, video, fps=fps) + p_start = np.asarray(data.qpos[:3]).copy() + try: + for i in range(steps): + step_controller() + if recorder is not None and i % capture_period == 0: + recorder.capture(data) + finally: + if recorder is not None: + recorder.close() + print(f"Wrote video: {video}") + p_end = np.asarray(data.qpos[:3]) + delta = p_end - p_start + print(f"Base position: start={p_start} end={p_end} delta={delta}") return with mujoco.viewer.launch_passive( @@ -132,9 +161,20 @@ def step_controller(): parser.add_argument("--steps", type=int, default=500) parser.add_argument("--scene", type=str, default="flat") parser.add_argument("--headless", action="store_true") + parser.add_argument("--video", type=str, default=None, + help="Write an mp4 of the run to this path (forces headless).") + parser.add_argument("--vx", type=float, default=0.0) + parser.add_argument("--vy", type=float, default=0.0) + parser.add_argument("--wz", type=float, default=0.0) + parser.add_argument("--fps", type=int, default=30) args = parser.parse_args() main( headless=args.headless, steps=args.steps, scene=args.scene, + video=args.video, + vx=args.vx, + vy=args.vy, + wz=args.wz, + fps=args.fps, ) diff --git a/mpx/utils/lipa_solver.py b/mpx/utils/lipa_solver.py index fc2e800..7eafd74 100644 --- a/mpx/utils/lipa_solver.py +++ b/mpx/utils/lipa_solver.py @@ -44,21 +44,47 @@ def lipa_dynamics(parameter, x, u, theta, t): return lipa_dynamics -@partial(jax.jit, static_argnums=(0, 1)) -def _lipa_solve_with_stats(cost, dynamics, settings, reference, parameter, W, x0, X_in, U_in, V_in): - """Single LIPA call that returns the final variables plus solver stats.""" +def _wrap_inequalities(inequalities): + def lipa_inequalities(reference, x, u, theta, t): + del theta + return inequalities(reference, x, u, t) + + return lipa_inequalities + + +def _empty_inequalities(reference, x, u, t): + del reference, x, u, t + return jnp.empty(0) + + +@partial(jax.jit, static_argnames=("cost", "dynamics", "inequalities")) +def _lipa_solve_with_stats( + cost, dynamics, inequalities, settings, reference, parameter, W, x0, X_in, U_in, V_in +): + """Single LIPA call that returns the final variables plus solver stats. + + `inequalities=None` keeps the prior behavior (no constraint blocks, ``g_dim=0``). + Otherwise the constraint shape is inferred from a trace-time evaluation of the + user callable on the warm-start sample. + """ lipa_cost = partial(_wrap_cost(cost), W, reference) lipa_dynamics = partial(_wrap_dynamics(dynamics), parameter) + ineq_callable = inequalities if inequalities is not None else _empty_inequalities + lipa_inequalities = partial(_wrap_inequalities(ineq_callable), reference) + T = U_in.shape[0] + sample_g = lipa_inequalities(X_in[0], U_in[0], jnp.empty(0, dtype=X_in.dtype), 0) + g_dim = sample_g.shape[0] + vars_in = Variables( X=X_in, U=U_in, - S=jnp.zeros((T + 1, 0), dtype=X_in.dtype), + S=jnp.zeros((T + 1, g_dim), dtype=X_in.dtype), Y_dyn=V_in, Y_eq=jnp.zeros((T + 1, 0), dtype=X_in.dtype), - Z=jnp.zeros((T + 1, 0), dtype=X_in.dtype), + Z=jnp.zeros((T + 1, g_dim), dtype=X_in.dtype), Theta=jnp.empty(0, dtype=X_in.dtype), ) @@ -67,6 +93,7 @@ def _lipa_solve_with_stats(cost, dynamics, settings, reference, parameter, W, x0 x0=x0, cost=lipa_cost, dynamics=lipa_dynamics, + inequalities=lipa_inequalities, settings=settings, ) return vars_out.X, vars_out.U, vars_out.Y_dyn, iterations, no_errors @@ -97,7 +124,7 @@ def _default_settings(): return SolverSettings(**common) -def build_lipa_solve(cost, dynamics, settings=None): +def build_lipa_solve(cost, dynamics, settings=None, *, inequalities=None): """Return a `solve(reference, parameter, W, x0, X0, U0, V0) -> (X, U, V)`. Used by online MPC (e.g. `MPCWrapper.run`). For offline benchmarks, @@ -105,7 +132,9 @@ def build_lipa_solve(cost, dynamics, settings=None): LIPA's internal iteration count and avoids resetting µ/η repeatedly. Defaults differ by backend (parallel LQR + parallel line search on GPU). - Override via `config.lipa_settings`. + Override via `config.lipa_settings`. Pass `inequalities=callable(reference, + x, u, t) -> g` to enforce ``g <= 0`` constraints; omit to keep the prior + inequality-free behavior shared with the FDDP / primal-dual solvers. """ if settings is None: @@ -113,7 +142,7 @@ def build_lipa_solve(cost, dynamics, settings=None): def solve(reference, parameter, W, x0, X0, U0, V0): X, U, V, _iters, _no_errors = _lipa_solve_with_stats( - cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0 ) return X, U, V @@ -132,6 +161,9 @@ def run_lipa_offline( V0, *, settings=None, + inequalities=None, + warmup_cost=None, + warmup_settings=None, warmup=True, verbose=True, ): @@ -140,6 +172,16 @@ def run_lipa_offline( Unlike `run_offline_solve`, which loops one-step solvers until cost plateaus, this calls LIPA exactly once. Reported `n_iterations` is LIPA's internal IPM iteration count. + + Two-phase warm start: if `warmup_cost` is provided (typically the soft- + penalty version of `cost`), an initial LIPA solve is run on that + inequality-free formulation, then the main inequality-enforcing solve + starts from its result. This sidesteps a class of local-basin pitfalls + where the AL term η·Jᵀc dominates and the IPM parks at a degenerate + iterate (e.g. on barrel_roll, the multi-shooting quaternion defect at + the apex of the maneuver hits a sign-flip singularity that the cold- + start solve cannot escape). The warm-start phase uses the same LIPA + solver — this is not bootstrapping from a different solver. """ from mpx.jax_ocp_solvers.jax_ocp_solvers import optimizers as ocp_opt @@ -162,15 +204,52 @@ def run_lipa_offline( print("{:<10} {:<20} {:<20} {:<20}".format("Iter", "Cost", "Constraint", "Time [ms]")) print("{:<10d} {:<20.5f} {:<20.5f} {:<20}".format(0, initial_l2_cost, initial_dynamics_violation, "-")) - if warmup: + do_warmup_phase = warmup_cost is not None and inequalities is not None + warmup_phase_settings = warmup_settings if warmup_settings is not None else settings + warmup_iters = 0 + warmup_time_ms = 0.0 + + if do_warmup_phase: + # Phase 1: solve the inequality-free (soft-penalty) problem once and + # use its (X, U, V) as the warm start for phase 2. We deliberately do + # NOT call _lipa_solve_with_stats twice (warmup + timed) here — the + # parallel-LQR scan reduction is not bit-deterministic across + # back-to-back invocations of the same compiled function on the same + # inputs (different floating-point summation order can land on + # numerically different iterates), and on stiff problems like + # h1_jump_forward that's enough drift to make phase 2 sometimes + # converge in 100 iters and sometimes hit max_iterations. The trade + # here is mildly inaccurate phase-1 wall-time accounting (first call + # includes any JIT compile that wasn't already cached) for + # reproducible phase-2 starting iterates. + start = timer() + Xp1, Up1, Vp1, iters_p1, _ = _lipa_solve_with_stats( + warmup_cost, dynamics, None, warmup_phase_settings, + reference, parameter, W, x0, X0, U0, V0, + ) + Xp1.block_until_ready() + warmup_time_ms = 1e3 * (timer() - start) + warmup_iters = int(iters_p1) + if verbose: + print( + "{:<10s} {:<20s} {:<20s} {:<20.5f}".format( + "ph1", "(warmup)", "(warmup)", warmup_time_ms + ) + ) + print(f" Phase 1 (soft-penalty warm start): {warmup_iters} iters") + # Phase 2 starts from phase 1's iterate. + X0, U0, V0 = Xp1, Up1, Vp1 + + if warmup and not do_warmup_phase: + # Single-phase mode: traditional warmup-then-timed pattern. Xw, _, _, _, _ = _lipa_solve_with_stats( - cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0 ) Xw.block_until_ready() start = timer() X, U, V, iterations, no_errors = _lipa_solve_with_stats( - cost, dynamics, settings, reference, parameter, W, x0, X0, U0, V0 + cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0 ) X.block_until_ready() stop = timer() @@ -180,7 +259,7 @@ def run_lipa_offline( final_objective = float(g) final_l2_cost = float(np.sqrt(np.sum(np.asarray(g) * np.asarray(g)))) final_dynamics_violation = float(np.sum(np.asarray(c) * np.asarray(c))) - n_iters = int(iterations) + n_iters = int(iterations) + warmup_iters converged = bool(no_errors) if verbose: @@ -189,22 +268,29 @@ def run_lipa_offline( 1, final_l2_cost, final_dynamics_violation, iteration_time_ms ) ) - print(f" LIPA internal iterations: {n_iters}, no_errors: {converged}") + if do_warmup_phase: + print( + f" Phase 2 (constrained): {int(iterations)} iters, no_errors: {converged}\n" + f" Total LIPA internal iterations: {n_iters}" + ) + else: + print(f" LIPA internal iterations: {n_iters}, no_errors: {converged}") history = [X0, X] stats = { "n_iterations": n_iters, + "warmup_iterations": warmup_iters, "converged": converged, "warmup_discarded": warmup, "objective_history": [initial_objective, final_objective], "l2_cost_history": [initial_l2_cost, final_l2_cost], "dynamics_violation_history": [initial_dynamics_violation, final_dynamics_violation], "metric_iteration_history": [0, 1], - "iteration_time_ms_history": [iteration_time_ms], + "iteration_time_ms_history": [iteration_time_ms + warmup_time_ms], "initial_objective": initial_objective, "initial_l2_cost": initial_l2_cost, "initial_dynamics_violation": initial_dynamics_violation, - "average_iteration_time_ms": iteration_time_ms, + "average_iteration_time_ms": iteration_time_ms + warmup_time_ms, "final_objective": final_objective, "final_l2_cost": final_l2_cost, "final_dynamics_violation": final_dynamics_violation, diff --git a/mpx/utils/mpc_wrapper.py b/mpx/utils/mpc_wrapper.py index 43e0810..157cfc7 100644 --- a/mpx/utils/mpc_wrapper.py +++ b/mpx/utils/mpc_wrapper.py @@ -29,6 +29,41 @@ class MPCData(PyTreeNode): mpx_data = MPCData +def lipa_pick_cost_and_inequalities(config, cost): + """Pick the LIPA call configuration based on the config. + + Returns ``(main_cost, inequalities, main_settings, warmup_cost, + warmup_settings)``: + + * Off path (no enforce): main = ``cost`` (the soft-penalty cost), no + inequalities, settings from ``config.lipa_settings``. + * Enforce path: main = ``cost_smooth + inequalities`` with + ``config.lipa_settings_enforce or config.lipa_settings``. The warm-start + pair (``cost``, ``lipa_settings``) is also returned so the offline path + can do a two-phase solve — phase 1 on the inequality-free formulation, + phase 2 on the constrained one starting from phase 1's iterate. This + sidesteps local-basin pitfalls (notably the multi-shooting quaternion + singularity at the apex of the barrel-roll maneuver) without + bootstrapping from a different solver. + + Configs opt in by setting ``lipa_enforce_inequalities = True`` and + providing both ``cost_smooth`` and ``inequalities``. + """ + enforce = getattr(config, "lipa_enforce_inequalities", False) + base_settings = getattr(config, "lipa_settings", None) + if not enforce: + return cost, None, base_settings, None, None + cost_smooth = getattr(config, "cost_smooth", None) + inequalities = getattr(config, "inequalities", None) + if cost_smooth is None or inequalities is None: + raise ValueError( + "lipa_enforce_inequalities=True requires both `cost_smooth` and " + "`inequalities` to be defined on the config." + ) + enforce_settings = getattr(config, "lipa_settings_enforce", None) or base_settings + return cost_smooth, inequalities, enforce_settings, cost, base_settings + + def build_solver_step(config, cost, dynamics, hessian_approx, limited_memory): solver_mode = getattr(config, "solver_mode", "primal_dual") @@ -56,8 +91,16 @@ def solve(reference, parameter, W, x0, X0, U0, V0): return solver_mode, solve if solver_mode == "lipa": - lipa_settings = getattr(config, "lipa_settings", None) - solve = build_lipa_solve(cost, dynamics, settings=lipa_settings) + # Online MPC stays single-phase: per-step warm-start via the data + # carry already chains across calls, and a per-step phase-1 would + # double the compile + per-step compute. The two-phase flow is + # offline-only (see run_lipa_offline / runOffline). + lipa_cost, lipa_inequalities, lipa_settings, _, _ = lipa_pick_cost_and_inequalities( + config, cost + ) + solve = build_lipa_solve( + lipa_cost, dynamics, settings=lipa_settings, inequalities=lipa_inequalities + ) return solver_mode, solve raise ValueError(f"Unsupported MPC solver_mode: {solver_mode}") @@ -336,8 +379,15 @@ def runOffline(self, qpos, qvel, *, return_stats=False, verbose=True, max_iter=1 # LIPA is a complete NLP solver; one call converges. Looping it # restarts the IPM µ/η each time, which inflates "iterations" # and wall time without improving the solution. + ( + lipa_cost, + lipa_inequalities, + lipa_settings, + lipa_warmup_cost, + lipa_warmup_settings, + ) = lipa_pick_cost_and_inequalities(self.config, self.cost) X0, U0, _, output, stats = run_lipa_offline( - self.cost, + lipa_cost, self.dynamics, reference, parameter, @@ -346,7 +396,10 @@ def runOffline(self, qpos, qvel, *, return_stats=False, verbose=True, max_iter=1 X0, U0, V0, - settings=getattr(self.config, "lipa_settings", None), + settings=lipa_settings, + inequalities=lipa_inequalities, + warmup_cost=lipa_warmup_cost, + warmup_settings=lipa_warmup_settings, verbose=verbose, ) else: diff --git a/mpx/utils/objectives.py b/mpx/utils/objectives.py index 9c14941..b69c3fa 100644 --- a/mpx/utils/objectives.py +++ b/mpx/utils/objectives.py @@ -74,8 +74,40 @@ def friction_constraint(u): H_constraint = J_friction_cone(u).T@H_penalty@J_friction_cone(u) return J_x(x,u).T@W@J_x(x,u), J_u(x,u).T@W@J_u(x,u) + H_constraint, J_x(x,u).T@W@J_u(x,u) -def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t): - +def _quadruped_wb_constraint_slacks(n_joints, n_contact, mu, torque_limit, dq_limit, x, u, friction_eps=1e-2): + grf = x[13 + 2 * n_joints + 3 * n_contact:] + tau = u[:n_joints] + dq = x[13 + n_joints:13 + 2 * n_joints] + Fx = grf[0::3] + Fy = grf[1::3] + Fz = grf[2::3] + s_friction = mu * Fz - jnp.sqrt(jnp.square(Fx) + jnp.square(Fy) + jnp.ones(n_contact) * friction_eps) + sym = jnp.kron(jnp.eye(n_joints), jnp.array([-1.0, 1.0])).T + s_torque = sym @ tau + (torque_limit + 1e-2) + s_dq = sym @ dq + (dq_limit + 1e-2) + return s_friction, s_torque, s_dq + + +def quadruped_wb_inequalities( + n_joints, n_contact, mu, torque_limit, dq_limit, reference, x, u, t, friction_eps=1e-12 +): + """LIPA-form inequalities ``g(x,u,t) <= 0`` for the quadruped whole-body problem. + + Friction is gated by the reference contact mask (vacuous in swing); torque and + joint-speed limits are always active. At the terminal stage there is no control + input, so all entries collapse to zero. + """ + s_friction, s_torque, s_dq = _quadruped_wb_constraint_slacks( + n_joints, n_contact, mu, torque_limit, dq_limit, x, u, friction_eps=friction_eps + ) + contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact] + g = jnp.concatenate([-contact * s_friction, -s_torque, -s_dq]) + N = reference.shape[0] - 1 + return jnp.where(t == N, jnp.zeros_like(g), g) + + +def quadruped_wb_smooth_cost(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t): + """Stage cost without any soft-inequality penalties (friction/torque/dq).""" p = x[:3] quat = x[3:7] q = x[7:7+n_joints] @@ -94,18 +126,6 @@ def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t): p_leg_ref = reference[t,13+n_joints:13+n_joints+3*n_contact] contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact] grf_ref = reference[t,13+n_joints+4*n_contact:13+n_joints+7*n_contact] - mu = 0.5 - friction_cone = mu*grf[2::3] - jnp.sqrt(jnp.square(grf[1::3]) + jnp.square(grf[::3]) + jnp.ones(n_contact)*1e-2) - friction_cone = penalty(friction_cone) - torque_limits = jnp.array([ - 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, - 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44 ]) - #min grf - # min_force = grf[2::3] - jnp.ones(n_contact)*10 - torque_limits = jnp.kron(jnp.eye(n_joints),(jnp.array([-1,1]))).T@tau+torque_limits + jnp.ones_like(torque_limits)*1e-2 - - joint_speed_limits = jnp.ones(2*n_joints)*10 - joint_speed_limits = jnp.kron(jnp.eye(n_joints),(jnp.array([-1,1]))).T@dq + joint_speed_limits + jnp.ones_like(joint_speed_limits)*1e-2 if swing_tracking: contact_map = jnp.ones(3*n_contact) @@ -116,14 +136,26 @@ def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t): (dp - dp_ref).T @ W[6+n_joints:9+n_joints,6+n_joints:9+n_joints] @ (dp - dp_ref) + (omega - omega_ref).T @ W[9+n_joints:12+n_joints,9+n_joints:12+n_joints] @ (omega - omega_ref) + dq.T @ W[12+n_joints:12+2*n_joints,12+n_joints:12+2*n_joints] @ dq +\ (contact_map*(p_leg - p_leg_ref)).T @W[12+2*n_joints:12+2*n_joints+3*n_contact,12+2*n_joints:12+2*n_joints+3*n_contact]@ (contact_map*(p_leg - p_leg_ref))+ \ tau.T @ W[12+2*n_joints+3*n_contact:12+3*n_joints+3*n_contact,12+2*n_joints+3*n_contact:12+3*n_joints+3*n_contact] @ tau +\ - (grf-grf_ref).T @ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact] @ (grf-grf_ref) +\ - jnp.sum(penalty(torque_limits,1,1)) + jnp.sum(friction_cone*contact) + jnp.sum(penalty(joint_speed_limits,1,1)) + (grf-grf_ref).T @ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact] @ (grf-grf_ref) term_cost = (p - p_ref).T @ W[:3,:3] @ (p - p_ref) + math.quat_sub(quat,quat_ref).T@W[3:6,3:6]@math.quat_sub(quat,quat_ref) + (q - q_ref).T @ W[6:6+n_joints,6:6+n_joints] @ (q - q_ref) +\ (dp - dp_ref).T @ W[6+n_joints:9+n_joints,6+n_joints:9+n_joints] @ (dp - dp_ref) + (omega - omega_ref).T @ W[9+n_joints:12+n_joints,9+n_joints:12+n_joints] @ (omega - omega_ref) + dq.T @ W[12+n_joints:12+2*n_joints,12+n_joints:12+2*n_joints] @ dq - return jnp.where(t == N, 0.5 * term_cost, 0.5 * stage_cost) + +def quadruped_wb_obj(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t): + smooth = quadruped_wb_smooth_cost(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t) + s_friction, s_torque, s_dq = _quadruped_wb_constraint_slacks( + n_joints, n_contact, 0.5, 44.0, 10.0, x, u + ) + contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact] + soft = ( + jnp.sum(penalty(s_friction) * contact) + + jnp.sum(penalty(s_torque, 1, 1)) + + jnp.sum(penalty(s_dq, 1, 1)) + ) + return smooth + jnp.where(t == N, 0.0, 0.5 * soft) + def quadruped_wb_hessian_gn(swing_tracking,n_joints,n_contact,W,reference,x, u, t): contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact] @@ -323,7 +355,43 @@ def torque_constraint(u): return J_x(x,u).T@W@J_x(x,u), J_u(x,u).T@W@J_u(x,u), J_x(x,u).T@W@J_u(x,u) -def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t): +def _h1_kinodynamic_friction_slack(n_joints, n_contact, mu, u, friction_eps=1e-1): + grf = u[n_joints:] + Fx = grf[0::3] + Fy = grf[1::3] + Fz = grf[2::3] + return mu * Fz - jnp.sqrt(jnp.square(Fx) + jnp.square(Fy) + jnp.ones(n_contact) * friction_eps) + + +def h1_kinodynamic_inequalities(n_joints, n_contact, mu, reference, x, u, t, friction_eps=1e-12): + """LIPA-form ``g <= 0`` inequalities for the H1 kinodynamic problem. + + Two physical constraints, both gated by the reference contact mask + (vacuous during swing): + + * ``Fz >= 0`` — a foot can only push into the ground, not pull. Without + this, the soft-penalty optimizer happily uses negative Fz to "anchor" + the foot, which produces unphysical jump take-offs and breaks the + Coulomb-cone interpretation: with Fz < 0 and `g = mu*Fz - sqrt(Fx²+Fy²)` + the cone becomes infeasible by `≈ |mu*Fz|` regardless of (Fx, Fy). + * Friction cone: ``sqrt(Fx² + Fy²) <= mu * Fz``. + + Other limits (joint-velocity, torque) live elsewhere and are not enforced + as constraints in this solver. + """ + grf = u[n_joints:] + Fz = grf[2::3] + s_friction = _h1_kinodynamic_friction_slack(n_joints, n_contact, mu, u, friction_eps=friction_eps) + contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact] + g_friction = -contact * s_friction + g_fz = -contact * Fz + g = jnp.concatenate([g_friction, g_fz]) + N = reference.shape[0] - 1 + return jnp.where(t == N, jnp.zeros_like(g), g) + + +def h1_kinodynamic_smooth_cost(n_joints, n_contact, N, W, reference, x, u, t): + """H1 kinodynamic stage cost with the friction soft-penalty stripped out.""" p = x[:3] quat = x[3:7] @@ -342,14 +410,8 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t): dp_ref = reference[t,7+n_joints:10+n_joints] omega_ref = reference[t,10+n_joints:13+n_joints] p_leg_ref = reference[t,13+n_joints:13+n_joints+3*n_contact] - contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact] grf_ref = reference[t,13+n_joints+4*n_contact:13+n_joints+7*n_contact] - mu = 0.7 - friction_cone = mu * grf[2::3] - jnp.sqrt( - jnp.square(grf[1::3]) + jnp.square(grf[::3]) + jnp.ones(n_contact) * 1e-1 - ) - stage_cost = ( (p - p_ref).T @ W[:3,:3] @ (p - p_ref) + math.quat_sub(quat,quat_ref).T @ W[3:6,3:6] @ math.quat_sub(quat,quat_ref) @@ -366,7 +428,6 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t): + (grf - grf_ref).T @ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact] @ (grf - grf_ref) - + jnp.sum(penalty(friction_cone) * contact) ) term_cost = ( (p - p_ref).T @ W[:3,:3] @ (p - p_ref) @@ -379,6 +440,14 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t): return jnp.where(t == N, 0.5 * term_cost, 0.5 * stage_cost) + +def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t): + smooth = h1_kinodynamic_smooth_cost(n_joints, n_contact, N, W, reference, x, u, t) + s_friction = _h1_kinodynamic_friction_slack(n_joints, n_contact, 0.7, u) + contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact] + soft = jnp.sum(penalty(s_friction) * contact) + return smooth + jnp.where(t == N, 0.0, 0.5 * soft) + def talos_wb_obj(n_joints,n_contact,N,W,reference,x, u, t): p = x[:3] diff --git a/mpx/utils/sim.py b/mpx/utils/sim.py index 55b2aae..cf0dcb9 100644 --- a/mpx/utils/sim.py +++ b/mpx/utils/sim.py @@ -393,3 +393,65 @@ def render_ghost_trajectory( ) return ghost_geoms, scratch_data + + +class VideoRecorder: + """Offscreen mp4 recorder built around `mujoco.Renderer`. + + Designed for headless execution of the online MPC examples (mjx_quad, + mjx_h1, ...). Tracks the robot base by default (lookat = qpos[:3]). + + Requires: + * `MUJOCO_GL=egl` (or another working backend) set BEFORE `import mujoco` + — see `enable_offscreen_gl_for_video()`. + * `imageio[ffmpeg]` for libx264 mp4 output. + """ + + def __init__( + self, + model: mujoco.MjModel, + path: str, + *, + fps: int = 30, + width: int = 640, + height: int = 480, + distance: float = 3.0, + azimuth: float = 90.0, + elevation: float = -20.0, + ): + import imageio # late import: only needed when recording + + self._renderer = mujoco.Renderer(model, height=height, width=width) + self._writer = imageio.get_writer( + path, + format="FFMPEG", + codec="libx264", + fps=fps, + macro_block_size=1, + ) + self._cam = mujoco.MjvCamera() + self._cam.distance = float(distance) + self._cam.azimuth = float(azimuth) + self._cam.elevation = float(elevation) + self._cam.lookat[:] = [0.0, 0.0, 0.0] + + def capture(self, data: mujoco.MjData, lookat: np.ndarray | None = None) -> None: + """Render and append one frame; lookat defaults to the floating-base position.""" + + if lookat is None: + lookat = np.asarray(data.qpos[:3], dtype=np.float64) + self._cam.lookat[:] = np.asarray(lookat, dtype=np.float64).reshape(3) + self._renderer.update_scene(data, self._cam) + self._writer.append_data(self._renderer.render()) + + def close(self) -> None: + try: + self._writer.close() + finally: + self._renderer.close() + + def __enter__(self): + return self + + def __exit__(self, *_exc): + self.close() From 4f40512813c16b2cf8d16ff4b96cd24df876e23e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Sousa-Pinto?= Date: Tue, 12 May 2026 23:50:21 -0700 Subject: [PATCH 3/3] Add real-time mp4 recording for offline OCP runs Adds `--video PATH` to `offline_task.py`, with playback fps defaulted to `1/config.dt` so trajectories play at wall-clock speed. The three LIPA configs switch to sequential LQR so offline runs converge bit-deterministically. --- mpx/config/config_aliengo_trot_two_step.py | 4 +- mpx/config/config_barrel_roll.py | 4 +- mpx/config/config_h1_jump_forward.py | 8 +-- mpx/data/acrobot/scene.xml | 2 +- mpx/data/aliengo/scene_flat.xml | 2 +- mpx/data/unitree_h1/mjx_scene_h1_walk.xml | 2 +- mpx/examples/offline_task.py | 69 +++++++++++++++++++++- mpx/utils/sim.py | 23 +++++++- 8 files changed, 99 insertions(+), 15 deletions(-) diff --git a/mpx/config/config_aliengo_trot_two_step.py b/mpx/config/config_aliengo_trot_two_step.py index 2874f63..22456ff 100644 --- a/mpx/config/config_aliengo_trot_two_step.py +++ b/mpx/config/config_aliengo_trot_two_step.py @@ -78,8 +78,8 @@ def _lipa_settings(): µ_update_factor=0.9, cost_improvement_threshold=1e-3, primal_violation_threshold=1e-5, - use_parallel_lqr=True, - num_parallel_line_search_steps=8, + use_parallel_lqr=False, + num_parallel_line_search_steps=1, ) lipa_settings = _lipa_settings() diff --git a/mpx/config/config_barrel_roll.py b/mpx/config/config_barrel_roll.py index 9d793e6..8e8bbee 100644 --- a/mpx/config/config_barrel_roll.py +++ b/mpx/config/config_barrel_roll.py @@ -122,8 +122,8 @@ def _lipa_settings(): µ_update_factor=0.9, cost_improvement_threshold=1e-3, primal_violation_threshold=1e-5, - use_parallel_lqr=True, - num_parallel_line_search_steps=8, + use_parallel_lqr=False, + num_parallel_line_search_steps=1, ) lipa_settings = _lipa_settings() diff --git a/mpx/config/config_h1_jump_forward.py b/mpx/config/config_h1_jump_forward.py index f2e2277..a9ec2cd 100644 --- a/mpx/config/config_h1_jump_forward.py +++ b/mpx/config/config_h1_jump_forward.py @@ -73,8 +73,8 @@ def _lipa_settings(): cost_improvement_threshold=1e-3, primal_violation_threshold=1e-5, num_iterative_refinement_steps=2, - use_parallel_lqr=True, - num_parallel_line_search_steps=8, + use_parallel_lqr=False, + num_parallel_line_search_steps=1, ) lipa_settings = _lipa_settings() @@ -89,8 +89,8 @@ def _lipa_settings_enforce(): cost_improvement_threshold=1e-3, primal_violation_threshold=1e-5, num_iterative_refinement_steps=2, - use_parallel_lqr=True, - num_parallel_line_search_steps=8, + use_parallel_lqr=False, + num_parallel_line_search_steps=1, ) lipa_settings_enforce = _lipa_settings_enforce() diff --git a/mpx/data/acrobot/scene.xml b/mpx/data/acrobot/scene.xml index 2ca267d..1fd4d9f 100644 --- a/mpx/data/acrobot/scene.xml +++ b/mpx/data/acrobot/scene.xml @@ -3,7 +3,7 @@