diff --git a/.claude/skills/benchmark/SKILL.md b/.claude/skills/benchmark/SKILL.md index 30b66ef90..0ea7b66b1 100644 --- a/.claude/skills/benchmark/SKILL.md +++ b/.claude/skills/benchmark/SKILL.md @@ -24,13 +24,35 @@ When a run completes (`dstack ps` shows `exited (0)`): 2. **Find HF folder name**: `dstack logs NAME 2>&1 | grep "Uploading data/"` → extract folder name from the upload log line 3. **Update table score** in BENCHMARKS.md 4. **Update table HF link**: `[FOLDER](https://huggingface.co/datasets/SLM-Lab/benchmark-dev/tree/main/data/FOLDER)` -5. **Pull HF data locally**: `source .env && hf download SLM-Lab/benchmark-dev --local-dir data/benchmark-dev --repo-type dataset --include "data/FOLDER/*"` -6. **Generate plot**: `uv run slm-lab plot -t "EnvName" -f data/benchmark-dev/data/FOLDER1,data/benchmark-dev/data/FOLDER2` +5. **Pull HF data locally**: `source .env && huggingface-cli download SLM-Lab/benchmark-dev --local-dir data/benchmark-dev --repo-type dataset --include "data/FOLDER/*"` +6. **Generate plot**: List ALL data folders for that env (`ls data/benchmark-dev/data/ | grep -i envname`), then generate with ONLY the folders matching BENCHMARKS.md entries: + ```bash + uv run slm-lab plot -t "EnvName" -d data/benchmark-dev/data -f FOLDER1,FOLDER2,... + ``` + NOTE: `-d` sets the base data dir, `-f` takes folder names (NOT full paths). + If some folders are in `data/` (local runs) and some in `data/benchmark-dev/data/`, use `data/` as base (it has the `info/` subfolder needed for metrics). 7. **Verify plot exists** in `docs/plots/` 8. **Commit** score + link + plot together A row in BENCHMARKS.md is NOT complete until it has: score, HF link, and plot. +## Per-Run Graduation Checklist + +**After intake, graduate each finalized run to public HF benchmark:** + +1. **Upload folder to public HF**: + ```bash + source .env && huggingface-cli upload SLM-Lab/benchmark data/benchmark-dev/data/FOLDER data/FOLDER --repo-type dataset + ``` +2. **Update BENCHMARKS.md link**: Change `SLM-Lab/benchmark-dev` → `SLM-Lab/benchmark` for that entry +3. **Upload docs/ to public HF** (updated plots + BENCHMARKS.md): + ```bash + source .env && huggingface-cli upload SLM-Lab/benchmark docs docs --repo-type dataset + source .env && huggingface-cli upload SLM-Lab/benchmark README.md README.md --repo-type dataset + ``` +4. **Commit** link update +5. **Push** to origin + ## Launch ```bash @@ -75,26 +97,28 @@ source .env && hf download SLM-Lab/benchmark-dev \ ### Generate Plots ```bash -# Find folders for a game +# Find folders for a game (check both local data/ and benchmark-dev) +ls data/ | grep -i pong ls data/benchmark-dev/data/ | grep -i pong -# Generate comparison plot (include all algorithms available) -uv run slm-lab plot -t "Pong" \ - -f data/benchmark-dev/data/ppo_folder,data/benchmark-dev/data/sac_folder +# Generate comparison plot — use -d for base dir, -f for folder names only +# Use data/ as base (has info/ subfolder with trial_metrics) +uv run slm-lab plot -t "Pong-v5" -f ppo_pong_folder,sac_pong_folder,crossq_pong_folder ``` ### Graduate to Public HF -When benchmarks are finalized, publish from `benchmark-dev` → `benchmark`: +When a run is finalized, graduate individually from `benchmark-dev` → `benchmark`: ```bash -source .env && hf upload SLM-Lab/benchmark \ - data/benchmark-dev/data data --repo-type dataset - -# Update BENCHMARKS.md links: benchmark-dev → benchmark -# Upload docs and README -source .env && hf upload SLM-Lab/benchmark docs docs --repo-type dataset -source .env && hf upload SLM-Lab/benchmark README.md README.md --repo-type dataset +# Upload individual folder +source .env && huggingface-cli upload SLM-Lab/benchmark \ + data/benchmark-dev/data/FOLDER data/FOLDER --repo-type dataset + +# Update BENCHMARKS.md link for that entry: benchmark-dev → benchmark +# Then upload docs/ (includes updated plots + BENCHMARKS.md) +source .env && huggingface-cli upload SLM-Lab/benchmark docs docs --repo-type dataset +source .env && huggingface-cli upload SLM-Lab/benchmark README.md README.md --repo-type dataset ``` | Repo | Purpose | diff --git a/.dstack/run-cpu-search.yml b/.dstack/run-cpu-search.yml index 924a5855f..e8fbfcf45 100644 --- a/.dstack/run-cpu-search.yml +++ b/.dstack/run-cpu-search.yml @@ -3,8 +3,8 @@ name: slm-lab python: 3.12 -files: - - ..:/workflow +repos: + - "..:/workflow" env: - HF_TOKEN @@ -13,6 +13,9 @@ env: - SPEC_NAME - LAB_MODE - SPEC_VARS # --set overrides, e.g. "-s env=ALE/Breakout-v5" + - PROFILE + - PROF_SKIP + - PROF_ACTIVE commands: - apt-get update && apt-get install -y swig libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 diff --git a/.dstack/run-cpu-train.yml b/.dstack/run-cpu-train.yml index 127702859..efea759d8 100644 --- a/.dstack/run-cpu-train.yml +++ b/.dstack/run-cpu-train.yml @@ -3,8 +3,8 @@ name: slm-lab python: 3.12 -files: - - ..:/workflow +repos: + - "..:/workflow" env: - HF_TOKEN @@ -13,6 +13,9 @@ env: - SPEC_NAME - LAB_MODE - SPEC_VARS # --set overrides, e.g. "-s env=ALE/Breakout-v5" + - PROFILE + - PROF_SKIP + - PROF_ACTIVE commands: - apt-get update && apt-get install -y swig libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 diff --git a/.dstack/run-gpu-search.yml b/.dstack/run-gpu-search.yml index 05ddf1e45..d8d2cf28d 100644 --- a/.dstack/run-gpu-search.yml +++ b/.dstack/run-gpu-search.yml @@ -3,8 +3,8 @@ name: slm-lab python: 3.12 -files: - - ..:/workflow +repos: + - "..:/workflow" env: - HF_TOKEN @@ -13,6 +13,9 @@ env: - SPEC_NAME - LAB_MODE - SPEC_VARS # --set overrides, e.g. "-s env=ALE/Breakout-v5" + - PROFILE + - PROF_SKIP + - PROF_ACTIVE commands: - apt-get update && apt-get install -y swig libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 diff --git a/.dstack/run-gpu-train.yml b/.dstack/run-gpu-train.yml index a6dc20e3f..ac3e34865 100644 --- a/.dstack/run-gpu-train.yml +++ b/.dstack/run-gpu-train.yml @@ -3,8 +3,8 @@ name: slm-lab python: 3.12 -files: - - ..:/workflow +repos: + - "..:/workflow" env: - HF_TOKEN @@ -13,6 +13,9 @@ env: - SPEC_NAME - LAB_MODE - SPEC_VARS # --set overrides, e.g. "-s env=ALE/Breakout-v5" + - PROFILE + - PROF_SKIP + - PROF_ACTIVE commands: - apt-get update && apt-get install -y swig libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 @@ -21,12 +24,12 @@ commands: resources: gpu: - name: [RTX3090] + memory: 20GB.. count: 1 memory: 32GB.. spot_policy: auto -max_duration: 6h +max_duration: 8h max_price: 0.50 retry: on_events: [no-capacity] diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..82e22a59e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,51 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.lz4 filter=lfs diff=lfs merge=lfs -text +*.mds filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +# Image files - small plot PNGs tracked as regular git objects (no LFS needed) +# Video files - compressed +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text diff --git a/CLAUDE.md b/CLAUDE.md index dc17f441f..9809998c3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,6 +4,7 @@ You are a seasoned software engineer with the following traits: +- **Supervisor-first**: Delegate implementation to agent teams — your role is to orchestrate, review, and commit, not to implement directly - **Quality-driven**: Code quality is non-negotiable - clean, idiomatic, maintainable code every time - **Autonomous**: Make informed technical decisions independently - only ask when requirements are genuinely unclear - **Pragmatic**: Balance perfect with practical - ship working solutions, iterate when needed @@ -22,11 +23,17 @@ You are a seasoned software engineer with the following traits: Apply these six principles to every decision. 1. **Consistent** — Design from first principles — unified naming, patterns, and conventions throughout. + Establish naming conventions and structural patterns first. When the same concept uses the same name everywhere, the codebase becomes searchable, replaceable, and predictable. 2. **Correct** — Constructed from known truths, not debugged into shape. + Build upward from solid foundations — each layer verified before the next is added. Correctness is built from the start, not tested into existence. 3. **Clear** — Code does what it says — intent is obvious from naming and logic alone. + A lot of coding is naming. If you need a comment to explain what code does, the code is not clear enough. 4. **Concise** — Simplified to the essence — nothing left to remove. + Brevity is about fewer concepts to hold in your head, not fewer characters. Eliminate duplication, remove dead code, strip unnecessary abstraction. 5. **Simple** — Few moving parts, easy to explain, cheap to maintain — complexity is not sophistication. + A complex architecture with dozens of tangled dependencies is not intelligence — it is poor design. Reduce to the fewest moving parts while losing nothing essential. 6. **Salient** — Essential enough to be used widely, fundamental enough to last. + Code that follows the preceding principles naturally endures — used broadly, needed deeply, lasting because it was built right. ## Style Guide @@ -60,14 +67,17 @@ Apply these six principles to every decision. ## Agent Teams -**For any non-trivial task, deploy agent teams.** This is the standard operating mode — do not default to working solo. The lead orchestrates (breaks down work, assigns tasks, reviews outputs, commits) — it should never get buried in implementation. Delegation keeps the lead strategic, enables parallel execution, and protects context window from long-running tasks. +**You are the lead. You do not implement — you delegate, supervise, and review.** -**Guidelines:** -1. **Give enough context in spawn prompts** - teammates don't inherit conversation history, only CLAUDE.md and project context -2. **Size tasks appropriately** - self-contained units with clear deliverables, ~5-6 per teammate -3. **Avoid file conflicts** - each teammate owns different files +For any non-trivial task, use TeamCreate with multiple teammates (not single-Agent subagents). Teammates share a task list, claim work, and message each other directly. Solo work is only acceptable for trivial, single-file changes. -> Work autonomously: run things in parallel, continue without pausing, pick up the next task immediately. For long-running tasks, use `sleep N` to actively wait and check in — do NOT delegate to background processes. Stay engaged in the conversation. +**Do NOT:** use subagents as a substitute for teams, implement tasks yourself (spawn new teammates instead), or start implementing while teammates are still working. + +**Workflow:** Break into parallel units → TeamCreate → TaskCreate per unit → spawn 3-5 teammates with full context (they only inherit CLAUDE.md, not conversation history) → require plan approval for risky tasks → supervise and review → commit final result yourself. + +**Sizing:** ~5-6 tasks per teammate, self-contained units, each teammate owns different files. + +**Panel of agents:** For design decisions or ambiguous requirements, spawn 3+ teammates with different perspectives. Have them debate and challenge each other — adversarial review beats independent comparison. Converge on the approach that survives scrutiny. ## Documentation diff --git a/docs/BENCHMARKS.md b/docs/BENCHMARKS.md index f1024d086..7c15e5a08 100644 --- a/docs/BENCHMARKS.md +++ b/docs/BENCHMARKS.md @@ -107,12 +107,12 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 ## Progress -| Phase | Category | Envs | REINFORCE | SARSA | DQN | DDQN+PER | A2C | PPO | SAC | Overall | -|-------|----------|------|-----------|-------|-----|----------|-----|-----|-----|---------| -| 1 | Classic Control | 3 | ✅ | ✅ | ⚠️ | ✅ | ✅ | ✅ | ✅ | Done | -| 2 | Box2D | 2 | N/A | N/A | ⚠️ | ✅ | ❌ | ⚠️ | ⚠️ | Done | -| 3 | MuJoCo | 11 | N/A | N/A | N/A | N/A | N/A | ⚠️ | ⚠️ | Done | -| 4 | Atari | 57 | N/A | N/A | N/A | Skip | Done | Done | Done | Done | +| Phase | Category | Envs | REINFORCE | SARSA | DQN | DDQN+PER | A2C | PPO | SAC | CrossQ | Overall | +|-------|----------|------|-----------|-------|-----|----------|-----|-----|-----|--------|---------| +| 1 | Classic Control | 3 | ✅ | ✅ | ⚠️ | ✅ | ✅ | ✅ | ✅ | ⚠️ | Done | +| 2 | Box2D | 2 | N/A | N/A | ⚠️ | ✅ | ❌ | ⚠️ | ⚠️ | ⚠️ | Done | +| 3 | MuJoCo | 11 | N/A | N/A | N/A | N/A | N/A | ⚠️ | ⚠️ | ⚠️ | Done | +| 4 | Atari | 57 | N/A | N/A | N/A | Skip | Done | Done | Done | ❌ | Done | **Legend**: ✅ Solved | ⚠️ Close (>80%) | 📊 Acceptable | ❌ Failed | 🔄 In progress/Pending | Skip Not started | N/A Not applicable @@ -137,6 +137,7 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | A2C | ✅ | 496.68 | [slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml) | a2c_gae_cartpole_arc | [a2c_gae_cartpole_arc_2026_02_11_142531](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_cartpole_arc_2026_02_11_142531) | | PPO | ✅ | 498.94 | [slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml) | ppo_cartpole_arc | [ppo_cartpole_arc_2026_02_11_144029](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_cartpole_arc_2026_02_11_144029) | | SAC | ✅ | 406.09 | [slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml) | sac_cartpole_arc | [sac_cartpole_arc_2026_02_11_144155](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_cartpole_arc_2026_02_11_144155) | +| CrossQ | ⚠️ | 334.59 | [slm_lab/spec/benchmark/crossq/crossq_classic.yaml](../slm_lab/spec/benchmark/crossq/crossq_classic.yaml) | crossq_cartpole | [crossq_cartpole_2026_03_02_100434](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_cartpole_2026_03_02_100434) | ![CartPole-v1](plots/CartPole-v1_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -153,6 +154,7 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | A2C | ✅ | -83.99 | [slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml) | a2c_gae_acrobot_arc | [a2c_gae_acrobot_arc_2026_02_11_153806](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_acrobot_arc_2026_02_11_153806) | | PPO | ✅ | -81.28 | [slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml) | ppo_acrobot_arc | [ppo_acrobot_arc_2026_02_11_153758](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_acrobot_arc_2026_02_11_153758) | | SAC | ✅ | -92.60 | [slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml) | sac_acrobot_arc | [sac_acrobot_arc_2026_02_11_162211](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_acrobot_arc_2026_02_11_162211) | +| CrossQ | ✅ | -103.13 | [slm_lab/spec/benchmark/crossq/crossq_classic.yaml](../slm_lab/spec/benchmark/crossq/crossq_classic.yaml) | crossq_acrobot | [crossq_acrobot_2026_02_23_153622](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_acrobot_2026_02_23_153622) | ![Acrobot-v1](plots/Acrobot-v1_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -167,12 +169,13 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | A2C | ❌ | -820.74 | [slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml) | a2c_gae_pendulum_arc | [a2c_gae_pendulum_arc_2026_02_11_162217](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_pendulum_arc_2026_02_11_162217) | | PPO | ✅ | -174.87 | [slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_classic_arc.yaml) | ppo_pendulum_arc | [ppo_pendulum_arc_2026_02_11_162156](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_pendulum_arc_2026_02_11_162156) | | SAC | ✅ | -150.97 | [slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_classic_arc.yaml) | sac_pendulum_arc | [sac_pendulum_arc_2026_02_11_162240](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_pendulum_arc_2026_02_11_162240) | +| CrossQ | ✅ | -145.66 | [slm_lab/spec/benchmark/crossq/crossq_classic.yaml](../slm_lab/spec/benchmark/crossq/crossq_classic.yaml) | crossq_pendulum | [crossq_pendulum_2026_02_28_130648](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_pendulum_2026_02_28_130648) | ![Pendulum-v1](plots/Pendulum-v1_multi_trial_graph_mean_returns_ma_vs_frames.png) ### Phase 2: Box2D -#### 2.1 LunarLander-v3 (Discrete) +#### 2.1 LunarLander-v3 **Docs**: [LunarLander](https://gymnasium.farama.org/environments/box2d/lunar_lander/) | State: Box(8) | Action: Discrete(4) | Target reward MA > 200 @@ -185,10 +188,11 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | A2C | ❌ | 27.38 | [slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml) | a2c_gae_lunar_arc | [a2c_gae_lunar_arc_2026_02_11_224304](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_lunar_arc_2026_02_11_224304) | | PPO | ⚠️ | 183.30 | [slm_lab/spec/benchmark_arc/ppo/ppo_box2d_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_box2d_arc.yaml) | ppo_lunar_arc | [ppo_lunar_arc_2026_02_11_201303](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_lunar_arc_2026_02_11_201303) | | SAC | ⚠️ | 106.17 | [slm_lab/spec/benchmark_arc/sac/sac_box2d_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_box2d_arc.yaml) | sac_lunar_arc | [sac_lunar_arc_2026_02_11_201417](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_lunar_arc_2026_02_11_201417) | +| CrossQ | ❌ | 139.21 | [slm_lab/spec/benchmark/crossq/crossq_box2d.yaml](../slm_lab/spec/benchmark/crossq/crossq_box2d.yaml) | crossq_lunar | [crossq_lunar_2026_02_28_130733](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_lunar_2026_02_28_130733) | -![LunarLander-v3 Discrete](plots/LunarLander-v3_Discrete_multi_trial_graph_mean_returns_ma_vs_frames.png) +![LunarLander-v3](plots/LunarLander-v3_multi_trial_graph_mean_returns_ma_vs_frames.png) -#### 2.2 LunarLander-v3 (Continuous) +#### 2.2 LunarLanderContinuous-v3 **Docs**: [LunarLander](https://gymnasium.farama.org/environments/box2d/lunar_lander/) | State: Box(8) | Action: Box(2) | Target reward MA > 200 @@ -199,8 +203,9 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | A2C | ❌ | -76.81 | [slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_classic_arc.yaml) | a2c_gae_lunar_continuous_arc | [a2c_gae_lunar_continuous_arc_2026_02_11_224301](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_lunar_continuous_arc_2026_02_11_224301) | | PPO | ⚠️ | 132.58 | [slm_lab/spec/benchmark_arc/ppo/ppo_box2d_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_box2d_arc.yaml) | ppo_lunar_continuous_arc | [ppo_lunar_continuous_arc_2026_02_11_224229](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_lunar_continuous_arc_2026_02_11_224229) | | SAC | ⚠️ | 125.00 | [slm_lab/spec/benchmark_arc/sac/sac_box2d_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_box2d_arc.yaml) | sac_lunar_continuous_arc | [sac_lunar_continuous_arc_2026_02_12_222203](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_lunar_continuous_arc_2026_02_12_222203) | +| CrossQ | ✅ | 268.91 | [slm_lab/spec/benchmark/crossq/crossq_box2d.yaml](../slm_lab/spec/benchmark/crossq/crossq_box2d.yaml) | crossq_lunar_continuous | [crossq_lunar_continuous_2026_03_01_140517](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_lunar_continuous_2026_03_01_140517) | -![LunarLander-v3 Continuous](plots/LunarLander-v3_Continuous_multi_trial_graph_mean_returns_ma_vs_frames.png) +![LunarLanderContinuous-v3](plots/LunarLanderContinuous-v3_multi_trial_graph_mean_returns_ma_vs_frames.png) ### Phase 3: MuJoCo @@ -208,13 +213,14 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 **Settings**: max_frame 4e6-10e6 | num_envs 16 | max_session 4 | log_frequency 1e4 -**Algorithms**: PPO and SAC. Network: MLP [256,256], orthogonal init. PPO uses tanh activation; SAC uses relu. +**Algorithms**: PPO, SAC, and CrossQ. Network: MLP [256,256], orthogonal init. PPO uses tanh activation; SAC and CrossQ use relu. CrossQ uses Batch Renormalization in critics (no target networks). -**Note on SAC frame budgets**: SAC uses higher update-to-data ratios (more gradient updates per step), making it more sample-efficient but slower per frame than PPO. SAC benchmarks use 1-4M frames (vs PPO's 4-10M) to fit within practical GPU wall-time limits (~6h). Scores may still be improving at cutoff. +**Note on SAC/CrossQ frame budgets**: SAC uses higher update-to-data ratios (more gradient updates per step), making it more sample-efficient but slower per frame than PPO. SAC benchmarks use 1-4M frames (vs PPO's 4-10M) to fit within practical GPU wall-time limits (~6h). CrossQ uses UTD=1 (like PPO) but eliminates target network overhead, achieving ~700 fps — its frame budgets (3-7.5M) reflect this speed advantage. Scores may still be improving at cutoff. **Spec Files** (one file per algorithm, all envs via YAML anchors): - **PPO**: [ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) - **SAC**: [sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) +- **CrossQ**: [crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) **Spec Variants**: Each file has a base config (shared via YAML anchors) with per-env overrides: @@ -225,6 +231,8 @@ Search budget: ~3-4 trials per dimension (8 trials = 2-3 dims, 16 = 3-4 dims, 20 | ppo_{env}_arc | Ant, Hopper, Swimmer, IP, IDP | Per-env tuned (gamma, lam, lr) | | sac_mujoco_arc | (generic, use with -s flags) | Base: gamma=0.99, iter=4, lr=3e-4, [256,256] | | sac_{env}_arc | All 11 envs | Per-env tuned (iter, gamma, lr, net size) | +| crossq_mujoco | (generic base) | Base: gamma=0.99, iter=1, lr=1e-3, policy_delay=3 | +| crossq_{env} | All 11 envs | Per-env tuned (critic width, actor LN, iter) | **Reproduce**: Copy `SPEC_NAME` and `MAX_FRAME` from the table below. @@ -236,32 +244,47 @@ source .env && slm-lab run-remote --gpu -s env=ENV -s max_frame=MAX_FRAME \ # SAC: env and max_frame are hardcoded per spec — no -s flags needed source .env && slm-lab run-remote --gpu \ slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml SPEC_NAME train -n NAME + +# CrossQ: env and max_frame are hardcoded per spec — no -s flags needed +source .env && slm-lab run-remote --gpu \ + slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml SPEC_NAME train -n NAME ``` | ENV | SPEC_NAME | MAX_FRAME | |-----|-----------|-----------| | Ant-v5 | ppo_ant_arc | 10e6 | | | sac_ant_arc | 2e6 | +| | crossq_ant | 3e6 | | HalfCheetah-v5 | ppo_mujoco_arc | 10e6 | | | sac_halfcheetah_arc | 4e6 | +| | crossq_halfcheetah | 4e6 | | Hopper-v5 | ppo_hopper_arc | 4e6 | | | sac_hopper_arc | 3e6 | +| | crossq_hopper | 3e6 | | Humanoid-v5 | ppo_mujoco_arc | 10e6 | | | sac_humanoid_arc | 1e6 | +| | crossq_humanoid | 2e6 | | HumanoidStandup-v5 | ppo_mujoco_arc | 4e6 | | | sac_humanoid_standup_arc | 1e6 | +| | crossq_humanoid_standup | 2e6 | | InvertedDoublePendulum-v5 | ppo_inverted_double_pendulum_arc | 10e6 | | | sac_inverted_double_pendulum_arc | 2e6 | +| | crossq_inverted_double_pendulum | 2e6 | | InvertedPendulum-v5 | ppo_inverted_pendulum_arc | 4e6 | | | sac_inverted_pendulum_arc | 2e6 | +| | crossq_inverted_pendulum | 7e6 | | Pusher-v5 | ppo_mujoco_longhorizon_arc | 4e6 | | | sac_pusher_arc | 1e6 | +| | crossq_pusher | 2e6 | | Reacher-v5 | ppo_mujoco_longhorizon_arc | 4e6 | | | sac_reacher_arc | 1e6 | +| | crossq_reacher | 2e6 | | Swimmer-v5 | ppo_swimmer_arc | 4e6 | | | sac_swimmer_arc | 2e6 | +| | crossq_swimmer | 3e6 | | Walker2d-v5 | ppo_mujoco_arc | 10e6 | | | sac_walker2d_arc | 3e6 | +| | crossq_walker2d | 7e6 | #### 3.1 Ant-v5 @@ -273,6 +296,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 2138.28 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_ant_arc | [ppo_ant_arc_ant_2026_02_12_190644](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_ant_arc_ant_2026_02_12_190644) | | SAC | ✅ | 4942.91 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_ant_arc | [sac_ant_arc_2026_02_11_225529](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_ant_arc_2026_02_11_225529) | +| CrossQ | ✅ | 4517.00 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_ant | [crossq_ant_2026_03_01_102428](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_ant_2026_03_01_102428) | ![Ant-v5](plots/Ant-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -286,6 +310,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 6240.68 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_arc | [ppo_mujoco_arc_halfcheetah_2026_02_12_195553](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_arc_halfcheetah_2026_02_12_195553) | | SAC | ✅ | 9815.16 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_halfcheetah_arc | [sac_halfcheetah_4m_i2_arc_2026_02_14_185522](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_halfcheetah_4m_i2_arc_2026_02_14_185522) | +| CrossQ | ✅ | 8616.52 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_halfcheetah | [crossq_halfcheetah_2026_03_01_101317](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_halfcheetah_2026_03_01_101317) | ![HalfCheetah-v5](plots/HalfCheetah-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -299,6 +324,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ⚠️ | 1653.74 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_hopper_arc | [ppo_hopper_arc_hopper_2026_02_12_222206](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_hopper_arc_hopper_2026_02_12_222206) | | SAC | ⚠️ | 1416.52 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_hopper_arc | [sac_hopper_3m_i4_arc_2026_02_14_185434](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_hopper_3m_i4_arc_2026_02_14_185434) | +| CrossQ | ⚠️ | 1168.53 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_hopper | [crossq_hopper_2026_02_21_101148](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_hopper_2026_02_21_101148) | ![Hopper-v5](plots/Hopper-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -312,6 +338,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 2661.26 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_arc | [ppo_mujoco_arc_humanoid_2026_02_12_185439](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_arc_humanoid_2026_02_12_185439) | | SAC | ✅ | 1989.65 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_humanoid_arc | [sac_humanoid_arc_2026_02_12_020016](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_humanoid_arc_2026_02_12_020016) | +| CrossQ | ✅ | 1755.29 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_humanoid | [crossq_humanoid_2026_03_01_165208](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_humanoid_2026_03_01_165208) | ![Humanoid-v5](plots/Humanoid-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -325,6 +352,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 150104.59 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_arc | [ppo_mujoco_arc_humanoidstandup_2026_02_12_115050](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_arc_humanoidstandup_2026_02_12_115050) | | SAC | ✅ | 137357.00 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_humanoid_standup_arc | [sac_humanoid_standup_arc_2026_02_12_225150](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_humanoid_standup_arc_2026_02_12_225150) | +| CrossQ | ✅ | 150912.66 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_humanoid_standup | [crossq_humanoid_standup_2026_02_28_184305](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_humanoid_standup_2026_02_28_184305) | ![HumanoidStandup-v5](plots/HumanoidStandup-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -338,6 +366,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 8383.76 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_inverted_double_pendulum_arc | [ppo_inverted_double_pendulum_arc_inverteddoublependulum_2026_02_12_225231](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_inverted_double_pendulum_arc_inverteddoublependulum_2026_02_12_225231) | | SAC | ✅ | 9032.67 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_inverted_double_pendulum_arc | [sac_inverted_double_pendulum_arc_2026_02_12_025206](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_inverted_double_pendulum_arc_2026_02_12_025206) | +| CrossQ | ✅ | 8027.38 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_inverted_double_pendulum | [crossq_inverted_double_pendulum_2026_03_01_101354](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_inverted_double_pendulum_2026_03_01_101354) | ![InvertedDoublePendulum-v5](plots/InvertedDoublePendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -345,12 +374,13 @@ source .env && slm-lab run-remote --gpu \ **Docs**: [InvertedPendulum](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/) | State: Box(4) | Action: Box(1) | Target reward MA ~1000 -**Settings**: max_frame 4e6 | num_envs 16 | max_session 4 | log_frequency 1e4 +**Settings**: max_frame 10e6 | num_envs 16 | max_session 4 | log_frequency 1e4 | Algorithm | Status | MA | SPEC_FILE | SPEC_NAME | HF Data | |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 949.94 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_inverted_pendulum_arc | [ppo_inverted_pendulum_arc_invertedpendulum_2026_02_12_062037](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_inverted_pendulum_arc_invertedpendulum_2026_02_12_062037) | | SAC | ✅ | 928.43 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_inverted_pendulum_arc | [sac_inverted_pendulum_arc_2026_02_12_225503](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_inverted_pendulum_arc_2026_02_12_225503) | +| CrossQ | ⚠️ | 877.83 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_inverted_pendulum | [crossq_inverted_pendulum_2026_02_28_184348](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_inverted_pendulum_2026_02_28_184348) | ![InvertedPendulum-v5](plots/InvertedPendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -364,6 +394,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | -49.59 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_longhorizon_arc | [ppo_mujoco_longhorizon_arc_pusher_2026_02_12_222228](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_longhorizon_arc_pusher_2026_02_12_222228) | | SAC | ✅ | -43.00 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_pusher_arc | [sac_pusher_arc_2026_02_12_053603](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_pusher_arc_2026_02_12_053603) | +| CrossQ | ✅ | -37.08 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_pusher | [crossq_pusher_2026_02_21_134637](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_pusher_2026_02_21_134637) | ![Pusher-v5](plots/Pusher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -377,6 +408,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | -5.03 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_longhorizon_arc | [ppo_mujoco_longhorizon_arc_reacher_2026_02_12_115033](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_longhorizon_arc_reacher_2026_02_12_115033) | | SAC | ✅ | -6.31 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_reacher_arc | [sac_reacher_arc_2026_02_12_055200](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_reacher_arc_2026_02_12_055200) | +| CrossQ | ✅ | -5.65 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_reacher | [crossq_reacher_2026_02_28_184304](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_reacher_2026_02_28_184304) | ![Reacher-v5](plots/Reacher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -390,6 +422,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 282.44 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_swimmer_arc | [ppo_swimmer_arc_swimmer_2026_02_12_100445](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_swimmer_arc_swimmer_2026_02_12_100445) | | SAC | ✅ | 301.34 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_swimmer_arc | [sac_swimmer_arc_2026_02_12_054349](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_swimmer_arc_2026_02_12_054349) | +| CrossQ | ✅ | 221.12 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_swimmer | [crossq_swimmer_2026_02_21_184204](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_swimmer_2026_02_21_184204) | ![Swimmer-v5](plots/Swimmer-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -403,6 +436,7 @@ source .env && slm-lab run-remote --gpu \ |-----------|--------|-----|-----------|-----------|---------| | PPO | ✅ | 4378.62 | [slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_arc.yaml) | ppo_mujoco_arc | [ppo_mujoco_arc_walker2d_2026_02_12_190312](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_mujoco_arc_walker2d_2026_02_12_190312) | | SAC | ⚠️ | 3123.66 | [slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_mujoco_arc.yaml) | sac_walker2d_arc | [sac_walker2d_3m_i4_arc_2026_02_14_185550](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_walker2d_3m_i4_arc_2026_02_14_185550) | +| CrossQ | ✅ | 4389.62 | [slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml](../slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml) | crossq_walker2d | [crossq_walker2d_2026_02_28_184343](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_walker2d_2026_02_28_184343) | ![Walker2d-v5](plots/Walker2d-v5_multi_trial_graph_mean_returns_ma_vs_frames.png) @@ -421,6 +455,7 @@ source .env && slm-lab run-remote --gpu \ - **A2C**: [a2c_atari_arc.yaml](../slm_lab/spec/benchmark_arc/a2c/a2c_atari_arc.yaml) - RMSprop (lr=7e-4), training_frequency=32 - **PPO**: [ppo_atari_arc.yaml](../slm_lab/spec/benchmark_arc/ppo/ppo_atari_arc.yaml) - AdamW (lr=2.5e-4), minibatch=256, horizon=128, epochs=4, max_frame=10e6 - **SAC**: [sac_atari_arc.yaml](../slm_lab/spec/benchmark_arc/sac/sac_atari_arc.yaml) - Categorical SAC, AdamW (lr=3e-4), training_iter=3, training_frequency=4, max_frame=2e6 +- **CrossQ**: [crossq_atari.yaml](../slm_lab/spec/benchmark/crossq/crossq_atari.yaml) - Categorical CrossQ, AdamW (lr=1e-3), training_iter=3, training_frequency=4, max_frame=2e6 (experimental — limited results on 6 games) **PPO Lambda Variants** (table shows best result per game): @@ -443,6 +478,10 @@ source .env && slm-lab run-remote --gpu -s env=ENV -s max_frame=1e7 \ # SAC (2M frames - off-policy, more sample-efficient but slower per frame) source .env && slm-lab run-remote --gpu -s env=ENV \ slm_lab/spec/benchmark_arc/sac/sac_atari_arc.yaml sac_atari_arc train -n NAME + +# CrossQ (2M frames - experimental, limited games tested) +source .env && slm-lab run-remote --gpu -s env=ENV \ + slm_lab/spec/benchmark/crossq/crossq_atari.yaml crossq_atari train -n NAME ``` > **Note**: HF Data links marked "-" indicate runs completed but not yet uploaded to HuggingFace. Scores are extracted from local trial_metrics. @@ -491,6 +530,7 @@ source .env && slm-lab run-remote --gpu -s env=ENV \ | ALE/Breakout-v5 | 326.47 | ppo_atari_lam70_arc | [ppo_atari_lam70_arc_breakout_2026_02_13_230455](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam70_arc_breakout_2026_02_13_230455) | | | 20.23 | sac_atari_arc | [sac_atari_arc_breakout_2026_02_15_201235](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_breakout_2026_02_15_201235) | | | 273 | a2c_gae_atari_arc | [a2c_gae_atari_breakout_2026_01_31_213610](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_breakout_2026_01_31_213610) | +| | ❌ 4.40 | crossq_atari | [crossq_atari_breakout_2026_02_25_030241](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_breakout_2026_02_25_030241) | | ALE/Carnival-v5 | 3912.59 | ppo_atari_lam70_arc | [ppo_atari_lam70_arc_carnival_2026_02_13_230438](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam70_arc_carnival_2026_02_13_230438) | | | 3501.37 | sac_atari_arc | [sac_atari_arc_carnival_2026_02_17_105834](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_carnival_2026_02_17_105834) | | | 2170 | a2c_gae_atari_arc | [a2c_gae_atari_carnival_2026_02_01_082726](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_carnival_2026_02_01_082726) | @@ -554,6 +594,7 @@ source .env && slm-lab run-remote --gpu -s env=ENV \ | ALE/MsPacman-v5 | 2330.74 | ppo_atari_lam85_arc | [ppo_atari_lam85_arc_mspacman_2026_02_14_102435](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam85_arc_mspacman_2026_02_14_102435) | | | 1336.96 | sac_atari_arc | [sac_atari_arc_mspacman_2026_02_17_221523](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_mspacman_2026_02_17_221523) | | | 2110 | a2c_gae_atari_arc | [a2c_gae_atari_mspacman_2026_02_01_001100](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_mspacman_2026_02_01_001100) | +| | ❌ 327.79 | crossq_atari | [crossq_atari_mspacman_2026_02_23_171317](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_mspacman_2026_02_23_171317) | | ALE/NameThisGame-v5 | 6879.23 | ppo_atari_arc | [ppo_atari_arc_namethisgame_2026_02_14_103319](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_arc_namethisgame_2026_02_14_103319) | | | 3992.71 | sac_atari_arc | [sac_atari_arc_namethisgame_2026_02_17_220905](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_namethisgame_2026_02_17_220905) | | | 5412 | a2c_gae_atari_arc | [a2c_gae_atari_namethisgame_2026_02_01_132733](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_namethisgame_2026_02_01_132733) | @@ -563,12 +604,14 @@ source .env && slm-lab run-remote --gpu -s env=ENV \ | ALE/Pong-v5 | 16.69 | ppo_atari_lam85_arc | [ppo_atari_lam85_arc_pong_2026_02_14_103722](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam85_arc_pong_2026_02_14_103722) | | | 10.89 | sac_atari_arc | [sac_atari_arc_pong_2026_02_17_160429](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_pong_2026_02_17_160429) | | | 10.17 | a2c_gae_atari_arc | [a2c_gae_atari_pong_2026_01_31_213635](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_pong_2026_01_31_213635) | +| | ❌ -20.59 | crossq_atari | [crossq_atari_pong_2026_02_23_171158](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_pong_2026_02_23_171158) | | ALE/Pooyan-v5 | 5308.66 | ppo_atari_lam70_arc | [ppo_atari_lam70_arc_pooyan_2026_02_14_114730](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam70_arc_pooyan_2026_02_14_114730) | | | 2530.78 | sac_atari_arc | [sac_atari_arc_pooyan_2026_02_17_220346](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_pooyan_2026_02_17_220346) | | | 2997 | a2c_gae_atari_arc | [a2c_gae_atari_pooyan_2026_02_01_132748](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_pooyan_2026_02_01_132748) | | ALE/Qbert-v5 | 15460.48 | ppo_atari_arc | [ppo_atari_arc_qbert_2026_02_14_120409](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_arc_qbert_2026_02_14_120409) | | | 3331.98 | sac_atari_arc | [sac_atari_arc_qbert_2026_02_17_223117](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_qbert_2026_02_17_223117) | | | 12619 | a2c_gae_atari_arc | [a2c_gae_atari_qbert_2026_01_31_213720](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_qbert_2026_01_31_213720) | +| | ❌ 3189.73 | crossq_atari | [crossq_atari_qbert_2026_02_25_030458](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_qbert_2026_02_25_030458) | | ALE/Riverraid-v5 | 9599.75 | ppo_atari_lam85_arc | [ppo_atari_lam85_arc_riverraid_2026_02_14_124700](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam85_arc_riverraid_2026_02_14_124700) | | | 4744.95 | sac_atari_arc | [sac_atari_arc_riverraid_2026_02_18_014310](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_riverraid_2026_02_18_014310) | | | 6558 | a2c_gae_atari_arc | [a2c_gae_atari_riverraid_2026_02_01_132507](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_riverraid_2026_02_01_132507) | @@ -581,6 +624,7 @@ source .env && slm-lab run-remote --gpu -s env=ENV \ | ALE/Seaquest-v5 | 1775.14 | ppo_atari_arc | [ppo_atari_arc_seaquest_2026_02_11_095444](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_arc_seaquest_2026_02_11_095444) | | | 1565.44 | sac_atari_arc | [sac_atari_arc_seaquest_2026_02_18_020822](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_seaquest_2026_02_18_020822) | | | 850 | a2c_gae_atari_arc | [a2c_gae_atari_seaquest_2026_02_01_001001](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_seaquest_2026_02_01_001001) | +| | ❌ 234.63 | crossq_atari | [crossq_atari_seaquest_2026_02_25_030441](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_seaquest_2026_02_25_030441) | | ALE/Skiing-v5 | -28217.28 | ppo_atari_arc | [ppo_atari_arc_skiing_2026_02_14_174807](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_arc_skiing_2026_02_14_174807) | | | -17464.22 | sac_atari_arc | [sac_atari_arc_skiing_2026_02_18_024444](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_skiing_2026_02_18_024444) | | | -14235 | a2c_gae_atari_arc | [a2c_gae_atari_skiing_2026_02_01_132451](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_skiing_2026_02_01_132451) | @@ -590,6 +634,7 @@ source .env && slm-lab run-remote --gpu -s env=ENV \ | ALE/SpaceInvaders-v5 | 892.49 | ppo_atari_arc | [ppo_atari_arc_spaceinvaders_2026_02_14_131114](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_arc_spaceinvaders_2026_02_14_131114) | | | 507.33 | sac_atari_arc | [sac_atari_arc_spaceinvaders_2026_02_18_033139](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_spaceinvaders_2026_02_18_033139) | | | 784 | a2c_gae_atari_arc | [a2c_gae_atari_spaceinvaders_2026_02_01_000950](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_spaceinvaders_2026_02_01_000950) | +| | ❌ 404.50 | crossq_atari | [crossq_atari_spaceinvaders_2026_02_25_030410](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/crossq_atari_spaceinvaders_2026_02_25_030410) | | ALE/StarGunner-v5 | 49328.73 | ppo_atari_lam70_arc | [ppo_atari_lam70_arc_stargunner_2026_02_14_131149](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/ppo_atari_lam70_arc_stargunner_2026_02_14_131149) | | | 4295.97 | sac_atari_arc | [sac_atari_arc_stargunner_2026_02_18_033151](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/sac_atari_arc_stargunner_2026_02_18_033151) | | | 8665 | a2c_gae_atari_arc | [a2c_gae_atari_stargunner_2026_02_01_132406](https://huggingface.co/datasets/SLM-Lab/benchmark/tree/main/data/a2c_gae_atari_stargunner_2026_02_01_132406) | diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 3950639cc..ee4067959 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,25 @@ +# SLM-Lab v5.2.0 + +Training path performance optimization. **+15% SAC throughput on GPU**, verified with no score regression. + +**What changed (18 files):** +- `polyak_update`: in-place `lerp_()` replaces 3-op manual arithmetic +- `SAC`: single `log_softmax→exp` replaces dual softmax+log_softmax; cached entropy between policy/alpha loss; cached `_is_per` and `_LOG2` +- `to_torch_batch`: uint8/float16 sent directly to GPU then `.float()` — avoids 4x CPU float32 intermediate (matters for Atari 84x84x4) +- `SumTree`: iterative propagation/retrieval replaces recursion; vectorized sampling +- `forward_tails`: cached output (was called twice per step) +- `VectorFullGameStatistics`: `deque(maxlen=N)` + `np.flatnonzero` replaces list+pop(0)+loop +- `pydash→builtins`: `isinstance` over `ps.is_list/is_dict`, dict comprehensions over `ps.pick/ps.omit` in hot paths +- `PPO`: `total_loss` as plain float prevents computation graph leak across epochs +- Minor: `hasattr→is not None` in conv/recurrent forward, cached `_is_dev`, `no_decay` early exit in VarScheduler + +**Measured gains (normalized, same hardware A/B on RTX 3090):** +- SAC MuJoCo: +15-17% fps +- SAC Atari: +14% fps +- PPO: ~0% (env-bound; most optimizations target SAC's training-heavy inner loop — PPO doesn't use polyak, replay buffer, twin Q, or entropy tuning) + +--- + # SLM-Lab v5.1.0 TorchArc YAML benchmarks replace original hardcoded network architectures across all benchmark categories. diff --git a/docs/plots/Acrobot-v1_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Acrobot-v1_multi_trial_graph_mean_returns_ma_vs_frames.png index 8321f1276..8f014efcc 100644 Binary files a/docs/plots/Acrobot-v1_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Acrobot-v1_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/AirRaid_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/AirRaid_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index a97da85b3..000000000 Binary files a/docs/plots/AirRaid_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Alien_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Alien_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index f78c6bffd..000000000 Binary files a/docs/plots/Alien_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Amidar_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Amidar_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 9da3a5a93..000000000 Binary files a/docs/plots/Amidar_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Ant-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Ant-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 086958bd4..d17d2cab6 100644 Binary files a/docs/plots/Ant-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Ant-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Assault_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Assault_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 86ea544ea..000000000 Binary files a/docs/plots/Assault_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Asterix_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Asterix_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 8eb8d84c5..000000000 Binary files a/docs/plots/Asterix_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Asteroids_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Asteroids_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index cdad7518d..000000000 Binary files a/docs/plots/Asteroids_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Atlantis_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Atlantis_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 005e7238f..000000000 Binary files a/docs/plots/Atlantis_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/BankHeist_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/BankHeist_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 3e77903bc..000000000 Binary files a/docs/plots/BankHeist_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/BattleZone_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/BattleZone_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index d57ecbc7b..000000000 Binary files a/docs/plots/BattleZone_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/BeamRider_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/BeamRider_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 682f09b2b..000000000 Binary files a/docs/plots/BeamRider_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Berzerk_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Berzerk_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index cf72035e0..000000000 Binary files a/docs/plots/Berzerk_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Bowling_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Bowling_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 83899ea1d..000000000 Binary files a/docs/plots/Bowling_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Boxing_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Boxing_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index ebecb01ec..000000000 Binary files a/docs/plots/Boxing_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Breakout-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Breakout-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 0bf06b954..62806b30c 100644 Binary files a/docs/plots/Breakout-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Breakout-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Breakout_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Breakout_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index e504c3154..000000000 Binary files a/docs/plots/Breakout_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Carnival_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Carnival_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index c994720e9..000000000 Binary files a/docs/plots/Carnival_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/CartPole-v1_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/CartPole-v1_multi_trial_graph_mean_returns_ma_vs_frames.png index 0151bf4fb..00604faee 100644 Binary files a/docs/plots/CartPole-v1_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/CartPole-v1_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Centipede_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Centipede_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1f4eacce5..000000000 Binary files a/docs/plots/Centipede_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/ChopperCommand_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/ChopperCommand_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 70031ced5..000000000 Binary files a/docs/plots/ChopperCommand_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/CrazyClimber_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/CrazyClimber_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index c7d7a1d83..000000000 Binary files a/docs/plots/CrazyClimber_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Defender_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Defender_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1010b0fef..000000000 Binary files a/docs/plots/Defender_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/DemonAttack_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/DemonAttack_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index bb9971e12..000000000 Binary files a/docs/plots/DemonAttack_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/DoubleDunk_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/DoubleDunk_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 51c95d429..000000000 Binary files a/docs/plots/DoubleDunk_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/ElevatorAction_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/ElevatorAction_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 258a546a7..000000000 Binary files a/docs/plots/ElevatorAction_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Enduro_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Enduro_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index f62f2fb48..000000000 Binary files a/docs/plots/Enduro_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/FishingDerby_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/FishingDerby_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1b3877887..000000000 Binary files a/docs/plots/FishingDerby_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Freeway_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Freeway_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 3edee5aad..000000000 Binary files a/docs/plots/Freeway_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Frostbite_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Frostbite_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index a8915039c..000000000 Binary files a/docs/plots/Frostbite_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Gopher_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Gopher_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 889eb2255..000000000 Binary files a/docs/plots/Gopher_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Gravitar_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Gravitar_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index f1181d722..000000000 Binary files a/docs/plots/Gravitar_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/HalfCheetah-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/HalfCheetah-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 05c65c99f..8e6d3f652 100644 Binary files a/docs/plots/HalfCheetah-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/HalfCheetah-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Hero_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Hero_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 7704374b8..000000000 Binary files a/docs/plots/Hero_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Hopper-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Hopper-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 5f1e80ebe..fec1aa4d3 100644 Binary files a/docs/plots/Hopper-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Hopper-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Humanoid-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Humanoid-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 99d9e55a2..df7b11967 100644 Binary files a/docs/plots/Humanoid-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Humanoid-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/HumanoidStandup-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/HumanoidStandup-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 0177b5a73..09c4d8791 100644 Binary files a/docs/plots/HumanoidStandup-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/HumanoidStandup-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/IceHockey_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/IceHockey_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 47eef9b04..000000000 Binary files a/docs/plots/IceHockey_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/InvertedDoublePendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/InvertedDoublePendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index fd511f9d6..e44e41658 100644 Binary files a/docs/plots/InvertedDoublePendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/InvertedDoublePendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/InvertedPendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/InvertedPendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index d6a59c049..c215e9730 100644 Binary files a/docs/plots/InvertedPendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/InvertedPendulum-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Jamesbond_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Jamesbond_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1e4b4861e..000000000 Binary files a/docs/plots/Jamesbond_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/JourneyEscape_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/JourneyEscape_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index b7ed746d4..000000000 Binary files a/docs/plots/JourneyEscape_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Kangaroo_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Kangaroo_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index d6917b9fe..000000000 Binary files a/docs/plots/Kangaroo_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Krull_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Krull_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 76f0060c0..000000000 Binary files a/docs/plots/Krull_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/KungFuMaster_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/KungFuMaster_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 248483026..000000000 Binary files a/docs/plots/KungFuMaster_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/LunarLander-v3_Continuous_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/LunarLander-v3_Continuous_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1f1b8306b..000000000 Binary files a/docs/plots/LunarLander-v3_Continuous_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/LunarLander-v3_Discrete_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/LunarLander-v3_Discrete_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 5a4eeaa36..000000000 Binary files a/docs/plots/LunarLander-v3_Discrete_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/LunarLander-v3_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/LunarLander-v3_multi_trial_graph_mean_returns_ma_vs_frames.png index 6904e0cb4..7a451d825 100644 Binary files a/docs/plots/LunarLander-v3_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/LunarLander-v3_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/LunarLanderContinuous-v3_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/LunarLanderContinuous-v3_multi_trial_graph_mean_returns_ma_vs_frames.png new file mode 100644 index 000000000..a4c88546b Binary files /dev/null and b/docs/plots/LunarLanderContinuous-v3_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/MsPacman-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/MsPacman-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 7cb68ea77..ee320d73f 100644 Binary files a/docs/plots/MsPacman-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/MsPacman-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/MsPacman_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/MsPacman_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index cefd4cd85..000000000 Binary files a/docs/plots/MsPacman_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/NameThisGame_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/NameThisGame_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 103d86744..000000000 Binary files a/docs/plots/NameThisGame_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Pendulum-v1_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Pendulum-v1_multi_trial_graph_mean_returns_ma_vs_frames.png index c33cd0c54..40b1e8a36 100644 Binary files a/docs/plots/Pendulum-v1_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Pendulum-v1_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Phoenix_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Phoenix_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index df6623d97..000000000 Binary files a/docs/plots/Phoenix_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Pong-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Pong-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 03cd49ff6..9363b9dc2 100644 Binary files a/docs/plots/Pong-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Pong-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Pong_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Pong_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 67aff01bf..000000000 Binary files a/docs/plots/Pong_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Pooyan_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Pooyan_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 90f8296f7..000000000 Binary files a/docs/plots/Pooyan_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Pusher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Pusher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 18fe8d84c..d24f830ed 100644 Binary files a/docs/plots/Pusher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Pusher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Qbert-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Qbert-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 55ad74861..9b5b156ff 100644 Binary files a/docs/plots/Qbert-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Qbert-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Qbert_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Qbert_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 2d9146787..000000000 Binary files a/docs/plots/Qbert_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Reacher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Reacher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index c8492e4a8..7608d2efd 100644 Binary files a/docs/plots/Reacher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Reacher-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Riverraid_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Riverraid_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index b7948bebb..000000000 Binary files a/docs/plots/Riverraid_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/RoadRunner_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/RoadRunner_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 4cfc61d08..000000000 Binary files a/docs/plots/RoadRunner_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Robotank_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Robotank_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1063444c3..000000000 Binary files a/docs/plots/Robotank_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Seaquest-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Seaquest-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 7c48e901f..f5fd1d12b 100644 Binary files a/docs/plots/Seaquest-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Seaquest-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Seaquest_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Seaquest_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 03e04ef3a..000000000 Binary files a/docs/plots/Seaquest_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Skiing_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Skiing_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 995c3ab91..000000000 Binary files a/docs/plots/Skiing_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Solaris_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Solaris_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 75b9e3638..000000000 Binary files a/docs/plots/Solaris_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/SpaceInvaders-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/SpaceInvaders-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 22c5706fe..8385ffbc1 100644 Binary files a/docs/plots/SpaceInvaders-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/SpaceInvaders-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/SpaceInvaders_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/SpaceInvaders_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index f6f4e593e..000000000 Binary files a/docs/plots/SpaceInvaders_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/StarGunner_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/StarGunner_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 9ef59c994..000000000 Binary files a/docs/plots/StarGunner_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Surround_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Surround_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 64dda3e42..000000000 Binary files a/docs/plots/Surround_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Swimmer-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Swimmer-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 945b27557..0159f2503 100644 Binary files a/docs/plots/Swimmer-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Swimmer-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/Tennis_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Tennis_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 1ea9bf18a..000000000 Binary files a/docs/plots/Tennis_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/TimePilot_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/TimePilot_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 8b5ecc21d..000000000 Binary files a/docs/plots/TimePilot_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Tutankham_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Tutankham_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 39b7f573f..000000000 Binary files a/docs/plots/Tutankham_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/UpNDown_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/UpNDown_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index bce831156..000000000 Binary files a/docs/plots/UpNDown_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/VideoPinball_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/VideoPinball_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 631964c39..000000000 Binary files a/docs/plots/VideoPinball_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Walker2d-v5_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Walker2d-v5_multi_trial_graph_mean_returns_ma_vs_frames.png index 81cd9ed43..d368f5f6f 100644 Binary files a/docs/plots/Walker2d-v5_multi_trial_graph_mean_returns_ma_vs_frames.png and b/docs/plots/Walker2d-v5_multi_trial_graph_mean_returns_ma_vs_frames.png differ diff --git a/docs/plots/WizardOfWor_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/WizardOfWor_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index e9f1351be..000000000 Binary files a/docs/plots/WizardOfWor_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/YarsRevenge_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/YarsRevenge_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index c1a9f123d..000000000 Binary files a/docs/plots/YarsRevenge_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/docs/plots/Zaxxon_multi_trial_graph_mean_returns_ma_vs_frames.png b/docs/plots/Zaxxon_multi_trial_graph_mean_returns_ma_vs_frames.png deleted file mode 100644 index 5138ead7e..000000000 Binary files a/docs/plots/Zaxxon_multi_trial_graph_mean_returns_ma_vs_frames.png and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index 5aa54b510..624956e0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "slm-lab" -version = "5.1.0" +version = "5.2.0" description = "Modular Deep Reinforcement Learning framework in PyTorch." readme = "README.md" requires-python = ">=3.12.0" @@ -60,13 +60,14 @@ ml = [ dev = [ "coverage>=7.6.1", "flaky>=3.5.3", + "glances>=4.3.3", + "HolisticTraceAnalysis>=0.5.0", "ipykernel>=6.29.5", + "nvidia-ml-py>=13.580.65", "pytest-cov>=2.7.1", "pytest-timeout>=1.3.3", "pytest>=6.0.0", "ruff>=0.8.3", - "glances>=4.3.3", - "nvidia-ml-py>=13.580.65", ] [tool.uv] diff --git a/slm_lab/agent/__init__.py b/slm_lab/agent/__init__.py index f50c8c564..e2affde37 100644 --- a/slm_lab/agent/__init__.py +++ b/slm_lab/agent/__init__.py @@ -177,6 +177,7 @@ def __init__(self, env: "gym.Env", spec: dict[str, Any]): self.eval_df = self.train_df self.metrics = {} # store scalar metrics for Ray Tune reporting + self._is_dev = lab_mode() == "dev" # cache for hot-path check def register_algo_var(self, var_name: str, source_obj: object) -> None: """Register a variable for logging. Expects source_obj to have an attribute named var_name.""" @@ -191,7 +192,7 @@ def update( done: bool, ) -> None: """Interface update method for tracker at agent.update()""" - if lab_mode() == "dev": # log tensorboard only on dev mode + if self._is_dev: # log tensorboard only on dev mode self.track_tensorboard(action) def __str__(self) -> str: @@ -278,7 +279,7 @@ def get_mean_lr(self) -> float: return np.nan lrs = [] for attr, obj in self.agent.algorithm.__dict__.items(): - if attr.endswith("lr_scheduler"): + if attr.endswith("lr_scheduler") and obj is not None: lr = obj.get_last_lr() if hasattr(lr, "cpu"): lr = lr.cpu().item() @@ -346,7 +347,7 @@ def log_summary(self, df_mode: str) -> None: logger.info("\n".join(lines)) if ( - lab_mode() == "dev" and df_mode == "train" + self._is_dev and df_mode == "train" ): # log tensorboard only on dev mode and train df data self.log_tensorboard() diff --git a/slm_lab/agent/algorithm/__init__.py b/slm_lab/agent/algorithm/__init__.py index 8da7c05cd..b6ba3b52f 100644 --- a/slm_lab/agent/algorithm/__init__.py +++ b/slm_lab/agent/algorithm/__init__.py @@ -2,6 +2,7 @@ # Contains implementations of reinforcement learning algorithms. # Uses the nets module to build neural networks as the relevant function approximators from .actor_critic import * +from .crossq import * from .dqn import * from .ppo import * from .random import * diff --git a/slm_lab/agent/algorithm/actor_critic.py b/slm_lab/agent/algorithm/actor_critic.py index a1a0b7026..4ef5dba2a 100644 --- a/slm_lab/agent/algorithm/actor_critic.py +++ b/slm_lab/agent/algorithm/actor_critic.py @@ -5,7 +5,6 @@ from slm_lab.lib import logger, math_util, util from slm_lab.lib.decorator import lab_api import numpy as np -import pydash as ps import torch logger = logger.get_logger(__name__) @@ -32,15 +31,16 @@ def __init__(self, epsilon: float = 1e-8, clip: float = 10.0): self._warmup = 1000 # Number of samples before trusting variance def update(self, values: torch.Tensor) -> None: - """Update running statistics with new values (batched Welford's)""" - values_np = values.detach().cpu().numpy().flatten() - for v in values_np: - self.count += 1 - delta = v - self.mean - self.mean += delta / self.count - delta2 = v - self.mean - self.m2 += delta * delta2 - # Update variance after enough samples + """Update running statistics with new values (Chan's parallel merge)""" + batch = values.detach().flatten() + batch_count = len(batch) + batch_mean = batch.mean().item() + batch_var = batch.var(correction=0).item() if batch_count > 1 else 0.0 + delta = batch_mean - self.mean + total = self.count + batch_count + self.mean = self.mean + delta * batch_count / max(total, 1) + self.m2 += batch_var * batch_count + delta**2 * self.count * batch_count / max(total, 1) + self.count = total if self.count > 1: self.var = self.m2 / self.count @@ -58,8 +58,28 @@ def denormalize(self, values: torch.Tensor) -> torch.Tensor: return values * std +class PercentileNormalizer: + """EMA-tracked 5th/95th percentile advantage normalization. DreamerV3. + Alternative to ReturnNormalizer; selected via spec key normalize_advantages.""" + + def __init__(self, decay=0.99): + self.perc5 = 0.0 + self.perc95 = 0.0 + self.decay = decay + + def update(self, values): + p5 = torch.quantile(values, 0.05) + p95 = torch.quantile(values, 0.95) + self.perc5 = self.decay * self.perc5 + (1 - self.decay) * p5.item() + self.perc95 = self.decay * self.perc95 + (1 - self.decay) * p95.item() + + def normalize(self, values): + scale = max(1.0, self.perc95 - self.perc5) + return values / scale + + class ActorCritic(Reinforce): - ''' + """ Implementation of single threaded Advantage Actor Critic Original paper: "Asynchronous Methods for Deep Reinforcement Learning" https://arxiv.org/abs/1602.01783 @@ -115,43 +135,52 @@ class ActorCritic(Reinforce): "type": "MLPNet", "shared": true, ... - ''' + """ @lab_api def init_algorithm_params(self): - '''Initialize other algorithm parameters''' + """Initialize other algorithm parameters""" # set default - util.set_attr(self, dict( - action_pdtype='default', - action_policy='default', - explore_var_spec=None, - entropy_coef_spec=None, - policy_loss_coef=1.0, - val_loss_coef=1.0, - normalize_v_targets=False, # Normalize value targets to prevent gradient explosion - )) - util.set_attr(self, self.algorithm_spec, [ - 'action_pdtype', - 'action_policy', - # theoretically, AC does not have policy update; but in this implementation we have such option - 'explore_var_spec', - 'gamma', # the discount factor - 'lam', - 'num_step_returns', - 'entropy_coef_spec', - 'policy_loss_coef', - 'val_loss_coef', - 'training_frequency', - 'normalize_v_targets', - ]) + util.set_attr( + self, + dict( + action_pdtype="default", + action_policy="default", + explore_var_spec=None, + entropy_coef_spec=None, + policy_loss_coef=1.0, + val_loss_coef=1.0, + normalize_v_targets=False, # Normalize value targets to prevent gradient explosion + ), + ) + util.set_attr( + self, + self.algorithm_spec, + [ + "action_pdtype", + "action_policy", + # theoretically, AC does not have policy update; but in this implementation we have such option + "explore_var_spec", + "gamma", # the discount factor + "lam", + "num_step_returns", + "entropy_coef_spec", + "policy_loss_coef", + "val_loss_coef", + "training_frequency", + "normalize_v_targets", + ], + ) self.to_train = 0 self.action_policy = getattr(policy_util, self.action_policy) self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) self.agent.explore_var = self.explore_var_scheduler.start_val if self.entropy_coef_spec is not None: - self.entropy_coef_scheduler = policy_util.VarScheduler(self.entropy_coef_spec) + self.entropy_coef_scheduler = policy_util.VarScheduler( + self.entropy_coef_spec + ) self.agent.entropy_coef = self.entropy_coef_scheduler.start_val - self.agent.mt.register_algo_var('entropy_coef', self.agent) + self.agent.mt.register_algo_var("entropy_coef", self.agent) # Initialize return normalizer for value target scaling (VecNormalize-style) if self.normalize_v_targets: self.return_normalizer = ReturnNormalizer() @@ -169,7 +198,7 @@ def init_algorithm_params(self): @lab_api def init_nets(self, global_nets=None): - ''' + """ Initialize the neural networks used to learn the actor and critic from the spec Below we automatically select an appropriate net based on two different conditions 1. If the action space is discrete or continuous action @@ -182,53 +211,61 @@ def init_nets(self, global_nets=None): 3. If the network type is feedforward, convolutional, or recurrent - Feedforward and convolutional networks take a single state as input and require an OnPolicyReplay or OnPolicyBatchReplay memory - Recurrent networks take n states as input and use gymnasium's FrameStackObservation wrapper for sequence handling - ''' - assert 'shared' in self.net_spec, 'Specify "shared" for ActorCritic network in net_spec' - self.shared = self.net_spec['shared'] + """ + assert "shared" in self.net_spec, ( + 'Specify "shared" for ActorCritic network in net_spec' + ) + self.shared = self.net_spec["shared"] # create actor/critic specific specs actor_net_spec = self.net_spec.copy() critic_net_spec = self.net_spec.copy() for k in self.net_spec: - if 'actor_' in k: - actor_net_spec[k.replace('actor_', '')] = actor_net_spec.pop(k) + if "actor_" in k: + actor_net_spec[k.replace("actor_", "")] = actor_net_spec.pop(k) critic_net_spec.pop(k) - if 'critic_' in k: - critic_net_spec[k.replace('critic_', '')] = critic_net_spec.pop(k) + if "critic_" in k: + critic_net_spec[k.replace("critic_", "")] = critic_net_spec.pop(k) actor_net_spec.pop(k) - if critic_net_spec['use_same_optim']: + if critic_net_spec["use_same_optim"]: critic_net_spec = actor_net_spec in_dim = self.agent.state_dim out_dim = net_util.get_out_dim(self.agent, add_critic=self.shared) # main actor network, may contain out_dim self.shared == True - NetClass = getattr(net, actor_net_spec['type']) + NetClass = getattr(net, actor_net_spec["type"]) self.net = NetClass(actor_net_spec, in_dim, out_dim) - self.net_names = ['net'] + self.net_names = ["net"] if not self.shared: # add separate network for critic critic_out_dim = 1 - CriticNetClass = getattr(net, critic_net_spec['type']) + CriticNetClass = getattr(net, critic_net_spec["type"]) self.critic_net = CriticNetClass(critic_net_spec, in_dim, critic_out_dim) - self.net_names.append('critic_net') + self.net_names.append("critic_net") # init net optimizer and its lr scheduler # steps_per_schedule: frames processed per scheduler.step() call steps_per_schedule = self.training_frequency * self.agent.env.num_envs self.optim = net_util.get_optim(self.net, self.net.optim_spec) - self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec, steps_per_schedule) + self.lr_scheduler = net_util.get_lr_scheduler( + self.optim, self.net.lr_scheduler_spec, steps_per_schedule + ) if not self.shared: - self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec) - self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec, steps_per_schedule) + self.critic_optim = net_util.get_optim( + self.critic_net, self.critic_net.optim_spec + ) + self.critic_lr_scheduler = net_util.get_lr_scheduler( + self.critic_optim, self.critic_net.lr_scheduler_spec, steps_per_schedule + ) net_util.set_global_nets(self, global_nets) self.end_init_nets() @lab_api def calc_pdparam(self, x, net=None): - ''' + """ The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. - ''' + """ out = super().calc_pdparam(x, net=net) if self.shared: - assert ps.is_list(out), 'Shared output should be a list [pdparam, v]' + assert isinstance(out, list), "Shared output should be a list [pdparam, v]" if len(out) == 2: # single policy pdparam = out[0] else: # multiple-task policies, still assumes 1 value @@ -239,9 +276,9 @@ def calc_pdparam(self, x, net=None): return pdparam def calc_v(self, x, net=None, use_cache=True): - ''' + """ Forward-pass to calculate the predicted state-value from critic_net. - ''' + """ if self.shared: # output: policy, value if use_cache: # uses cache from calc_pdparam to prevent double-pass v_pred = self.v_pred @@ -254,34 +291,38 @@ def calc_v(self, x, net=None, use_cache=True): return v_pred def calc_pdparam_v(self, batch): - '''Efficiently forward to get pdparam and v by batch for loss computation''' - states = batch['states'] + """Efficiently forward to get pdparam and v by batch for loss computation""" + states = batch["states"] if self.agent.env.is_venv: states = math_util.venv_unpack(states) pdparam = self.calc_pdparam(states) - v_pred = self.calc_v(states) # uses self.v_pred from calc_pdparam if self.shared + v_pred = self.calc_v( + states + ) # uses self.v_pred from calc_pdparam if self.shared return pdparam, v_pred def calc_ret_advs_v_targets(self, batch, v_preds): - '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets''' + """Calculate plain returns, and advs = rets - v_preds, v_targets = rets""" v_preds = v_preds.detach() # adv does not accumulate grad if self.agent.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.agent.env.num_envs) - rets = math_util.calc_returns(batch['rewards'], batch['terminateds'], self.gamma) + rets = math_util.calc_returns( + batch["rewards"], batch["terminateds"], self.gamma + ) advs = rets - v_preds v_targets = rets if self.agent.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) - logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + logger.debug(f"advs: {advs}\nv_targets: {v_targets}") return advs, v_targets def calc_nstep_advs_v_targets(self, batch, v_preds): - ''' + """ Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf - ''' - next_states = batch['next_states'][-1] + """ + next_states = batch["next_states"][-1] if not self.agent.env.is_venv: next_states = next_states.unsqueeze(dim=0) with torch.no_grad(): @@ -289,21 +330,27 @@ def calc_nstep_advs_v_targets(self, batch, v_preds): v_preds = v_preds.detach() # adv does not accumulate grad if self.agent.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.agent.env.num_envs) - nstep_rets = math_util.calc_nstep_returns(batch['rewards'], batch['terminateds'], next_v_pred, self.gamma, self.num_step_returns) + nstep_rets = math_util.calc_nstep_returns( + batch["rewards"], + batch["terminateds"], + next_v_pred, + self.gamma, + self.num_step_returns, + ) advs = nstep_rets - v_preds v_targets = nstep_rets if self.agent.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) - logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + logger.debug(f"advs: {advs}\nv_targets: {v_targets}") return advs, v_targets def calc_gae_advs_v_targets(self, batch, v_preds): - ''' + """ Calculate GAE, and advs = GAE, v_targets = advs + v_preds See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf - ''' - next_states = batch['next_states'][-1] + """ + next_states = batch["next_states"][-1] if not self.agent.env.is_venv: next_states = next_states.unsqueeze(dim=0) with torch.no_grad(): @@ -313,28 +360,30 @@ def calc_gae_advs_v_targets(self, batch, v_preds): v_preds = math_util.venv_pack(v_preds, self.agent.env.num_envs) next_v_pred = next_v_pred.unsqueeze(dim=0) v_preds_all = torch.cat((v_preds, next_v_pred), dim=0) - advs = math_util.calc_gaes(batch['rewards'], batch['terminateds'], v_preds_all, self.gamma, self.lam) + advs = math_util.calc_gaes( + batch["rewards"], batch["terminateds"], v_preds_all, self.gamma, self.lam + ) v_targets = advs + v_preds # NOTE: Advantage normalization moved to per-minibatch in training loop (like SB3) if self.agent.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) - logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + logger.debug(f"advs: {advs}\nv_targets: {v_targets}") return advs, v_targets def calc_policy_loss(self, batch, pdparams, advs): - '''Calculate the actor's policy loss''' + """Calculate the actor's policy loss""" return super().calc_policy_loss(batch, pdparams, advs) def calc_val_loss(self, v_preds, v_targets): - '''Calculate the critic's value loss. + """Calculate the critic's value loss. If normalize_v_targets is enabled with return_normalizer, uses running statistics to normalize targets consistently across training (like SB3's VecNormalize). This enables the critic to learn values in a stable range regardless of the environment's actual return scale. - ''' - assert v_preds.shape == v_targets.shape, f'{v_preds.shape} != {v_targets.shape}' + """ + assert v_preds.shape == v_targets.shape, f"{v_preds.shape} != {v_targets.shape}" if self.return_normalizer is not None: # Update running statistics with new targets @@ -343,7 +392,9 @@ def calc_val_loss(self, v_preds, v_targets): v_targets_norm = self.return_normalizer.normalize(v_targets) # Normalize predictions using same statistics v_preds_norm = self.return_normalizer.normalize(v_preds) - val_loss = self.val_loss_coef * self.net.loss_fn(v_preds_norm, v_targets_norm) + val_loss = self.val_loss_coef * self.net.loss_fn( + v_preds_norm, v_targets_norm + ) elif self.normalize_v_targets: # Fallback: batch normalization (less stable but prevents explosion) v_max = v_targets.abs().max() * 2 + 1e-8 @@ -351,15 +402,17 @@ def calc_val_loss(self, v_preds, v_targets): v_std = v_targets.std() + 1e-8 v_preds_norm = v_preds_clipped / v_std v_targets_norm = v_targets / v_std - val_loss = self.val_loss_coef * self.net.loss_fn(v_preds_norm, v_targets_norm) + val_loss = self.val_loss_coef * self.net.loss_fn( + v_preds_norm, v_targets_norm + ) else: val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets) - logger.debug(f'Critic value loss: {val_loss:g}') + logger.debug(f"Critic value loss: {val_loss:g}") return val_loss def train(self): - '''Train actor critic by computing the loss in batch efficiently''' + """Train actor critic by computing the loss in batch efficiently""" if self.to_train == 1: batch = self.sample() self.agent.env.set_batch_size(len(batch)) @@ -369,22 +422,40 @@ def train(self): val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss - self.net.train_step(loss, self.optim, self.lr_scheduler, global_net=self.global_net) + self.net.train_step( + loss, self.optim, self.lr_scheduler, global_net=self.global_net + ) self.agent.env.tick_opt_step() else: - self.net.train_step(policy_loss, self.optim, self.lr_scheduler, global_net=self.global_net) - self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, global_net=self.global_critic_net) + self.net.train_step( + policy_loss, + self.optim, + self.lr_scheduler, + global_net=self.global_net, + ) + self.critic_net.train_step( + val_loss, + self.critic_optim, + self.critic_lr_scheduler, + global_net=self.global_critic_net, + ) self.agent.env.tick_opt_step() self.agent.env.tick_opt_step() loss = policy_loss + val_loss # Step LR scheduler once per training iteration if self.lr_scheduler is not None: self.lr_scheduler.step() - if not self.shared and hasattr(self, 'critic_lr_scheduler') and self.critic_lr_scheduler is not None: + if ( + not self.shared + and hasattr(self, "critic_lr_scheduler") + and self.critic_lr_scheduler is not None + ): self.critic_lr_scheduler.step() # reset self.to_train = 0 - logger.debug(f'Trained {self.name} at epi: {self.agent.env.get("epi")}, frame: {self.agent.env.get("frame")}, t: {self.agent.env.get("t")}, total_reward so far: {self.agent.env.total_reward}, loss: {loss:g}') + logger.debug( + f"Trained {self.name} at epi: {self.agent.env.get('epi')}, frame: {self.agent.env.get('frame')}, t: {self.agent.env.get('t')}, total_reward so far: {self.agent.env.total_reward}, loss: {loss:g}" + ) return loss.item() else: return np.nan @@ -393,5 +464,7 @@ def train(self): def update(self): self.agent.explore_var = self.explore_var_scheduler.update(self, self.agent.env) if self.entropy_coef_spec is not None: - self.agent.entropy_coef = self.entropy_coef_scheduler.update(self, self.agent.env) + self.agent.entropy_coef = self.entropy_coef_scheduler.update( + self, self.agent.env + ) return self.agent.explore_var diff --git a/slm_lab/agent/algorithm/base.py b/slm_lab/agent/algorithm/base.py index d796304e3..dece8469d 100644 --- a/slm_lab/agent/algorithm/base.py +++ b/slm_lab/agent/algorithm/base.py @@ -10,126 +10,137 @@ class Algorithm(ABC): - '''Abstract Algorithm class to define the API methods''' + """Abstract Algorithm class to define the API methods""" def __init__(self, agent, global_nets=None): - ''' + """ @param {*} agent is the container for algorithm and related components, and interfaces with env. - ''' + """ self.agent = agent - self.algorithm_spec = agent.agent_spec['algorithm'] - self.name = self.algorithm_spec['name'] - self.memory_spec = agent.agent_spec['memory'] - self.net_spec = agent.agent_spec['net'] + self.algorithm_spec = agent.agent_spec["algorithm"] + self.name = self.algorithm_spec["name"] + self.memory_spec = agent.agent_spec["memory"] + self.net_spec = agent.agent_spec["net"] self.init_algorithm_params() self.init_nets(global_nets) @lab_api def init_algorithm_params(self): - '''Initialize other algorithm parameters and schedulers''' + """Initialize other algorithm parameters and schedulers""" # Initialize common scheduler attributes - if hasattr(self, 'explore_var_spec') and self.explore_var_spec is not None: + if hasattr(self, "explore_var_spec") and self.explore_var_spec is not None: from slm_lab.agent.algorithm import policy_util + self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) self.agent.explore_var = self.explore_var_scheduler.start_val # Register for logging - self.agent.mt.register_algo_var('explore_var', self.agent) + self.agent.mt.register_algo_var("explore_var", self.agent) @abstractmethod @lab_api def init_nets(self, global_nets=None): - '''Initialize the neural network from the spec''' + """Initialize the neural network from the spec""" raise NotImplementedError @lab_api def end_init_nets(self): - '''Checkers and conditional loaders called at the end of init_nets()''' + """Checkers and conditional loaders called at the end of init_nets()""" # check all nets naming - assert hasattr(self, 'net_names') + assert hasattr(self, "net_names") for net_name in self.net_names: - assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}' + assert net_name.endswith("net"), ( + f'Naming convention: net_name must end with "net"; got {net_name}' + ) # load algorithm if is in train@ resume or enjoy mode - if self.agent.spec['meta']['resume'] or lab_mode() == 'enjoy': + if self.agent.spec["meta"]["resume"] or lab_mode() == "enjoy": self.load() @lab_api def calc_pdparam(self, x, net=None): - ''' + """ To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs. The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. - ''' + """ raise NotImplementedError def to_action(self, action: torch.Tensor) -> np.ndarray: - '''Convert tensor action to numpy with gymnasium-compatible shapes - + """Convert tensor action to numpy with gymnasium-compatible shapes + Handles 8 action type combinations: 1. Single CartPole (2 actions): (1,) → scalar int - 2. Vector CartPole (2 actions): (2,) → (2,) + 2. Vector CartPole (2 actions): (2,) → (2,) 3. Single LunarLander (4 actions): (1,) → scalar int 4. Vector LunarLander (4 actions): (2,) → (2,) 5. Single Pendulum (1D): (1, 1) → (1,) 6. Vector Pendulum (1D): (2, 1) → (2, 1) 7. Single BipedalWalker (4D): (1, 4) → (4,) 8. Vector BipedalWalker (4D): (2, 4) → (2, 4) - ''' + """ action_np = action.cpu().numpy() - + # Single environments need scalars for discrete, squeezed arrays for continuous if not self.agent.env.is_venv: if self.agent.env.is_discrete and action_np.size == 1: action_np = action_np.item() # (1,) or scalar → int elif not self.agent.env.is_discrete and action_np.ndim == 2: action_np = action_np.squeeze(0) # (1, action_dim) → (action_dim,) - + # Vector continuous environments need (num_envs, action_dim) shape elif self.agent.env.is_venv and not self.agent.env.is_discrete: - if action_np.ndim == 1: # Got (num_envs*action_dim,), need (num_envs, action_dim) - action_np = action_np.reshape(self.agent.env.num_envs, self.agent.env.action_dim) - + if ( + action_np.ndim == 1 + ): # Got (num_envs*action_dim,), need (num_envs, action_dim) + action_np = action_np.reshape( + self.agent.env.num_envs, self.agent.env.action_dim + ) + return action_np @lab_api def act(self, state): - '''Standard act method.''' + """Standard act method.""" raise NotImplementedError @abstractmethod @lab_api def sample(self): - '''Samples a batch from memory''' + """Samples a batch from memory""" raise NotImplementedError @abstractmethod @lab_api def train(self): - '''Implement algorithm train, or throw NotImplementedError''' + """Implement algorithm train, or throw NotImplementedError""" raise NotImplementedError @abstractmethod @lab_api def update(self): - '''Implement algorithm update, or throw NotImplementedError''' + """Implement algorithm update, or throw NotImplementedError""" raise NotImplementedError @lab_api def save(self, ckpt=None): - '''Save net models for algorithm given the required property self.net_names''' - if not hasattr(self, 'net_names'): - logger.info('No net declared in self.net_names in init_nets(); no models to save.') + """Save net models for algorithm given the required property self.net_names""" + if not hasattr(self, "net_names"): + logger.info( + "No net declared in self.net_names in init_nets(); no models to save." + ) else: net_util.save_algorithm(self, ckpt=ckpt) @lab_api def load(self): - '''Load net models for algorithm given the required property self.net_names''' - if not hasattr(self, 'net_names'): - logger.info('No net declared in self.net_names in init_nets(); no models to load.') + """Load net models for algorithm given the required property self.net_names""" + if not hasattr(self, "net_names"): + logger.info( + "No net declared in self.net_names in init_nets(); no models to load." + ) else: net_util.load_algorithm(self) # set decayable variables to initial values for k, v in vars(self).items(): - if k.endswith('_scheduler') and hasattr(v, 'start_val'): - var_name = k.replace('_scheduler', '') + if k.endswith("_scheduler") and hasattr(v, "start_val"): + var_name = k.replace("_scheduler", "") setattr(self.agent, var_name, v.start_val) diff --git a/slm_lab/agent/algorithm/crossq.py b/slm_lab/agent/algorithm/crossq.py new file mode 100644 index 000000000..fc9b94e44 --- /dev/null +++ b/slm_lab/agent/algorithm/crossq.py @@ -0,0 +1,259 @@ +import numpy as np +import torch + +from slm_lab.agent import net +from slm_lab.agent.algorithm import policy_util +from slm_lab.agent.algorithm.sac import SoftActorCritic +from slm_lab.agent.net import net_util +from slm_lab.lib import logger, math_util +from slm_lab.lib.decorator import lab_api + +logger = logger.get_logger(__name__) + + +class CrossQ(SoftActorCritic): + """CrossQ: Batch Normalization in Deep RL (Bhatt et al., ICLR 2024). + + Eliminates target networks via cross batch normalization in critics. + Key differences from SAC: + - No target networks (BatchNorm provides sufficient regularization) + - Cross batch norm: current (s,a) and next (s',a') share BN statistics + - UTD=1 (training_iter=1) — 20x fewer gradient steps for same performance + """ + + @lab_api + def init_nets(self, global_nets=None): + self.shared = False + # steps_per_schedule: frames processed per scheduler.step() call + steps_per_schedule = self.training_frequency * self.agent.env.num_envs + + # Actor network (identical to SAC) + ActorNetClass = getattr(net, self.net_spec["type"]) + self.net = ActorNetClass( + self.net_spec, self.agent.state_dim, net_util.get_out_dim(self.agent) + ) + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler( + self.optim, self.net.lr_scheduler_spec, steps_per_schedule + ) + + # Critic networks — use critic_net_spec if provided, else net_spec + critic_net_spec = self.agent.agent_spec.get("critic_net", self.net_spec) + CriticNetClass = getattr(net, critic_net_spec["type"]) + + if self.agent.is_discrete: + q_in_dim, q_out_dim = self.agent.state_dim, self.agent.action_dim + else: + q_in_dim = self.agent.state_dim + self.agent.action_dim + q_out_dim = 1 + + self.q1_net = CriticNetClass(critic_net_spec, q_in_dim, q_out_dim) + self.q2_net = CriticNetClass(critic_net_spec, q_in_dim, q_out_dim) + if self.spectral_norm: + net_util.apply_spectral_norm_penultimate(self.q1_net) + net_util.apply_spectral_norm_penultimate(self.q2_net) + self.q1_optim = net_util.get_optim(self.q1_net, self.q1_net.optim_spec) + self.q1_lr_scheduler = net_util.get_lr_scheduler( + self.q1_optim, self.q1_net.lr_scheduler_spec, steps_per_schedule + ) + self.q2_optim = net_util.get_optim(self.q2_net, self.q2_net.optim_spec) + self.q2_lr_scheduler = net_util.get_lr_scheduler( + self.q2_optim, self.q2_net.lr_scheduler_spec, steps_per_schedule + ) + + # No target networks — this is CrossQ's key distinction + self.net_names = ["net", "q1_net", "q2_net"] + + self._init_entropy_tuning(steps_per_schedule) + + net_util.set_global_nets(self, global_nets) + self.end_init_nets() + + def calc_q_cross(self, states, actions, next_states, next_actions, q_net): + """Cross batch normalization forward pass. + + Concatenates current (s,a) and next (s',a') into a single batch, + forwards through the critic, then splits. This shares BatchNorm + statistics between current and next state batches — the core + innovation that eliminates the need for target networks. + """ + current = torch.cat([states, actions], dim=-1) + future = torch.cat([next_states, next_actions], dim=-1) + batch = torch.cat([current, future], dim=0) + q_all = q_net(batch) + q_current, q_next = q_all.chunk(2, dim=0) + return q_current.view(-1), q_next.view(-1) + + def calc_q_cross_discrete(self, states, next_states, q_net): + """Cross batch norm forward for discrete actions. + + For discrete actions, Q-network takes only states (outputs Q for all actions). + """ + batch = torch.cat([states, next_states], dim=0) + q_all = q_net(batch) + q_current, q_next = q_all.chunk(2, dim=0) + return q_current, q_next + + def train(self): + """Override SAC's train to use cross batch norm forward pass. + + One cross forward in train mode processes current (s,a) and next (s',a') + together through each critic, sharing BatchNorm statistics. Q_next values + from this same forward are detached for target computation — no separate + eval-mode forward needed. This is CrossQ's core mechanism. + """ + if self.to_train == 1: + self._anneal_target_entropy() + self._anneal_alpha() + for _ in range(self.training_iter): + batch = self.sample() + self.agent.env.set_batch_size(len(batch)) + + states = batch["states"] + actions = batch["actions"] + next_states = batch["next_states"] + + # Get next actions from policy (no gradient through policy for critic update) + with torch.no_grad(): + next_pdparams = self.calc_pdparam(next_states) + next_action_pd = policy_util.init_action_pd( + self.agent.ActionPD, next_pdparams + ) + + # Cross batch norm forward: one pass through each critic with both + # current and next batches concatenated. BN statistics are shared. + self.q1_net.train() + self.q2_net.train() + + if self.agent.is_discrete: + q1_current, q1_next = self.calc_q_cross_discrete( + states, next_states, self.q1_net + ) + q2_current, q2_next = self.calc_q_cross_discrete( + states, next_states, self.q2_net + ) + q1_preds = q1_current.gather( + 1, actions.long().unsqueeze(1) + ).squeeze(1) + q2_preds = q2_current.gather( + 1, actions.long().unsqueeze(1) + ).squeeze(1) + q1_all, q2_all = q1_current, q2_current + + # V(s') from cross forward Q_next (detached — no gradient into targets) + with torch.no_grad(): + next_log_probs = torch.nn.functional.log_softmax( + next_action_pd.logits, dim=-1 + ) + next_probs = next_log_probs.exp() + avg_q_next = (q1_next.detach() + q2_next.detach()) / 2 + v_next = ( + next_probs * (avg_q_next - self.alpha * next_log_probs) + ).sum(dim=-1) + else: + with torch.no_grad(): + next_log_probs, next_actions = self.calc_log_prob_action( + next_action_pd + ) + q1_preds, q1_next = self.calc_q_cross( + states, actions, next_states, next_actions, self.q1_net + ) + q2_preds, q2_next = self.calc_q_cross( + states, actions, next_states, next_actions, self.q2_net + ) + q1_all, q2_all = None, None + + # V(s') from cross forward Q_next (detached) + with torch.no_grad(): + min_q_next = torch.min(q1_next.detach(), q2_next.detach()) + v_next = min_q_next - self.alpha * next_log_probs + + # Compute targets from cross forward Q_next values + with torch.no_grad(): + q_targets = ( + batch["rewards"] + + self.gamma * (1 - batch["terminateds"]) * v_next + ) + + # Apply symlog compression to Q-values if enabled + if self.symlog: + symlog_targets = math_util.symlog(q_targets) + q1_loss = self.net.loss_fn( + math_util.symlog(q1_preds), symlog_targets + ) + q2_loss = self.net.loss_fn( + math_util.symlog(q2_preds), symlog_targets + ) + else: + q1_loss = self.net.loss_fn(q1_preds, q_targets) + q2_loss = self.net.loss_fn(q2_preds, q_targets) + self.q1_net.train_step( + q1_loss, + self.q1_optim, + self.q1_lr_scheduler, + global_net=self.global_q1_net, + ) + + self.q2_net.train_step( + q2_loss, + self.q2_optim, + self.q2_lr_scheduler, + global_net=self.global_q2_net, + ) + + self._train_step += 1 + loss = q1_loss + q2_loss + + # Policy and alpha updates with optional delay + if self._train_step % self.policy_delay == 0: + # Critics in eval mode for policy update: use frozen running stats, + # don't corrupt BN statistics with policy-only batch (paper requirement) + self.q1_net.eval() + self.q2_net.eval() + action_pd = policy_util.init_action_pd( + self.agent.ActionPD, self.calc_pdparam(states) + ) + policy_loss = self.calc_policy_loss( + states, action_pd, q1_all, q2_all + ) + self.net.train_step( + policy_loss, + self.optim, + self.lr_scheduler, + global_net=self.global_net, + ) + + # Alpha update: skip when using fixed alpha + if self.fixed_alpha is None: + alpha_loss = self.calc_alpha_loss(action_pd) + self.train_alpha(alpha_loss) + loss = loss + alpha_loss + + loss = loss + policy_loss + + self.agent.env.tick_opt_step() + self.try_update_per(torch.min(q1_preds, q2_preds), q_targets) + + # Step LR schedulers once per training iteration + if self.lr_scheduler is not None: + self.lr_scheduler.step() + if self.q1_lr_scheduler is not None: + self.q1_lr_scheduler.step() + if self.q2_lr_scheduler is not None: + self.q2_lr_scheduler.step() + if self.alpha_lr_scheduler is not None: + self.alpha_lr_scheduler.step() + # reset + self.to_train = 0 + logger.debug( + f"Trained {self.name} at epi: {self.agent.env.get('epi')}, " + f"frame: {self.agent.env.get('frame')}, t: {self.agent.env.get('t')}, " + f"total_reward so far: {self.agent.env.total_reward}, loss: {loss.item():g}" + ) + return loss.item() + else: + return np.nan + + def update_nets(self): + """No-op: CrossQ has no target networks to update.""" + pass diff --git a/slm_lab/agent/algorithm/policy_util.py b/slm_lab/agent/algorithm/policy_util.py index 3ddf8565e..212a71ed3 100644 --- a/slm_lab/agent/algorithm/policy_util.py +++ b/slm_lab/agent/algorithm/policy_util.py @@ -57,14 +57,10 @@ def get_action_pd_cls(action_pdtype, action_type): def guard_tensor(state, agent): '''Guard-cast tensor before being input to network''' - # Modern gymnasium handles frame stacking efficiently, no LazyFrames needed if not isinstance(state, np.ndarray): - state = np.array(state, dtype=np.float32) - elif state.dtype != np.float32: - state = state.astype(np.float32) - state = torch.from_numpy(state) + state = np.asarray(state) + state = torch.from_numpy(np.ascontiguousarray(state)) if not agent.env.is_venv: - # singleton state, unsqueeze as minibatch for net input state = state.unsqueeze(dim=0) return state @@ -83,7 +79,7 @@ def calc_pdparam(state, algorithm): ''' if not torch.is_tensor(state): # dont need to cast from numpy state = guard_tensor(state, algorithm.agent) - state = state.to(algorithm.net.device) + state = state.to(algorithm.net.device, non_blocking=True).float() pdparam = algorithm.calc_pdparam(state) return pdparam @@ -208,9 +204,8 @@ def __init__(self, var_decay_spec=None): def update(self, algorithm, clock): '''Get an updated value for var''' - if (util.in_eval_lab_mode()) or self._updater_name == 'no_decay': + if self._updater_name == 'no_decay' or util.in_eval_lab_mode(): return self.end_val - # Handle both old Clock objects and new ClockWrapper environments - step = clock.get() if hasattr(clock, 'get') else clock.get('frame') + step = clock.get() val = self._updater(self.start_val, self.end_val, self.start_step, self.end_step, step) return val diff --git a/slm_lab/agent/algorithm/ppo.py b/slm_lab/agent/algorithm/ppo.py index 6949d9230..cfb030a7a 100644 --- a/slm_lab/agent/algorithm/ppo.py +++ b/slm_lab/agent/algorithm/ppo.py @@ -1,6 +1,10 @@ from copy import deepcopy from slm_lab.agent.algorithm import policy_util -from slm_lab.agent.algorithm.actor_critic import ActorCritic, ReturnNormalizer +from slm_lab.agent.algorithm.actor_critic import ( + ActorCritic, + PercentileNormalizer, + ReturnNormalizer, +) from slm_lab.agent.net import net_util from slm_lab.lib import logger, math_util, util from slm_lab.lib.decorator import lab_api @@ -12,7 +16,7 @@ class PPO(ActorCritic): - ''' + """ Implementation of PPO This is actually just ActorCritic with a custom loss function Original paper: "Proximal Policy Optimization Algorithms" @@ -60,52 +64,71 @@ class PPO(ActorCritic): "type": "MLPNet", "shared": true, ... - ''' + """ @lab_api def init_algorithm_params(self): - '''Initialize other algorithm parameters''' + """Initialize other algorithm parameters""" # set default - util.set_attr(self, dict( - action_pdtype='default', - action_policy='default', - explore_var_spec=None, - entropy_coef_spec=None, - minibatch_size=4, - val_loss_coef=1.0, - normalize_v_targets=False, # Normalize value targets to prevent gradient explosion - clip_vloss=False, # CleanRL-style value loss clipping (uses clip_eps) - )) - util.set_attr(self, self.algorithm_spec, [ - 'action_pdtype', - 'action_policy', - # theoretically, PPO does not have policy update; but in this implementation we have such option - 'explore_var_spec', - 'gamma', - 'lam', - 'clip_eps_spec', - 'entropy_coef_spec', - 'val_loss_coef', - 'minibatch_size', - 'time_horizon', # training_frequency = actor * horizon - 'training_epoch', - 'normalize_v_targets', - 'clip_vloss', - ]) + util.set_attr( + self, + dict( + action_pdtype="default", + action_policy="default", + explore_var_spec=None, + entropy_coef_spec=None, + minibatch_size=4, + val_loss_coef=1.0, + normalize_v_targets=False, # Normalize value targets to prevent gradient explosion + clip_vloss=False, # CleanRL-style value loss clipping (uses clip_eps) + symlog=False, # Symlog value compression (DreamerV3) + normalize_advantages="standardize", # 'standardize' or 'percentile' (DreamerV3) + ), + ) + util.set_attr( + self, + self.algorithm_spec, + [ + "action_pdtype", + "action_policy", + # theoretically, PPO does not have policy update; but in this implementation we have such option + "explore_var_spec", + "gamma", + "lam", + "clip_eps_spec", + "entropy_coef_spec", + "val_loss_coef", + "minibatch_size", + "time_horizon", # training_frequency = actor * horizon + "training_epoch", + "normalize_v_targets", + "clip_vloss", + "symlog", + "normalize_advantages", + ], + ) self.to_train = 0 # guard num_envs = self.agent.env.num_envs if self.minibatch_size % num_envs != 0 or self.time_horizon % num_envs != 0: self.minibatch_size = math.ceil(self.minibatch_size / num_envs) * num_envs self.time_horizon = math.ceil(self.time_horizon / num_envs) * num_envs - logger.info(f'minibatch_size and time_horizon needs to be multiples of num_envs; autocorrected values: minibatch_size: {self.minibatch_size} time_horizon {self.time_horizon}') + logger.info( + f"minibatch_size and time_horizon needs to be multiples of num_envs; autocorrected values: minibatch_size: {self.minibatch_size} time_horizon {self.time_horizon}" + ) # Ensure minibatch_size doesn't exceed batch_size batch_size = self.time_horizon * num_envs if self.minibatch_size > batch_size: self.minibatch_size = batch_size - logger.info(f'minibatch_size cannot exceed batch_size ({batch_size}); autocorrected to: {self.minibatch_size}') - self.training_frequency = self.time_horizon # since all memories stores num_envs by batch in list - assert self.memory_spec['name'] == 'OnPolicyBatchReplay', f'PPO only works with OnPolicyBatchReplay, but got {self.memory_spec["name"]}' + logger.info( + f"minibatch_size cannot exceed batch_size ({batch_size}); autocorrected to: {self.minibatch_size}" + ) + self.training_frequency = ( + self.time_horizon + ) # since all memories stores num_envs by batch in list + assert self.memory_spec["name"] == "OnPolicyBatchReplay", ( + f"PPO only works with OnPolicyBatchReplay, but got {self.memory_spec['name']}" + ) self.action_policy = getattr(policy_util, self.action_policy) self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) self.agent.explore_var = self.explore_var_scheduler.start_val @@ -113,30 +136,35 @@ def init_algorithm_params(self): self.clip_eps_scheduler = policy_util.VarScheduler(self.clip_eps_spec) self.clip_eps = self.clip_eps_scheduler.start_val if self.entropy_coef_spec is not None: - self.entropy_coef_scheduler = policy_util.VarScheduler(self.entropy_coef_spec) + self.entropy_coef_scheduler = policy_util.VarScheduler( + self.entropy_coef_spec + ) self.agent.entropy_coef = self.entropy_coef_scheduler.start_val # Initialize return normalizer for value target scaling (VecNormalize-style) if self.normalize_v_targets: self.return_normalizer = ReturnNormalizer() else: self.return_normalizer = None + # Initialize percentile normalizer if selected + if self.normalize_advantages == "percentile": + self.percentile_normalizer = PercentileNormalizer() # PPO uses GAE self.calc_advs_v_targets = self.calc_gae_advs_v_targets # Register PPO-specific variables for logging - self.agent.mt.register_algo_var('clip_eps', self) + self.agent.mt.register_algo_var("clip_eps", self) if self.entropy_coef_spec is not None: - self.agent.mt.register_algo_var('entropy', self.agent) + self.agent.mt.register_algo_var("entropy", self.agent) @lab_api def init_nets(self, global_nets=None): - '''PPO uses old and new to calculate ratio for loss''' + """PPO uses old and new to calculate ratio for loss""" super().init_nets(global_nets) # create old net to calculate ratio self.old_net = deepcopy(self.net) assert id(self.old_net) != id(self.net) def calc_policy_loss(self, batch, pdparams, advs): - ''' + """ The PPO loss function (subscript t is omitted) L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * H[pi](s) ] @@ -147,11 +175,11 @@ def calc_policy_loss(self, batch, pdparams, advs): 2. L^VF = E[ mse(V(s_t), V^target) ] 3. H = E[ entropy ] - ''' + """ clip_eps = self.clip_eps action_pd = policy_util.init_action_pd(self.agent.ActionPD, pdparams) - states = batch['states'] - actions = batch['actions'] + states = batch["states"] + actions = batch["actions"] if self.agent.env.is_venv: states = math_util.venv_unpack(states) actions = math_util.venv_unpack(actions) @@ -159,17 +187,26 @@ def calc_policy_loss(self, batch, pdparams, advs): # Ensure advs is always 1D regardless of venv to match log_probs shape advs = advs.view(-1) - # Normalize advantages per minibatch (like SB3) - if len(advs) > 1: + # Normalize advantages per minibatch + if self.normalize_advantages == "percentile": + self.percentile_normalizer.update(advs) + advs = self.percentile_normalizer.normalize(advs) + elif len(advs) > 1: advs = math_util.standardize(advs) # L^CLIP log_probs = policy_util.reduce_multi_action(action_pd.log_prob(actions)) with torch.no_grad(): old_pdparams = self.calc_pdparam(states, net=self.old_net) - old_action_pd = policy_util.init_action_pd(self.agent.ActionPD, old_pdparams) - old_log_probs = policy_util.reduce_multi_action(old_action_pd.log_prob(actions)) - assert log_probs.shape == old_log_probs.shape, f'log_probs shape {log_probs.shape} != old_log_probs shape {old_log_probs.shape}' + old_action_pd = policy_util.init_action_pd( + self.agent.ActionPD, old_pdparams + ) + old_log_probs = policy_util.reduce_multi_action( + old_action_pd.log_prob(actions) + ) + assert log_probs.shape == old_log_probs.shape, ( + f"log_probs shape {log_probs.shape} != old_log_probs shape {old_log_probs.shape}" + ) # Clip log ratio to prevent numerical instability (exp overflow) log_ratio = torch.clamp(log_probs - old_log_probs, -20.0, 20.0) ratios = torch.exp(log_ratio) @@ -186,36 +223,43 @@ def calc_policy_loss(self, batch, pdparams, advs): ent_penalty = -self.agent.entropy_coef * entropy policy_loss = clip_loss + ent_penalty - logger.debug(f'PPO policy loss: {policy_loss:g}') + logger.debug(f"PPO policy loss: {policy_loss:g}") return policy_loss def calc_val_loss(self, v_preds, v_targets, old_v_preds=None): - '''Calculate PPO value loss with optional CleanRL-style value clipping. - - When clip_vloss=True, clips value predictions relative to old predictions - similar to policy clipping. This can improve stability for some environments. + """Calculate PPO value loss with optional CleanRL-style value clipping and symlog compression. Args: v_preds: Current value predictions v_targets: GAE-computed value targets old_v_preds: Value predictions from before network update (for clipping) - ''' + """ + # Apply symlog compression to both preds and targets if enabled + if self.symlog: + v_preds_loss = math_util.symlog(v_preds) + v_targets_loss = math_util.symlog(v_targets) + else: + v_preds_loss = v_preds + v_targets_loss = v_targets + if self.clip_vloss and old_v_preds is not None: # CleanRL-style value clipping - v_loss_unclipped = (v_preds - v_targets) ** 2 + if self.symlog: + old_v_preds = math_util.symlog(old_v_preds) + v_loss_unclipped = (v_preds_loss - v_targets_loss) ** 2 v_clipped = old_v_preds + torch.clamp( - v_preds - old_v_preds, + v_preds_loss - old_v_preds, -self.clip_eps, self.clip_eps, ) - v_loss_clipped = (v_clipped - v_targets) ** 2 + v_loss_clipped = (v_clipped - v_targets_loss) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) val_loss = 0.5 * self.val_loss_coef * v_loss_max.mean() - logger.debug(f'PPO clipped value loss: {val_loss:g}') + logger.debug(f"PPO clipped value loss: {val_loss:g}") return val_loss else: - # Standard value loss (inherited from ActorCritic) - return super().calc_val_loss(v_preds, v_targets) + # Standard value loss with symlog-compressed inputs + return super().calc_val_loss(v_preds_loss, v_targets_loss) def train(self): if self.to_train == 1: @@ -223,55 +267,93 @@ def train(self): batch = self.sample() self.agent.env.set_batch_size(len(batch)) with torch.no_grad(): - states = batch['states'] + states = batch["states"] if self.agent.env.is_venv: states = math_util.venv_unpack(states) # NOTE states is massive with batch_size = time_horizon * num_envs. Chunk up so forward pass can fit into device esp. GPU num_chunks = max(1, int(len(states) / self.minibatch_size)) - v_preds_chunks = [self.calc_v(states_chunk, use_cache=False) for states_chunk in torch.chunk(states, num_chunks)] + v_preds_chunks = [ + self.calc_v(states_chunk, use_cache=False) + for states_chunk in torch.chunk(states, num_chunks) + ] v_preds = torch.cat(v_preds_chunks) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack # Store old v_preds for value clipping (CleanRL-style) - batch['advs'], batch['v_targets'], batch['old_v_preds'] = advs, v_targets, v_preds + batch["advs"], batch["v_targets"], batch["old_v_preds"] = ( + advs, + v_targets, + v_preds, + ) if self.agent.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): - if k not in ('advs', 'v_targets', 'old_v_preds'): + if k not in ("advs", "v_targets", "old_v_preds"): batch[k] = math_util.venv_unpack(v) - total_loss = torch.tensor(0.0, device=self.net.device) + total_loss = 0.0 for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) for minibatch in minibatches: if self.agent.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): - if k not in ('advs', 'v_targets', 'old_v_preds'): - minibatch[k] = math_util.venv_pack(v, self.agent.env.num_envs) - advs, v_targets, old_v_preds = minibatch['advs'], minibatch['v_targets'], minibatch['old_v_preds'] + if k not in ("advs", "v_targets", "old_v_preds"): + minibatch[k] = math_util.venv_pack( + v, self.agent.env.num_envs + ) + advs, v_targets, old_v_preds = ( + minibatch["advs"], + minibatch["v_targets"], + minibatch["old_v_preds"], + ) pdparams, v_preds = self.calc_pdparam_v(minibatch) - policy_loss = self.calc_policy_loss(minibatch, pdparams, advs) # from actor - val_loss = self.calc_val_loss(v_preds, v_targets, old_v_preds) # from critic + policy_loss = self.calc_policy_loss( + minibatch, pdparams, advs + ) # from actor + val_loss = self.calc_val_loss( + v_preds, v_targets, old_v_preds + ) # from critic if self.shared: # shared network loss = policy_loss + val_loss - self.net.train_step(loss, self.optim, self.lr_scheduler, global_net=self.global_net) + self.net.train_step( + loss, + self.optim, + self.lr_scheduler, + global_net=self.global_net, + ) self.agent.env.tick_opt_step() else: - self.net.train_step(policy_loss, self.optim, self.lr_scheduler, global_net=self.global_net) - self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, global_net=self.global_critic_net) + self.net.train_step( + policy_loss, + self.optim, + self.lr_scheduler, + global_net=self.global_net, + ) + self.critic_net.train_step( + val_loss, + self.critic_optim, + self.critic_lr_scheduler, + global_net=self.global_critic_net, + ) self.agent.env.tick_opt_step() self.agent.env.tick_opt_step() loss = policy_loss + val_loss - total_loss += loss + total_loss += loss.item() # Step LR scheduler once per training iteration (per batch of collected experience) # This ensures proper LR decay matching CleanRL's approach if self.lr_scheduler is not None: self.lr_scheduler.step() - if not self.shared and hasattr(self, 'critic_lr_scheduler') and self.critic_lr_scheduler is not None: + if ( + not self.shared + and hasattr(self, "critic_lr_scheduler") + and self.critic_lr_scheduler is not None + ): self.critic_lr_scheduler.step() loss = total_loss / self.training_epoch / len(minibatches) # reset self.to_train = 0 - logger.debug(f'Trained {self.name} at epi: {self.agent.env.get("epi")}, frame: {self.agent.env.get("frame")}, t: {self.agent.env.get("t")}, total_reward so far: {self.agent.env.total_reward}, loss: {loss:g}') - return loss.item() + logger.debug( + f"Trained {self.name} at epi: {self.agent.env.get('epi')}, frame: {self.agent.env.get('frame')}, t: {self.agent.env.get('t')}, total_reward so far: {self.agent.env.total_reward}, loss: {loss:g}" + ) + return loss else: return np.nan @@ -279,6 +361,8 @@ def train(self): def update(self): self.agent.explore_var = self.explore_var_scheduler.update(self, self.agent.env) if self.entropy_coef_spec is not None: - self.agent.entropy_coef = self.entropy_coef_scheduler.update(self, self.agent.env) + self.agent.entropy_coef = self.entropy_coef_scheduler.update( + self, self.agent.env + ) self.clip_eps = self.clip_eps_scheduler.update(self, self.agent.env) return self.agent.explore_var diff --git a/slm_lab/agent/algorithm/sac.py b/slm_lab/agent/algorithm/sac.py index 48a6ccb21..a3eba9254 100644 --- a/slm_lab/agent/algorithm/sac.py +++ b/slm_lab/agent/algorithm/sac.py @@ -6,11 +6,13 @@ from slm_lab.agent.algorithm import policy_util from slm_lab.agent.algorithm.actor_critic import ActorCritic from slm_lab.agent.net import net_util -from slm_lab.lib import logger, util +from slm_lab.lib import logger, math_util, util from slm_lab.lib.decorator import lab_api logger = logger.get_logger(__name__) +_LOG2 = np.log(2) # constant for squash correction + class SoftActorCritic(ActorCritic): """ @@ -31,6 +33,13 @@ def init_algorithm_params(self): training_start_step=max(1000, self.agent.memory.batch_size), policy_delay=1, # update actor every N critic updates (1 = every step, 2 = TD3-style) entropy_penalty_coef=0.0, # SD-SAC entropy penalty coefficient (0 = disabled) + symlog=False, # Symlog Q-value compression (DreamerV3) + log_alpha_min=-5.0, # alpha clamp lower bound (exp(-5) ≈ 0.007) + log_alpha_max=2.0, # alpha clamp upper bound (exp(2) ≈ 7.4) + alpha_lr=None, # separate lr for alpha optimizer (None = use actor lr) + fixed_alpha=None, # Fixed alpha (no auto-tuning). Float or None. SAC-BBF uses 0.02. + alpha_anneal_frames=0, # Linearly anneal alpha to 0 over this many frames (0 = no anneal) + spectral_norm=False, # Spectral norm on penultimate critic Linear (Gogianu et al. 2021) ), ) util.set_attr( @@ -45,6 +54,13 @@ def init_algorithm_params(self): "training_start_step", "policy_delay", "entropy_penalty_coef", + "symlog", + "log_alpha_min", + "log_alpha_max", + "alpha_lr", + "fixed_alpha", + "alpha_anneal_frames", + "spectral_norm", ], ) if self.agent.is_discrete: @@ -55,6 +71,8 @@ def init_algorithm_params(self): self.action_policy = getattr(policy_util, self.action_policy) self._train_step = 0 # counter for policy delay self._entropy_ema = None # running entropy for SD-SAC penalty + self._cached_entropy = None # cached from policy loss for alpha loss + self._is_per = "Prioritized" in util.get_class_name(self.agent.memory) @lab_api def init_nets(self, global_nets=None): @@ -81,6 +99,9 @@ def init_nets(self, global_nets=None): self.q1_net = NetClass(self.net_spec, q_in_dim, q_out_dim) self.target_q1_net = NetClass(self.net_spec, q_in_dim, q_out_dim) + if self.spectral_norm: + net_util.apply_spectral_norm_penultimate(self.q1_net) + net_util.apply_spectral_norm_penultimate(self.target_q1_net) net_util.copy(self.q1_net, self.target_q1_net) self.q1_optim = net_util.get_optim(self.q1_net, self.q1_net.optim_spec) self.q1_lr_scheduler = net_util.get_lr_scheduler( @@ -89,6 +110,9 @@ def init_nets(self, global_nets=None): self.q2_net = NetClass(self.net_spec, q_in_dim, q_out_dim) self.target_q2_net = NetClass(self.net_spec, q_in_dim, q_out_dim) + if self.spectral_norm: + net_util.apply_spectral_norm_penultimate(self.q2_net) + net_util.apply_spectral_norm_penultimate(self.target_q2_net) net_util.copy(self.q2_net, self.target_q2_net) self.q2_optim = net_util.get_optim(self.q2_net, self.q2_net.optim_spec) self.q2_lr_scheduler = net_util.get_lr_scheduler( @@ -97,28 +121,7 @@ def init_nets(self, global_nets=None): self.net_names = ["net", "q1_net", "target_q1_net", "q2_net", "target_q2_net"] - # Automatic entropy temperature tuning - # Use 'auto' (default) or specify explicit target_entropy value - target_entropy_config = self.algorithm_spec.get("target_entropy", "auto") - if target_entropy_config == "auto": - # Discrete: H_target = 0.6 * log(|A|) — lower than Christodoulou 2019's 0.98 - # to allow meaningful exploitation. 0.98 is too close to max entropy. - # Continuous: H_target = -dim(A) per Haarnoja 2018 - if self.agent.is_discrete: - self.target_entropy = 0.6 * np.log(self.agent.action_dim) - else: - action_dim = np.prod(self.agent.action_space.shape) - self.target_entropy = -action_dim - else: - self.target_entropy = float(target_entropy_config) - - self.log_alpha = torch.zeros(1, requires_grad=True, device=self.net.device) - self.alpha = self.log_alpha.detach().exp() - self.alpha_optim = net_util.get_optim(self.log_alpha, self.net.optim_spec) - self.alpha_lr_scheduler = net_util.get_lr_scheduler( - self.alpha_optim, self.net.lr_scheduler_spec, steps_per_schedule - ) - self.agent.mt.register_algo_var("alpha", self) + self._init_entropy_tuning(steps_per_schedule) net_util.set_global_nets(self, global_nets) self.end_init_nets() @@ -145,7 +148,7 @@ def calc_log_prob_action(self, action_pd, reparam=False): ) # Sum across action dimensions # Numerically stable squash correction: log(1 - tanh^2(x)) = 2*(log(2) - x - softplus(-2x)) squash_correction = ( - 2 * (np.log(2) - raw_actions - F.softplus(-2 * raw_actions)) + 2 * (_LOG2 - raw_actions - F.softplus(-2 * raw_actions)) ).sum(-1) log_probs = raw_log_probs - squash_correction actions = torch.tanh(raw_actions) @@ -175,8 +178,8 @@ def calc_v_next(self, next_states, action_pd): Continuous: V(s') = min(Q1,Q2) - α·log(π) where a ~ π """ if self.agent.is_discrete: - next_probs = action_pd.probs next_log_probs = F.log_softmax(action_pd.logits, dim=-1) + next_probs = next_log_probs.exp() next_q1_all = self.target_q1_net(next_states) next_q2_all = self.target_q2_net(next_states) avg_q = ( @@ -202,15 +205,18 @@ def calc_q_targets(self, batch): def calc_policy_loss_discrete(self, states, action_pd, q1_all, q2_all): """J_π = E[Σ_a π(a|s)[α·log(π) - avg(Q1,Q2)]] + entropy_penalty (SD-SAC)""" - action_probs = action_pd.probs action_log_probs = F.log_softmax(action_pd.logits, dim=-1) + action_probs = action_log_probs.exp() with torch.no_grad(): avg_q_all = (q1_all + q2_all) / 2 # SD-SAC: avg instead of min for discrete policy_loss = ( - (action_probs * (self.alpha.detach() * action_log_probs - avg_q_all)) + (action_probs * (self.alpha * action_log_probs - avg_q_all)) .sum(dim=1) .mean() ) + # Cache entropy for alpha loss to avoid recomputing probs/log_probs + with torch.no_grad(): + self._cached_entropy = -(action_probs * action_log_probs).sum(dim=-1).mean() # SD-SAC entropy penalty: beta * 0.5 * (H_old - H_new)^2 if self.entropy_penalty_coef > 0: entropy = -(action_probs * action_log_probs).sum(dim=-1).mean() @@ -245,10 +251,8 @@ def calc_policy_loss(self, states, action_pd, q1_all=None, q2_all=None): def calc_alpha_loss_discrete(self, action_pd): """J_α = -α * (H_target - H) — matches continuous SAC sign convention""" - action_probs = action_pd.probs - action_log_probs = F.log_softmax(action_pd.logits, dim=-1) - with torch.no_grad(): - entropy_current = -(action_probs * action_log_probs).sum(dim=-1).mean() + # Reuse cached entropy from policy loss to avoid recomputing probs/log_probs + entropy_current = self._cached_entropy # Sign must match continuous: when H > H_target, alpha decreases return -(self.log_alpha.exp() * (self.target_entropy - entropy_current)) @@ -271,7 +275,7 @@ def calc_alpha_loss(self, action_pd): return fn(action_pd) def try_update_per(self, q_preds, q_targets): - if "Prioritized" not in util.get_class_name(self.agent.memory): + if not self._is_per: return with torch.no_grad(): errors = (q_preds - q_targets).abs().cpu().numpy() @@ -283,11 +287,83 @@ def train_alpha(self, alpha_loss): self.alpha_optim.step() # Clamp log_alpha to prevent runaway growth in truncation-only envs (e.g. Acrobot) with torch.no_grad(): - self.log_alpha.clamp_(-5.0, 2.0) # alpha in [~0.007, ~7.4] + self.log_alpha.clamp_(self.log_alpha_min, self.log_alpha_max) self.alpha = self.log_alpha.detach().exp() + def _init_entropy_tuning(self, steps_per_schedule): + """Initialize entropy temperature (alpha). Shared by SAC and CrossQ. + + Supports two modes: + - Auto-tuning (default): learnable log_alpha with optimizer + - Fixed alpha: constant alpha, optionally annealed to 0 over alpha_anneal_frames + """ + if self.fixed_alpha is not None: + # Fixed alpha mode (SAC-BBF approach): no auto-tuning + self._fixed_alpha_start = float(self.fixed_alpha) + self.alpha = torch.tensor(self._fixed_alpha_start, device=self.net.device) + self.log_alpha = torch.tensor(np.log(self._fixed_alpha_start), device=self.net.device) + self.alpha_optim = None + self.alpha_lr_scheduler = None + self.target_entropy = None + self._entropy_anneal_frames = 0 + self.agent.mt.register_algo_var("alpha", self) + return + + # Auto-tuning mode + target_entropy_config = self.algorithm_spec.get("target_entropy", "auto") + if target_entropy_config == "auto": + if self.agent.is_discrete: + log_action_dim = np.log(self.agent.action_dim) + ea = self.algorithm_spec.get("entropy_anneal", {}) + start_ratio = ea.get("start_ratio", 0.6) + end_ratio = ea.get("end_ratio", start_ratio) + self._entropy_anneal_frames = ea.get("frames", 0) + self.target_entropy = start_ratio * log_action_dim + self._target_entropy_start = start_ratio * log_action_dim + self._target_entropy_end = end_ratio * log_action_dim + else: + action_dim = np.prod(self.agent.action_space.shape) + self.target_entropy = -action_dim + self._entropy_anneal_frames = 0 + else: + self.target_entropy = float(target_entropy_config) + self._entropy_anneal_frames = 0 + + self.log_alpha = torch.zeros(1, requires_grad=True, device=self.net.device) + self.alpha = self.log_alpha.detach().exp() + alpha_optim_spec = dict(self.net.optim_spec) + if self.alpha_lr is not None: + alpha_optim_spec["lr"] = self.alpha_lr + self.alpha_optim = net_util.get_optim(self.log_alpha, alpha_optim_spec) + self.alpha_lr_scheduler = net_util.get_lr_scheduler( + self.alpha_optim, self.net.lr_scheduler_spec, steps_per_schedule + ) + self.agent.mt.register_algo_var("alpha", self) + + def _anneal_target_entropy(self): + """Linearly anneal target_entropy for discrete actions.""" + if self._entropy_anneal_frames <= 0: + return + frame = self.agent.env.get("frame") + t = min(frame / self._entropy_anneal_frames, 1.0) + self.target_entropy = self._target_entropy_start + t * ( + self._target_entropy_end - self._target_entropy_start + ) + + def _anneal_alpha(self): + """Linearly anneal fixed alpha to 0 over alpha_anneal_frames.""" + if self.fixed_alpha is None or self.alpha_anneal_frames <= 0: + return + frame = self.agent.env.get("frame") + t = min(frame / self.alpha_anneal_frames, 1.0) + self.alpha = torch.tensor( + self._fixed_alpha_start * (1.0 - t), device=self.net.device + ) + def train(self): if self.to_train == 1: + self._anneal_target_entropy() + self._anneal_alpha() for _ in range(self.training_iter): batch = self.sample() self.agent.env.set_batch_size(len(batch)) @@ -300,7 +376,18 @@ def train(self): q1_preds, q1_all = self.calc_q(states, actions, self.q1_net) q2_preds, q2_all = self.calc_q(states, actions, self.q2_net) - q1_loss = self.net.loss_fn(q1_preds, q_targets) + # Apply symlog compression to Q-values if enabled + if self.symlog: + symlog_targets = math_util.symlog(q_targets) + q1_loss = self.net.loss_fn( + math_util.symlog(q1_preds), symlog_targets + ) + q2_loss = self.net.loss_fn( + math_util.symlog(q2_preds), symlog_targets + ) + else: + q1_loss = self.net.loss_fn(q1_preds, q_targets) + q2_loss = self.net.loss_fn(q2_preds, q_targets) self.q1_net.train_step( q1_loss, self.q1_optim, @@ -308,7 +395,6 @@ def train(self): global_net=self.global_q1_net, ) - q2_loss = self.net.loss_fn(q2_preds, q_targets) self.q2_net.train_step( q2_loss, self.q2_optim, @@ -334,10 +420,13 @@ def train(self): global_net=self.global_net, ) - alpha_loss = self.calc_alpha_loss(action_pd) - self.train_alpha(alpha_loss) + # Alpha update: skip when using fixed alpha + if self.fixed_alpha is None: + alpha_loss = self.calc_alpha_loss(action_pd) + self.train_alpha(alpha_loss) + loss = loss + alpha_loss - loss = loss + policy_loss + alpha_loss + loss = loss + policy_loss # update target networks only when policy is updated self.update_nets() diff --git a/slm_lab/agent/memory/prioritized.py b/slm_lab/agent/memory/prioritized.py index d790997c2..edc4f1478 100644 --- a/slm_lab/agent/memory/prioritized.py +++ b/slm_lab/agent/memory/prioritized.py @@ -1,7 +1,6 @@ from slm_lab.agent.memory.replay import Replay from slm_lab.lib import util import numpy as np -import random class SumTree: @@ -30,24 +29,22 @@ def __init__(self, capacity): self.indices = np.zeros(capacity) # Stores the indices of the experiences def _propagate(self, idx, change): - parent = (idx - 1) // 2 - - self.tree[parent] += change - - if parent != 0: - self._propagate(parent, change) + while idx != 0: + idx = (idx - 1) // 2 + self.tree[idx] += change def _retrieve(self, idx, s): - left = 2 * idx + 1 - right = left + 1 - - if left >= len(self.tree): - return idx - - if s <= self.tree[left]: - return self._retrieve(left, s) - else: - return self._retrieve(right, s - self.tree[left]) + tree = self.tree + tree_len = len(tree) + while True: + left = 2 * idx + 1 + if left >= tree_len: + return idx + if s <= tree[left]: + idx = left + else: + s -= tree[left] + idx = left + 1 def total(self): return self.tree[0] @@ -148,16 +145,15 @@ def get_priority(self, error): def sample_idxs(self, batch_size): '''Samples batch_size indices from memory in proportional to their priority.''' - batch_idxs = np.zeros(batch_size) - tree_idxs = np.zeros(batch_size, dtype=int) + batch_idxs = np.empty(batch_size, dtype=int) + tree_idxs = np.empty(batch_size, dtype=int) + total = self.tree.total() + samples = np.random.uniform(0, total, size=batch_size) for i in range(batch_size): - s = random.uniform(0, self.tree.total()) - (tree_idx, p, idx) = self.tree.get(s) - batch_idxs[i] = idx + tree_idx, p, idx = self.tree.get(samples[i]) + batch_idxs[i] = int(idx) tree_idxs[i] = tree_idx - - batch_idxs = np.asarray(batch_idxs).astype(int) if self.use_cer: old_batch_idxs = batch_idxs.copy() batch_idxs = self.apply_cer(batch_idxs) diff --git a/slm_lab/agent/memory/replay.py b/slm_lab/agent/memory/replay.py index 3b7192bfa..2d3dbf6f9 100644 --- a/slm_lab/agent/memory/replay.py +++ b/slm_lab/agent/memory/replay.py @@ -154,10 +154,10 @@ def add_experience( self.head = (self.head + 1) % self.max_size # Preserve dtype: uint8 images stay uint8 (memory efficient); everything else float16 state_dtype = np.uint8 if state.dtype == np.uint8 else np.float16 - self.states[self.head] = state.astype(state_dtype) + self.states[self.head] = state if state.dtype == state_dtype else state.astype(state_dtype) self.actions[self.head] = action self.rewards[self.head] = reward - self.ns_buffer.append(next_state.astype(state_dtype)) + self.ns_buffer.append(next_state if next_state.dtype == state_dtype else next_state.astype(state_dtype)) self.dones[self.head] = done self.terminateds[self.head] = terminated self.truncateds[self.head] = truncated diff --git a/slm_lab/agent/net/batch_renorm.py b/slm_lab/agent/net/batch_renorm.py new file mode 100644 index 000000000..ea85e4a4e --- /dev/null +++ b/slm_lab/agent/net/batch_renorm.py @@ -0,0 +1,161 @@ +"""Batch Renormalization (Ioffe 2017) for CrossQ. + +Standard BatchNorm uses noisy minibatch statistics during training but +running statistics during inference — this mismatch causes instability +in off-policy RL (CrossQ). Batch Renormalization smoothly transitions +from minibatch to running statistics via clamped correction factors. + +Reference: https://arxiv.org/abs/1702.03275 +""" + +import torch +import torch.nn as nn +from torch.nn.modules.lazy import LazyModuleMixin + + +class BatchRenorm1d(nn.Module): + """Batch Renormalization layer. + + During training, normalizes using minibatch statistics but applies + correction factors (r, d) that gradually align output with running + statistics. During warmup, behaves identically to standard BatchNorm. + + Args: + num_features: Number of features (channels). + eps: Numerical stability constant. + momentum: Running stats update rate (PyTorch convention: new = old * (1-m) + batch * m). + r_max: Maximum scale correction factor after warmup. + d_max: Maximum shift correction factor after warmup. + warmup_steps: Training steps before full BRN correction is active. + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.01, + r_max: float = 3.0, + d_max: float = 5.0, + warmup_steps: int = 10000, + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.r_max_limit = r_max + self.d_max_limit = d_max + self.warmup_steps = warmup_steps + + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer("step", torch.tensor(0, dtype=torch.long)) + + def _reshape_for_broadcast(self, v: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape 1D (C,) tensor to broadcast with x of shape (B, C, ...) .""" + if x.dim() == 2: + return v + shape = [1, -1] + [1] * (x.dim() - 2) # e.g. (1, C, 1, 1) for 4D + return v.view(shape) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.training: + rm = self._reshape_for_broadcast(self.running_mean, x) + rv = self._reshape_for_broadcast(self.running_var, x) + w = self._reshape_for_broadcast(self.weight, x) + b = self._reshape_for_broadcast(self.bias, x) + x_hat = (x - rm) / (rv + self.eps).sqrt() + return w * x_hat + b + + # Compute batch statistics over all dims except features + dims = [0] + list(range(2, x.dim())) + batch_mean = x.mean(dims) + batch_var = x.var(dims, unbiased=False) + batch_std = (batch_var + self.eps).sqrt() + running_std = (self.running_var + self.eps).sqrt() + + # Warmup schedule: linearly increase r_max 1->limit, d_max 0->limit + t = min(self.step.item() / max(self.warmup_steps, 1), 1.0) + r_max = 1.0 + t * (self.r_max_limit - 1.0) + d_max = t * self.d_max_limit + + # Correction factors (detached — no gradient through r, d) + r = (batch_std.detach() / running_std).clamp(1.0 / r_max, r_max) + d = ((batch_mean.detach() - self.running_mean) / running_std).clamp(-d_max, d_max) + + # Reshape for broadcasting with arbitrary-dim input + bm = self._reshape_for_broadcast(batch_mean, x) + bs = self._reshape_for_broadcast(batch_std, x) + r = self._reshape_for_broadcast(r, x) + d = self._reshape_for_broadcast(d, x) + w = self._reshape_for_broadcast(self.weight, x) + b = self._reshape_for_broadcast(self.bias, x) + + # Normalize with batch stats, correct toward running stats + x_hat = (x - bm) / bs * r + d + + # Update running statistics + with torch.no_grad(): + self.running_mean.lerp_(batch_mean, self.momentum) + self.running_var.lerp_(batch_var, self.momentum) + self.step += 1 + + return w * x_hat + b + + def extra_repr(self) -> str: + return ( + f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, " + f"r_max={self.r_max_limit}, d_max={self.d_max_limit}, " + f"warmup_steps={self.warmup_steps}" + ) + + +class LazyBatchRenorm1d(LazyModuleMixin, BatchRenorm1d): + """Lazy version that infers num_features from first input. + + Use in TorchArc YAML specs where input dimensions are unknown: + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 10000 + """ + + cls_to_become = BatchRenorm1d + weight: nn.UninitializedParameter + bias: nn.UninitializedParameter + + def __init__( + self, + eps: float = 1e-5, + momentum: float = 0.01, + r_max: float = 3.0, + d_max: float = 5.0, + warmup_steps: int = 10000, + ): + super().__init__(0, eps=eps, momentum=momentum, r_max=r_max, d_max=d_max, warmup_steps=warmup_steps) + self.weight = nn.UninitializedParameter() + self.bias = nn.UninitializedParameter() + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params(): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def initialize_parameters(self, input: torch.Tensor) -> None: + if self.has_uninitialized_params(): + with torch.no_grad(): + num_features = input.shape[1] + self.num_features = num_features + self.weight.materialize((num_features,)) + self.bias.materialize((num_features,)) + self.register_buffer("running_mean", torch.zeros(num_features, device=input.device)) + self.register_buffer("running_var", torch.ones(num_features, device=input.device)) + self.reset_parameters() + + +# Register in torch.nn so TorchArc can resolve from YAML specs +if not hasattr(nn, "BatchRenorm1d"): + setattr(nn, "BatchRenorm1d", BatchRenorm1d) +if not hasattr(nn, "LazyBatchRenorm1d"): + setattr(nn, "LazyBatchRenorm1d", LazyBatchRenorm1d) diff --git a/slm_lab/agent/net/conv.py b/slm_lab/agent/net/conv.py index 05d8504e1..b1a50460a 100644 --- a/slm_lab/agent/net/conv.py +++ b/slm_lab/agent/net/conv.py @@ -121,11 +121,11 @@ def __init__(self, net_spec, in_dim, out_dim): self.conv_model = self.build_conv_layers(self.conv_hid_layers) self.conv_out_dim = self.get_conv_output_size() - # fc body + # fc body (set default for forward() None-check instead of hasattr) + self.fc_model = None if ps.is_empty(self.fc_hid_layers): tail_in_dim = self.conv_out_dim else: - # fc body from flattened conv self.fc_model = net_util.build_fc_model([self.conv_out_dim] + self.fc_hid_layers, self.hid_layers_activation) tail_in_dim = self.fc_hid_layers[-1] @@ -169,7 +169,7 @@ def forward(self, x): x = x / 255.0 x = self.conv_model(x) x = x.view(x.size(0), -1) - if hasattr(self, 'fc_model'): + if self.fc_model is not None: x = self.fc_model(x) return net_util.forward_tails(x, self.tails, self.log_std) @@ -264,6 +264,7 @@ def __init__(self, net_spec, in_dim, out_dim): # fc body if ps.is_empty(self.fc_hid_layers): + self.fc_model = None tail_in_dim = self.conv_out_dim else: # fc layer from flattened conv @@ -286,7 +287,7 @@ def forward(self, x): x = x / 255.0 x = self.conv_model(x) x = x.view(x.size(0), -1) # to (batch_size, -1) - if hasattr(self, 'fc_model'): + if self.fc_model is not None: x = self.fc_model(x) state_value = self.v(x) raw_advantages = self.adv(x) diff --git a/slm_lab/agent/net/mlp.py b/slm_lab/agent/net/mlp.py index 0254bfeb4..7f4aea454 100644 --- a/slm_lab/agent/net/mlp.py +++ b/slm_lab/agent/net/mlp.py @@ -8,7 +8,7 @@ class MLPNet(Net, nn.Module): - ''' + """ Class for generating arbitrary sized feedforward neural network If more than 1 output tensors, will create a self.model_tails instead of making last layer part of self.model @@ -42,10 +42,10 @@ class MLPNet(Net, nn.Module): For continuous actions, you can use state-independent log_std (CleanRL-style) by setting: "log_std_init": 0.0 # initial value for log_std parameter This creates a learnable nn.Parameter for log_std instead of a state-dependent network head. - ''' + """ def __init__(self, net_spec, in_dim, out_dim): - ''' + """ net_spec: hid_layers: list containing dimensions of the hidden layers hid_layers_activation: activation function for the hidden layers @@ -60,47 +60,65 @@ def __init__(self, net_spec, in_dim, out_dim): polyak_coef: ratio of polyak weight update gpu: whether to train using a GPU. Note this will only work if a GPU is available, othewise setting gpu=True does nothing log_std_init: if set, use state-independent log_std as nn.Parameter initialized to this value (CleanRL-style) - ''' + """ nn.Module.__init__(self) super().__init__(net_spec, in_dim, out_dim) # set default - util.set_attr(self, dict( - out_layer_activation=None, - init_fn=None, - clip_grad_val=None, - loss_spec={'name': 'MSELoss'}, - optim_spec={'name': 'Adam'}, - lr_scheduler_spec=None, - update_type='replace', - update_frequency=1, - polyak_coef=0.0, - gpu=False, - log_std_init=None, # State-independent log_std (CleanRL-style) if set - actor_init_std=None, # CleanRL uses 0.01 for Atari - critic_init_std=None, # CleanRL uses 1.0 for Atari - )) - util.set_attr(self, self.net_spec, [ - 'shared', - 'hid_layers', - 'hid_layers_activation', - 'out_layer_activation', - 'init_fn', - 'clip_grad_val', - 'loss_spec', - 'optim_spec', - 'lr_scheduler_spec', - 'update_type', - 'update_frequency', - 'polyak_coef', - 'gpu', - 'log_std_init', - 'actor_init_std', - 'critic_init_std', - ]) + util.set_attr( + self, + dict( + out_layer_activation=None, + init_fn=None, + clip_grad_val=None, + loss_spec={"name": "MSELoss"}, + optim_spec={"name": "Adam"}, + lr_scheduler_spec=None, + update_type="replace", + update_frequency=1, + polyak_coef=0.0, + gpu=False, + log_std_init=None, # State-independent log_std (CleanRL-style) if set + actor_init_std=None, # CleanRL uses 0.01 for Atari + critic_init_std=None, # CleanRL uses 1.0 for Atari + layer_norm=False, # LayerNorm after each hidden layer (BRO, NeurIPS 2024) + batch_norm=False, # BatchNorm1d after each hidden layer (CrossQ, ICLR 2024) + ), + ) + util.set_attr( + self, + self.net_spec, + [ + "shared", + "hid_layers", + "hid_layers_activation", + "out_layer_activation", + "init_fn", + "clip_grad_val", + "loss_spec", + "optim_spec", + "lr_scheduler_spec", + "update_type", + "update_frequency", + "polyak_coef", + "gpu", + "log_std_init", + "actor_init_std", + "critic_init_std", + "layer_norm", + "batch_norm", + ], + ) dims = [self.in_dim] + self.hid_layers - self.model = net_util.build_fc_model(dims, self.hid_layers_activation) - self.tails, self.log_std = net_util.build_tails(dims[-1], self.out_dim, self.out_layer_activation, self.log_std_init) + self.model = net_util.build_fc_model( + dims, + self.hid_layers_activation, + layer_norm=self.layer_norm, + batch_norm=self.batch_norm, + ) + self.tails, self.log_std = net_util.build_tails( + dims[-1], self.out_dim, self.out_layer_activation, self.log_std_init + ) net_util.init_layers(self, self.init_fn) net_util.init_tails(self, self.actor_init_std, self.critic_init_std) @@ -109,12 +127,12 @@ def __init__(self, net_spec, in_dim, out_dim): self.train() def forward(self, x): - '''The feedforward step''' + """The feedforward step""" return net_util.forward_tails(self.model(x), self.tails, self.log_std) class HydraMLPNet(Net, nn.Module): - ''' + """ Class for generating arbitrary sized feedforward neural network with multiple state and action heads, and a single shared body. e.g. net_spec @@ -147,10 +165,10 @@ class HydraMLPNet(Net, nn.Module): "polyak_coef": 0.9, "gpu": true } - ''' + """ def __init__(self, net_spec, in_dim, out_dim): - ''' + """ Multi state processing heads, single shared body, and multi action tails. There is one state and action head per body/environment Example: @@ -172,39 +190,48 @@ def __init__(self, net_spec, in_dim, out_dim): |______________| |______________| | | env 1 action env 2 action - ''' + """ nn.Module.__init__(self) super().__init__(net_spec, in_dim, out_dim) # set default - util.set_attr(self, dict( - out_layer_activation=None, - init_fn=None, - clip_grad_val=None, - loss_spec={'name': 'MSELoss'}, - optim_spec={'name': 'Adam'}, - lr_scheduler_spec=None, - update_type='replace', - update_frequency=1, - polyak_coef=0.0, - gpu=False, - )) - util.set_attr(self, self.net_spec, [ - 'hid_layers', - 'hid_layers_activation', - 'out_layer_activation', - 'init_fn', - 'clip_grad_val', - 'loss_spec', - 'optim_spec', - 'lr_scheduler_spec', - 'update_type', - 'update_frequency', - 'polyak_coef', - 'gpu', - ]) - assert len(self.hid_layers) == 3, 'Your hidden layers must specify [*heads], [body], [*tails]. If not, use MLPNet' - assert isinstance(self.in_dim, list), 'Hydra network needs in_dim as list' - assert isinstance(self.out_dim, list), 'Hydra network needs out_dim as list' + util.set_attr( + self, + dict( + out_layer_activation=None, + init_fn=None, + clip_grad_val=None, + loss_spec={"name": "MSELoss"}, + optim_spec={"name": "Adam"}, + lr_scheduler_spec=None, + update_type="replace", + update_frequency=1, + polyak_coef=0.0, + gpu=False, + ), + ) + util.set_attr( + self, + self.net_spec, + [ + "hid_layers", + "hid_layers_activation", + "out_layer_activation", + "init_fn", + "clip_grad_val", + "loss_spec", + "optim_spec", + "lr_scheduler_spec", + "update_type", + "update_frequency", + "polyak_coef", + "gpu", + ], + ) + assert len(self.hid_layers) == 3, ( + "Your hidden layers must specify [*heads], [body], [*tails]. If not, use MLPNet" + ) + assert isinstance(self.in_dim, list), "Hydra network needs in_dim as list" + assert isinstance(self.out_dim, list), "Hydra network needs out_dim as list" self.head_hid_layers = self.hid_layers[0] self.body_hid_layers = self.hid_layers[1] self.tail_hid_layers = self.hid_layers[2] @@ -214,10 +241,14 @@ def __init__(self, net_spec, in_dim, out_dim): self.tail_hid_layers = self.tail_hid_layers * len(self.out_dim) self.model_heads = self.build_model_heads(in_dim) - heads_out_dim = np.sum([head_hid_layers[-1] for head_hid_layers in self.head_hid_layers]) + heads_out_dim = np.sum( + [head_hid_layers[-1] for head_hid_layers in self.head_hid_layers] + ) dims = [heads_out_dim] + self.body_hid_layers self.model_body = net_util.build_fc_model(dims, self.hid_layers_activation) - self.model_tails = self.build_model_tails(self.out_dim, self.out_layer_activation) + self.model_tails = self.build_model_tails( + self.out_dim, self.out_layer_activation + ) net_util.init_layers(self, self.init_fn) self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) @@ -225,8 +256,10 @@ def __init__(self, net_spec, in_dim, out_dim): self.train() def build_model_heads(self, in_dim): - '''Build each model_head. These are stored as Sequential models in model_heads''' - assert len(self.head_hid_layers) == len(in_dim), 'Hydra head hid_params inconsistent with number in dims' + """Build each model_head. These are stored as Sequential models in model_heads""" + assert len(self.head_hid_layers) == len(in_dim), ( + "Hydra head hid_params inconsistent with number in dims" + ) model_heads = nn.ModuleList() for in_d, hid_layers in zip(in_dim, self.head_hid_layers): dims = [in_d] + hid_layers @@ -235,17 +268,23 @@ def build_model_heads(self, in_dim): return model_heads def build_model_tails(self, out_dim, out_layer_activation): - '''Build each model_tail. These are stored as Sequential models in model_tails''' + """Build each model_tail. These are stored as Sequential models in model_tails""" if not ps.is_list(out_layer_activation): out_layer_activation = [out_layer_activation] * len(out_dim) model_tails = nn.ModuleList() if ps.is_empty(self.tail_hid_layers): for out_d, out_activ in zip(out_dim, out_layer_activation): - tail = net_util.build_fc_model([self.body_hid_layers[-1], out_d], out_activ) + tail = net_util.build_fc_model( + [self.body_hid_layers[-1], out_d], out_activ + ) model_tails.append(tail) else: - assert len(self.tail_hid_layers) == len(out_dim), 'Hydra tail hid_params inconsistent with number out dims' - for out_d, out_activ, hid_layers in zip(out_dim, out_layer_activation, self.tail_hid_layers): + assert len(self.tail_hid_layers) == len(out_dim), ( + "Hydra tail hid_params inconsistent with number out dims" + ) + for out_d, out_activ, hid_layers in zip( + out_dim, out_layer_activation, self.tail_hid_layers + ): dims = hid_layers model_tail = net_util.build_fc_model(dims, self.hid_layers_activation) tail_out = net_util.build_fc_model([dims[-1], out_d], out_activ) @@ -254,7 +293,7 @@ def build_model_tails(self, out_dim, out_layer_activation): return model_tails def forward(self, xs): - '''The feedforward step''' + """The feedforward step""" head_xs = [] for model_head, x in zip(self.model_heads, xs): head_xs.append(model_head(x)) @@ -267,7 +306,7 @@ def forward(self, xs): class DuelingMLPNet(MLPNet): - ''' + """ Class for generating arbitrary sized feedforward neural network, with dueling heads. Intended for Q-Learning algorithms only. Implementation based on "Dueling Network Architectures for Deep Reinforcement Learning" http://proceedings.mlr.press/v48/wangf16.pdf @@ -296,37 +335,44 @@ class DuelingMLPNet(MLPNet): "polyak_coef": 0.9, "gpu": true } - ''' + """ def __init__(self, net_spec, in_dim, out_dim): nn.Module.__init__(self) Net.__init__(self, net_spec, in_dim, out_dim) # set default - util.set_attr(self, dict( - init_fn=None, - clip_grad_val=None, - loss_spec={'name': 'MSELoss'}, - optim_spec={'name': 'Adam'}, - lr_scheduler_spec=None, - update_type='replace', - update_frequency=1, - polyak_coef=0.0, - gpu=False, - )) - util.set_attr(self, self.net_spec, [ - 'shared', - 'hid_layers', - 'hid_layers_activation', - 'init_fn', - 'clip_grad_val', - 'loss_spec', - 'optim_spec', - 'lr_scheduler_spec', - 'update_type', - 'update_frequency', - 'polyak_coef', - 'gpu', - ]) + util.set_attr( + self, + dict( + init_fn=None, + clip_grad_val=None, + loss_spec={"name": "MSELoss"}, + optim_spec={"name": "Adam"}, + lr_scheduler_spec=None, + update_type="replace", + update_frequency=1, + polyak_coef=0.0, + gpu=False, + ), + ) + util.set_attr( + self, + self.net_spec, + [ + "shared", + "hid_layers", + "hid_layers_activation", + "init_fn", + "clip_grad_val", + "loss_spec", + "optim_spec", + "lr_scheduler_spec", + "update_type", + "update_frequency", + "polyak_coef", + "gpu", + ], + ) # Guard against inappropriate algorithms and environments # Build model body @@ -341,7 +387,7 @@ def __init__(self, net_spec, in_dim, out_dim): self.to(self.device) def forward(self, x): - '''The feedforward step''' + """The feedforward step""" x = self.model_body(x) state_value = self.v(x) raw_advantages = self.adv(x) diff --git a/slm_lab/agent/net/net_util.py b/slm_lab/agent/net/net_util.py index e5bc38126..2002f6291 100644 --- a/slm_lab/agent/net/net_util.py +++ b/slm_lab/agent/net/net_util.py @@ -10,12 +10,12 @@ logger = logger.get_logger(__name__) # register custom torch.optim (Global variants for A3C Hogwild) -setattr(torch.optim, 'GlobalAdam', optimizer.GlobalAdam) -setattr(torch.optim, 'GlobalRMSprop', optimizer.GlobalRMSprop) +setattr(torch.optim, "GlobalAdam", optimizer.GlobalAdam) +setattr(torch.optim, "GlobalRMSprop", optimizer.GlobalRMSprop) class NoOpLRScheduler: - '''Symbolic LRScheduler class for API consistency''' + """Symbolic LRScheduler class for API consistency""" def __init__(self, optim): self.optim = optim @@ -24,20 +24,24 @@ def step(self, epoch=None): pass def get_last_lr(self): - if hasattr(self.optim, 'defaults'): - return self.optim.defaults['lr'] + if hasattr(self.optim, "defaults"): + return self.optim.defaults["lr"] else: # TODO retrieve lr more generally - return self.optim.param_groups[0]['lr'] + return self.optim.param_groups[0]["lr"] -def build_fc_model(dims, activation=None): - '''Build a full-connected model by interleaving nn.Linear and activation_fn''' - assert len(dims) >= 2, 'dims need to at least contain input, output' +def build_fc_model(dims, activation=None, layer_norm=False, batch_norm=False): + """Build a full-connected model by interleaving nn.Linear, optional normalization, and activation_fn""" + assert len(dims) >= 2, "dims need to at least contain input, output" # shift dims and make pairs of (in, out) dims per layer dim_pairs = list(zip(dims[:-1], dims[1:])) layers = [] for in_d, out_d in dim_pairs: layers.append(nn.Linear(in_d, out_d)) + if batch_norm: + layers.append(nn.BatchNorm1d(out_d)) + elif layer_norm: + layers.append(nn.LayerNorm(out_d)) if activation is not None: layers.append(get_activation_fn(activation)) model = nn.Sequential(*layers) @@ -45,29 +49,29 @@ def build_fc_model(dims, activation=None): def get_nn_name(uncased_name): - '''Helper to get the proper name in PyTorch nn given a case-insensitive name''' + """Helper to get the proper name in PyTorch nn given a case-insensitive name""" for nn_name in nn.__dict__: if uncased_name.lower() == nn_name.lower(): return nn_name - raise ValueError(f'Name {uncased_name} not found in {nn.__dict__}') + raise ValueError(f"Name {uncased_name} not found in {nn.__dict__}") def get_activation_fn(activation): - '''Helper to generate activation function layers for net''' + """Helper to generate activation function layers for net""" ActivationClass = getattr(nn, get_nn_name(activation)) return ActivationClass() def get_loss_fn(cls, loss_spec): - '''Helper to parse loss param and construct loss_fn for net''' - LossClass = getattr(nn, get_nn_name(loss_spec['name'])) - loss_spec = ps.omit(loss_spec, 'name') - loss_fn = LossClass(**loss_spec) + """Helper to parse loss param and construct loss_fn for net""" + LossClass = getattr(nn, get_nn_name(loss_spec["name"])) + loss_kwargs = {k: v for k, v in loss_spec.items() if k != "name"} + loss_fn = LossClass(**loss_kwargs) return loss_fn def get_lr_scheduler(optim, lr_scheduler_spec, steps_per_schedule=1): - '''Helper to parse lr_scheduler param and construct Pytorch optim.lr_scheduler. + """Helper to parse lr_scheduler param and construct Pytorch optim.lr_scheduler. Args: optim: The optimizer to schedule @@ -75,38 +79,40 @@ def get_lr_scheduler(optim, lr_scheduler_spec, steps_per_schedule=1): steps_per_schedule: Number of env frames processed per scheduler.step() call. For PPO: training_frequency * num_envs (e.g., 128 * 8 = 1024) This converts frame-based specs to update-based scheduling. - ''' - if ps.is_empty(lr_scheduler_spec): + """ + if not lr_scheduler_spec: lr_scheduler = NoOpLRScheduler(optim) - elif lr_scheduler_spec['name'] == 'LinearToZero': - LRSchedulerClass = getattr(torch.optim.lr_scheduler, 'LambdaLR') - frame = float(lr_scheduler_spec['frame']) + elif lr_scheduler_spec["name"] == "LinearToZero": + LRSchedulerClass = getattr(torch.optim.lr_scheduler, "LambdaLR") + frame = float(lr_scheduler_spec["frame"]) # Convert from total frames to number of scheduler updates num_updates = max(1, frame / steps_per_schedule) - lr_scheduler = LRSchedulerClass(optim, lr_lambda=lambda x, n=num_updates: max(0, 1 - x / n)) + lr_scheduler = LRSchedulerClass( + optim, lr_lambda=lambda x, n=num_updates: max(0, 1 - x / n) + ) else: - LRSchedulerClass = getattr(torch.optim.lr_scheduler, lr_scheduler_spec['name']) - lr_scheduler_spec = ps.omit(lr_scheduler_spec, 'name') - lr_scheduler = LRSchedulerClass(optim, **lr_scheduler_spec) + LRSchedulerClass = getattr(torch.optim.lr_scheduler, lr_scheduler_spec["name"]) + sched_kwargs = {k: v for k, v in lr_scheduler_spec.items() if k != "name"} + lr_scheduler = LRSchedulerClass(optim, **sched_kwargs) return lr_scheduler def get_optim(net, optim_spec): - '''Helper to parse optim param and construct optim for net''' - OptimClass = getattr(torch.optim, optim_spec['name']) - optim_spec = ps.omit(optim_spec, 'name') + """Helper to parse optim param and construct optim for net""" + OptimClass = getattr(torch.optim, optim_spec["name"]) + optim_kwargs = {k: v for k, v in optim_spec.items() if k != "name"} if torch.is_tensor(net): # for non-net tensor variable - optim = OptimClass([net], **optim_spec) + optim = OptimClass([net], **optim_kwargs) else: - optim = OptimClass(net.parameters(), **optim_spec) + optim = OptimClass(net.parameters(), **optim_kwargs) return optim def get_policy_out_dim(agent): - '''Helper method to construct the policy network out_dim for an agent according to is_discrete, action_type''' + """Helper method to construct the policy network out_dim for an agent according to is_discrete, action_type""" action_dim = agent.action_dim if agent.is_discrete: - if agent.action_type == 'multi_discrete': + if agent.action_type == "multi_discrete": assert ps.is_list(action_dim), action_dim policy_out_dim = action_dim else: @@ -122,7 +128,7 @@ def get_policy_out_dim(agent): def get_out_dim(agent, add_critic=False): - '''Construct the NetClass out_dim for an agent according to is_discrete, action_type, and whether to add a critic unit''' + """Construct the NetClass out_dim for an agent according to is_discrete, action_type, and whether to add a critic unit""" policy_out_dim = get_policy_out_dim(agent) if add_critic: if ps.is_list(policy_out_dim): @@ -135,21 +141,23 @@ def get_out_dim(agent, add_critic=False): def init_layers(net, init_fn_name): - '''Primary method to initialize the weights of the layers of a network''' + """Primary method to initialize the weights of the layers of a network""" if init_fn_name is None: return # get nonlinearity nonlinearity = get_nn_name(net.hid_layers_activation).lower() - if nonlinearity == 'leakyrelu': - nonlinearity = 'leaky_relu' # guard name + if nonlinearity == "leakyrelu": + nonlinearity = "leaky_relu" # guard name # get init_fn and add arguments depending on nonlinearity init_fn = getattr(nn.init, init_fn_name) - if 'kaiming' in init_fn_name: # has 'nonlinearity' as arg - assert nonlinearity in ['relu', 'leaky_relu'], f'Kaiming initialization not supported for {nonlinearity}' + if "kaiming" in init_fn_name: # has 'nonlinearity' as arg + assert nonlinearity in ["relu", "leaky_relu"], ( + f"Kaiming initialization not supported for {nonlinearity}" + ) init_fn = partial(init_fn, nonlinearity=nonlinearity) - elif 'orthogonal' in init_fn_name or 'xavier' in init_fn_name: # has 'gain' as arg + elif "orthogonal" in init_fn_name or "xavier" in init_fn_name: # has 'gain' as arg gain = nn.init.calculate_gain(nonlinearity) init_fn = partial(init_fn, gain=gain) else: @@ -160,28 +168,31 @@ def init_layers(net, init_fn_name): def init_params(module, init_fn): - '''Initialize module's weights using init_fn, and biases to 0.0''' + """Initialize module's weights using init_fn, and biases to 0.0""" bias_init = 0.0 classname = util.get_class_name(module) - if 'Net' in classname: # skip if it's a net, not pytorch layer + if "Net" in classname: # skip if it's a net, not pytorch layer pass - elif classname == 'BatchNorm2d': + elif classname == "BatchNorm2d": pass # can't init BatchNorm2d - elif any(k in classname for k in ('Conv', 'Linear')): - init_fn(module.weight) - nn.init.constant_(module.bias, bias_init) - elif 'GRU' in classname: + elif any(k in classname for k in ("Conv", "Linear")): + if not hasattr(module, "weight") or module.weight is None: + pass # skip lazy modules not yet materialized (e.g. LazyWeightNormLinear) + else: + init_fn(module.weight) + nn.init.constant_(module.bias, bias_init) + elif "GRU" in classname: for name, param in module.named_parameters(): - if 'weight' in name: + if "weight" in name: init_fn(param) - elif 'bias' in name: + elif "bias" in name: nn.init.constant_(param, bias_init) else: pass def init_tails(net, actor_init_std=None, critic_init_std=None): - '''Reinitialize output head layers with specific stds (CleanRL-style). + """Reinitialize output head layers with specific stds (CleanRL-style). For PPO/ActorCritic with shared network, proper head initialization is critical: - Actor head: small std (0.01) for near-uniform initial policy @@ -194,8 +205,8 @@ def init_tails(net, actor_init_std=None, critic_init_std=None): net: Network with self.tails attribute (ModuleList of [actor_tail, critic_tail]) actor_init_std: std for actor output head orthogonal init (default: None = no reinit) critic_init_std: std for critic output head orthogonal init (default: None = no reinit) - ''' - if not hasattr(net, 'tails') or not isinstance(net.tails, nn.ModuleList): + """ + if not hasattr(net, "tails") or not isinstance(net.tails, nn.ModuleList): return # Only applies to multi-tail networks (shared actor-critic) tails = list(net.tails) @@ -212,7 +223,7 @@ def init_tails(net, actor_init_std=None, critic_init_std=None): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, actor_init_std) nn.init.constant_(module.bias, 0.0) - logger.debug(f'Reinitialized actor tail with std={actor_init_std}') + logger.debug(f"Reinitialized actor tail with std={actor_init_std}") # Reinitialize critic head if critic_init_std is not None: @@ -220,66 +231,70 @@ def init_tails(net, actor_init_std=None, critic_init_std=None): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, critic_init_std) nn.init.constant_(module.bias, 0.0) - logger.debug(f'Reinitialized critic tail with std={critic_init_std}') + logger.debug(f"Reinitialized critic tail with std={critic_init_std}") # params methods def save(net, model_path): - '''Save model weights to path''' + """Save model weights to path""" torch.save(net.state_dict(), util.smart_path(model_path)) def save_algorithm(algorithm, ckpt=None): - '''Save all the nets for an algorithm''' + """Save all the nets for an algorithm""" agent = algorithm.agent net_names = algorithm.net_names - model_prepath = agent.spec['meta']['model_prepath'] + model_prepath = agent.spec["meta"]["model_prepath"] if ckpt is not None: - model_prepath += f'_ckpt-{ckpt}' + model_prepath += f"_ckpt-{ckpt}" for net_name in net_names: net = getattr(algorithm, net_name) - model_path = f'{model_prepath}_{net_name}_model.pt' + model_path = f"{model_prepath}_{net_name}_model.pt" save(net, model_path) - optim_name = net_name.replace('net', 'optim') + optim_name = net_name.replace("net", "optim") optim = getattr(algorithm, optim_name, None) if optim is not None: # only trainable net has optim - optim_path = f'{model_prepath}_{net_name}_optim.pt' + optim_path = f"{model_prepath}_{net_name}_optim.pt" save(optim, optim_path) - logger.debug(f'Saved algorithm {util.get_class_name(algorithm)} nets {net_names} to {model_prepath}_*.pt') + logger.debug( + f"Saved algorithm {util.get_class_name(algorithm)} nets {net_names} to {model_prepath}_*.pt" + ) def load(net, model_path): - '''Load model weights from a path into a net module''' - device = None if torch.cuda.is_available() else 'cpu' + """Load model weights from a path into a net module""" + device = None if torch.cuda.is_available() else "cpu" net.load_state_dict(torch.load(util.smart_path(model_path), map_location=device)) def load_algorithm(algorithm): - '''Load all the nets for an algorithm''' + """Load all the nets for an algorithm""" agent = algorithm.agent net_names = algorithm.net_names - model_prepath = agent.spec['meta']['model_prepath'] - is_enjoy = lab_mode() == 'enjoy' + model_prepath = agent.spec["meta"]["model_prepath"] + is_enjoy = lab_mode() == "enjoy" if is_enjoy: - model_prepath += '_ckpt-best' - logger.info(f'Loading algorithm {util.get_class_name(algorithm)} nets {net_names} from {model_prepath}_*.pt') + model_prepath += "_ckpt-best" + logger.info( + f"Loading algorithm {util.get_class_name(algorithm)} nets {net_names} from {model_prepath}_*.pt" + ) for net_name in net_names: net = getattr(algorithm, net_name) - model_path = f'{model_prepath}_{net_name}_model.pt' + model_path = f"{model_prepath}_{net_name}_model.pt" load(net, model_path) if is_enjoy: # skip loading optim in enjoy mode - not needed for inference continue - optim_name = net_name.replace('net', 'optim') + optim_name = net_name.replace("net", "optim") optim = getattr(algorithm, optim_name, None) if optim is not None: # only trainable net has optim - optim_path = f'{model_prepath}_{net_name}_optim.pt' + optim_path = f"{model_prepath}_{net_name}_optim.pt" load(optim, optim_path) def copy(src_net, tar_net): - '''Copy model weights from src to target''' + """Copy model weights from src to target""" state_dict = src_net.state_dict() # Transfer state dict to target device if different tar_device = next(tar_net.parameters()).device @@ -290,43 +305,45 @@ def copy(src_net, tar_net): def polyak_update(src_net, tar_net, old_ratio=0.5): - ''' + """ Polyak weight update to update a target tar_net, retain old weights by its ratio, i.e. target <- old_ratio * source + (1 - old_ratio) * target - ''' + Uses in-place lerp_ to avoid allocating intermediate tensors. + """ for src_param, tar_param in zip(src_net.parameters(), tar_net.parameters()): - tar_param.data.copy_(old_ratio * src_param.data + (1.0 - old_ratio) * tar_param.data) + tar_param.data.lerp_(src_param.data, old_ratio) def update_target_net(src_net, tar_net, frame, num_envs): - ''' + """ Update target network using replace or polyak strategy. For replace: only updates every update_frequency frames. For polyak: updates every call with exponential moving average. - + @param src_net: Source network to copy/blend from @param tar_net: Target network to update @param frame: Current training frame for frequency gating @param num_envs: Number of parallel environments (for frame_mod calculation) - ''' - from slm_lab.lib import util - - if src_net.update_type == 'replace': - if util.frame_mod(frame, src_net.update_frequency, num_envs): + """ + if src_net.update_type == "replace": + remainder = num_envs or 1 + if frame % src_net.update_frequency < remainder: copy(src_net, tar_net) - elif src_net.update_type == 'polyak': + elif src_net.update_type == "polyak": polyak_update(src_net, tar_net, src_net.polyak_coef) else: - raise ValueError(f'Unknown update_type "{src_net.update_type}". Should be "replace" or "polyak".') + raise ValueError( + f'Unknown update_type "{src_net.update_type}". Should be "replace" or "polyak".' + ) def to_check_train_step(): - '''Condition for running assert_trained''' - return os.environ.get('PY_ENV') == 'test' or lab_mode() == 'dev' + """Condition for running assert_trained""" + return os.environ.get("PY_ENV") == "test" or lab_mode() == "dev" def dev_check_train_step(fn): - ''' + """ Decorator to check if net.train_step actually updates the network weights properly Triggers only if to_check_train_step is True (dev/test mode) @example @@ -334,7 +351,8 @@ def dev_check_train_step(fn): @net_util.dev_check_train_step def train_step(self, ...): ... - ''' + """ + @wraps(fn) def check_fn(*args, **kwargs): if not to_check_train_step(): @@ -362,30 +380,37 @@ def check_fn(*args, **kwargs): else: # check parameter updates try: - assert not all(torch.equal(w1, w2) for w1, w2 in zip(pre_params, post_params)), f'Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}' + assert not all( + torch.equal(w1, w2) for w1, w2 in zip(pre_params, post_params) + ), ( + f"Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}" + ) except Exception as e: logger.error(e) - if os.environ.get('PY_ENV') == 'test': + if os.environ.get("PY_ENV") == "test": # raise error if in unit test - raise(e) + raise (e) # check grad norms min_norm, max_norm = 0.0, 1e5 for p_name, param in net.named_parameters(): try: grad_norm = param.grad.norm() - assert min_norm < grad_norm < max_norm, f'Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation.' + assert min_norm < grad_norm < max_norm, ( + f"Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation." + ) except Exception as e: logger.warning(e) - logger.debug('Passed network parameter update check.') + logger.debug("Passed network parameter update check.") # store grad norms for debugging net.store_grad_norms() return loss + return check_fn def get_grad_norms(algorithm): - '''Gather all the net's grad norms of an algorithm for debugging''' + """Gather all the net's grad norms of an algorithm for debugging""" grad_norms = [] for net_name in algorithm.net_names: net = getattr(algorithm, net_name) @@ -395,7 +420,7 @@ def get_grad_norms(algorithm): def init_global_nets(algorithm): - ''' + """ Initialize global_nets for Hogwild using an identical instance of an algorithm from an isolated Session in spec.meta.distributed, specify either: - 'shared': global network parameter is shared all the time. In this mode, algorithm local network will be replaced directly by global_net via overriding by identify attribute name @@ -403,48 +428,50 @@ def init_global_nets(algorithm): NOTE: A3C Hogwild is CPU-only because PyTorch share_memory_() requires CPU tensors. For GPU-accelerated training, use A2C or PPO instead. - ''' - dist_mode = algorithm.agent.spec['meta']['distributed'] - assert dist_mode in ('shared', 'synced'), 'Unrecognized distributed mode' + """ + dist_mode = algorithm.agent.spec["meta"]["distributed"] + assert dist_mode in ("shared", "synced"), "Unrecognized distributed mode" global_nets = {} for net_name in algorithm.net_names: - optim_name = net_name.replace('net', 'optim') - if not hasattr(algorithm, optim_name): # only for trainable network, i.e. has an optim + optim_name = net_name.replace("net", "optim") + if not hasattr( + algorithm, optim_name + ): # only for trainable network, i.e. has an optim continue g_net = getattr(algorithm, net_name) # Move to CPU for share_memory_() (required by PyTorch multiprocessing) - g_net.to('cpu') + g_net.to("cpu") g_net.share_memory() # make net global - if dist_mode == 'shared': # use the same name to override the local net + if dist_mode == "shared": # use the same name to override the local net global_nets[net_name] = g_net else: # keep a separate reference for syncing - global_nets[f'global_{net_name}'] = g_net + global_nets[f"global_{net_name}"] = g_net # if optim is Global, set to override the local optim and its scheduler optim = getattr(algorithm, optim_name) - if hasattr(optim, 'share_memory'): + if hasattr(optim, "share_memory"): optim.share_memory() # make optim global global_nets[optim_name] = optim - lr_scheduler_name = net_name.replace('net', 'lr_scheduler') + lr_scheduler_name = net_name.replace("net", "lr_scheduler") lr_scheduler = getattr(algorithm, lr_scheduler_name) global_nets[lr_scheduler_name] = lr_scheduler - logger.info(f'Initialized global_nets attr {list(global_nets.keys())} for Hogwild') + logger.info(f"Initialized global_nets attr {list(global_nets.keys())} for Hogwild") return global_nets def set_global_nets(algorithm, global_nets): - '''For Hogwild, set attr built in init_global_nets above. Use in algorithm init.''' + """For Hogwild, set attr built in init_global_nets above. Use in algorithm init.""" # set attr first so algorithm always has self.global_{net} to pass into train_step for net_name in algorithm.net_names: - setattr(algorithm, f'global_{net_name}', None) + setattr(algorithm, f"global_{net_name}", None) # set attr created in init_global_nets if global_nets is not None: # set global nets and optims util.set_attr(algorithm, global_nets) - logger.info(f'Set global_nets attr {list(global_nets.keys())} for Hogwild') + logger.info(f"Set global_nets attr {list(global_nets.keys())} for Hogwild") def push_global_grads(net, global_net): - '''Push gradients to global_net, call inside train_step between loss.backward() and optim.step()''' + """Push gradients to global_net, call inside train_step between loss.backward() and optim.step()""" for param, global_param in zip(net.parameters(), global_net.parameters()): if global_param.grad is not None: return # quick skip @@ -454,7 +481,7 @@ def push_global_grads(net, global_net): def build_tails(tail_in_dim, out_dim, out_layer_activation, log_std_init=None): - '''Build output tails with optional state-independent log_std (CleanRL-style for continuous control).''' + """Build output tails with optional state-independent log_std (CleanRL-style for continuous control).""" import numpy as np import pydash as ps @@ -464,17 +491,66 @@ def build_tails(tail_in_dim, out_dim, out_layer_activation, log_std_init=None): # State-independent log_std: out_dim = [action_dim, action_dim] for continuous actions if log_std_init is not None and len(out_dim) == 2 and out_dim[0] == out_dim[1]: action_dim = out_dim[0] - out_activ = out_layer_activation[0] if ps.is_list(out_layer_activation) else out_layer_activation - return build_fc_model([tail_in_dim, action_dim], out_activ), nn.Parameter(torch.ones(action_dim) * log_std_init) + out_activ = ( + out_layer_activation[0] + if ps.is_list(out_layer_activation) + else out_layer_activation + ) + return build_fc_model([tail_in_dim, action_dim], out_activ), nn.Parameter( + torch.ones(action_dim) * log_std_init + ) # Multi-tail output if not ps.is_list(out_layer_activation): out_layer_activation = [out_layer_activation] * len(out_dim) - return nn.ModuleList([build_fc_model([tail_in_dim, d], a) for d, a in zip(out_dim, out_layer_activation)]), None + return nn.ModuleList( + [ + build_fc_model([tail_in_dim, d], a) + for d, a in zip(out_dim, out_layer_activation) + ] + ), None def forward_tails(x, tails, log_std=None): - '''Forward pass through tails, handling log_std expansion if present.''' + """Forward pass through tails, handling log_std expansion if present.""" if log_std is not None: - return [tails(x), log_std.expand_as(tails(x))] + out = tails(x) + return [out, log_std.expand_as(out)] return [t(x) for t in tails] if isinstance(tails, nn.ModuleList) else tails(x) + + +def apply_spectral_norm_penultimate(net_module): + """Apply spectral normalization to the penultimate Linear layer of a network. + + Gogianu et al. (ICML 2021) showed that applying SN to only the penultimate + layer bounds the critic's Lipschitz constant without degrading performance + (applying to all layers hurts because optimal Q-functions are non-smooth). + + Works with MLPNet (model + tails), ConvNet (conv_model + fc_model + tails), + and TorchArcNet (body + tails). The "penultimate Linear" is the last Linear + in the body/model, just before the output tail. + """ + # Collect all Linear layers across the network + linear_layers = [] + for name, module in net_module.named_modules(): + if isinstance(module, nn.Linear): + linear_layers.append((name, module)) + + if len(linear_layers) < 2: + logger.warning( + f"Cannot apply spectral norm: found {len(linear_layers)} Linear layers, need at least 2" + ) + return + + # Penultimate = second-to-last Linear layer + target_name, target_module = linear_layers[-2] + + # Navigate to parent module and apply spectral_norm + parts = target_name.split(".") + parent = net_module + for part in parts[:-1]: + parent = getattr(parent, part) + attr_name = parts[-1] + + setattr(parent, attr_name, torch.nn.utils.spectral_norm(target_module)) + logger.info(f"Applied spectral norm to penultimate Linear: {target_name}") diff --git a/slm_lab/agent/net/recurrent.py b/slm_lab/agent/net/recurrent.py index f8736d997..dda3a1b85 100644 --- a/slm_lab/agent/net/recurrent.py +++ b/slm_lab/agent/net/recurrent.py @@ -114,6 +114,7 @@ def __init__(self, net_spec, in_dim, out_dim): self.in_dim = in_dim if isinstance(in_dim, (int, np.integer)) else in_dim[1] # fc body: state processing model if ps.is_empty(self.fc_hid_layers): + self.fc_model = None self.rnn_input_dim = self.in_dim else: fc_dims = [self.in_dim] + self.fc_hid_layers @@ -144,7 +145,7 @@ def forward(self, x): # Process through fc layers if present x = x.view(-1, self.in_dim) - if hasattr(self, 'fc_model'): + if self.fc_model is not None: x = self.fc_model(x) x = x.view(batch_size, self.seq_len, self.rnn_input_dim) diff --git a/slm_lab/agent/net/torcharc_net.py b/slm_lab/agent/net/torcharc_net.py index bf1a744d6..391492cfb 100644 --- a/slm_lab/agent/net/torcharc_net.py +++ b/slm_lab/agent/net/torcharc_net.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn +import slm_lab.agent.net.batch_renorm # noqa: F401 — registers BatchRenorm1d in torch.nn for TorchArc +import slm_lab.agent.net.weight_norm # noqa: F401 — registers WeightNormLinear in torch.nn for TorchArc from slm_lab.agent.net import net_util from slm_lab.agent.net.base import Net from slm_lab.lib import util @@ -117,6 +119,8 @@ def __init__(self, net_spec, in_dim, out_dim): def _get_body_out_dim(self): """Compute body output dimension via dummy forward pass.""" + # Use eval mode for dummy pass — BatchNorm requires batch_size > 1 in train mode + self.body.eval() with torch.no_grad(): if isinstance(self.in_dim, (int, np.integer)): dummy = torch.ones(1, self.in_dim, device=self.device) diff --git a/slm_lab/agent/net/weight_norm.py b/slm_lab/agent/net/weight_norm.py new file mode 100644 index 000000000..ffc0c19d2 --- /dev/null +++ b/slm_lab/agent/net/weight_norm.py @@ -0,0 +1,44 @@ +"""Weight Normalization linear layer (Salimans & Kingma, 2016). + +Decouples weight magnitude from direction: w = g * (v / ||v||). +Smoother optimization landscape without normalizing activations. +""" + +import torch +import torch.nn as nn +import torch.nn.utils.parametrizations as P + + +class WeightNormLinear(nn.Module): + """Linear layer with weight normalization applied.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + self.linear = P.weight_norm(nn.Linear(in_features, out_features, bias=bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class LazyWeightNormLinear(nn.Module): + """Lazy version of WeightNormLinear -- infers in_features from first input.""" + + def __init__(self, out_features: int, bias: bool = True): + super().__init__() + self.out_features = out_features + self.bias = bias + self._linear = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._linear is None: + self._linear = P.weight_norm( + nn.Linear(x.shape[-1], self.out_features, bias=self.bias) + ).to(x.device) + return self._linear(x) + + +# Register in torch.nn so TorchArc can resolve from YAML specs +if not hasattr(nn, "WeightNormLinear"): + setattr(nn, "WeightNormLinear", WeightNormLinear) +if not hasattr(nn, "LazyWeightNormLinear"): + setattr(nn, "LazyWeightNormLinear", LazyWeightNormLinear) diff --git a/slm_lab/cli/plot.py b/slm_lab/cli/plot.py index 902dacec8..cd00c8a0f 100644 --- a/slm_lab/cli/plot.py +++ b/slm_lab/cli/plot.py @@ -10,23 +10,24 @@ logger = logger.get_logger(__name__) # File patterns for trial metrics -TRIAL_METRICS_PATH = '*t0_trial_metrics.json' -SPEC_PATH = '*spec.json' +TRIAL_METRICS_PATH = "*t0_trial_metrics.json" +SPEC_PATH = "*spec.json" # Algorithm order for legend (fixed ordering) -ALGO_ORDER = ['REINFORCE', 'SARSA', 'DQN', 'DDQN+PER', 'A2C', 'PPO', 'SAC'] +ALGO_ORDER = ["REINFORCE", "SARSA", "DQN", "DDQN+PER", "A2C", "PPO", "SAC", "CrossQ"] # Colors by algorithm lineage: # - REINFORCE/SARSA: yellow/brown (classic methods) # - DQN/DDQN+PER: teal/green tones (value-based) # - A2C/PPO/SAC: blue/purple/red tones (actor-critic family) ALGO_PALETTE = { - 'REINFORCE': 'hsl(45, 80%, 55%)', # golden yellow - 'SARSA': 'hsl(30, 60%, 45%)', # brown - 'DQN': 'hsl(175, 55%, 45%)', # teal - 'DDQN+PER': 'hsl(145, 50%, 45%)', # green - 'A2C': 'hsl(220, 65%, 55%)', # blue - 'PPO': 'hsl(280, 55%, 55%)', # purple - 'SAC': 'hsl(350, 65%, 55%)', # red + "REINFORCE": "hsl(45, 80%, 55%)", # golden yellow + "SARSA": "hsl(30, 60%, 45%)", # brown + "DQN": "hsl(175, 55%, 45%)", # teal + "DDQN+PER": "hsl(145, 50%, 45%)", # green + "A2C": "hsl(220, 65%, 55%)", # blue + "PPO": "hsl(280, 55%, 55%)", # purple + "SAC": "hsl(350, 65%, 55%)", # red + "CrossQ": "hsl(25, 85%, 55%)", # orange } @@ -36,7 +37,7 @@ def get_spec_data(folder_path: Path) -> dict | None: # or just spec.json in the folder root or info/ folder matches = list(folder_path.glob(SPEC_PATH)) if not matches: - matches = list((folder_path / 'info').glob(SPEC_PATH)) + matches = list((folder_path / "info").glob(SPEC_PATH)) if matches: return util.read(str(matches[0])) @@ -47,44 +48,45 @@ def get_algo_name_from_spec(spec: dict) -> str: """Extract algorithm name from spec.""" # Spec structure: {spec_name: {agent: ...}} OR {agent: ...} (resolved spec) try: - if 'agent' in spec: - agent_spec = spec['agent'] + if "agent" in spec: + agent_spec = spec["agent"] else: # Get the first key (spec_name) spec_name = list(spec.keys())[0] - agent_spec = spec[spec_name]['agent'] + agent_spec = spec[spec_name]["agent"] # Handle list of agents (multi-agent) or single agent dict if isinstance(agent_spec, list): - algo_name = agent_spec[0]['algorithm']['name'] + algo_name = agent_spec[0]["algorithm"]["name"] else: - algo_name = agent_spec['algorithm']['name'] + algo_name = agent_spec["algorithm"]["name"] # Standardize names name_map = { - 'VanillaDQN': 'DQN', - 'DoubleDQN': 'DDQN+PER', # Usually used with PER in benchmarks - 'PPO': 'PPO', - 'SAC': 'SAC', - 'A2C': 'A2C', - 'ActorCritic': 'A2C', - 'SoftActorCritic': 'SAC', - 'DQN': 'DQN', - 'Reinforce': 'REINFORCE', - 'REINFORCE': 'REINFORCE', + "VanillaDQN": "DQN", + "DoubleDQN": "DDQN+PER", # Usually used with PER in benchmarks + "PPO": "PPO", + "SAC": "SAC", + "A2C": "A2C", + "ActorCritic": "A2C", + "SoftActorCritic": "SAC", + "DQN": "DQN", + "Reinforce": "REINFORCE", + "REINFORCE": "REINFORCE", + "CrossQ": "CrossQ", } # specific check for DDQN/PER - if algo_name == 'DoubleDQN': - memory_name = '' - if isinstance(agent_spec, list): - memory_name = str(agent_spec[0].get('memory', {}).get('name', '')) - else: - memory_name = str(agent_spec.get('memory', {}).get('name', '')) + if algo_name == "DoubleDQN": + memory_name = "" + if isinstance(agent_spec, list): + memory_name = str(agent_spec[0].get("memory", {}).get("name", "")) + else: + memory_name = str(agent_spec.get("memory", {}).get("name", "")) - if 'Prioritized' in memory_name: - return 'DDQN+PER' - return 'DDQN' + if "Prioritized" in memory_name: + return "DDQN+PER" + return "DDQN" return name_map.get(algo_name, algo_name) except Exception as e: @@ -95,28 +97,50 @@ def get_algo_name_from_spec(spec: dict) -> str: def get_env_name_from_spec(spec: dict) -> str: """Extract environment name from spec.""" try: - if 'env' in spec: - return spec['env']['name'] + if "env" in spec: + return spec["env"]["name"] spec_name = list(spec.keys())[0] - return spec[spec_name]['env']['name'] + return spec[spec_name]["env"]["name"] except Exception: return None def find_trial_metrics(folder_path: Path) -> str | None: """Find trial metrics file in a folder.""" - matches = list((folder_path / 'info').glob(TRIAL_METRICS_PATH)) + matches = list((folder_path / "info").glob(TRIAL_METRICS_PATH)) if matches: return str(matches[0]) return None def plot( - folders: str = typer.Option(..., "--folders", "-f", help="Comma-separated data folder names (e.g., ppo_cartpole_2026_01_11,a2c_gae_cartpole_2026_01_11)"), - title: str = typer.Option(None, "--title", "-t", help="Plot title. If omitted, extracted from spec env name."), - data_folder: str = typer.Option("data", "--data-folder", "-d", help="Base data folder path"), - output_folder: str = typer.Option("docs/plots", "--output", "-o", help="Output folder for plots"), - showlegend: bool = typer.Option(True, "--legend/--no-legend", help="Show legend on plot"), + folders: str = typer.Option( + ..., + "--folders", + "-f", + help="Comma-separated data folder names (e.g., ppo_cartpole_2026_01_11,a2c_gae_cartpole_2026_01_11)", + ), + title: str = typer.Option( + None, + "--title", + "-t", + help="Plot title. If omitted, extracted from spec env name.", + ), + legend_labels: str = typer.Option( + None, + "--legends", + "-l", + help="Comma-separated legend labels (must match number of folders)", + ), + data_folder: str = typer.Option( + "data", "--data-folder", "-d", help="Base data folder path" + ), + output_folder: str = typer.Option( + "docs/plots", "--output", "-o", help="Output folder for plots" + ), + showlegend: bool = typer.Option( + True, "--legend/--no-legend", help="Show legend on plot" + ), ): """ Plot benchmark comparison graphs from explicit folder paths. @@ -131,13 +155,21 @@ def plot( data_path = Path(util.smart_path(data_folder)) output_path = Path(util.smart_path(output_folder)) - folder_list = [f.strip() for f in folders.split(',')] + folder_list = [f.strip() for f in folders.split(",")] + custom_legends = ( + [label.strip() for label in legend_labels.split(",")] if legend_labels else None + ) + if custom_legends and len(custom_legends) != len(folder_list): + logger.error( + f"Number of legend labels ({len(custom_legends)}) must match folders ({len(folder_list)})" + ) + raise typer.Exit(1) trial_metrics_paths = [] legends = [] detected_title = None - for folder_name in folder_list: + for i, folder_name in enumerate(folder_list): folder_path = data_path / folder_name if not folder_path.exists(): logger.error(f"Folder not found: {folder_path}") @@ -153,12 +185,20 @@ def plot( # Get metadata from spec spec = get_spec_data(folder_path) if spec: - algo_name = get_algo_name_from_spec(spec) + algo_name = ( + get_algo_name_from_spec(spec) + if not custom_legends + else custom_legends[i] + ) env_name = get_env_name_from_spec(spec) if not detected_title and env_name: detected_title = env_name else: - algo_name = folder_name.split('_')[0].upper() # Fallback + algo_name = ( + custom_legends[i] + if custom_legends + else folder_name.split("_")[0].upper() + ) legends.append(algo_name) logger.info(f" {algo_name}: {metrics_path}") @@ -167,22 +207,25 @@ def plot( logger.error("Need at least 1 folder to plot") raise typer.Exit(1) - # Sort tracks by ALGO_ORDER to ensure consistent ordering and coloring - combined = [] - for path, legend in zip(trial_metrics_paths, legends): - try: - order = ALGO_ORDER.index(legend) - except ValueError: - order = 999 # Put unknown at end - combined.append((order, legend, path)) + # Sort tracks by ALGO_ORDER for consistent ordering (skip when custom legends are provided) + if not custom_legends: + combined = [] + for path, legend in zip(trial_metrics_paths, legends): + try: + order = ALGO_ORDER.index(legend) + except ValueError: + order = 999 # Put unknown at end + combined.append((order, legend, path)) - combined.sort() + combined.sort() - legends = [x[1] for x in combined] - trial_metrics_paths = [x[2] for x in combined] + legends = [x[1] for x in combined] + trial_metrics_paths = [x[2] for x in combined] # Use detected title if not provided - final_title = title if title else (detected_title if detected_title else "Benchmark") + final_title = ( + title if title else (detected_title if detected_title else "Benchmark") + ) # Build palette (consistent colors for same algorithms) palette = [ALGO_PALETTE.get(legend, None) for legend in legends] @@ -194,7 +237,7 @@ def plot( # Generate output filename from title (strips ALE/ prefix for Atari env names) filename_title = Path(final_title).name # ALE/Pong-v5 → Pong-v5 - safe_title = filename_title.replace(' ', '_').replace('(', '').replace(')', '') + safe_title = filename_title.replace(" ", "_").replace("(", "").replace(")", "") graph_prepath = str(output_path / safe_title) viz.plot_multi_trial( @@ -203,17 +246,19 @@ def plot( final_title, graph_prepath, ma=True, - name_time_pairs=[('mean_returns', 'frames')], + name_time_pairs=[("mean_returns", "frames")], palette=palette, showlegend=showlegend, ) - output_file = f'{graph_prepath}_multi_trial_graph_mean_returns_ma_vs_frames.png' + output_file = f"{graph_prepath}_multi_trial_graph_mean_returns_ma_vs_frames.png" logger.info(f"Saved: {output_file}") def list_data( - data_folder: str = typer.Option("data", "--data-folder", "-d", help="Data folder path"), + data_folder: str = typer.Option( + "data", "--data-folder", "-d", help="Data folder path" + ), ): """ List available experiments in data folder. @@ -225,8 +270,13 @@ def list_data( logger.error(f"Data folder not found: {data_folder}") raise typer.Exit(1) - experiments = sorted([d.name for d in data_path.iterdir() - if d.is_dir() and not d.name.startswith('.')]) + experiments = sorted( + [ + d.name + for d in data_path.iterdir() + if d.is_dir() and not d.name.startswith(".") + ] + ) if not experiments: logger.info(f"No experiments found in {data_folder}") @@ -234,7 +284,7 @@ def list_data( logger.info(f"\nExperiments in {data_folder}:") for exp in experiments: - metrics_files = list((data_path / exp / 'info').glob('*trial_metrics.json')) + metrics_files = list((data_path / exp / "info").glob("*trial_metrics.json")) status = "✓" if metrics_files else "○" logger.info(f" {status} {exp}") diff --git a/slm_lab/cli/remote.py b/slm_lab/cli/remote.py index 6049993cd..c8cd68c1c 100644 --- a/slm_lab/cli/remote.py +++ b/slm_lab/cli/remote.py @@ -23,6 +23,9 @@ def run_remote( gpu: bool = typer.Option( False, "--gpu", help="Use GPU hardware (default: CPU)" ), + profile: bool = typer.Option( + False, "--profile", help="Enable performance profiling (forces dev mode)" + ), ): """ Launch experiment on dstack with auto HF upload. @@ -41,6 +44,10 @@ def run_remote( slm-lab run-remote spec.json ppo_pong train --gpu # GPU train (for image envs) slm-lab run-remote spec.json ppo_pong search --gpu # GPU search (for image envs) """ + # Force dev mode when profiling (matching local behavior) + if profile and mode != "dev": + mode = "dev" + run_name = name or spec_name.replace("_", "-") # Auto-select config file based on hardware type and mode @@ -55,10 +62,15 @@ def run_remote( env["SPEC_NAME"] = spec_name env["LAB_MODE"] = mode env["SPEC_VARS"] = " ".join(f"-s {item}" for item in sets) if sets else "" + env["PROFILE"] = "true" if profile else "" + env.setdefault("PROF_SKIP", "500") + env.setdefault("PROF_ACTIVE", "20") logger.info(f"Launching: {run_name} ({config_file})") logger.info(f" {spec_file} / {spec_name} / {mode}") logger.info(f" Pull: slm-lab pull {spec_name}") + if profile: + logger.info(" Profiling: enabled (traces will be collected)") result = subprocess.run(cmd, env=env) if result.returncode != 0: diff --git a/slm_lab/env/wrappers.py b/slm_lab/env/wrappers.py index 67106b0bb..82de4ffc5 100644 --- a/slm_lab/env/wrappers.py +++ b/slm_lab/env/wrappers.py @@ -2,13 +2,10 @@ import math import time -from typing import Any - +from collections import deque import gymnasium as gym import numpy as np import pandas as pd -import pydash as ps - from slm_lab.lib import util @@ -55,8 +52,9 @@ def reset_clock(self): def load(self, train_df: pd.DataFrame): """Load clock state from training dataframe.""" last_row = train_df.iloc[-1] - last_clock_vals = ps.pick(last_row, *["epi", "t", "wall_t", "opt_step", "frame"]) - util.set_attr(self, last_clock_vals) + for key in ("epi", "t", "wall_t", "opt_step", "frame"): + if key in last_row.index: + setattr(self, key, last_row[key]) self.start_wall_t -= self.wall_t def get(self, unit: str = "frame") -> int: @@ -180,43 +178,38 @@ class VectorFullGameStatistics(gym.vector.VectorWrapper): def __init__(self, env: gym.vector.VectorEnv, buffer_length: int = 100): super().__init__(env) self.buffer_length = buffer_length - self.return_queue = [] # Full-game returns + self.return_queue = deque(maxlen=buffer_length) # Full-game returns self._ongoing_returns = np.zeros(self.num_envs, dtype=np.float64) self._prev_lives = None + self._zero_lives = np.zeros(self.num_envs) # pre-allocate fallback def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) self._ongoing_returns.fill(0.0) - self._prev_lives = info.get("lives", np.zeros(self.num_envs)) + self._prev_lives = info.get("lives", self._zero_lives) return obs, info def step(self, actions): obs, rewards, terminated, truncated, info = self.env.step(actions) - # Accumulate raw rewards (note: rewards here are already clipped for training) - # We use the clipped rewards since AtariVectorEnv doesn't expose raw rewards easily self._ongoing_returns += rewards - lives = info.get("lives", np.zeros(self.num_envs)) + lives = info.get("lives", self._zero_lives) - # Check for true game-over (lives dropped to 0) - # Only record when we transition TO 0 lives (not when already at 0) + # Check for true game-over (lives dropped to 0) — vectorized if self._prev_lives is not None: game_over = (lives == 0) & (self._prev_lives > 0) - for i in range(self.num_envs): - if game_over[i]: - self.return_queue.append(self._ongoing_returns[i]) - if len(self.return_queue) > self.buffer_length: - self.return_queue.pop(0) - self._ongoing_returns[i] = 0.0 - - # Also reset on truncation (time limit) - for i in range(self.num_envs): - if truncated[i] and not terminated[i]: + done_idxs = np.flatnonzero(game_over) + for i in done_idxs: self.return_queue.append(self._ongoing_returns[i]) - if len(self.return_queue) > self.buffer_length: - self.return_queue.pop(0) - self._ongoing_returns[i] = 0.0 + self._ongoing_returns[done_idxs] = 0.0 + + # Also reset on truncation (time limit) — vectorized + trunc_only = truncated & ~terminated + trunc_idxs = np.flatnonzero(trunc_only) + for i in trunc_idxs: + self.return_queue.append(self._ongoing_returns[i]) + self._ongoing_returns[trunc_idxs] = 0.0 self._prev_lives = lives.copy() return obs, rewards, terminated, truncated, info diff --git a/slm_lab/experiment/control.py b/slm_lab/experiment/control.py index a8949d5db..59d216704 100644 --- a/slm_lab/experiment/control.py +++ b/slm_lab/experiment/control.py @@ -4,7 +4,6 @@ import gymnasium as gym import numpy as np -import pydash as ps import torch import torch.multiprocessing as mp @@ -15,6 +14,7 @@ from slm_lab.lib import logger, util from slm_lab.lib.env_var import lab_mode from slm_lab.lib.perf import log_perf_setup, optimize +from slm_lab.lib.torch_profiler import torch_profiler_context from slm_lab.spec import spec_util @@ -51,7 +51,8 @@ def __init__(self, spec: dict, global_nets=None): self.perf_setup = optimize() self.agent, self.env = make_agent_env(self.spec, global_nets) - if ps.get(self.spec, "meta.rigorous_eval"): + self._rigorous_eval = self.spec.get("meta", {}).get("rigorous_eval", False) + if self._rigorous_eval: with util.ctx_lab_mode("eval"): self.eval_env = make_env(self.spec) else: @@ -95,7 +96,7 @@ def try_ckpt(self, agent: Agent, env: gym.Env): if self.index == 0: analysis.analyze_trial(self.spec) - if ps.get(self.spec, "meta.rigorous_eval") and self.to_ckpt(env, "eval"): + if self._rigorous_eval and self.to_ckpt(env, "eval"): logger.info("Running eval ckpt") analysis.gen_avg_return(agent, self.eval_env) mt.ckpt(self.eval_env, "eval") @@ -105,28 +106,31 @@ def try_ckpt(self, agent: Agent, env: gym.Env): def run_rl(self): """Run the main RL loop until clock.max_frame""" state, info = self.env.reset() - - while self.env.get() < self.env.max_frame: - with torch.no_grad(): - action = self.agent.act(state) - next_state, reward, terminated, truncated, info = self.env.step(action) - - done = np.logical_or(terminated, truncated) - self.agent.update( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - terminated=terminated, - truncated=truncated - ) - self.try_ckpt(self.agent, self.env) - - if util.epi_done(done): - state, info = self.env.reset() - else: - state = next_state + is_venv = self.env.is_venv + + with torch_profiler_context() as prof_step: + while self.env.get() < self.env.max_frame: + action = self.agent.act(state) # Agent.act() already uses torch.no_grad() + next_state, reward, terminated, truncated, info = self.env.step(action) + + done = terminated | truncated # numpy bitwise-or, same as logical_or for bool arrays + self.agent.update( + state=state, + action=action, + reward=reward, + next_state=next_state, + done=done, + terminated=terminated, + truncated=truncated + ) + self.try_ckpt(self.agent, self.env) + + if not is_venv and done: + state, info = self.env.reset() + else: + state = next_state + + prof_step() def close(self): """Close session and clean up. Save agent, close env.""" diff --git a/slm_lab/lib/math_util.py b/slm_lab/lib/math_util.py index d9c0eebe3..c68eb3fa5 100644 --- a/slm_lab/lib/math_util.py +++ b/slm_lab/lib/math_util.py @@ -5,13 +5,24 @@ # general math methods + +def symlog(x): + """Symmetric logarithmic compression. DreamerV3 (Hafner et al., 2023).""" + return torch.sign(x) * torch.log1p(torch.abs(x)) + + +def symexp(x): + """Inverse of symlog.""" + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + def center_mean(v): - '''Center an array by its mean''' + """Center an array by its mean""" return v - v.mean() def normalize(v): - '''Method to normalize a rank-1 np array''' + """Method to normalize a rank-1 np array""" v_min = v.min() v_max = v.max() v_range = v_max - v_min @@ -21,19 +32,19 @@ def normalize(v): def standardize(v): - '''Method to standardize a rank-1 np array''' - assert len(v) > 1, 'Cannot standardize vector of size 1' + """Method to standardize a rank-1 np array""" + assert len(v) > 1, "Cannot standardize vector of size 1" v_std = (v - v.mean()) / (v.std() + 1e-08) return v_std def to_one_hot(data, max_val): - '''Convert an int list of data into one-hot vectors''' + """Convert an int list of data into one-hot vectors""" return np.eye(max_val)[np.array(data)] def venv_pack(batch_tensor, num_envs): - '''Apply the reverse of venv_unpack to pack a batch tensor from (b*num_envs, *shape) to (b, num_envs, *shape)''' + """Apply the reverse of venv_unpack to pack a batch tensor from (b*num_envs, *shape) to (b, num_envs, *shape)""" shape = list(batch_tensor.shape) if len(shape) < 2: # scalar data (b, num_envs,) return batch_tensor.view(-1, num_envs) @@ -43,11 +54,11 @@ def venv_pack(batch_tensor, num_envs): def venv_unpack(batch_tensor): - ''' + """ Unpack a sampled vec env batch tensor e.g. for a state with original shape (4, ), vec env should return vec state with shape (num_envs, 4) to store in memory When sampled with batch_size b, we should get shape (b, num_envs, 4). But we need to unpack the num_envs dimension to get (b * num_envs, 4) for passing to a network. This method does that. - ''' + """ shape = list(batch_tensor.shape) if len(shape) < 3: # scalar data (b, num_envs,) return batch_tensor.view(-1) @@ -59,14 +70,15 @@ def venv_unpack(batch_tensor): # Policy Gradient calc # advantage functions + def calc_returns(rewards, terminateds, gamma): - ''' + """ Calculate the simple returns (full rollout) i.e. sum discounted rewards up till termination IMPORTANT: Use 'terminateds' not 'dones' for correct return calculation. When truncated (time limit), we should bootstrap from V(next_state), not zero it. Only zero out future returns on true episode termination. - ''' + """ T = len(rewards) rets = torch.zeros_like(rewards) future_ret = torch.tensor(0.0, dtype=rewards.dtype) @@ -77,7 +89,7 @@ def calc_returns(rewards, terminateds, gamma): def calc_nstep_returns(rewards, terminateds, next_v_pred, gamma, n): - ''' + """ Estimate the advantages using n-step returns. Ref: http://www-anw.cs.umass.edu/~barto/courses/cs687/Chapter%207.pdf Also see Algorithm S3 from A3C paper https://arxiv.org/pdf/1602.01783.pdf for the calculation used below R^(n)_t = r_{t} + gamma r_{t+1} + ... + gamma^(n-1) r_{t+n-1} + gamma^(n) V(s_{t+n}) @@ -85,7 +97,7 @@ def calc_nstep_returns(rewards, terminateds, next_v_pred, gamma, n): IMPORTANT: Use 'terminateds' not 'dones' for correct n-step return calculation. When truncated (time limit), we should bootstrap from V(next_state), not zero it. Only zero out future returns on true episode termination. - ''' + """ rets = torch.zeros_like(rewards) future_ret = next_v_pred not_terminateds = 1 - terminateds @@ -95,7 +107,7 @@ def calc_nstep_returns(rewards, terminateds, next_v_pred, gamma, n): def calc_gaes(rewards, terminateds, v_preds, gamma, lam): - ''' + """ Estimate the advantages using GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf v_preds are values predicted for current states, with one last element as the final next_state delta is defined as r + gamma * V(s') - V(s) in eqn 10 @@ -107,9 +119,11 @@ def calc_gaes(rewards, terminateds, v_preds, gamma, lam): This method computes in torch tensor to prevent unnecessary moves between devices (e.g. GPU tensor to CPU numpy) NOTE any standardization is done outside of this method - ''' + """ T = len(rewards) - assert T + 1 == len(v_preds), f'T+1: {T+1} v.s. v_preds.shape: {v_preds.shape}' # v_preds runs into t+1 + assert T + 1 == len(v_preds), ( + f"T+1: {T + 1} v.s. v_preds.shape: {v_preds.shape}" + ) # v_preds runs into t+1 gaes = torch.zeros_like(rewards) future_gae = torch.tensor(0.0, dtype=rewards.dtype) not_terminateds = 1 - terminateds # only reset on true termination, not truncation @@ -127,13 +141,14 @@ def calc_q_value_logits(state_value, raw_advantages): # generic variable decay methods + def no_decay(start_val, end_val, start_step, end_step, step): - '''dummy method for API consistency''' + """dummy method for API consistency""" return start_val def linear_decay(start_val, end_val, start_step, end_step, step): - '''Simple linear decay with annealing''' + """Simple linear decay with annealing""" if step < start_step: return start_val slope = (end_val - start_val) / (end_step - start_step) @@ -141,8 +156,10 @@ def linear_decay(start_val, end_val, start_step, end_step, step): return val -def rate_decay(start_val, end_val, start_step, end_step, step, decay_rate=0.9, frequency=20.): - '''Compounding rate decay that anneals in 20 decay iterations until end_step''' +def rate_decay( + start_val, end_val, start_step, end_step, step, decay_rate=0.9, frequency=20.0 +): + """Compounding rate decay that anneals in 20 decay iterations until end_step""" if step < start_step: return start_val if step >= end_step: @@ -153,13 +170,13 @@ def rate_decay(start_val, end_val, start_step, end_step, step, decay_rate=0.9, f return val -def periodic_decay(start_val, end_val, start_step, end_step, step, frequency=60.): - ''' +def periodic_decay(start_val, end_val, start_step, end_step, step, frequency=60.0): + """ Linearly decaying sinusoid that decays in roughly 10 iterations until explore_anneal_epi Plot the equation below to see the pattern suppose sinusoidal decay, start_val = 1, end_val = 0.2, stop after 60 unscaled x steps then we get 0.2+0.5*(1-0.2)(1 + cos x)*(1-x/60) - ''' + """ if step < start_step: return start_val if step >= end_step: diff --git a/slm_lab/lib/ml_util.py b/slm_lab/lib/ml_util.py index ca4cf3bf5..c5454781d 100644 --- a/slm_lab/lib/ml_util.py +++ b/slm_lab/lib/ml_util.py @@ -160,12 +160,16 @@ def to_torch_batch(batch, device, is_episodic): for k in batch: if is_episodic: # for episodic format batch[k] = np.concatenate(batch[k]) - elif ps.is_list(batch[k]): + elif isinstance(batch[k], list): batch[k] = np.array(batch[k]) - # Optimize tensor creation - direct device placement avoids intermediate CPU tensor - if batch[k].dtype != np.float32: - batch[k] = batch[k].astype(np.float32) - batch[k] = torch.from_numpy(batch[k]).to(device, non_blocking=True) + arr = batch[k] + if not arr.flags['C_CONTIGUOUS']: + arr = np.ascontiguousarray(arr) + if arr.dtype == np.float32: + batch[k] = torch.from_numpy(arr).to(device, non_blocking=True) + else: + # For uint8/float16: send to device, cast to float32 on GPU + batch[k] = torch.from_numpy(arr).to(device, non_blocking=True).float() return batch diff --git a/slm_lab/lib/torch_profiler.py b/slm_lab/lib/torch_profiler.py new file mode 100644 index 000000000..fbfb941c9 --- /dev/null +++ b/slm_lab/lib/torch_profiler.py @@ -0,0 +1,74 @@ +"""torch.profiler Kineto trace collection for HTA analysis. + +Collects GPU/CPU traces when --profile is enabled. Traces are saved as +gzipped TensorBoard files compatible with Holistic Trace Analysis (HTA). +""" + +import os +from contextlib import contextmanager +from pathlib import Path + +import torch +from torch.profiler import ProfilerActivity, schedule, tensorboard_trace_handler + +from slm_lab.lib import logger +from slm_lab.lib.env_var import profile + +logger = logger.get_logger(__name__) + + +def _get_trace_dir() -> Path: + """Resolve trace output directory from LOG_PREPATH.""" + log_prepath = os.environ.get("LOG_PREPATH", "data/profiler") + return Path(log_prepath).parent.parent / "traces" + + +def create_torch_profiler() -> torch.profiler.profile | None: + """Create a torch.profiler.profile instance for Kineto trace collection. + + Returns None if profiling is not enabled. + + Schedule is configurable via env vars (useful for different algorithms): + PROF_SKIP: steps to skip before profiling (default 500) + PROF_ACTIVE: steps to actively record (default 20) + PPO needs high skip (trains every time_horizon ~2048 steps). + SAC/CrossQ need moderate skip (train every step after training_start_step). + """ + if not profile(): + return None + + trace_dir = _get_trace_dir() + trace_dir.mkdir(parents=True, exist_ok=True) + + skip_first = int(os.environ.get("PROF_SKIP", "500")) + active = int(os.environ.get("PROF_ACTIVE", "20")) + logger.info(f"Torch profiler traces: {trace_dir} (skip={skip_first}, active={active})") + + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + return torch.profiler.profile( + activities=activities, + schedule=schedule(skip_first=skip_first, wait=5, warmup=2, active=active, repeat=1), + on_trace_ready=tensorboard_trace_handler(str(trace_dir), use_gzip=True), + profile_memory=True, + record_shapes=True, + with_stack=True, + ) + + +@contextmanager +def torch_profiler_context(): + """Context manager that wraps the RL loop with torch.profiler. + + Yields a step callback to call after each iteration. + If profiling is disabled, yields a no-op callable. + """ + prof = create_torch_profiler() + if prof is None: + yield lambda: None + return + + with prof: + yield prof.step diff --git a/slm_lab/lib/util.py b/slm_lab/lib/util.py index c8237a17c..05d194ad2 100644 --- a/slm_lab/lib/util.py +++ b/slm_lab/lib/util.py @@ -92,8 +92,8 @@ def cast_df(val): def cast_list(val): - """missing pydash method to cast value as list""" - if ps.is_list(val): + """Cast value as list if not already""" + if isinstance(val, list): return val else: return [val] @@ -116,17 +116,17 @@ def frame_mod(frame, frequency, num_envs): def flatten_dict(obj, delim="."): - """Missing pydash method to flatten dict""" + """Flatten a nested dict with delimited keys""" nobj = {} for key, val in obj.items(): - if ps.is_dict(val) and not ps.is_empty(val): + if isinstance(val, dict) and val: strip = flatten_dict(val, delim) for k, v in strip.items(): nobj[key + delim + k] = v - elif ps.is_list(val) and not ps.is_empty(val) and ps.is_dict(val[0]): + elif isinstance(val, list) and val and isinstance(val[0], dict): for idx, v in enumerate(val): nobj[key + delim + str(idx)] = v - if ps.is_object(v): + if isinstance(v, dict): nobj = flatten_dict(nobj, delim) else: nobj[key] = val @@ -362,7 +362,7 @@ def log_dict(data: dict, title: str = None): {k: v}, default_flow_style=False, indent=2, sort_keys=False ).rstrip() lines.append(yaml_str) - elif v is not None and not ps.reg_exp_js_match(str(v), "/<.+>/"): + elif v is not None and not (isinstance(v, str) and v.startswith("<") and v.endswith(">")): lines.append(f"{k}: {v}") logger.info("\n".join(lines)) @@ -377,14 +377,16 @@ def log_self_desc(cls, omit=None): # Fallback for minimal install (no torch) obj_dict = {k: str(v) for k, v in cls.__dict__.items() if not k.startswith("_")} if omit: - obj_dict = ps.omit(obj_dict, omit) + obj_dict = {k: v for k, v in obj_dict.items() if k not in omit} log_dict(obj_dict, get_class_name(cls)) def set_attr(obj, attr_dict, keys=None): """Set attribute of an object from a dict""" + if attr_dict is None: + return obj if keys is not None: - attr_dict = ps.pick(attr_dict, keys) + attr_dict = {k: attr_dict[k] for k in keys if k in attr_dict} for attr, val in attr_dict.items(): setattr(obj, attr, val) return obj diff --git a/slm_lab/spec/benchmark/crossq/crossq_atari.yaml b/slm_lab/spec/benchmark/crossq/crossq_atari.yaml new file mode 100644 index 000000000..1894e1194 --- /dev/null +++ b/slm_lab/spec/benchmark/crossq/crossq_atari.yaml @@ -0,0 +1,132 @@ +# CrossQ Atari — SAC without target networks + Batch Renormalization in critics +# v14: iter=1 + slow alpha_lr=3e-5 for stability. v12 iter=1 had ~320fps but Pong/Breakout/Seaquest +# diverged due to fast alpha dynamics. Slow alpha_lr (from v9c) prevents divergence without +# needing extra gradient steps. Keeps v12 speed advantage over SAC. + +# Actor ConvNet: no batch normalization +_conv_actor: &conv_actor + modules: + body: + Sequential: + - LazyConv2d: + out_channels: 32 + kernel_size: 8 + stride: 4 + - ReLU: + - LazyConv2d: + out_channels: 64 + kernel_size: 4 + stride: 2 + - ReLU: + - LazyConv2d: + out_channels: 64 + kernel_size: 3 + stride: 1 + - ReLU: + - Flatten: + - LazyLinear: + out_features: 512 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# Critic ConvNet: wider FC (1024) + BRN to compensate for UTD=1 (fewer gradient steps) +_conv_critic: &conv_critic + modules: + body: + Sequential: + - LazyConv2d: + out_channels: 32 + kernel_size: 8 + stride: 4 + - ReLU: + - LazyConv2d: + out_channels: 64 + kernel_size: 4 + stride: 2 + - ReLU: + - LazyConv2d: + out_channels: 64 + kernel_size: 3 + stride: 1 + - ReLU: + - Flatten: + - LazyLinear: + out_features: 1024 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 5000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +crossq_atari: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_start_step: 1000 + training_frequency: 4 + training_iter: 1 + policy_delay: 3 + alpha_lr: 3.0e-5 + log_alpha_max: 1.0 + symlog: false + spectral_norm: false + memory: + name: Replay + batch_size: 256 + max_size: 200000 + use_cer: false + net: + type: TorchArcNet + arc: *conv_actor + shared: false + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: 0.5 + use_same_optim: false + loss_spec: + name: SmoothL1Loss + optim_spec: + name: Adam + lr: 1.0e-3 + normalize: true + gpu: auto + critic_net: + type: TorchArcNet + arc: *conv_critic + shared: false + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: 0.5 + use_same_optim: false + loss_spec: + name: SmoothL1Loss + optim_spec: + name: Adam + lr: 1.0e-3 + normalize: true + gpu: auto + env: + name: ${env} + num_envs: 16 + max_t: null + max_frame: 2000000 + life_loss_info: true + meta: + distributed: false + eval_frequency: 10000 + log_frequency: 10000 + max_session: 4 + max_trial: 1 diff --git a/slm_lab/spec/benchmark/crossq/crossq_box2d.yaml b/slm_lab/spec/benchmark/crossq/crossq_box2d.yaml new file mode 100644 index 000000000..3a5654c04 --- /dev/null +++ b/slm_lab/spec/benchmark/crossq/crossq_box2d.yaml @@ -0,0 +1,126 @@ +# CrossQ Box2D — SAC without target networks + Batch Renormalization in critics +# Plain [256,256] actor, [256,256]+BRN critics, lr=3e-4 + +_actor: &actor + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - ReLU: + - LazyLinear: + out_features: 256 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +_critic_brn: &critic_brn + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 5000 + - ReLU: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 5000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +_net: &net + type: TorchArcNet + arc: *actor + hid_layers_activation: relu + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 3.0e-4 + gpu: auto + +_critic_net: &critic_net + type: TorchArcNet + arc: *critic_brn + hid_layers_activation: relu + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 3.0e-4 + gpu: auto + +_memory: &memory + name: Replay + batch_size: 256 + max_size: 100000 + use_cer: false + +_meta: &meta + distributed: false + log_frequency: 1000 + eval_frequency: 5000 + max_session: 4 + max_trial: 1 + +crossq_lunar: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_start_step: 1000 + policy_delay: 3 + log_alpha_max: 2.0 + memory: *memory + net: *net + critic_net: *critic_net + env: + name: LunarLander-v3 + num_envs: 8 + max_t: null + max_frame: 300000 + meta: *meta + +crossq_lunar_continuous: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Normal + action_policy: default + gamma: 0.994 + training_iter: 4 + training_frequency: 1 + training_start_step: 256 + policy_delay: 3 + log_alpha_max: 0.5 + memory: + <<: *memory + max_size: 1000000 + net: *net + critic_net: *critic_net + env: + name: LunarLanderContinuous-v3 + num_envs: 8 + max_t: null + max_frame: 300000 + meta: *meta diff --git a/slm_lab/spec/benchmark/crossq/crossq_classic.yaml b/slm_lab/spec/benchmark/crossq/crossq_classic.yaml new file mode 100644 index 000000000..17212daeb --- /dev/null +++ b/slm_lab/spec/benchmark/crossq/crossq_classic.yaml @@ -0,0 +1,149 @@ +# CrossQ Classic Control — SAC without target networks + Batch Renormalization in critics +# Plain [256,256] actor, [256,256]+BRN critics, lr=3e-4 + +_actor: &actor + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - ReLU: + - LazyLinear: + out_features: 256 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +_critic_brn: &critic_brn + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 5000 + - ReLU: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 5000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +_net: &net + type: TorchArcNet + arc: *actor + hid_layers_activation: relu + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 3.0e-4 + gpu: auto + +_critic_net: &critic_net + type: TorchArcNet + arc: *critic_brn + hid_layers_activation: relu + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 3.0e-4 + gpu: auto + +_memory: &memory + name: Replay + batch_size: 256 + max_size: 100000 + use_cer: false + +_meta: &meta + distributed: false + log_frequency: 500 + eval_frequency: 500 + max_session: 4 + max_trial: 1 + +crossq_cartpole: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_iter: 2 + training_start_step: 1000 + policy_delay: 3 + log_alpha_max: 2.0 + memory: *memory + net: *net + critic_net: *critic_net + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 200000 + meta: *meta + +crossq_acrobot: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_start_step: 1000 + policy_delay: 3 + log_alpha_max: 2.0 + memory: *memory + net: *net + critic_net: *critic_net + env: + name: Acrobot-v1 + num_envs: 4 + max_t: null + max_frame: 300000 + meta: *meta + +crossq_pendulum: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Normal + action_policy: default + gamma: 0.99 + training_iter: 4 + training_frequency: 1 + training_start_step: 256 + policy_delay: 3 + log_alpha_max: 0.5 + memory: *memory + net: *net + critic_net: *critic_net + env: + name: Pendulum-v1 + num_envs: 4 + max_t: null + max_frame: 300000 + meta: + <<: *meta + eval_frequency: 5000 diff --git a/slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml b/slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml new file mode 100644 index 000000000..8d3ab7294 --- /dev/null +++ b/slm_lab/spec/benchmark/crossq/crossq_mujoco.yaml @@ -0,0 +1,401 @@ +# CrossQ MuJoCo — SAC without target networks + Batch Renormalization in critics +# Plain [256,256] actor (no BN), critic width scales with env difficulty +# UTD=1 (training_iter=1) — the key CrossQ advantage over SAC +# Critic width formula: W = 256 * sqrt(SAC_training_iter) + +# --- Actor: plain MLP, no batch normalization --- + +_actor: &actor + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - ReLU: + - LazyLinear: + out_features: 256 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# --- Actor with LayerNorm: for harder envs --- + +_actor_ln: &actor_ln + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - LayerNorm: + normalized_shape: [256] + - ReLU: + - LazyLinear: + out_features: 256 + - LayerNorm: + normalized_shape: [256] + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# --- Critics: width scales with env difficulty --- + +# [256,256]+BRN for easy envs (SAC iter<=2, sqrt(2)~1.4 -> stay at 256) +_critic_256_brn: &critic_256_brn + modules: + body: + Sequential: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + - LazyLinear: + out_features: 256 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# [512,512]+BRN for hard envs (SAC iter=4, sqrt(4)=2 -> 256*2=512) +_critic_512_brn: &critic_512_brn + modules: + body: + Sequential: + - LazyLinear: + out_features: 512 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + - LazyLinear: + out_features: 512 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# [1024,1024]+BRN for very hard envs (SAC iter=16, sqrt(16)=4 -> 256*4=1024) +_critic_1024_brn: &critic_1024_brn + modules: + body: + Sequential: + - LazyLinear: + out_features: 1024 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + - LazyLinear: + out_features: 1024 + - LazyBatchRenorm1d: + momentum: 0.01 + eps: 0.001 + warmup_steps: 100000 + - ReLU: + graph: + input: x + modules: + body: [x] + output: body + +# --- Net specs --- + +_net: &net + type: TorchArcNet + arc: *actor + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 1.0e-3 + betas: [0.5, 0.999] + gpu: auto + +_critic_net_256: &critic_net_256 + type: TorchArcNet + arc: *critic_256_brn + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 1.0e-3 + betas: [0.5, 0.999] + gpu: auto + +_critic_net_512: &critic_net_512 + type: TorchArcNet + arc: *critic_512_brn + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 1.0e-3 + betas: [0.5, 0.999] + gpu: auto + +_critic_net_1024: &critic_net_1024 + type: TorchArcNet + arc: *critic_1024_brn + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 1.0e-3 + betas: [0.5, 0.999] + gpu: auto + +# --- Algorithm and memory --- + +_algorithm: &algorithm + name: CrossQ + action_pdtype: default + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_iter: 1 + training_start_step: 5000 + policy_delay: 3 + log_alpha_max: 0.5 + +_memory: &memory + name: Replay + batch_size: 256 + max_size: 1000000 + use_cer: false + +_agent: &agent + name: CrossQ + algorithm: *algorithm + memory: *memory + net: *net + critic_net: *critic_net_256 + +_meta: &meta + distributed: false + log_frequency: 10000 + eval_frequency: 10000 + max_session: 4 + max_trial: 1 + +# --- Generic spec (use with -s env=X -s max_frame=Y) --- + +crossq_mujoco: + agent: *agent + env: + name: ${env} + num_envs: 16 + max_t: null + max_frame: ${max_frame} + meta: *meta + +# --- Easy envs: [256,256] critics, plain actor --- + +crossq_inverted_pendulum: + agent: + <<: *agent + critic_net: *critic_net_512 + env: + name: InvertedPendulum-v5 + num_envs: 16 + max_t: null + max_frame: 7000000 + meta: + <<: *meta + eval_frequency: 1000 + +crossq_swimmer: + agent: + <<: *agent + algorithm: + <<: *algorithm + gamma: 0.9999 + env: + name: Swimmer-v5 + num_envs: 16 + max_t: null + max_frame: 3000000 + meta: + <<: *meta + eval_frequency: 1000 + +crossq_reacher: + agent: *agent + env: + name: Reacher-v5 + num_envs: 16 + max_t: null + max_frame: 2000000 + meta: + <<: *meta + eval_frequency: 1000 + +crossq_pusher: + agent: *agent + env: + name: Pusher-v5 + num_envs: 16 + max_t: null + max_frame: 2000000 + meta: + <<: *meta + eval_frequency: 1000 + +# --- Hard envs: [512,512] critics, LN actor --- + +crossq_halfcheetah: + agent: + <<: *agent + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_512 + env: + name: HalfCheetah-v5 + num_envs: 16 + max_t: null + max_frame: 4000000 + meta: *meta + +crossq_hopper: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_start_step: 10000 + critic_net: *critic_net_512 + env: + name: Hopper-v5 + num_envs: 16 + max_t: null + max_frame: 3000000 + meta: *meta + +crossq_walker2d: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_start_step: 10000 + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_512 + env: + name: Walker2d-v5 + num_envs: 16 + max_t: null + max_frame: 7000000 + meta: *meta + +crossq_ant: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_start_step: 10000 + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_512 + env: + name: Ant-v5 + num_envs: 16 + max_t: null + max_frame: 3000000 + meta: *meta + +crossq_inverted_double_pendulum: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_iter: 2 + training_start_step: 5000 + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_512 + env: + name: InvertedDoublePendulum-v5 + num_envs: 16 + max_t: null + max_frame: 2000000 + meta: + <<: *meta + eval_frequency: 1000 + +# --- Very hard envs: [1024,1024] critics, LN actor, iter=2 --- + +crossq_humanoid: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_iter: 4 + training_start_step: 10000 + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_1024 + env: + name: Humanoid-v5 + num_envs: 16 + max_t: null + max_frame: 2000000 + meta: + <<: *meta + eval_frequency: 1000 + +crossq_humanoid_standup: + agent: + <<: *agent + algorithm: + <<: *algorithm + training_iter: 2 + training_start_step: 10000 + net: + <<: *net + arc: *actor_ln + critic_net: *critic_net_1024 + env: + name: HumanoidStandup-v5 + num_envs: 16 + max_t: null + max_frame: 2000000 + meta: + <<: *meta + eval_frequency: 1000 diff --git a/slm_lab/spec/benchmark_arc/dqn/dqn_box2d_arc.yaml b/slm_lab/spec/benchmark_arc/dqn/dqn_box2d_arc.yaml index 80bcf3f56..a1a0dea0f 100644 --- a/slm_lab/spec/benchmark_arc/dqn/dqn_box2d_arc.yaml +++ b/slm_lab/spec/benchmark_arc/dqn/dqn_box2d_arc.yaml @@ -170,6 +170,7 @@ ddqn_per_mountaincar_arc: gpu: auto env: name: MountainCar-v0 + num_envs: 4 max_t: null max_frame: 500000 meta: diff --git a/slm_lab/spec/benchmark_arc/dqn/dqn_classic_arc.yaml b/slm_lab/spec/benchmark_arc/dqn/dqn_classic_arc.yaml index 509fc8870..2fc776ac2 100644 --- a/slm_lab/spec/benchmark_arc/dqn/dqn_classic_arc.yaml +++ b/slm_lab/spec/benchmark_arc/dqn/dqn_classic_arc.yaml @@ -421,6 +421,7 @@ dqn_mountaincar_arc: <<: *dqn_net_128x2_relu env: name: MountainCar-v0 + num_envs: 4 max_t: null max_frame: 500000 meta: diff --git a/slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_v2_arc.yaml b/slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_v2_arc.yaml new file mode 100644 index 000000000..ad226aa8c --- /dev/null +++ b/slm_lab/spec/benchmark_arc/ppo/ppo_mujoco_v2_arc.yaml @@ -0,0 +1,64 @@ +# PPO v2 MuJoCo with normalization stack (BRO + DreamerV3-inspired) +# percentile advantage normalization, layer norm + +ppo_mujoco_v2_arc: + agent: + name: PPO + algorithm: + name: PPO + action_pdtype: default + action_policy: default + explore_var_spec: null + gamma: 0.99 + lam: 0.95 + clip_eps_spec: + name: no_decay + start_val: 0.2 + end_val: 0.2 + start_step: 0 + end_step: 0 + entropy_coef_spec: + name: no_decay + start_val: 0.0 + end_val: 0.0 + start_step: 0 + end_step: 0 + val_loss_coef: 0.5 + time_horizon: 2048 + minibatch_size: 64 + training_epoch: 10 + normalize_v_targets: true + normalize_advantages: percentile + memory: + name: OnPolicyBatchReplay + net: + type: MLPNet + shared: false + hid_layers: [256, 256] + hid_layers_activation: tanh + init_fn: orthogonal_ + normalize: true + clip_grad_val: 0.5 + use_same_optim: true + log_std_init: 0.0 + layer_norm: true + loss_spec: + name: MSELoss + optim_spec: + name: AdamW + lr: 3e-4 + gpu: auto + env: + name: ${env} + num_envs: 16 + max_t: null + max_frame: ${max_frame} + normalize_obs: true + normalize_reward: true + meta: + distributed: false + log_frequency: 10000 + eval_frequency: 10000 + rigorous_eval: 0 + max_session: 4 + max_trial: 1 diff --git a/slm_lab/spec/benchmark_arc/sac/sac_mujoco_v2_arc.yaml b/slm_lab/spec/benchmark_arc/sac/sac_mujoco_v2_arc.yaml new file mode 100644 index 000000000..3a72b0fc0 --- /dev/null +++ b/slm_lab/spec/benchmark_arc/sac/sac_mujoco_v2_arc.yaml @@ -0,0 +1,47 @@ +# SAC v2 MuJoCo with normalization stack (BRO) +# layer norm on critic + +sac_mujoco_v2_arc: + agent: + name: SoftActorCritic + algorithm: + name: SoftActorCritic + action_pdtype: default + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_iter: 4 + training_start_step: 5000 + memory: + name: Replay + batch_size: 256 + max_size: 1000000 + use_cer: false + net: + type: MLPNet + hid_layers: [256, 256] + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + layer_norm: true + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 3e-4 + update_type: polyak + update_frequency: 1 + polyak_coef: 0.005 + gpu: auto + env: + name: ${env} + num_envs: 16 + max_t: null + max_frame: ${max_frame} + meta: + distributed: false + log_frequency: 10000 + eval_frequency: 10000 + rigorous_eval: 0 + max_session: 4 + max_trial: 1 diff --git a/slm_lab/spec/experimental/roadmap/crossq_pendulum.yaml b/slm_lab/spec/experimental/roadmap/crossq_pendulum.yaml new file mode 100644 index 000000000..c2aacb0f6 --- /dev/null +++ b/slm_lab/spec/experimental/roadmap/crossq_pendulum.yaml @@ -0,0 +1,76 @@ +# CrossQ Pendulum verification (continuous actions) + +_pendulum_meta: &pendulum_meta + distributed: false + log_frequency: 500 + eval_frequency: 500 + max_session: 1 + max_trial: 1 + +_pendulum_net: &pendulum_net + type: MLPNet + hid_layers: [64, 64] + hid_layers_activation: relu + init_fn: xavier_uniform_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.001 + gpu: auto + +# 1. SAC baseline on Pendulum +sac_pendulum_baseline: + agent: + name: SoftActorCritic + algorithm: + name: SoftActorCritic + action_pdtype: default + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_start_step: 500 + memory: + name: Replay + batch_size: 64 + max_size: 50000 + use_cer: false + net: + <<: *pendulum_net + update_type: polyak + polyak_coef: 0.005 + log_std_init: 0.0 + env: + name: Pendulum-v1 + num_envs: 1 + max_t: null + max_frame: 50000 + meta: *pendulum_meta + +# 2. CrossQ on Pendulum (continuous) +crossq_pendulum: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: default + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_iter: 1 + training_start_step: 500 + memory: + name: Replay + batch_size: 64 + max_size: 50000 + use_cer: false + net: + <<: *pendulum_net + log_std_init: 0.0 + env: + name: Pendulum-v1 + num_envs: 1 + max_t: null + max_frame: 50000 + meta: *pendulum_meta diff --git a/slm_lab/spec/experimental/roadmap/ppo_cartpole_features.yaml b/slm_lab/spec/experimental/roadmap/ppo_cartpole_features.yaml new file mode 100644 index 000000000..39d06061a --- /dev/null +++ b/slm_lab/spec/experimental/roadmap/ppo_cartpole_features.yaml @@ -0,0 +1,113 @@ +# PPO CartPole feature ablation specs (quick smoke test) + +_ppo_base_algorithm: &ppo_base_algorithm + name: PPO + action_pdtype: default + action_policy: default + gamma: 0.99 + lam: 0.95 + clip_eps_spec: + name: no_decay + start_val: 0.2 + end_val: 0.2 + start_step: 0 + end_step: 0 + entropy_coef_spec: + name: no_decay + start_val: 0.01 + end_val: 0.01 + start_step: 0 + end_step: 0 + val_loss_coef: 0.5 + time_horizon: 128 + minibatch_size: 64 + training_epoch: 4 + +_ppo_base_net: &ppo_base_net + type: MLPNet + shared: false + hid_layers: [64, 64] + hid_layers_activation: tanh + clip_grad_val: 0.5 + use_same_optim: true + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.0003 + gpu: auto + +_ppo_base_meta: &ppo_base_meta + distributed: false + log_frequency: 500 + eval_frequency: 500 + max_session: 4 + max_trial: 1 + +# 1. Baseline +ppo_cartpole_baseline: + agent: + name: PPO + algorithm: *ppo_base_algorithm + memory: + name: OnPolicyBatchReplay + net: *ppo_base_net + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *ppo_base_meta + +# 2. With layer_norm +ppo_cartpole_layernorm: + agent: + name: PPO + algorithm: *ppo_base_algorithm + memory: + name: OnPolicyBatchReplay + net: + <<: *ppo_base_net + layer_norm: true + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *ppo_base_meta + +# 3. With percentile normalization +ppo_cartpole_percentile: + agent: + name: PPO + algorithm: + <<: *ppo_base_algorithm + normalize_advantages: percentile + memory: + name: OnPolicyBatchReplay + net: *ppo_base_net + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *ppo_base_meta + +# 4. v2 stack (layer_norm + percentile) +ppo_cartpole_v2: + agent: + name: PPO + algorithm: + <<: *ppo_base_algorithm + normalize_advantages: percentile + memory: + name: OnPolicyBatchReplay + net: + <<: *ppo_base_net + layer_norm: true + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *ppo_base_meta diff --git a/slm_lab/spec/experimental/roadmap/ppo_humanoid_features.yaml b/slm_lab/spec/experimental/roadmap/ppo_humanoid_features.yaml new file mode 100644 index 000000000..7d62997ad --- /dev/null +++ b/slm_lab/spec/experimental/roadmap/ppo_humanoid_features.yaml @@ -0,0 +1,125 @@ +# PPO Humanoid feature ablation specs (showcase — features help here) + +_ppo_base_algorithm: &ppo_base_algorithm + name: PPO + action_pdtype: default + action_policy: default + gamma: 0.997 + lam: 0.97 + clip_eps_spec: + name: no_decay + start_val: 0.2 + end_val: 0.2 + start_step: 0 + end_step: 0 + entropy_coef_spec: + name: no_decay + start_val: 0.001 + end_val: 0.001 + start_step: 0 + end_step: 0 + val_loss_coef: 0.5 + time_horizon: 512 + minibatch_size: 64 + training_epoch: 10 + normalize_v_targets: true + +_ppo_base_net: &ppo_base_net + type: MLPNet + shared: false + hid_layers: [256, 256] + hid_layers_activation: tanh + init_fn: orthogonal_ + normalize: true + clip_grad_val: 0.5 + use_same_optim: true + log_std_init: 0.0 + loss_spec: + name: MSELoss + optim_spec: + name: AdamW + lr: 0.0002 + gpu: auto + +_ppo_base_meta: &ppo_base_meta + distributed: false + log_frequency: 5000 + eval_frequency: 5000 + max_session: 4 + max_trial: 1 + +# 1. Baseline +ppo_humanoid_baseline: + agent: + name: PPO + algorithm: *ppo_base_algorithm + memory: + name: OnPolicyBatchReplay + net: *ppo_base_net + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + normalize_obs: true + normalize_reward: true + meta: *ppo_base_meta + +# 2. With layer_norm +ppo_humanoid_layernorm: + agent: + name: PPO + algorithm: *ppo_base_algorithm + memory: + name: OnPolicyBatchReplay + net: + <<: *ppo_base_net + layer_norm: true + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + normalize_obs: true + normalize_reward: true + meta: *ppo_base_meta + +# 3. With percentile normalization +ppo_humanoid_percentile: + agent: + name: PPO + algorithm: + <<: *ppo_base_algorithm + normalize_advantages: percentile + memory: + name: OnPolicyBatchReplay + net: *ppo_base_net + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + normalize_obs: true + normalize_reward: true + meta: *ppo_base_meta + +# 4. v2 stack (layer_norm + percentile) +ppo_humanoid_v2: + agent: + name: PPO + algorithm: + <<: *ppo_base_algorithm + normalize_advantages: percentile + memory: + name: OnPolicyBatchReplay + net: + <<: *ppo_base_net + layer_norm: true + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + normalize_obs: true + normalize_reward: true + meta: *ppo_base_meta diff --git a/slm_lab/spec/experimental/roadmap/sac_crossq_cartpole.yaml b/slm_lab/spec/experimental/roadmap/sac_crossq_cartpole.yaml new file mode 100644 index 000000000..877061f1e --- /dev/null +++ b/slm_lab/spec/experimental/roadmap/sac_crossq_cartpole.yaml @@ -0,0 +1,84 @@ +# SAC + CrossQ CartPole verification (quick smoke test) + +_sac_base_meta: &sac_base_meta + distributed: false + log_frequency: 500 + eval_frequency: 500 + max_session: 4 + max_trial: 1 + +_sac_base_net: &sac_base_net + type: MLPNet + hid_layers: [64, 64] + hid_layers_activation: relu + init_fn: xavier_uniform_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.0003 + update_type: polyak + polyak_coef: 0.005 + gpu: auto + +# 1. SAC baseline +sac_cartpole_baseline: + agent: + name: SoftActorCritic + algorithm: + name: SoftActorCritic + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_start_step: 500 + memory: + name: Replay + batch_size: 64 + max_size: 50000 + use_cer: false + net: *sac_base_net + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *sac_base_meta + +# 2. CrossQ on CartPole (discrete) — with BatchNorm critics +crossq_cartpole: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: Categorical + action_policy: default + gamma: 0.99 + training_frequency: 1 + training_iter: 1 + training_start_step: 1000 + memory: + name: Replay + batch_size: 256 + max_size: 50000 + use_cer: false + net: + type: MLPNet + hid_layers: [256, 256] + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: 1.0 + batch_norm: true + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.0001 + gpu: auto + env: + name: CartPole-v1 + num_envs: 4 + max_t: null + max_frame: 100000 + meta: *sac_base_meta diff --git a/slm_lab/spec/experimental/roadmap/sac_crossq_humanoid.yaml b/slm_lab/spec/experimental/roadmap/sac_crossq_humanoid.yaml new file mode 100644 index 000000000..2704b2e07 --- /dev/null +++ b/slm_lab/spec/experimental/roadmap/sac_crossq_humanoid.yaml @@ -0,0 +1,86 @@ +# SAC + CrossQ Humanoid verification (showcase — CrossQ should win with proper BN) + +_sac_base_meta: &sac_base_meta + distributed: false + log_frequency: 5000 + eval_frequency: 5000 + max_session: 4 + max_trial: 1 + +_sac_base_net: &sac_base_net + type: MLPNet + hid_layers: [400, 300] + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: null + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.00073 + update_type: polyak + update_frequency: 1 + polyak_coef: 0.02 + gpu: auto + +# 1. SAC baseline +sac_humanoid_baseline: + agent: + name: SoftActorCritic + algorithm: + name: SoftActorCritic + action_pdtype: default + action_policy: default + gamma: 0.98 + training_frequency: 1 + training_iter: 1 + training_start_step: 5000 + memory: + name: Replay + batch_size: 256 + max_size: 200000 + use_cer: false + net: *sac_base_net + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + meta: *sac_base_meta + +# 2. CrossQ on Humanoid — with BatchNorm critics, paper hparams +crossq_humanoid: + agent: + name: CrossQ + algorithm: + name: CrossQ + action_pdtype: default + action_policy: default + gamma: 0.98 + training_frequency: 1 + training_iter: 1 + training_start_step: 5000 + memory: + name: Replay + batch_size: 256 + max_size: 200000 + use_cer: false + net: + type: MLPNet + hid_layers: [256, 256] + hid_layers_activation: relu + init_fn: orthogonal_ + clip_grad_val: 1.0 + batch_norm: true + loss_spec: + name: MSELoss + optim_spec: + name: Adam + lr: 0.0001 + gpu: auto + env: + name: Humanoid-v5 + num_envs: 4 + max_t: null + max_frame: 200000 + meta: *sac_base_meta diff --git a/slm_lab/spec/spec_util.py b/slm_lab/spec/spec_util.py index 43b78e431..f0e978cda 100644 --- a/slm_lab/spec/spec_util.py +++ b/slm_lab/spec/spec_util.py @@ -145,18 +145,38 @@ def set_variables(spec_str: str, sets: list[str] | None) -> tuple[str, str | Non env_short = None for item in sets: k, v = item.split("=", 1) - # For numeric values, replace quoted "${var}" with unquoted value + # For numeric values, replace quoted "${var}" with the canonical numeric string. + # YAML doesn't recognize bare scientific notation like "4e6" — must use int or float form. try: - float(v) - spec_str = spec_str.replace(f'"${{{k}}}"', v) + num = float(v) + canonical = str(int(num)) if num == int(num) else str(num) + spec_str = spec_str.replace(f'"${{{k}}}"', canonical) + spec_str = spec_str.replace(f"${{{k}}}", canonical) except ValueError: - pass - spec_str = spec_str.replace(f"${{{k}}}", v) + spec_str = spec_str.replace(f"${{{k}}}", v) if k == "env": env_short = v.split("/")[-1].split("-")[0].lower() return spec_str, env_short +VAR_PATTERN = re.compile(r"\$\{(\w+)\}") + + +def _find_unsubstituted_vars(obj, path: str = "") -> list[str]: + """Recursively find unsubstituted ${var} placeholders in a parsed spec dict.""" + found = [] + if isinstance(obj, str): + for match in VAR_PATTERN.finditer(obj): + found.append(f"{path}: ${{{match.group(1)}}}") + elif isinstance(obj, dict): + for k, v in obj.items(): + found.extend(_find_unsubstituted_vars(v, f"{path}.{k}" if path else k)) + elif isinstance(obj, list): + for i, v in enumerate(obj): + found.extend(_find_unsubstituted_vars(v, f"{path}[{i}]")) + return found + + def get(spec_file, spec_name, experiment_ts=None, sets: list[str] | None = None): """ Get an experiment spec from spec_file, spec_name. @@ -189,6 +209,16 @@ def get(spec_file, spec_name, experiment_ts=None, sets: list[str] | None = None) f"spec_name {spec_name} is not in spec_file {spec_file}. Choose from:\n {ps.join(spec_dict.keys(), ',')}" ) spec = spec_dict[spec_name] + + # Fail fast on unsubstituted ${var} placeholders + unsubstituted = _find_unsubstituted_vars(spec) + if unsubstituted: + vars_str = "\n ".join(unsubstituted) + raise ValueError( + f"Unsubstituted variables in spec '{spec_name}':\n {vars_str}\n" + f"Pass them via -s, e.g.: slm-lab run -s max_frame=4e6 ..." + ) + # fill-in info at runtime spec["name"] = spec_name if env_short: diff --git a/test/agent/algorithm/test_crossq.py b/test/agent/algorithm/test_crossq.py new file mode 100644 index 000000000..b84e84e6e --- /dev/null +++ b/test/agent/algorithm/test_crossq.py @@ -0,0 +1,223 @@ +"""Tests for CrossQ algorithm (Tier 2).""" + +import pytest +import torch +import torch.nn as nn + +from slm_lab.agent.algorithm.crossq import CrossQ +from slm_lab.agent.algorithm.sac import SoftActorCritic + +try: + import torcharc + + HAS_TORCHARC = True +except ImportError: + HAS_TORCHARC = False + + +# --------------------------------------------------------------------------- +# Unit tests — no env / agent needed +# --------------------------------------------------------------------------- + + +class TestCrossQClass: + def test_inherits_from_sac(self): + assert issubclass(CrossQ, SoftActorCritic) + + def test_class_exists_in_algorithm_module(self): + from slm_lab.agent import algorithm + + assert hasattr(algorithm, "CrossQ") + + +class TestCalcQCross: + """Test calc_q_cross with a simple Linear Q-net (no torcharc needed).""" + + @pytest.fixture + def q_net(self): + """Simple Q-net: input_dim=6 (state=4 + action=2), output=1.""" + net = nn.Linear(6, 1) + return net + + def test_output_shapes(self, q_net): + batch = 8 + states = torch.randn(batch, 4) + actions = torch.randn(batch, 2) + next_states = torch.randn(batch, 4) + next_actions = torch.randn(batch, 2) + + # Call unbound — CrossQ.calc_q_cross is a regular method, invoke via class + q_current, q_next = CrossQ.calc_q_cross( + None, states, actions, next_states, next_actions, q_net + ) + assert q_current.shape == (batch,) + assert q_next.shape == (batch,) + + def test_batch_split_correctness(self, q_net): + """Verify the first half of the concatenated batch corresponds to current.""" + batch = 4 + states = torch.ones(batch, 4) + actions = torch.ones(batch, 2) + next_states = torch.zeros(batch, 4) + next_actions = torch.zeros(batch, 2) + + q_current, q_next = CrossQ.calc_q_cross( + None, states, actions, next_states, next_actions, q_net + ) + # current and next should differ because inputs differ + assert not torch.allclose(q_current, q_next) + + def test_same_input_gives_same_output(self, q_net): + """When current == next, both halves should be identical.""" + batch = 4 + states = torch.randn(batch, 4) + actions = torch.randn(batch, 2) + + q_current, q_next = CrossQ.calc_q_cross( + None, states, actions, states, actions, q_net + ) + assert torch.allclose(q_current, q_next) + + +class TestCalcQCrossDiscrete: + """Test calc_q_cross_discrete with a simple Linear Q-net.""" + + @pytest.fixture + def q_net(self): + """Simple Q-net: input=state_dim=4, output=action_dim=2.""" + return nn.Linear(4, 2) + + def test_output_shapes(self, q_net): + batch = 8 + states = torch.randn(batch, 4) + next_states = torch.randn(batch, 4) + + q_current, q_next = CrossQ.calc_q_cross_discrete( + None, states, next_states, q_net + ) + assert q_current.shape == (batch, 2) + assert q_next.shape == (batch, 2) + + def test_same_input_gives_same_output(self, q_net): + batch = 4 + states = torch.randn(batch, 4) + + q_current, q_next = CrossQ.calc_q_cross_discrete(None, states, states, q_net) + assert torch.allclose(q_current, q_next) + + +class TestUpdateNetsNoop: + def test_update_nets_is_noop(self): + """update_nets should do nothing (no target networks).""" + crossq = CrossQ.__new__(CrossQ) + # Should not raise + crossq.update_nets() + + +# --------------------------------------------------------------------------- +# Integration tests — require agent + env via spec +# --------------------------------------------------------------------------- + + +def _get_crossq_cartpole_spec(): + """Build a minimal CrossQ spec for CartPole (discrete) using MLPNet.""" + from slm_lab.spec import spec_util + + spec = spec_util.get("benchmark/sac/sac_cartpole.json", "sac_cartpole") + # Override to CrossQ + spec["agent"]["name"] = "CrossQ" + spec["agent"]["algorithm"]["name"] = "CrossQ" + spec["agent"]["algorithm"]["training_iter"] = 1 # UTD=1 + spec = spec_util.override_spec(spec, "test") + return spec + + +def _get_crossq_pendulum_spec(): + """Build a minimal CrossQ spec for Pendulum (continuous) using MLPNet.""" + from slm_lab.spec import spec_util + + spec = spec_util.get("benchmark/sac/sac_cartpole.json", "sac_cartpole") + # Override to CrossQ + continuous env + spec["agent"]["name"] = "CrossQ" + spec["agent"]["algorithm"]["name"] = "CrossQ" + spec["agent"]["algorithm"]["action_pdtype"] = "default" + spec["agent"]["algorithm"]["training_iter"] = 1 + spec["env"]["name"] = "Pendulum-v1" + spec = spec_util.override_spec(spec, "test") + return spec + + +class TestCrossQIntegration: + def test_no_target_networks(self): + from slm_lab.experiment.control import make_agent_env + from slm_lab.spec import spec_util + + spec = _get_crossq_cartpole_spec() + spec_util.tick(spec, "trial") + agent, env = make_agent_env(spec) + algo = agent.algorithm + + assert not hasattr(algo, "target_q1_net") + assert not hasattr(algo, "target_q2_net") + + def test_net_names(self): + from slm_lab.experiment.control import make_agent_env + from slm_lab.spec import spec_util + + spec = _get_crossq_cartpole_spec() + spec_util.tick(spec, "trial") + agent, env = make_agent_env(spec) + + assert agent.algorithm.net_names == ["net", "q1_net", "q2_net"] + + def test_session_cartpole(self): + """CrossQ completes a short training session on CartPole.""" + from slm_lab.experiment.control import Session + from slm_lab.spec import spec_util + + spec = _get_crossq_cartpole_spec() + spec_util.tick(spec, "trial") + spec_util.tick(spec, "session") + spec_util.save(spec, unit="trial") + session = Session(spec) + metrics = session.run() + assert isinstance(metrics, dict) + + def test_session_pendulum(self): + """CrossQ completes a short training session on Pendulum (continuous).""" + from slm_lab.experiment.control import Session + from slm_lab.spec import spec_util + + spec = _get_crossq_pendulum_spec() + spec_util.tick(spec, "trial") + spec_util.tick(spec, "session") + spec_util.save(spec, unit="trial") + session = Session(spec) + metrics = session.run() + assert isinstance(metrics, dict) + + def test_bn_mode_switching(self): + """Critics switch between eval (target) and train (cross forward) modes.""" + from slm_lab.experiment.control import make_agent_env + from slm_lab.spec import spec_util + + spec = _get_crossq_cartpole_spec() + spec_util.tick(spec, "trial") + agent, env = make_agent_env(spec) + algo = agent.algorithm + + # After init, nets are in train mode + assert algo.q1_net.training + assert algo.q2_net.training + + # Simulate eval mode switch (as in train() target computation) + algo.q1_net.eval() + algo.q2_net.eval() + assert not algo.q1_net.training + assert not algo.q2_net.training + + # Switch back to train (as in cross batch norm forward) + algo.q1_net.train() + algo.q2_net.train() + assert algo.q1_net.training + assert algo.q2_net.training diff --git a/test/agent/algorithm/test_normalizers.py b/test/agent/algorithm/test_normalizers.py new file mode 100644 index 000000000..4f07e8c4e --- /dev/null +++ b/test/agent/algorithm/test_normalizers.py @@ -0,0 +1,73 @@ +import torch + +from slm_lab.agent.algorithm.actor_critic import PercentileNormalizer + + +def test_init_zeros(): + norm = PercentileNormalizer() + assert norm.perc5 == 0.0 + assert norm.perc95 == 0.0 + + +def test_update_tracks_percentiles(): + norm = PercentileNormalizer(decay=0.0) # no EMA, instant update + values = torch.arange(100, dtype=torch.float32) + norm.update(values) + # With decay=0, perc5/perc95 should match torch.quantile exactly + expected_p5 = torch.quantile(values, 0.05).item() + expected_p95 = torch.quantile(values, 0.95).item() + assert abs(norm.perc5 - expected_p5) < 1e-4 + assert abs(norm.perc95 - expected_p95) < 1e-4 + + +def test_normalize_divides_by_scale(): + norm = PercentileNormalizer(decay=0.0) + values = torch.arange(100, dtype=torch.float32) + norm.update(values) + scale = max(1.0, norm.perc95 - norm.perc5) + result = norm.normalize(values) + expected = values / scale + assert torch.allclose(result, expected) + + +def test_ema_decay_converges(): + """After many updates with the same distribution, percentiles should converge""" + norm = PercentileNormalizer(decay=0.99) + values = torch.randn(1000) + for _ in range(500): + norm.update(values) + # After convergence, perc5/perc95 should approximate the true quantiles + true_p5 = torch.quantile(values, 0.05).item() + true_p95 = torch.quantile(values, 0.95).item() + assert abs(norm.perc5 - true_p5) < 0.3 + assert abs(norm.perc95 - true_p95) < 0.3 + + +def test_normalize_zero_range(): + """All same values: scale should be max(1.0, 0) = 1.0""" + norm = PercentileNormalizer(decay=0.0) + values = torch.ones(100) + norm.update(values) + result = norm.normalize(values) + # scale = max(1.0, perc95 - perc5) = max(1.0, 0.0) = 1.0 + assert torch.allclose(result, values) + + +def test_normalize_uniform_distribution(): + norm = PercentileNormalizer(decay=0.0) + values = torch.linspace(0, 100, 1000) + norm.update(values) + result = norm.normalize(values) + scale = max(1.0, norm.perc95 - norm.perc5) + assert torch.allclose(result, values / scale) + + +def test_normalize_skewed_distribution(): + """Skewed distribution still produces finite output""" + norm = PercentileNormalizer(decay=0.0) + # Exponential-like skew + values = torch.exp(torch.randn(500)) + norm.update(values) + result = norm.normalize(values) + assert torch.all(torch.isfinite(result)) + assert norm.perc95 > norm.perc5 diff --git a/test/agent/algorithm/test_policy_util.py b/test/agent/algorithm/test_policy_util.py index 45f541c79..9170cc344 100644 --- a/test/agent/algorithm/test_policy_util.py +++ b/test/agent/algorithm/test_policy_util.py @@ -1,4 +1,5 @@ """Tests for policy_util module, especially ACTION_PDS configuration.""" + import pytest from torch import distributions @@ -10,29 +11,31 @@ class TestActionPds: def test_multi_continuous_includes_normal(self): """Normal should be first option for multi_continuous (standard for SAC/PPO).""" - pdtypes = policy_util.ACTION_PDS['multi_continuous'] - assert 'Normal' in pdtypes, 'Normal must be available for multi_continuous' - assert pdtypes[0] == 'Normal', 'Normal should be first (default) for multi_continuous' + pdtypes = policy_util.ACTION_PDS["multi_continuous"] + assert "Normal" in pdtypes, "Normal must be available for multi_continuous" + assert pdtypes[0] == "Normal", ( + "Normal should be first (default) for multi_continuous" + ) def test_multi_continuous_includes_multivariate_normal(self): """MultivariateNormal should also be available for multi_continuous.""" - pdtypes = policy_util.ACTION_PDS['multi_continuous'] - assert 'MultivariateNormal' in pdtypes + pdtypes = policy_util.ACTION_PDS["multi_continuous"] + assert "MultivariateNormal" in pdtypes def test_continuous_includes_normal(self): """Normal should be first option for continuous.""" - pdtypes = policy_util.ACTION_PDS['continuous'] - assert pdtypes[0] == 'Normal' + pdtypes = policy_util.ACTION_PDS["continuous"] + assert pdtypes[0] == "Normal" def test_discrete_includes_categorical(self): """Categorical should be first option for discrete.""" - pdtypes = policy_util.ACTION_PDS['discrete'] - assert pdtypes[0] == 'Categorical' + pdtypes = policy_util.ACTION_PDS["discrete"] + assert pdtypes[0] == "Categorical" def test_all_action_types_have_defaults(self): """All action types should have at least one distribution option.""" for action_type, pdtypes in policy_util.ACTION_PDS.items(): - assert len(pdtypes) > 0, f'{action_type} has no distributions' + assert len(pdtypes) > 0, f"{action_type} has no distributions" class TestGetActionPdCls: @@ -40,28 +43,28 @@ class TestGetActionPdCls: def test_normal_for_multi_continuous(self): """Normal should be valid for multi_continuous action types.""" - pd_cls = policy_util.get_action_pd_cls('Normal', 'multi_continuous') + pd_cls = policy_util.get_action_pd_cls("Normal", "multi_continuous") assert pd_cls == distributions.Normal def test_normal_for_continuous(self): """Normal should be valid for continuous action types.""" - pd_cls = policy_util.get_action_pd_cls('Normal', 'continuous') + pd_cls = policy_util.get_action_pd_cls("Normal", "continuous") assert pd_cls == distributions.Normal def test_categorical_for_discrete(self): """Categorical should be valid for discrete action types.""" - pd_cls = policy_util.get_action_pd_cls('Categorical', 'discrete') + pd_cls = policy_util.get_action_pd_cls("Categorical", "discrete") assert pd_cls == distributions.Categorical def test_invalid_pdtype_raises(self): """Invalid pdtype for action type should raise assertion.""" with pytest.raises(AssertionError): - policy_util.get_action_pd_cls('Categorical', 'continuous') + policy_util.get_action_pd_cls("Categorical", "continuous") def test_invalid_action_type_raises(self): """Invalid action type should raise KeyError.""" with pytest.raises(KeyError): - policy_util.get_action_pd_cls('Normal', 'invalid_type') + policy_util.get_action_pd_cls("Normal", "invalid_type") class TestInitActionPd: @@ -70,6 +73,7 @@ class TestInitActionPd: def test_normal_distribution_init(self): """Normal distribution should initialize with loc and scale.""" import torch + # pdparam shape: [batch, 2, action_dim] where 2 is [loc, log_scale] pdparam = [torch.zeros(2, 2), torch.zeros(2, 2)] # loc, log_scale action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) @@ -79,6 +83,7 @@ def test_normal_distribution_init(self): def test_categorical_distribution_init(self): """Categorical distribution should initialize with logits.""" import torch + pdparam = torch.randn(2, 4) # batch_size=2, num_actions=4 action_pd = policy_util.init_action_pd(distributions.Categorical, pdparam) assert isinstance(action_pd, distributions.Categorical) @@ -87,6 +92,7 @@ def test_categorical_distribution_init(self): def test_normal_1d_continuous_init(self): """Normal distribution should handle 1D continuous (Pendulum-like) with tensor input.""" import torch + # 1D action: pdparam is [batch, 2] tensor (loc, log_scale concatenated) pdparam = torch.randn(256, 2) action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) @@ -98,17 +104,19 @@ def test_normal_1d_continuous_init(self): def test_normal_1d_log_prob_shape(self): """1D continuous log_prob should have correct shape for sum(-1).""" import torch + pdparam = torch.randn(256, 2) action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) actions = action_pd.rsample() log_prob = action_pd.log_prob(actions) # sum(-1) should produce [batch] shape, not scalar result = log_prob.sum(-1) - assert result.shape == (256,), f'Expected shape (256,), got {result.shape}' + assert result.shape == (256,), f"Expected shape (256,), got {result.shape}" def test_normal_multidim_continuous_init(self): """Normal distribution should handle multi-dim continuous (Lunar-like) with list input.""" import torch + # Multi-dim action: pdparam is list of [loc, log_scale] tensors pdparam = [torch.randn(256, 2), torch.randn(256, 2)] action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) @@ -119,29 +127,36 @@ def test_normal_multidim_continuous_init(self): def test_normal_multidim_log_prob_shape(self): """Multi-dim continuous log_prob should have correct shape for sum(-1).""" import torch + pdparam = [torch.randn(256, 2), torch.randn(256, 2)] action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) actions = action_pd.rsample() log_prob = action_pd.log_prob(actions) result = log_prob.sum(-1) - assert result.shape == (256,), f'Expected shape (256,), got {result.shape}' + assert result.shape == (256,), f"Expected shape (256,), got {result.shape}" def test_entropy_1d_continuous_shape(self): """1D continuous entropy should have correct shape for sum(-1).""" import torch + pdparam = torch.randn(256, 2) action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) entropy = action_pd.entropy() # Shape should be [batch, 1] for consistent sum(-1) behavior - assert entropy.shape == (256, 1), f'Expected shape (256, 1), got {entropy.shape}' + assert entropy.shape == (256, 1), ( + f"Expected shape (256, 1), got {entropy.shape}" + ) def test_entropy_multidim_continuous_shape(self): """Multi-dim continuous entropy should have correct shape for sum(-1).""" import torch + pdparam = [torch.randn(256, 6), torch.randn(256, 6)] # HalfCheetah-like action_pd = policy_util.init_action_pd(distributions.Normal, pdparam) entropy = action_pd.entropy() - assert entropy.shape == (256, 6), f'Expected shape (256, 6), got {entropy.shape}' + assert entropy.shape == (256, 6), ( + f"Expected shape (256, 6), got {entropy.shape}" + ) def test_entropy_sum_then_mean_pattern(self): """Entropy sum(-1).mean() should scale with action dimensions (CleanRL standard). @@ -154,6 +169,7 @@ def test_entropy_sum_then_mean_pattern(self): contribution N times weaker for N-dimensional action spaces. """ import torch + batch_size = 256 # 1D continuous (Pendulum-like) @@ -165,7 +181,10 @@ def test_entropy_sum_then_mean_pattern(self): entropy_1d_result = entropy_1d.mean() # 6D continuous (HalfCheetah-like) - pdparam_6d = [torch.zeros(batch_size, 6), torch.zeros(batch_size, 6)] # zero mean, unit std + pdparam_6d = [ + torch.zeros(batch_size, 6), + torch.zeros(batch_size, 6), + ] # zero mean, unit std action_pd_6d = policy_util.init_action_pd(distributions.Normal, pdparam_6d) entropy_6d = action_pd_6d.entropy() if entropy_6d.dim() > 1: @@ -177,4 +196,4 @@ def test_entropy_sum_then_mean_pattern(self): # gives consistent entropy per dimension ratio = entropy_6d_result / entropy_1d_result # Should be close to 6.0 (6 dims vs 1 dim) - assert 5.5 < ratio < 6.5, f'Entropy should scale ~6x, got ratio={ratio:.2f}' + assert 5.5 < ratio < 6.5, f"Entropy should scale ~6x, got ratio={ratio:.2f}" diff --git a/test/agent/algorithm/test_ppo_features.py b/test/agent/algorithm/test_ppo_features.py new file mode 100644 index 000000000..f950cab11 --- /dev/null +++ b/test/agent/algorithm/test_ppo_features.py @@ -0,0 +1,34 @@ +from slm_lab.experiment.control import make_agent_env +from slm_lab.spec import spec_util + + +def _make_ppo_agent(algorithm_overrides=None): + """Create a PPO agent with optional algorithm spec overrides.""" + spec = spec_util.get("benchmark/ppo/ppo_cartpole.json", "ppo_cartpole") + spec_util.tick(spec, "trial") + spec = spec_util.override_spec(spec, "test") + if algorithm_overrides: + spec["agent"]["algorithm"].update(algorithm_overrides) + agent, env = make_agent_env(spec) + return agent + + +def test_ppo_default_symlog_false(): + agent = _make_ppo_agent() + assert agent.algorithm.symlog is False + + +def test_ppo_default_normalize_advantages_standardize(): + agent = _make_ppo_agent() + assert agent.algorithm.normalize_advantages == "standardize" + + +def test_ppo_symlog_true_sets_attribute(): + agent = _make_ppo_agent({"symlog": True}) + assert agent.algorithm.symlog is True + + +def test_ppo_percentile_normalizer_created(): + agent = _make_ppo_agent({"normalize_advantages": "percentile"}) + assert agent.algorithm.normalize_advantages == "percentile" + assert hasattr(agent.algorithm, "percentile_normalizer") diff --git a/test/agent/algorithm/test_sac_features.py b/test/agent/algorithm/test_sac_features.py new file mode 100644 index 000000000..d44f1e7e5 --- /dev/null +++ b/test/agent/algorithm/test_sac_features.py @@ -0,0 +1,23 @@ +from slm_lab.experiment.control import make_agent_env +from slm_lab.spec import spec_util + + +def _make_sac_agent(algorithm_overrides=None): + """Create a SAC agent with optional algorithm spec overrides.""" + spec = spec_util.get("benchmark/sac/sac_cartpole.json", "sac_cartpole") + spec_util.tick(spec, "trial") + spec = spec_util.override_spec(spec, "test") + if algorithm_overrides: + spec["agent"]["algorithm"].update(algorithm_overrides) + agent, env = make_agent_env(spec) + return agent + + +def test_sac_default_symlog_false(): + agent = _make_sac_agent() + assert agent.algorithm.symlog is False + + +def test_sac_symlog_true_sets_attribute(): + agent = _make_sac_agent({"symlog": True}) + assert agent.algorithm.symlog is True diff --git a/test/agent/net/test_conv.py b/test/agent/net/test_conv.py index bfbead3d1..a852d349e 100644 --- a/test/agent/net/test_conv.py +++ b/test/agent/net/test_conv.py @@ -68,7 +68,7 @@ def test_no_fc(): net = ConvNet(no_fc_net_spec, in_dim, out_dim) assert isinstance(net, nn.Module) assert hasattr(net, 'conv_model') - assert not hasattr(net, 'fc_model') + assert net.fc_model is None assert hasattr(net, 'tails') assert not isinstance(net.tails, nn.ModuleList) diff --git a/test/agent/net/test_mlp.py b/test/agent/net/test_mlp.py index 65bbbd5cd..85d62ef0f 100644 --- a/test/agent/net/test_mlp.py +++ b/test/agent/net/test_mlp.py @@ -11,22 +11,13 @@ "hid_layers_activation": "relu", "init_fn": "xavier_uniform_", "clip_grad_val": 1.0, - "loss_spec": { - "name": "MSELoss" - }, - "optim_spec": { - "name": "Adam", - "lr": 0.02 - }, - "lr_scheduler_spec": { - "name": "StepLR", - "step_size": 30, - "gamma": 0.1 - }, + "loss_spec": {"name": "MSELoss"}, + "optim_spec": {"name": "Adam", "lr": 0.02}, + "lr_scheduler_spec": {"name": "StepLR", "step_size": 30, "gamma": 0.1}, "update_type": "replace", "update_frequency": 1, "polyak_coef": 0.9, - "gpu": True + "gpu": True, } in_dim = 10 out_dim = 3 @@ -41,8 +32,8 @@ def test_init(): net = MLPNet(net_spec, in_dim, out_dim) assert isinstance(net, nn.Module) - assert hasattr(net, 'model') - assert hasattr(net, 'tails') + assert hasattr(net, "model") + assert hasattr(net, "tails") assert not isinstance(net.tails, nn.ModuleList) @@ -60,11 +51,11 @@ def test_train_step(): def test_no_lr_scheduler(): nopo_lrs_net_spec = deepcopy(net_spec) - nopo_lrs_net_spec['lr_scheduler_spec'] = None + nopo_lrs_net_spec["lr_scheduler_spec"] = None net = MLPNet(nopo_lrs_net_spec, in_dim, out_dim) assert isinstance(net, nn.Module) - assert hasattr(net, 'model') - assert hasattr(net, 'tails') + assert hasattr(net, "model") + assert hasattr(net, "tails") assert not isinstance(net.tails, nn.ModuleList) y = net.forward(x) @@ -74,8 +65,8 @@ def test_no_lr_scheduler(): def test_multitails(): net = MLPNet(net_spec, in_dim, [3, 4]) assert isinstance(net, nn.Module) - assert hasattr(net, 'model') - assert hasattr(net, 'tails') + assert hasattr(net, "model") + assert hasattr(net, "tails") assert isinstance(net.tails, nn.ModuleList) assert len(net.tails) == 2 @@ -83,3 +74,41 @@ def test_multitails(): assert len(y) == 2 assert y[0].shape == (batch_size, 3) assert y[1].shape == (batch_size, 4) + + +# layer_norm tests + + +def test_layer_norm_false_no_layernorm(): + """layer_norm=False (default) should not include LayerNorm modules""" + spec = {**net_spec, "layer_norm": False} + mlp = MLPNet(spec, in_dim, out_dim) + has_ln = any(isinstance(m, nn.LayerNorm) for m in mlp.model.modules()) + assert not has_ln + + +def test_layer_norm_true_has_layernorm(): + """layer_norm=True should add LayerNorm layers in model""" + spec = {**net_spec, "layer_norm": True} + mlp = MLPNet(spec, in_dim, out_dim) + ln_layers = [m for m in mlp.model.modules() if isinstance(m, nn.LayerNorm)] + assert len(ln_layers) > 0 + + +def test_layer_norm_forward_shape_unchanged(): + """Output shape should be the same regardless of layer_norm setting""" + spec = {**net_spec, "layer_norm": True} + mlp = MLPNet(spec, in_dim, out_dim) + y = mlp.forward(x) + assert y.shape == (batch_size, out_dim) + + +def test_build_fc_model_layer_norm_layer_count(): + """build_fc_model with layer_norm=True should have more layers (Linear + LayerNorm + activation)""" + from slm_lab.agent.net.net_util import build_fc_model + + model_no_ln = build_fc_model([in_dim, 32, 16], "relu", layer_norm=False) + model_ln = build_fc_model([in_dim, 32, 16], "relu", layer_norm=True) + # With layer_norm, each dim pair gets: Linear + LayerNorm + Activation = 3 layers + # Without: Linear + Activation = 2 layers + assert len(model_ln) > len(model_no_ln) diff --git a/test/agent/net/test_recurrent.py b/test/agent/net/test_recurrent.py index 0c3bed811..6da651aaf 100644 --- a/test/agent/net/test_recurrent.py +++ b/test/agent/net/test_recurrent.py @@ -88,7 +88,7 @@ def test_no_fc(): no_fc_net_spec['fc_hid_layers'] = [] net = RecurrentNet(no_fc_net_spec, in_dim, out_dim) assert isinstance(net, nn.Module) - assert not hasattr(net, 'fc_model') + assert net.fc_model is None assert hasattr(net, 'rnn_model') assert hasattr(net, 'tails') assert not isinstance(net.tails, nn.ModuleList) diff --git a/test/cli/test_main.py b/test/cli/test_main.py index 6120f5f99..ef014e298 100644 --- a/test/cli/test_main.py +++ b/test/cli/test_main.py @@ -48,7 +48,7 @@ def test_single_variable_substitution(self): def test_numeric_variable_substitution(self): spec_str = '{"env": {"max_frame": "${max_frame}"}}' result, _ = set_variables(spec_str, ["max_frame=3e6"]) - assert '"max_frame": 3e6' in result # unquoted number + assert '"max_frame": 3000000' in result # canonical integer form for YAML compat def test_variable_with_equals_in_value(self): spec_str = '{"meta": {"note": "${note}"}}' diff --git a/test/env/test_action_rescaling.py b/test/env/test_action_rescaling.py index ab9710593..38512e2bd 100644 --- a/test/env/test_action_rescaling.py +++ b/test/env/test_action_rescaling.py @@ -1,4 +1,5 @@ """Tests for automatic action rescaling wrapper.""" + import numpy as np import pytest diff --git a/test/profile_training.py b/test/profile_training.py new file mode 100644 index 000000000..7038ca370 --- /dev/null +++ b/test/profile_training.py @@ -0,0 +1,164 @@ +"""Profile SLM-Lab training with PyTorch profiler. + +Usage: + uv run python test/profile_training.py # PPO CartPole (default) + uv run python test/profile_training.py --algo sac # SAC CartPole + uv run python test/profile_training.py --frames 10000 --algo ppo +""" + +import argparse +import os +import sys + +# Set env vars before any SLM-Lab imports +os.environ["lab_mode"] = "train" +os.environ["LOG_LEVEL"] = "WARNING" +os.environ["OPTIMIZE_PERF"] = "true" +os.environ["PROFILE"] = "false" +os.environ["RENDER"] = "false" +os.environ["LOG_EXTRA"] = "false" +os.environ["UPLOAD_HF"] = "false" + +import torch +from torch.profiler import ProfilerActivity, profile + +from slm_lab.experiment.control import Session +from slm_lab.spec import spec_util + + +ALGO_CONFIGS = { + "ppo": { + "spec_file": "benchmark/ppo/ppo_cartpole.json", + "spec_name": "ppo_cartpole", + }, + "sac": { + "spec_file": "benchmark/sac/sac_cartpole.json", + "spec_name": "sac_cartpole", + }, +} + + +def load_spec(algo: str, max_frame: int, num_envs: int) -> dict: + """Load and configure spec for profiling.""" + config = ALGO_CONFIGS[algo] + spec = spec_util.get(config["spec_file"], config["spec_name"]) + + # Override for quick profiling + spec["env"]["max_frame"] = max_frame + spec["env"]["num_envs"] = num_envs + spec["meta"]["max_session"] = 1 + spec["meta"]["log_frequency"] = max_frame + 1 # suppress checkpointing + spec["meta"]["eval_frequency"] = max_frame + 1 + + # Tick to set up directories and indices + spec_util.tick(spec, "session") + return spec + + +def run_profile(algo: str, max_frame: int, num_envs: int): + """Run profiling for a given algorithm.""" + print(f"\n{'='*70}") + print(f"Profiling {algo.upper()} on CartPole-v1") + print(f" max_frame={max_frame}, num_envs={num_envs}") + print(f"{'='*70}\n") + + spec = load_spec(algo, max_frame, num_envs) + session = Session(spec) + + # Profile only the RL loop (skip final analysis which needs checkpoint data) + with profile( + activities=[ProfilerActivity.CPU], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + session.run_rl() + + session.close() + + # Print tables sorted by total CPU time + print(f"\n{'='*70}") + print(f"[{algo.upper()}] Top 30 ops by cpu_time_total") + print(f"{'='*70}") + print( + prof.key_averages().table( + sort_by="cpu_time_total", + row_limit=30, + ) + ) + + print(f"\n{'='*70}") + print(f"[{algo.upper()}] Top 30 ops by self_cpu_time_total") + print(f"{'='*70}") + print( + prof.key_averages().table( + sort_by="self_cpu_time_total", + row_limit=30, + ) + ) + + # Group by input shapes to find hot tensor ops + print(f"\n{'='*70}") + print(f"[{algo.upper()}] Top 20 ops grouped by input shape") + print(f"{'='*70}") + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cpu_time_total", + row_limit=20, + ) + ) + + # Stack trace view for call hierarchy + print(f"\n{'='*70}") + print(f"[{algo.upper()}] Top 30 ops by self_cpu_time_total (with stack)") + print(f"{'='*70}") + print( + prof.key_averages(group_by_stack_n=5).table( + sort_by="self_cpu_time_total", + row_limit=30, + ) + ) + + # Save Chrome trace + trace_path = f"test/profile_trace_{algo}.json" + prof.export_chrome_trace(trace_path) + print(f"\nChrome trace saved to: {trace_path}") + print(f" Open chrome://tracing and load this file to visualize.\n") + + +def main(): + parser = argparse.ArgumentParser(description="Profile SLM-Lab training") + parser.add_argument( + "--algo", + choices=list(ALGO_CONFIGS.keys()), + default="ppo", + help="Algorithm to profile (default: ppo)", + ) + parser.add_argument( + "--frames", + type=int, + default=5000, + help="Max frames for profiling (default: 5000)", + ) + parser.add_argument( + "--num-envs", + type=int, + default=4, + help="Number of parallel envs (default: 4)", + ) + parser.add_argument( + "--all", + action="store_true", + help="Profile all algorithms", + ) + args = parser.parse_args() + + if args.all: + for algo in ALGO_CONFIGS: + run_profile(algo, args.frames, args.num_envs) + else: + run_profile(args.algo, args.frames, args.num_envs) + + +if __name__ == "__main__": + main() diff --git a/test/spec/test_crossq_spec.py b/test/spec/test_crossq_spec.py new file mode 100644 index 000000000..140724a10 --- /dev/null +++ b/test/spec/test_crossq_spec.py @@ -0,0 +1,54 @@ +"""Tests for CrossQ spec loading and validation.""" + +import pytest + +try: + import torcharc + + HAS_TORCHARC = True +except ImportError: + HAS_TORCHARC = False + + +@pytest.mark.skipif(not HAS_TORCHARC, reason="torcharc not installed") +class TestCrossQSpec: + def test_spec_loads(self): + from slm_lab.spec import spec_util + + spec = spec_util.get( + "benchmark/crossq/crossq_mujoco.yaml", "crossq_halfcheetah" + ) + assert spec is not None + assert spec["name"] == "crossq_halfcheetah" + + def test_algorithm_name(self): + from slm_lab.spec import spec_util + + spec = spec_util.get( + "benchmark/crossq/crossq_mujoco.yaml", "crossq_halfcheetah" + ) + assert spec["agent"]["algorithm"]["name"] == "CrossQ" + + def test_training_iter_utd1(self): + from slm_lab.spec import spec_util + + spec = spec_util.get( + "benchmark/crossq/crossq_mujoco.yaml", "crossq_halfcheetah" + ) + assert spec["agent"]["algorithm"]["training_iter"] == 1 + + def test_critic_net_type(self): + from slm_lab.spec import spec_util + + spec = spec_util.get( + "benchmark/crossq/crossq_mujoco.yaml", "crossq_halfcheetah" + ) + assert spec["agent"]["critic_net"]["type"] == "TorchArcNet" + + def test_spec_check_passes(self): + from slm_lab.spec import spec_util + + spec = spec_util.get( + "benchmark/crossq/crossq_mujoco.yaml", "crossq_halfcheetah" + ) + assert spec_util.check(spec) diff --git a/uv.lock b/uv.lock index 4197f7442..73eebe32d 100644 --- a/uv.lock +++ b/uv.lock @@ -2,11 +2,14 @@ version = 1 revision = 3 requires-python = ">=3.12.0" resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", ] supported-markers = [ @@ -125,6 +128,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anyio" +version = "4.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "typing-extensions", marker = "(python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, +] + [[package]] name = "appnope" version = "0.1.4" @@ -134,6 +150,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/3c0a35f46e52108d4707c44b95cfe2afcafc50800b5450c197454569b776/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f", size = 54393, upload-time = "2025-07-30T10:01:40.97Z" }, + { url = "https://files.pythonhosted.org/packages/9d/f4/98bbd6ee89febd4f212696f13c03ca302b8552e7dbf9c8efa11ea4a388c3/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b", size = 29328, upload-time = "2025-07-30T10:01:41.916Z" }, + { url = "https://files.pythonhosted.org/packages/43/24/90a01c0ef12ac91a6be05969f29944643bc1e5e461155ae6559befa8f00b/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a", size = 31269, upload-time = "2025-07-30T10:01:42.716Z" }, + { url = "https://files.pythonhosted.org/packages/0d/82/b484f702fec5536e71836fc2dbc8c5267b3f6e78d2d539b4eaa6f0db8bf8/argon2_cffi_bindings-25.1.0-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb", size = 92364, upload-time = "2025-07-30T10:01:44.887Z" }, + { url = "https://files.pythonhosted.org/packages/44/b4/678503f12aceb0262f84fa201f6027ed77d71c5019ae03b399b97caa2f19/argon2_cffi_bindings-25.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85", size = 91934, upload-time = "2025-07-30T10:01:47.203Z" }, + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, +] + +[[package]] +name = "arrow" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tzdata", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/33/032cdc44182491aa708d06a68b62434140d8c50820a087fac7af37703357/arrow-1.4.0.tar.gz", hash = "sha256:ed0cc050e98001b8779e84d461b0098c4ac597e88704a655582b21d116e526d7", size = 152931, upload-time = "2025-10-18T17:46:46.761Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/c9/d7977eaacb9df673210491da99e6a247e93df98c715fc43fd136ce1d3d33/arrow-1.4.0-py3-none-any.whl", hash = "sha256:749f0769958ebdc79c173ff0b0670d59051a535fa26e8eba02953dc19eb43205", size = 68797, upload-time = "2025-10-18T17:46:45.663Z" }, +] + [[package]] name = "asttokens" version = "3.0.0" @@ -143,6 +205,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, ] +[[package]] +name = "async-lru" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8a/ca724066c32a53fa75f59e0f21aa822fdaa8a0dffa112d223634e3caabf9/async_lru-2.2.0.tar.gz", hash = "sha256:80abae2a237dbc6c60861d621619af39f0d920aea306de34cb992c879e01370c", size = 14654, upload-time = "2026-02-20T19:11:43.848Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/5c/af990f019b8dd11c5492a6371fe74a5b0276357370030b67254a87329944/async_lru-2.2.0-py3-none-any.whl", hash = "sha256:e2c1cf731eba202b59c5feedaef14ffd9d02ad0037fcda64938699f2c380eafe", size = 7890, upload-time = "2026-02-20T19:11:42.273Z" }, +] + [[package]] name = "attrs" version = "25.3.0" @@ -152,6 +223,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "babel" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/51899539b6ceeeb420d40ed3cd4b7a40519404f9baf3d4ac99dc413a834b/babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d", size = 9959554, upload-time = "2026-02-01T12:30:56.078Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845, upload-time = "2026-02-01T12:30:53.445Z" }, +] + +[[package]] +name = "beautifulsoup4" +version = "4.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "typing-extensions", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/b0/1c6a16426d389813b48d95e26898aff79abbde42ad353958ad95cc8c9b21/beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86", size = 627737, upload-time = "2025-11-30T15:08:26.084Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" }, +] + +[[package]] +name = "bleach" +version = "6.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, +] + +[package.optional-dependencies] +css = [ + { name = "tinycss2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] + [[package]] name = "box2d-py" version = "2.3.5" @@ -379,6 +489,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" }, ] +[[package]] +name = "fastjsonschema" +version = "2.21.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/b5/23b216d9d985a956623b6bd12d4086b60f0059b27799f23016af04a74ea1/fastjsonschema-2.21.2.tar.gz", hash = "sha256:b1eb43748041c880796cd077f1a07c3d94e93ae84bba5ed36800a33554ae05de", size = 374130, upload-time = "2025-08-14T18:49:36.666Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -397,6 +516,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/42/cca66659a786567c8af98587d66d75e7d2b6e65662f8daab75db708ac35b/flaky-3.5.3-py2.py3-none-any.whl", hash = "sha256:a94931c46a33469ec26f09b652bc88f55a8f5cc77807b90ca7bbafef1108fd7d", size = 22368, upload-time = "2019-01-17T00:06:42.499Z" }, ] +[[package]] +name = "fqdn" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/3e/a80a8c077fd798951169626cde3e239adeba7dab75deb3555716415bd9b0/fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f", size = 6015, upload-time = "2021-03-11T07:16:29.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014", size = 9121, upload-time = "2021-03-11T07:16:28.351Z" }, +] + [[package]] name = "frozenlist" version = "1.5.0" @@ -567,6 +695,15 @@ mujoco = [ { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + [[package]] name = "hf-xet" version = "1.1.10" @@ -579,6 +716,53 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/3d/ab7109e607ed321afaa690f557a9ada6d6d164ec852fd6bf9979665dc3d6/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f900481cf6e362a6c549c61ff77468bd59d6dd082f3170a36acfef2eb6a6793f", size = 3353360, upload-time = "2025-09-12T20:10:25.563Z" }, ] +[[package]] +name = "holistictraceanalysis" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyterlab", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "networkx", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pandas", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "plotly", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pydot", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pytest", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/92/b069b3831ae95e0b6cdb193bf943b5128e78c2b85656a69d60cc6b97b82d/holistictraceanalysis-0.5.0.tar.gz", hash = "sha256:731b22cfd94f907ad06a1c751187b81e073242cd81fb987bf345f068801371ba", size = 354167, upload-time = "2025-05-29T21:05:55.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/2e/9db9a4b5e2c307264497ca2a9a92ed4d2ca1b31890d55cbe53d1bbdafcd0/holistictraceanalysis-0.5.0-py3-none-any.whl", hash = "sha256:dbb4d461ff5ea7488207e07de5e2f52aeb5d82837ec6acfe57e1f0a124972ec3", size = 371233, upload-time = "2025-05-29T21:05:54.011Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "h11", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "certifi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "httpcore", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "idna", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -669,9 +853,11 @@ name = "ipython" version = "8.18.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", ] dependencies = [ @@ -694,7 +880,8 @@ name = "ipython" version = "9.0.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", ] dependencies = [ @@ -725,6 +912,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, ] +[[package]] +name = "isoduration" +version = "20.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "arrow", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7c/1a/3c8edc664e06e6bd06cce40c6b22da5f1429aa4224d0c590f3be21c91ead/isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9", size = 11649, upload-time = "2020-11-01T11:00:00.312Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/55/e5326141505c5d5e34c5e0935d2908a74e4561eca44108fbfb9c13d2911a/isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042", size = 11321, upload-time = "2020-11-01T10:59:58.02Z" }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -749,6 +948,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "json5" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/77/e8/a3f261a66e4663f22700bc8a17c08cb83e91fbf086726e7a228398968981/json5-0.13.0.tar.gz", hash = "sha256:b1edf8d487721c0bf64d83c28e91280781f6e21f4a797d3261c7c828d4c165bf", size = 52441, upload-time = "2026-01-01T19:42:14.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/9e/038522f50ceb7e74f1f991bf1b699f24b0c2bbe7c390dd36ad69f4582258/json5-0.13.0-py3-none-any.whl", hash = "sha256:9a08e1dd65f6a4d4c6fa82d216cf2477349ec2346a38fd70cc11d2557499fbcc", size = 36163, upload-time = "2026-01-01T19:42:13.962Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonschema" version = "4.23.0" @@ -764,6 +981,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566", size = 88462, upload-time = "2024-07-08T18:40:00.165Z" }, ] +[package.optional-dependencies] +format-nongpl = [ + { name = "fqdn", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "idna", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "isoduration", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jsonpointer", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rfc3339-validator", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rfc3986-validator", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "uri-template", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "webcolors", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] + [[package]] name = "jsonschema-specifications" version = "2024.10.1" @@ -805,6 +1034,128 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965, upload-time = "2024-03-12T12:37:32.36Z" }, ] +[[package]] +name = "jupyter-events" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema", extra = ["format-nongpl"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "python-json-logger", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pyyaml", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "referencing", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rfc3339-validator", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rfc3986-validator", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/c3/306d090461e4cf3cd91eceaff84bede12a8e52cd821c2d20c9a4fd728385/jupyter_events-0.12.0.tar.gz", hash = "sha256:fc3fce98865f6784c9cd0a56a20644fc6098f21c8c33834a8d9fe383c17e554b", size = 62196, upload-time = "2025-02-03T17:23:41.485Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/48/577993f1f99c552f18a0428731a755e06171f9902fa118c379eb7c04ea22/jupyter_events-0.12.0-py3-none-any.whl", hash = "sha256:6464b2fa5ad10451c3d35fabc75eab39556ae1e2853ad0c0cc31b656731a97fb", size = 19430, upload-time = "2025-02-03T17:23:38.643Z" }, +] + +[[package]] +name = "jupyter-lsp" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-server", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/5a/9066c9f8e94ee517133cd98dba393459a16cd48bba71a82f16a65415206c/jupyter_lsp-2.3.0.tar.gz", hash = "sha256:458aa59339dc868fb784d73364f17dbce8836e906cd75fd471a325cba02e0245", size = 54823, upload-time = "2025-08-27T17:47:34.671Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/60/1f6cee0c46263de1173894f0fafcb3475ded276c472c14d25e0280c18d6d/jupyter_lsp-2.3.0-py3-none-any.whl", hash = "sha256:e914a3cb2addf48b1c7710914771aaf1819d46b2e5a79b0f917b5478ec93f34f", size = 76687, upload-time = "2025-08-27T17:47:33.15Z" }, +] + +[[package]] +name = "jupyter-server" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "argon2-cffi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-client", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-events", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-server-terminals", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "nbconvert", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "nbformat", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "prometheus-client", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pyzmq", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "send2trash", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "terminado", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tornado", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "websocket-client", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/ac/e040ec363d7b6b1f11304cc9f209dac4517ece5d5e01821366b924a64a50/jupyter_server-2.17.0.tar.gz", hash = "sha256:c38ea898566964c888b4772ae1ed58eca84592e88251d2cfc4d171f81f7e99d5", size = 731949, upload-time = "2025-08-21T14:42:54.042Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/80/a24767e6ca280f5a49525d987bf3e4d7552bf67c8be07e8ccf20271f8568/jupyter_server-2.17.0-py3-none-any.whl", hash = "sha256:e8cb9c7db4251f51ed307e329b81b72ccf2056ff82d50524debde1ee1870e13f", size = 388221, upload-time = "2025-08-21T14:42:52.034Z" }, +] + +[[package]] +name = "jupyter-server-terminals" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "terminado", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/a7/bcd0a9b0cbba88986fe944aaaf91bfda603e5a50bda8ed15123f381a3b2f/jupyter_server_terminals-0.5.4.tar.gz", hash = "sha256:bbda128ed41d0be9020349f9f1f2a4ab9952a73ed5f5ac9f1419794761fb87f5", size = 31770, upload-time = "2026-01-14T16:53:20.213Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/2d/6674563f71c6320841fc300911a55143925112a72a883e2ca71fba4c618d/jupyter_server_terminals-0.5.4-py3-none-any.whl", hash = "sha256:55be353fc74a80bc7f3b20e6be50a55a61cd525626f578dcb66a5708e2007d14", size = 13704, upload-time = "2026-01-14T16:53:18.738Z" }, +] + +[[package]] +name = "jupyterlab" +version = "4.5.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-lru", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "httpx", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "ipykernel", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-lsp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-server", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyterlab-server", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "notebook-shim", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "setuptools", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tornado", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/2d/953a5612a34a3c799a62566a548e711d103f631672fd49650e0f2de80870/jupyterlab-4.5.5.tar.gz", hash = "sha256:eac620698c59eb810e1729909be418d9373d18137cac66637141abba613b3fda", size = 23968441, upload-time = "2026-02-23T18:57:34.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/52/372d3494766d690dfdd286871bf5f7fb9a6c61f7566ccaa7153a163dd1df/jupyterlab-4.5.5-py3-none-any.whl", hash = "sha256:a35694a40a8e7f2e82f387472af24e61b22adcce87b5a8ab97a5d9c486202a6d", size = 12446824, upload-time = "2026-02-23T18:57:30.398Z" }, +] + +[[package]] +name = "jupyterlab-pygments" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/51/9187be60d989df97f5f0aba133fa54e7300f17616e065d1ada7d7646b6d6/jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d", size = 512900, upload-time = "2023-11-23T09:26:37.44Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884, upload-time = "2023-11-23T09:26:34.325Z" }, +] + +[[package]] +name = "jupyterlab-server" +version = "2.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "babel", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "json5", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jsonschema", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-server", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "requests", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/2c/90153f189e421e93c4bb4f9e3f59802a1f01abd2ac5cf40b152d7f735232/jupyterlab_server-2.28.0.tar.gz", hash = "sha256:35baa81898b15f93573e2deca50d11ac0ae407ebb688299d3a5213265033712c", size = 76996, upload-time = "2025-10-22T13:59:18.37Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/07/a000fe835f76b7e1143242ab1122e6362ef1c03f23f83a045c38859c2ae0/jupyterlab_server-2.28.0-py3-none-any.whl", hash = "sha256:e4355b148fdcf34d312bbbc80f22467d6d20460e8b8736bf235577dd18506968", size = 59830, upload-time = "2025-10-22T13:59:16.767Z" }, +] + [[package]] name = "kaleido" version = "0.2.0" @@ -941,6 +1292,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "mistune" +version = "3.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1023,6 +1383,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c0/fb39bd876ea2fd9509343d643690cd2f9715e6a77271e7c7b26f1eea70c1/narwhals-1.31.0-py3-none-any.whl", hash = "sha256:2a7b79bb5f511055c4c0142121fc0d4171ea171458e12d44dbd9c8fc6488e997", size = 313124, upload-time = "2025-03-17T15:26:23.87Z" }, ] +[[package]] +name = "nbclient" +version = "0.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "nbformat", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/91/1c1d5a4b9a9ebba2b4e32b8c852c2975c872aec1fe42ab5e516b2cecd193/nbclient-0.10.4.tar.gz", hash = "sha256:1e54091b16e6da39e297b0ece3e10f6f29f4ac4e8ee515d29f8a7099bd6553c9", size = 62554, upload-time = "2025-12-23T07:45:46.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/a0/5b0c2f11142ed1dddec842457d3f65eaf71a0080894eb6f018755b319c3a/nbclient-0.10.4-py3-none-any.whl", hash = "sha256:9162df5a7373d70d606527300a95a975a47c137776cd942e52d9c7e29ff83440", size = 25465, upload-time = "2025-12-23T07:45:44.51Z" }, +] + +[[package]] +name = "nbconvert" +version = "7.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "bleach", extra = ["css"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "defusedxml", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyterlab-pygments", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "markupsafe", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "mistune", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "nbclient", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "nbformat", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "packaging", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pandocfilters", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pygments", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/47/81f886b699450d0569f7bc551df2b1673d18df7ff25cc0c21ca36ed8a5ff/nbconvert-7.17.0.tar.gz", hash = "sha256:1b2696f1b5be12309f6c7d707c24af604b87dfaf6d950794c7b07acab96dda78", size = 862855, upload-time = "2026-01-29T16:37:48.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/4b/8d5f796a792f8a25f6925a96032f098789f448571eb92011df1ae59e8ea8/nbconvert-7.17.0-py3-none-any.whl", hash = "sha256:4f99a63b337b9a23504347afdab24a11faa7d86b405e5c8f9881cd313336d518", size = 261510, upload-time = "2026-01-29T16:37:46.322Z" }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jsonschema", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jupyter-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "traitlets", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749, upload-time = "2024-04-04T11:20:37.371Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454, upload-time = "2024-04-04T11:20:34.895Z" }, +] + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -1041,14 +1456,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, ] +[[package]] +name = "notebook-shim" +version = "0.2.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-server", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/d2/92fa3243712b9a3e8bafaf60aac366da1cada3639ca767ff4b5b3654ec28/notebook_shim-0.2.4.tar.gz", hash = "sha256:b4b2cfa1b65d98307ca24361f5b30fe785b53c3fd07b7a47e89acb5e6ac638cb", size = 13167, upload-time = "2024-02-14T23:35:18.353Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/33/bd5b9137445ea4b680023eb0469b2bb969d61303dedb2aac6560ff3d14a1/notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef", size = 13307, upload-time = "2024-02-14T23:35:16.286Z" }, +] + [[package]] name = "numpy" version = "2.0.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/a9/75/10dd1f8116a8b796cb2c737b674e02d02e80454bda953fa7e65d8c12b016/numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78", size = 18902015, upload-time = "2024-08-26T20:19:40.945Z" } @@ -1064,7 +1493,8 @@ name = "numpy" version = "2.2.4" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/e1/78/31103410a57bc2c2b93a3597340a8119588571f6a4539067546cb9a0bfac/numpy-2.2.4.tar.gz", hash = "sha256:9ba03692a45d3eef66559efe1d1096c4b9b75c0986b5dff5530c378fb8331d4f", size = 20270701, upload-time = "2025-03-16T18:27:00.648Z" } @@ -1353,6 +1783,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436, upload-time = "2024-09-20T13:09:48.112Z" }, ] +[[package]] +name = "pandocfilters" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/6f/3dd4940bbe001c06a65f88e36bad298bc7a0de5036115639926b0c5c0458/pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e", size = 8454, upload-time = "2024-01-18T20:08:13.726Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc", size = 8663, upload-time = "2024-01-18T20:08:11.28Z" }, +] + [[package]] name = "parso" version = "0.8.4" @@ -1541,7 +1980,8 @@ name = "pyarrow" version = "17.0.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", ] dependencies = [ @@ -1557,9 +1997,11 @@ name = "pyarrow" version = "21.0.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/ef/c2/ea068b8f00905c06329a3dfcd40d0fcc2b7d0f2e355bdb25b65e0a0e4cd4/pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc", size = 1133487, upload-time = "2025-07-18T00:57:31.761Z" } @@ -1665,6 +2107,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/86/e74c978800131c657fc5145f2c1c63e0cea01a49b6216f729cf77a2e1edf/pydash-8.0.5-py3-none-any.whl", hash = "sha256:b2625f8981862e19911daa07f80ed47b315ce20d9b5eb57aaf97aaf570c3892f", size = 102077, upload-time = "2025-01-17T16:08:47.91Z" }, ] +[[package]] +name = "pydot" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/35/b17cb89ff865484c6a20ef46bf9d95a5f07328292578de0b295f4a6beec2/pydot-4.0.1.tar.gz", hash = "sha256:c2148f681c4a33e08bf0e26a9e5f8e4099a82e0e2a068098f32ce86577364ad5", size = 162594, upload-time = "2025-06-17T20:09:56.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/32/a7125fb28c4261a627f999d5fb4afff25b523800faed2c30979949d6facd/pydot-4.0.1-py3-none-any.whl", hash = "sha256:869c0efadd2708c0be1f916eb669f3d664ca684bc57ffb7ecc08e70d5e93fee6", size = 37087, upload-time = "2025-06-17T20:09:55.25Z" }, +] + [[package]] name = "pygame" version = "2.6.1" @@ -1697,6 +2151,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl", hash = "sha256:794a943daced39300879e4e47bd94525280685f42dbb5a998d336cfff151d74f", size = 3194996, upload-time = "2025-08-18T02:32:59.902Z" }, ] +[[package]] +name = "pyparsing" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/91/9c6ee907786a473bf81c5f53cf703ba0957b23ab84c264080fb5a450416f/pyparsing-3.3.2.tar.gz", hash = "sha256:c777f4d763f140633dcb6d8a3eda953bf7a214dc4eff598413c070bcdc117cbc", size = 6851574, upload-time = "2026-01-21T03:57:59.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, +] + [[package]] name = "pytest" version = "8.3.5" @@ -1748,6 +2211,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-json-logger" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/29/bf/eca6a3d43db1dae7070f70e160ab20b807627ba953663ba07928cdd3dc58/python_json_logger-4.0.0.tar.gz", hash = "sha256:f58e68eb46e1faed27e0f574a55a0455eecd7b8a5b88b85a784519ba3cff047f", size = 17683, upload-time = "2025-10-06T04:15:18.984Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/e5/fecf13f06e5e5f67e8837d777d1bc43fac0ed2b77a676804df5c34744727/python_json_logger-4.0.0-py3-none-any.whl", hash = "sha256:af09c9daf6a813aa4cc7180395f50f2a9e5fa056034c9953aec92e381c5ba1e2", size = 15548, upload-time = "2025-10-06T04:15:17.553Z" }, +] + [[package]] name = "pytz" version = "2025.1" @@ -1888,6 +2360,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928, upload-time = "2024-05-29T15:37:47.027Z" }, ] +[[package]] +name = "rfc3339-validator" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/ea/a9387748e2d111c3c2b275ba970b735e04e15cdb1eb30693b6b5708c4dbd/rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b", size = 5513, upload-time = "2021-05-12T16:37:54.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/44/4e421b96b67b2daff264473f7465db72fbdf36a07e05494f50300cc7b0c6/rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa", size = 3490, upload-time = "2021-05-12T16:37:52.536Z" }, +] + +[[package]] +name = "rfc3986-validator" +version = "0.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/da/88/f270de456dd7d11dcc808abfa291ecdd3f45ff44e3b549ffa01b126464d0/rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055", size = 6760, upload-time = "2019-10-28T16:00:19.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9", size = 4242, upload-time = "2019-10-28T16:00:13.976Z" }, +] + [[package]] name = "rich" version = "14.1.0" @@ -1945,6 +2438,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/cf/8ab81cb7dd7a3b0a3960c2769825038f3adcd75faf46dd6376086df8b128/ruff-0.11.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:842a472d7b4d6f5924e9297aa38149e5dcb1e628773b70e6387ae2c97a63c58f", size = 11378514, upload-time = "2025-03-21T13:31:06.166Z" }, ] +[[package]] +name = "send2trash" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/f0/184b4b5f8d00f2a92cf96eec8967a3d550b52cf94362dad1100df9e48d57/send2trash-2.1.0.tar.gz", hash = "sha256:1c72b39f09457db3c05ce1d19158c2cbef4c32b8bedd02c155e49282b7ea7459", size = 17255, upload-time = "2026-01-14T06:27:36.056Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/78/504fdd027da3b84ff1aecd9f6957e65f35134534ccc6da8628eb71e76d3f/send2trash-2.1.0-py3-none-any.whl", hash = "sha256:0da2f112e6d6bb22de6aa6daa7e144831a4febf2a87261451c4ad849fe9a873c", size = 17610, upload-time = "2026-01-14T06:27:35.218Z" }, +] + [[package]] name = "setuptools" version = "77.0.3" @@ -2005,6 +2507,7 @@ dev = [ { name = "coverage", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "flaky", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "glances", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "holistictraceanalysis", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "ipykernel", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "nvidia-ml-py", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "pytest", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -2052,6 +2555,7 @@ dev = [ { name = "coverage", specifier = ">=7.6.1" }, { name = "flaky", specifier = ">=3.5.3" }, { name = "glances", specifier = ">=4.3.3" }, + { name = "holistictraceanalysis", specifier = ">=0.5.0" }, { name = "ipykernel", specifier = ">=6.29.5" }, { name = "nvidia-ml-py", specifier = ">=13.580.65" }, { name = "pytest", specifier = ">=6.0.0" }, @@ -2090,6 +2594,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/d9/460cf1d58945dd771c228c29d5664f431dfc4060d3d092fed40546b11472/smart_open-7.3.1-py3-none-any.whl", hash = "sha256:e243b2e7f69d6c0c96dd763d6fbbedbb4e0e4fc6d74aa007acc5b018d523858c", size = 61722, upload-time = "2025-09-08T10:03:52.02Z" }, ] +[[package]] +name = "soupsieve" +version = "2.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.43" @@ -2194,14 +2707,41 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/1d/b5d63f1a6b824282b57f7b581810d20b7a28ca951f2d5b59f1eb0782c12b/tensorboardx-2.6.4-py3-none-any.whl", hash = "sha256:5970cf3a1f0a6a6e8b180ccf46f3fe832b8a25a70b86e5a237048a7c0beb18e2", size = 87201, upload-time = "2025-06-10T22:37:05.44Z" }, ] +[[package]] +name = "terminado" +version = "0.18.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess", marker = "(os_name != 'nt' and platform_machine == 'x86_64' and sys_platform == 'linux') or (os_name != 'nt' and sys_platform == 'darwin')" }, + { name = "tornado", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8a/11/965c6fd8e5cc254f1fe142d547387da17a8ebfd75a3455f637c663fb38a0/terminado-0.18.1.tar.gz", hash = "sha256:de09f2c4b85de4765f7714688fff57d3e75bad1f909b589fde880460c753fd2e", size = 32701, upload-time = "2024-03-12T14:34:39.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" }, +] + +[[package]] +name = "tinycss2" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085, upload-time = "2024-10-24T14:58:29.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, +] + [[package]] name = "torch" version = "2.8.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine != 'x86_64' and sys_platform == 'darwin'", ] dependencies = [ @@ -2224,7 +2764,8 @@ name = "torch" version = "2.8.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", ] dependencies = [ @@ -2374,6 +2915,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/ed/582c4daba0f3e1688d923b5cb914ada1f9defa702df38a1916c899f7c4d1/ujson-5.10.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b9500e61fce0cfc86168b248104e954fead61f9be213087153d272e817ec7b4f", size = 1043580, upload-time = "2024-05-14T02:01:31.447Z" }, ] +[[package]] +name = "uri-template" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/31/c7/0336f2bd0bcbada6ccef7aaa25e443c118a704f828a0620c6fa0207c1b64/uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7", size = 21678, upload-time = "2023-06-21T01:49:05.374Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/00/3fca040d7cf8a32776d3d81a00c8ee7457e00f80c649f1e4a863c8321ae9/uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363", size = 11140, upload-time = "2023-06-21T01:49:03.467Z" }, +] + [[package]] name = "urllib3" version = "2.3.0" @@ -2406,6 +2956,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166, upload-time = "2024-01-06T02:10:55.763Z" }, ] +[[package]] +name = "webcolors" +version = "25.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/7a/eb316761ec35664ea5174709a68bbd3389de60d4a1ebab8808bfc264ed67/webcolors-25.10.0.tar.gz", hash = "sha256:62abae86504f66d0f6364c2a8520de4a0c47b80c03fc3a5f1815fedbef7c19bf", size = 53491, upload-time = "2025-10-31T07:51:03.977Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/cc/e097523dd85c9cf5d354f78310927f1656c422bd7b2613b2db3e3f9a0f2c/webcolors-25.10.0-py3-none-any.whl", hash = "sha256:032c727334856fc0b968f63daa252a1ac93d33db2f5267756623c210e57a4f1d", size = 14905, upload-time = "2025-10-31T07:51:01.778Z" }, +] + +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + +[[package]] +name = "websocket-client" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, +] + [[package]] name = "werkzeug" version = "3.1.3"