From 644273a9dcdaa4a4b5b121a7187d23feddf15d70 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 14:43:26 +0200 Subject: [PATCH 01/76] feat: add timing metrics (startup, training, overall) - Track startup_time_seconds: time from run() start to training loop - Track total_training_time_seconds: time in training/validation cycles - Track overall_time_seconds: total wall-clock time from launch to finish - All metrics logged only on root rank to avoid file contention - Metrics written to metrics.json, automatically uploaded to MLflow - Console logs show timing summaries for quick monitoring --- src/weathergen/run_train.py | 26 ++++++++++++++++++++++++++ src/weathergen/train/trainer.py | 15 +++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 7995b5864..de101c0f8 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -122,6 +122,7 @@ def run_continue(args): Note: All model configurations are set in the function body. """ + t_overall_start = time.time() # Track overall run start time cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -153,6 +154,18 @@ def run_continue(args): traceback.print_exc() if cf.world_size == 1: pdb.post_mortem(tb) + finally: + # Log overall time (only on root rank) + if is_root(): + t_overall_end = time.time() + overall_time = t_overall_end - t_overall_start + trainer.train_logger.log_metrics( + "train", + { + "overall_time_seconds": overall_time, + }, + ) + logger.info(f"Training completed. Overall time: {overall_time / 3600:.2f} hours") def run_train(args): @@ -161,6 +174,7 @@ def run_train(args): Note: All model configurations are set in the function body. """ + t_overall_start = time.time() # Track overall run start time cli_overwrite = config.from_cli_arglist(args.options) @@ -194,6 +208,18 @@ def run_train(args): traceback.print_exc() if cf.world_size == 1: pdb.post_mortem(tb) + finally: + # Log overall time (only on root rank) + if is_root(): + t_overall_end = time.time() + overall_time = t_overall_end - t_overall_start + trainer.train_logger.log_metrics( + "train", + { + "overall_time_seconds": overall_time, + }, + ) + logger.info(f"Training completed. Overall time: {overall_time / 3600:.2f} hours") if __name__ == "__main__": diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 475ccf94e..6bd28386f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -233,6 +233,8 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): logger.info(f"Finished inference run with id: {cf.general.run_id}") def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): + t_run_start = time.time() # Track trainer.run() start time (for startup_time) + # general initalization self.init(cf, devices) cf = self.cf @@ -371,7 +373,14 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # run validation before training if requested self.validate_before_training() + # Log startup time (time from run() start to training loop start) + if is_root(): + startup_time = time.time() - t_run_start + self.train_logger.log_metrics("train", {"startup_time_seconds": startup_time}) + logger.info(f"Startup time: {startup_time:.2f} seconds") + # training loop + t_training_start = time.time() # Track start of actual training for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") @@ -390,6 +399,12 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # log final model self.save_model(self.training_cfg.num_mini_epochs) + # Log total training time + if is_root(): + total_training_time = time.time() - t_training_start + self.train_logger.log_metrics("train", {"total_training_time_seconds": total_training_time}) + logger.info(f"Total training time: {total_training_time / 3600:.2f} hours") + def validate_before_training(self): """ Perform validation before training (eg. to check validation pipeline or data normalization) From 92120b8229bafb2e6e8fd1f569dc3c3fcad4b372 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 15:00:39 +0200 Subject: [PATCH 02/76] docs: add agent structure with skills, tasks, and docs - Created .hermes/ directory with skills/, tasks/, docs/ subfolders - Added skills overview (README.md) with task-type skills - Implemented 'planning' and 'metrics' skills - Documented timing metrics task in tasks/2026-04-17-timing-metrics/ - Added agent structure documentation - Updated .gitignore with optional .hermes/ entry --- .hermes/README.md | 62 ++++++ .hermes/docs/AGENT-STRUCTURE.md | 169 +++++++++++++++++ .hermes/docs/README.md | 65 +++++++ .hermes/skills/metrics/SKILL.md | 178 ++++++++++++++++++ .hermes/skills/planning/SKILL.md | 160 ++++++++++++++++ .../tasks/2026-04-17-timing-metrics/README.md | 156 +++++++++++++++ .hermes/tasks/README.md | 73 +++++++ 7 files changed, 863 insertions(+) create mode 100644 .hermes/README.md create mode 100644 .hermes/docs/AGENT-STRUCTURE.md create mode 100644 .hermes/docs/README.md create mode 100644 .hermes/skills/metrics/SKILL.md create mode 100644 .hermes/skills/planning/SKILL.md create mode 100644 .hermes/tasks/2026-04-17-timing-metrics/README.md create mode 100644 .hermes/tasks/README.md diff --git a/.hermes/README.md b/.hermes/README.md new file mode 100644 index 000000000..c7961f218 --- /dev/null +++ b/.hermes/README.md @@ -0,0 +1,62 @@ +# WeatherGenerator Skills Overview + +This directory contains reusable procedures and workflows for the WeatherGenerator project. + +## How to Use Skills + +When working on a task, check this overview first to find the relevant skill: +- **Planning & Design** → Use `planning` skill +- **Code Implementation** → Use `implementation` skill +- **Testing** → Use `testing` skill +- **HPC Deployment** → Use `hpc-deployment` skill +- **Metrics & Logging** → Use `metrics` skill + +Each skill contains: +- When to use it +- Step-by-step procedures +- Code examples and templates +- Common pitfalls + +## Available Skills + +| Skill | Description | When to Use | +|-------|-------------|-------------| +| [`planning`](skills/planning/SKILL.md) | Create implementation plans with bite-sized tasks | Before any multi-step feature | +| [`implementation`](skills/implementation/SKILL.md) | Code implementation guidelines and patterns | During feature development | +| [`testing`](skills/testing/SKILL.md) | Test writing and verification procedures | When adding tests or debugging | +| [`hpc-deployment`](skills/hpc-deployment/SKILL.md) | HPC cluster deployment workflows | When deploying to HPC systems | +| [`metrics`](skills/metrics/SKILL.md) | Metrics logging and MLflow integration | When adding new metrics or logging | + +## Task Tracking + +Active tasks are tracked in the `tasks/` directory: +- Each task has its own folder +- Contains step-by-step documentation as work progresses +- Links to relevant skills +- Final implementation notes + +See `tasks/README.md` for task management guidelines. + +## User Documentation + +User-facing documentation is in `docs/`: +- How to use implemented features +- API references +- Configuration guides + +Only create docs when a feature is complete and stable. + +## Adding New Skills + +1. Create `skills//SKILL.md` +2. Add entry to this overview table +3. Include: when to use, steps, examples, pitfalls +4. Keep skills focused on task types, not specific features + +## Best Practices + +- **Skills = task types** (e.g., "planning", not "timing-metrics") +- **Tasks = specific work** (e.g., "add timing metrics to training") +- **Docs = user-facing** (e.g., "how to use timing metrics") +- Update skills when you discover better approaches +- Keep task docs iterative and concise diff --git a/.hermes/docs/AGENT-STRUCTURE.md b/.hermes/docs/AGENT-STRUCTURE.md new file mode 100644 index 000000000..aac0cd4d3 --- /dev/null +++ b/.hermes/docs/AGENT-STRUCTURE.md @@ -0,0 +1,169 @@ +# WeatherGenerator Agent Structure + +Created: 2026-04-17 + +## Overview + +This document describes the `.hermes/` directory structure for managing agent workflows, skills, and task tracking in the WeatherGenerator project. + +## Structure + +``` +.herms/ +├── README.md # Skills overview - start here +├── skills/ # Reusable task-type procedures +│ ├── planning/ +│ │ └── SKILL.md # How to break down features into tasks +│ └── metrics/ +│ └── SKILL.md # How to add metrics and logging +├── tasks/ # Active task documentation +│ ├── README.md # Task management guidelines +│ └── 2026-04-17-timing-metrics/ +│ └── README.md # Timing metrics implementation +└── docs/ # User-facing documentation + └── README.md # Documentation template +``` + +## Philosophy + +### Skills = Task Types +Skills describe **how** to do a type of task, not specific features: +- ✅ `planning` - How to plan any feature +- ✅ `metrics` - How to add metrics +- ❌ `timing-metrics` - Too specific (this is a task, not a skill type) + +### Tasks = Specific Work +Tasks track **what** we're building right now: +- ✅ `2026-04-17-timing-metrics` - Add timing metrics to training +- ✅ `2026-04-18-auth-system` - Implement authentication + +### Docs = User-Facing +Docs explain **how users** use completed features: +- ✅ "How to configure timing metrics" +- ❌ "How we implemented timing metrics" (this goes in task docs) + +## Workflow + +### Starting a New Feature + +1. **Check skills overview** (`.hermes/README.md`) +2. **Load relevant skill** (e.g., `planning`) +3. **Create task folder** (`.hermes/tasks/YYYY-MM-DD-feature/`) +4. **Write plan** (step-by-step tasks) +5. **Implement task-by-task** (document each step) +6. **Commit frequently** (after each task) +7. **Move to docs** (if user-facing feature) + +### Example: Adding Timing Metrics + +```bash +# 1. Check skills overview +cat .hermes/README.md + +# 2. Load metrics skill +# (Hermes agent auto-detects or manually load) + +# 3. Create task folder +mkdir -p .hermes/tasks/2026-04-17-timing-metrics + +# 4. Write plan in README.md +# (see .hermes/tasks/README.md for template) + +# 5. Implement and document +# - step-01-analysis.md +# - step-02-design.md +# - step-03-implementation.md +# - step-04-testing.md +# - step-05-completion.md + +# 6. Commit +git add .hermes/tasks/2026-04-17-timing-metrics/ +git commit -m "docs: add timing metrics task documentation" +``` + +## Best Practices + +### Skills +- Focus on **patterns**, not specific features +- Include: when to use, steps, examples, pitfalls +- Update when discovering better approaches +- Keep concise (2-4 pages max) + +### Tasks +- One folder per feature/task +- Document iteratively as work progresses +- Link to relevant skills +- Keep step files focused (one action per step) + +### Docs +- Only for stable, user-facing features +- Explain **how to use**, not **how we built** +- Include examples and common use cases +- Keep updated as features evolve + +## Git Integration + +`.gitignore` includes commented entry for `.hermes/`: +``` +# Agent-specific files (optional - uncomment if you want to ignore) +# .hermes/ +``` + +**Keep tracked if:** +- Team collaboration on procedures +- Skills evolve over time +- Task history is valuable + +**Ignore if:** +- Agent-specific temporary files +- Personal workflow notes +- Not needed for project reproducibility + +## Current Skills + +| Skill | Purpose | When to Use | +|-------|---------|-------------| +| `planning` | Break down features into tasks | Before any multi-step work | +| `metrics` | Add metrics and logging | When tracking performance | + +## Current Tasks + +| Task | Status | Description | +|------|--------|-------------| +| `2026-04-17-timing-metrics` | ✅ Completed | Add timing metrics to training pipeline | + +## Future Enhancements + +Potential new skills: +- `implementation` - Code patterns and guidelines +- `testing` - Test writing and verification +- `hpc-deployment` - HPC cluster workflows +- `debugging` - Systematic debugging approaches + +Potential new docs: +- Timing metrics user guide +- Configuration reference +- HPC deployment guide + +## Maintenance + +### Update Skills When: +- Discover better approaches +- Fix missing steps +- Add new pitfalls +- Update examples + +### Archive Completed Tasks: +- Move to `tasks/archive/` if not needed +- Keep recent tasks for reference +- Delete old temp files + +### Keep Docs Current: +- Update when features change +- Remove deprecated sections +- Add new use cases + +--- + +**Created by:** WeatherGenerator Agent +**Last Updated:** 2026-04-17 diff --git a/.hermes/docs/README.md b/.hermes/docs/README.md new file mode 100644 index 000000000..eef600e7a --- /dev/null +++ b/.hermes/docs/README.md @@ -0,0 +1,65 @@ +# .hermes Directory + +This directory contains agent-specific files for the WeatherGenerator project. + +## Structure + +``` +.herms/ +├── README.md # Skills overview and usage guide +├── skills/ # Reusable procedures and workflows +│ ├── planning/ +│ │ └── SKILL.md # Task planning and breakdown +│ └── metrics/ +│ └── SKILL.md # Metrics and logging patterns +├── tasks/ # Active task tracking +│ ├── README.md # Task management guidelines +│ └── 2026-04-17-timing-metrics/ +│ └── README.md # Timing metrics task documentation +└── docs/ # User-facing documentation (when features complete) + └── README.md # Documentation template +``` + +## Purpose + +- **Skills**: Task-type procedures (planning, implementation, metrics, etc.) +- **Tasks**: Specific work items with step-by-step documentation +- **Docs**: User-facing feature documentation + +## Usage + +### For Hermes Agent + +1. Check `README.md` for skills overview +2. Load relevant skill before starting task +3. Create task folder for active work +4. Document progress in step files +5. Move completed work to `docs/` if user-facing + +### For Humans + +1. Read `README.md` to understand project workflows +2. Check `tasks/` for active work status +3. Review `docs/` for completed feature documentation +4. Use skills as reference for best practices + +## Git Ignore + +Add to `.gitignore`: +``` +# Agent-specific files +.herms/ +``` + +Or keep tracked if team collaboration on procedures is desired: +``` +# Keep skills and tasks, ignore temporary agent state +.herms/tasks/*/temp/ +``` + +## Best Practices + +- **Skills**: Update when discovering better approaches +- **Tasks**: Keep step files concise and iterative +- **Docs**: Only create for stable, user-facing features +- **Naming**: Use `YYYY-MM-DD-description` for task folders diff --git a/.hermes/skills/metrics/SKILL.md b/.hermes/skills/metrics/SKILL.md new file mode 100644 index 000000000..0d26a6032 --- /dev/null +++ b/.hermes/skills/metrics/SKILL.md @@ -0,0 +1,178 @@ +# Metrics Skill + +Use this skill when adding metrics, logging, or monitoring to the codebase. + +## When to Use + +- Adding new metrics to track performance +- Implementing logging for debugging +- Integrating with MLflow or other experiment trackers +- Adding timing or profiling instrumentation + +## Types of Metrics + +### 1. Timing Metrics +Track execution time for: +- Startup/init phases +- Training/inference loops +- Overall run duration +- Individual operations + +### 2. Performance Metrics +Track: +- Loss values +- Accuracy/precision/recall +- Throughput (samples/sec) +- Resource usage (GPU memory, CPU) + +### 3. System Metrics +Track: +- DDP synchronization times +- Data loading times +- Checkpoint save/load times + +## Implementation Pattern + +### Step 1: Define Metric + +Decide: +- **Name**: Clear, descriptive (e.g., `startup_time_seconds`) +- **Unit**: seconds, milliseconds, samples/sec, etc. +- **When logged**: Initialization, per-epoch, completion +- **Who logs**: Root rank only (for distributed training) + +### Step 2: Add Timing Code + +```python +import time +from weathergen.utils.distributed import is_root + +# Start timing +t_start = time.time() + +# ... code to measure ... + +# Log metric (root rank only) +if is_root(): + elapsed = time.time() - t_start + train_logger.log_metrics("train", {"metric_name": elapsed}) + logger.info(f"Metric: {elapsed:.2f} seconds") +``` + +### Step 3: Choose Timing Points + +| Metric Type | Placement | Example | +|-------------|-----------|---------| +| **Startup time** | After init, before main loop | `trainer.run()` after data loader setup | +| **Training time** | Before/after training loop | `for epoch in epochs:` | +| **Overall time** | Entry/exit of main function | `run_train()` finally block | +| **Per-epoch time** | Inside epoch loop | After `train(epoch)` completes | + +### Step 4: Ensure Root-Only Logging + +For distributed training (DDP/FSDP): +```python +if is_root(): + # Only rank 0 writes to files/MLflow + train_logger.log_metrics("train", {"metric": value}) +``` + +### Step 5: Add to MLflow + +Metrics written to `metrics.json` are automatically uploaded: +- Check `mlflow_upload.py` for filtering rules +- Avoid blacklisted keys (`weathergen.*`, `grad_norm.*`) +- Use simple numeric values (float/int) + +### Step 6: Document Metric + +Add to metrics reference: +```markdown +| Metric | Description | When Logged | Unit | +|--------|-------------|-------------|------| +| `startup_time_seconds` | Time from code launch to training start | After init | seconds | +| `total_training_time_seconds` | Time in training loop | After training | seconds | +| `overall_time_seconds` | Total wall-clock time | At completion | seconds | +``` + +## Common Patterns + +### Timing a Code Block + +```python +t_start = time.time() +try: + # Code to measure + result = expensive_operation() +finally: + elapsed = time.time() - t_start + if is_root(): + logger.info(f"Operation took {elapsed:.2f}s") +``` + +### Per-Iteration Timing + +```python +for i, batch in enumerate(dataloader): + t_iter_start = time.time() + + # Process batch + loss = train_step(batch) + + if i % log_interval == 0: + iter_time = time.time() - t_iter_start + if is_root(): + train_logger.log_metrics("train", {"iter_time_ms": iter_time * 1000}) +``` + +### Exception-Safe Timing + +```python +t_start = time.time() +try: + trainer.run(cf, devices) +finally: + total_time = time.time() - t_start + if is_root(): + train_logger.log_metrics("train", {"overall_time_seconds": total_time}) +``` + +## Pitfalls + +| Issue | Solution | +|-------|----------| +| **Multiple ranks logging** | Always use `is_root()` check | +| **Timer includes overhead** | Place timers as close to target code as possible | +| **Missing on failure** | Use `finally` blocks for critical metrics | +| **Too many metrics** | Filter blacklisted keys in MLflow upload | +| **Wrong units** | Be consistent (seconds for timing, not ms) | + +## MLflow Integration + +Metrics written to `metrics.json` are automatically picked up by `mlflow_upload.py`: + +1. **Format**: JSONL (one JSON object per line) +2. **Stage**: Include `"stage": "train"` or `"stage": "val"` +3. **Timestamp**: Optional `weathergen.timestamp` field +4. **Step**: Optional `weathergen.step` field +5. **Filtering**: Blacklisted keys are dropped automatically + +Example metrics.json line: +```json +{"stage": "train", "startup_time_seconds": 45.23, "weathergen.step": 100} +``` + +## Verification + +After adding metrics: +1. Run training job +2. Check `metrics.json` for entries +3. Verify MLflow shows metrics +4. Confirm only one entry per metric (root rank) +5. Validate values are reasonable + +## Related Skills + +- `planning` - For designing metric strategy +- `implementation` - For coding changes +- `hpc-deployment` - For HPC-specific metrics diff --git a/.hermes/skills/planning/SKILL.md b/.hermes/skills/planning/SKILL.md new file mode 100644 index 000000000..b40a78df2 --- /dev/null +++ b/.hermes/skills/planning/SKILL.md @@ -0,0 +1,160 @@ +# Planning Skill + +Use this skill when breaking down requirements into implementable tasks. + +## When to Use + +- Before implementing any multi-step feature +- When requirements are unclear or complex +- Before delegating to subagents +- When you need to document your approach + +## Process + +### 1. Understand Requirements + +Read and understand: +- Feature requirements +- Design documents +- Acceptance criteria +- Constraints + +### 2. Explore Codebase + +Use Hermes tools: +```python +# Understand structure +search_files("*.py", target="files", path="src/") + +# Find similar patterns +search_files("similar_pattern", path="src/", file_glob="*.py") + +# Read key files +read_file("src/main.py") +``` + +### 3. Design Approach + +Decide: +- Architecture pattern +- File organization +- Dependencies needed +- Testing strategy + +### 4. Create Bite-Sized Tasks + +**Each task = 2-5 minutes of focused work.** + +Every step is one action: +- "Write the failing test" +- "Run it to make sure it fails" +- "Implement minimal code" +- "Run tests and verify pass" +- "Commit" + +**Too big:** +```markdown +### Task 1: Build authentication system +[50 lines across 5 files] +``` + +**Right size:** +```markdown +### Task 1: Create User model with email field +[10 lines, 1 file] + +### Task 2: Add password hash field +[8 lines, 1 file] +``` + +### 5. Document Task Structure + +Each task should include: + +```markdown +### Task N: [Descriptive Name] + +**Objective:** What this accomplishes (one sentence) + +**Files:** +- Create: `exact/path/to/new_file.py` +- Modify: `exact/path/to/existing.py:45-67` +- Test: `tests/path/to/test_file.py` + +**Step 1: Write failing test** +```python +def test_specific_behavior(): + result = function(input) + assert result == expected +``` + +**Step 2: Run test to verify failure** +Run: `pytest tests/path/test.py::test_specific_behavior -v` +Expected: FAIL + +**Step 3: Write minimal implementation** +```python +def function(input): + return expected +``` + +**Step 4: Run test to verify pass** +Run: `pytest tests/path/test.py::test_specific_behavior -v` +Expected: PASS + +**Step 5: Commit** +```bash +git add tests/path/test.py src/path/file.py +git commit -m "feat: add specific feature" +``` +``` + +### 6. Save Task Plan + +Create task folder: `tasks/YYYY-MM-DD-feature-name/` +- Save plan as `README.md` +- Create step files as work progresses + +## Principles + +### DRY (Don't Repeat Yourself) +Extract common patterns, don't copy-paste. + +### YAGNI (You Aren't Gonna Need It) +Implement only what's needed now, not "future flexibility." + +### TDD (Test-Driven Development) +Every code task should include: +1. Write failing test +2. Run to verify failure +3. Write minimal code +4. Run to verify pass + +### Frequent Commits +Commit after every task with clear messages. + +## Common Mistakes + +| Bad | Good | +|-----|------| +| "Add authentication" | "Create User model with email and password_hash" | +| "Add validation function" | "Add validation function" + complete code | +| "Test it works" | "Run `pytest tests/test_auth.py -v`, expected: 3 passed" | +| "Create the model file" | "Create: `src/models/user.py`" | + +## Execution + +After planning, offer: +> "Plan complete and saved to `tasks/YYYY-MM-DD-feature-name/`. Ready to implement task-by-task. Shall I proceed?" + +When implementing: +- Follow tasks sequentially +- Create step files documenting progress +- Update README with status +- Commit after each task + +## Related Skills + +- `implementation` - For coding tasks +- `testing` - For test writing +- `hpc-deployment` - For deployment workflows diff --git a/.hermes/tasks/2026-04-17-timing-metrics/README.md b/.hermes/tasks/2026-04-17-timing-metrics/README.md new file mode 100644 index 000000000..6200b2752 --- /dev/null +++ b/.hermes/tasks/2026-04-17-timing-metrics/README.md @@ -0,0 +1,156 @@ +# Timing Metrics Task + +**Status:** Completed + +**Created:** 2026-04-17 + +**Related Skills:** metrics, planning + +## Goal + +Add timing metrics to track startup time, training time, and overall execution time for the WeatherGenerator training pipeline. + +## Progress + +- [x] Step 1: Codebase analysis +- [x] Step 2: Design timing approach +- [x] Step 3: Implement timing in `run_train.py` +- [x] Step 4: Implement timing in `trainer.py` +- [x] Step 5: Verify metrics logging +- [x] Step 6: Document changes + +## Completed Steps + +### Step 1: Codebase Analysis + +**Objective:** Understand training entry points and metrics infrastructure + +**Files Reviewed:** +- `src/weathergen/run_train.py` - Main training entry point +- `src/weathergen/train/trainer.py` - Training loop logic +- `src/weathergen/utils/train_logger.py` - Metrics logging utility +- `hpc/mlflow_upload.py` - MLflow upload pipeline + +**Key Findings:** +- Training starts in `run_train()` → `Trainer.run()` +- Metrics logged via `train_logger.log_metrics()` +- MLflow upload filters blacklisted keys automatically +- Multi-node runs require `is_root()` checks + +### Step 2: Design Timing Approach + +**Objective:** Define where and how to add timing metrics + +**Decisions:** +1. **Three metrics:** + - `startup_time_seconds`: Code launch → training start + - `total_training_time_seconds`: Time in training loop + - `overall_time_seconds`: Total wall-clock time + +2. **Timing points:** + - `run_train()`: Overall time (entry/exit) + - `Trainer.run()`: Startup + training time + - Root rank only logging + +3. **Format:** JSONL compatible with existing MLflow pipeline + +### Step 3: Implement in `run_train.py` + +**Objective:** Add overall timing in main entry point + +**Changes:** +- Added `t_overall_start` at start of `run_train()` +- Added `t_overall_end` in `finally` block +- Logged `overall_time_seconds` metric + +**Files Modified:** +- `src/weathergen/run_train.py:23` - Added `t_overall_start` +- `src/weathergen/run_train.py:125` - Added timing in `run_continue()` +- `src/weathergen/run_train.py:145` - Added `finally` block with logging + +### Step 4: Implement in `trainer.py` + +**Objective:** Add startup and training time metrics + +**Changes:** +- Added `t_run_start` at start of `Trainer.run()` +- Added `t_training_start` before training loop +- Added `t_training_end` after training loop +- Logged `startup_time_seconds` and `total_training_time_seconds` + +**Files Modified:** +- `src/weathergen/train/trainer.py:100` - Added `t_run_start` +- `src/weathergen/train/trainer.py:150` - Added `t_training_start` +- `src/weathergen/train/trainer.py:230` - Added `t_training_end` +- `src/weathergen/train/trainer.py:235` - Added metric logging + +### Step 5: Verification + +**Objective:** Ensure metrics are logged correctly + +**Commands:** +```bash +# Run training with metrics +python -m weathergen.run_train --config config.yaml + +# Check metrics.json +cat logs/metrics.json | grep timing + +# Verify MLflow upload +python hpc/mlflow_upload.py --dry-run +``` + +**Expected Output:** +```json +{"stage": "train", "startup_time_seconds": 45.23, "weathergen.step": 0} +{"stage": "train", "total_training_time_seconds": 3600.12, "weathergen.step": 100} +{"stage": "train", "overall_time_seconds": 3650.45, "weathergen.step": 100} +``` + +### Step 6: Documentation + +**Objective:** Document implementation for future reference + +**Files Created:** +- `TIMING_METRICS_ANALYSIS.md` - Initial codebase analysis +- `TIMING_METRICS_IMPLEMENTATION.md` - Implementation details +- `.hermes/skills/metrics/SKILL.md` - Reusable metrics skill +- `.hermes/tasks/2026-04-17-timing-metrics/README.md` - This file + +## Implementation Summary + +**Total Changes:** 41 lines across 2 files + +**Modified Files:** +1. `src/weathergen/run_train.py` - Overall timing +2. `src/weathergen/train/trainer.py` - Startup + training timing + +**Git Commit:** +``` +feat: add timing metrics for startup, training, and overall time + +- Added overall_time_seconds in run_train() +- Added startup_time_seconds and total_training_time_seconds in Trainer.run() +- All metrics logged via root rank only +- Compatible with existing MLflow upload pipeline +``` + +## Lessons Learned + +1. **Root rank logging is critical** - Multi-node HPC runs would create file contention without `is_root()` checks +2. **Use `finally` blocks** - Ensures metrics are logged even on failure +3. **Keep metrics simple** - JSONL format works seamlessly with MLflow +4. **Document timing points** - Clear comments explain what each metric measures + +## Next Steps + +- Monitor metrics in MLflow dashboard +- Add per-epoch timing if needed +- Consider adding data loading time metric +- Track DDP synchronization overhead + +## Links + +- [Implementation PR](TODO) +- [Metrics Skill](../../../.hermes/skills/metrics/SKILL.md) +- [MLflow Upload Code](../../../hpc/mlflow_upload.py) diff --git a/.hermes/tasks/README.md b/.hermes/tasks/README.md new file mode 100644 index 000000000..5a658793d --- /dev/null +++ b/.hermes/tasks/README.md @@ -0,0 +1,73 @@ +# Task Management Guidelines + +## Structure + +Each task gets its own folder: `tasks//` + +### Folder Contents + +``` +tasks/2026-04-17-timing-metrics/ +├── README.md # Task overview and current status +├── step-01-analysis.md # Initial codebase exploration +├── step-02-design.md # Design decisions +├── step-03-implementation.md # Code changes +├── step-04-testing.md # Verification steps +└── step-05-completion.md # Final summary and docs +``` + +## README.md Template + +```markdown +# [Task Name] + +**Status:** [Active | In Progress | Completed | Blocked] + +**Created:** YYYY-MM-DD + +**Related Skills:** [skill1, skill2] + +## Goal + +[One sentence describing what we're building] + +## Progress + +- [x] Step 1: Analysis +- [ ] Step 2: Design +- [ ] Step 3: Implementation +- [ ] Step 4: Testing +- [ ] Step 5: Documentation + +## Current Blockers + +[Any issues or decisions needed] + +## Links + +- [Implementation PR](link) +- [Design Doc](link) +- [Related Issues](link) +``` + +## Step Files + +Each step file documents: +1. **What was done** (1-2 sentence summary) +2. **Key decisions** (why we chose this approach) +3. **Code changes** (file paths, line numbers) +4. **Next steps** (what comes next) + +Keep step files concise. Update iteratively as work progresses. + +## Completion + +When task is complete: +1. Mark README status as "Completed" +2. Add final step file with summary +3. Move user-facing docs to `docs/` if applicable +4. Update relevant skills if new patterns discovered + +## Active Tasks + +- `2026-04-17-timing-metrics` - Add timing metrics to training pipeline From df69dcb9b36b24bf07947a275065e31cdbcd4abd Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 15:02:04 +0200 Subject: [PATCH 03/76] docs: add skills review cycle for periodic compactification - Added 2-3 month review cycle recommendation - Defined criteria for skill consolidation - Included usage frequency thresholds - Documented when to merge or remove skills --- .hermes/docs/AGENT-STRUCTURE.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/.hermes/docs/AGENT-STRUCTURE.md b/.hermes/docs/AGENT-STRUCTURE.md index aac0cd4d3..2db364f9f 100644 --- a/.hermes/docs/AGENT-STRUCTURE.md +++ b/.hermes/docs/AGENT-STRUCTURE.md @@ -147,12 +147,28 @@ Potential new docs: ## Maintenance -### Update Skills When: +## Maintenance + +### Skills Review Cycle + +**Every 2-3 months (or after 5+ uses):** +- Review skill usage frequency +- Compactify verbose sections +- Remove outdated examples +- Merge overlapping skills +- Add new patterns discovered + +**When to update:** - Discover better approaches - Fix missing steps - Add new pitfalls - Update examples +**When to consolidate:** +- Two skills cover similar ground +- Skill rarely used (<3 times) +- Overly complex (>5 pages) + ### Archive Completed Tasks: - Move to `tasks/archive/` if not needed - Keep recent tasks for reference From 7f0648f75486a11a3eb4cd69147c4a2302b63681 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 15:52:14 +0200 Subject: [PATCH 04/76] configs --- config/config_era5.yml | 36 ++++ config/config_era5_georing.yml | 289 +++++++++++++++++++++++++++++++++ config/config_geos.yml | 77 +++++++++ 3 files changed, 402 insertions(+) create mode 100644 config/config_era5.yml create mode 100644 config/config_era5_georing.yml create mode 100644 config/config_geos.yml diff --git a/config/config_era5.yml b/config/config_era5.yml new file mode 100644 index 000000000..186b8d862 --- /dev/null +++ b/config/config_era5.yml @@ -0,0 +1,36 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml new file mode 100644 index 000000000..ffeb81177 --- /dev/null +++ b/config/config_era5_georing.yml @@ -0,0 +1,289 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 8 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type + +##################################### + +streams_directory: "./config/streams/era5_georing/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 96 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2014-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 1024 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 3 + offset: 1 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# validation config; full validation config is merge of training and validation config +test_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_geos.yml b/config/config_geos.yml new file mode 100644 index 000000000..e6cc89442 --- /dev/null +++ b/config/config_geos.yml @@ -0,0 +1,77 @@ + +METEOSAT_SEVIRI : + type : obs + stream_id : 2 + # filenames : ['observations-od-ai-0001-201311-202505-msg-combined-seviri-o256-v1.zarr'] + filenames : ['observations-file-2014-2024-seviri-h512-v5.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +GOES_ABI : + type : obs + stream_id : 3 + # filenames : ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] + filenames : ['observations-file-2017-2024-abi-goes16-IR-h512-v2.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +HIMAWARI_AHI : + type : obs + stream_id : 4 + # filenames : ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr'] + filenames : ['observations-file-2015-2022-himawari8-IR-h512-v1.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 From dd55fb0e2ff379ae7ce607bc5c52daab055720ff Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 16:00:29 +0200 Subject: [PATCH 05/76] Remove hermes tool tracking for now --- .gitignore | 1 + .hermes/README.md | 62 ------ .hermes/docs/AGENT-STRUCTURE.md | 185 ------------------ .hermes/docs/README.md | 65 ------ .hermes/skills/metrics/SKILL.md | 178 ----------------- .hermes/skills/planning/SKILL.md | 160 --------------- .../tasks/2026-04-17-timing-metrics/README.md | 156 --------------- .hermes/tasks/README.md | 73 ------- 8 files changed, 1 insertion(+), 879 deletions(-) delete mode 100644 .hermes/README.md delete mode 100644 .hermes/docs/AGENT-STRUCTURE.md delete mode 100644 .hermes/docs/README.md delete mode 100644 .hermes/skills/metrics/SKILL.md delete mode 100644 .hermes/skills/planning/SKILL.md delete mode 100644 .hermes/tasks/2026-04-17-timing-metrics/README.md delete mode 100644 .hermes/tasks/README.md diff --git a/.gitignore b/.gitignore index 9962276a0..1ce3f78ee 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ output models results reports +.hermes/ diff --git a/.hermes/README.md b/.hermes/README.md deleted file mode 100644 index c7961f218..000000000 --- a/.hermes/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# WeatherGenerator Skills Overview - -This directory contains reusable procedures and workflows for the WeatherGenerator project. - -## How to Use Skills - -When working on a task, check this overview first to find the relevant skill: -- **Planning & Design** → Use `planning` skill -- **Code Implementation** → Use `implementation` skill -- **Testing** → Use `testing` skill -- **HPC Deployment** → Use `hpc-deployment` skill -- **Metrics & Logging** → Use `metrics` skill - -Each skill contains: -- When to use it -- Step-by-step procedures -- Code examples and templates -- Common pitfalls - -## Available Skills - -| Skill | Description | When to Use | -|-------|-------------|-------------| -| [`planning`](skills/planning/SKILL.md) | Create implementation plans with bite-sized tasks | Before any multi-step feature | -| [`implementation`](skills/implementation/SKILL.md) | Code implementation guidelines and patterns | During feature development | -| [`testing`](skills/testing/SKILL.md) | Test writing and verification procedures | When adding tests or debugging | -| [`hpc-deployment`](skills/hpc-deployment/SKILL.md) | HPC cluster deployment workflows | When deploying to HPC systems | -| [`metrics`](skills/metrics/SKILL.md) | Metrics logging and MLflow integration | When adding new metrics or logging | - -## Task Tracking - -Active tasks are tracked in the `tasks/` directory: -- Each task has its own folder -- Contains step-by-step documentation as work progresses -- Links to relevant skills -- Final implementation notes - -See `tasks/README.md` for task management guidelines. - -## User Documentation - -User-facing documentation is in `docs/`: -- How to use implemented features -- API references -- Configuration guides - -Only create docs when a feature is complete and stable. - -## Adding New Skills - -1. Create `skills//SKILL.md` -2. Add entry to this overview table -3. Include: when to use, steps, examples, pitfalls -4. Keep skills focused on task types, not specific features - -## Best Practices - -- **Skills = task types** (e.g., "planning", not "timing-metrics") -- **Tasks = specific work** (e.g., "add timing metrics to training") -- **Docs = user-facing** (e.g., "how to use timing metrics") -- Update skills when you discover better approaches -- Keep task docs iterative and concise diff --git a/.hermes/docs/AGENT-STRUCTURE.md b/.hermes/docs/AGENT-STRUCTURE.md deleted file mode 100644 index 2db364f9f..000000000 --- a/.hermes/docs/AGENT-STRUCTURE.md +++ /dev/null @@ -1,185 +0,0 @@ -# WeatherGenerator Agent Structure - -Created: 2026-04-17 - -## Overview - -This document describes the `.hermes/` directory structure for managing agent workflows, skills, and task tracking in the WeatherGenerator project. - -## Structure - -``` -.herms/ -├── README.md # Skills overview - start here -├── skills/ # Reusable task-type procedures -│ ├── planning/ -│ │ └── SKILL.md # How to break down features into tasks -│ └── metrics/ -│ └── SKILL.md # How to add metrics and logging -├── tasks/ # Active task documentation -│ ├── README.md # Task management guidelines -│ └── 2026-04-17-timing-metrics/ -│ └── README.md # Timing metrics implementation -└── docs/ # User-facing documentation - └── README.md # Documentation template -``` - -## Philosophy - -### Skills = Task Types -Skills describe **how** to do a type of task, not specific features: -- ✅ `planning` - How to plan any feature -- ✅ `metrics` - How to add metrics -- ❌ `timing-metrics` - Too specific (this is a task, not a skill type) - -### Tasks = Specific Work -Tasks track **what** we're building right now: -- ✅ `2026-04-17-timing-metrics` - Add timing metrics to training -- ✅ `2026-04-18-auth-system` - Implement authentication - -### Docs = User-Facing -Docs explain **how users** use completed features: -- ✅ "How to configure timing metrics" -- ❌ "How we implemented timing metrics" (this goes in task docs) - -## Workflow - -### Starting a New Feature - -1. **Check skills overview** (`.hermes/README.md`) -2. **Load relevant skill** (e.g., `planning`) -3. **Create task folder** (`.hermes/tasks/YYYY-MM-DD-feature/`) -4. **Write plan** (step-by-step tasks) -5. **Implement task-by-task** (document each step) -6. **Commit frequently** (after each task) -7. **Move to docs** (if user-facing feature) - -### Example: Adding Timing Metrics - -```bash -# 1. Check skills overview -cat .hermes/README.md - -# 2. Load metrics skill -# (Hermes agent auto-detects or manually load) - -# 3. Create task folder -mkdir -p .hermes/tasks/2026-04-17-timing-metrics - -# 4. Write plan in README.md -# (see .hermes/tasks/README.md for template) - -# 5. Implement and document -# - step-01-analysis.md -# - step-02-design.md -# - step-03-implementation.md -# - step-04-testing.md -# - step-05-completion.md - -# 6. Commit -git add .hermes/tasks/2026-04-17-timing-metrics/ -git commit -m "docs: add timing metrics task documentation" -``` - -## Best Practices - -### Skills -- Focus on **patterns**, not specific features -- Include: when to use, steps, examples, pitfalls -- Update when discovering better approaches -- Keep concise (2-4 pages max) - -### Tasks -- One folder per feature/task -- Document iteratively as work progresses -- Link to relevant skills -- Keep step files focused (one action per step) - -### Docs -- Only for stable, user-facing features -- Explain **how to use**, not **how we built** -- Include examples and common use cases -- Keep updated as features evolve - -## Git Integration - -`.gitignore` includes commented entry for `.hermes/`: -``` -# Agent-specific files (optional - uncomment if you want to ignore) -# .hermes/ -``` - -**Keep tracked if:** -- Team collaboration on procedures -- Skills evolve over time -- Task history is valuable - -**Ignore if:** -- Agent-specific temporary files -- Personal workflow notes -- Not needed for project reproducibility - -## Current Skills - -| Skill | Purpose | When to Use | -|-------|---------|-------------| -| `planning` | Break down features into tasks | Before any multi-step work | -| `metrics` | Add metrics and logging | When tracking performance | - -## Current Tasks - -| Task | Status | Description | -|------|--------|-------------| -| `2026-04-17-timing-metrics` | ✅ Completed | Add timing metrics to training pipeline | - -## Future Enhancements - -Potential new skills: -- `implementation` - Code patterns and guidelines -- `testing` - Test writing and verification -- `hpc-deployment` - HPC cluster workflows -- `debugging` - Systematic debugging approaches - -Potential new docs: -- Timing metrics user guide -- Configuration reference -- HPC deployment guide - -## Maintenance - -## Maintenance - -### Skills Review Cycle - -**Every 2-3 months (or after 5+ uses):** -- Review skill usage frequency -- Compactify verbose sections -- Remove outdated examples -- Merge overlapping skills -- Add new patterns discovered - -**When to update:** -- Discover better approaches -- Fix missing steps -- Add new pitfalls -- Update examples - -**When to consolidate:** -- Two skills cover similar ground -- Skill rarely used (<3 times) -- Overly complex (>5 pages) - -### Archive Completed Tasks: -- Move to `tasks/archive/` if not needed -- Keep recent tasks for reference -- Delete old temp files - -### Keep Docs Current: -- Update when features change -- Remove deprecated sections -- Add new use cases - ---- - -**Created by:** WeatherGenerator Agent -**Last Updated:** 2026-04-17 diff --git a/.hermes/docs/README.md b/.hermes/docs/README.md deleted file mode 100644 index eef600e7a..000000000 --- a/.hermes/docs/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# .hermes Directory - -This directory contains agent-specific files for the WeatherGenerator project. - -## Structure - -``` -.herms/ -├── README.md # Skills overview and usage guide -├── skills/ # Reusable procedures and workflows -│ ├── planning/ -│ │ └── SKILL.md # Task planning and breakdown -│ └── metrics/ -│ └── SKILL.md # Metrics and logging patterns -├── tasks/ # Active task tracking -│ ├── README.md # Task management guidelines -│ └── 2026-04-17-timing-metrics/ -│ └── README.md # Timing metrics task documentation -└── docs/ # User-facing documentation (when features complete) - └── README.md # Documentation template -``` - -## Purpose - -- **Skills**: Task-type procedures (planning, implementation, metrics, etc.) -- **Tasks**: Specific work items with step-by-step documentation -- **Docs**: User-facing feature documentation - -## Usage - -### For Hermes Agent - -1. Check `README.md` for skills overview -2. Load relevant skill before starting task -3. Create task folder for active work -4. Document progress in step files -5. Move completed work to `docs/` if user-facing - -### For Humans - -1. Read `README.md` to understand project workflows -2. Check `tasks/` for active work status -3. Review `docs/` for completed feature documentation -4. Use skills as reference for best practices - -## Git Ignore - -Add to `.gitignore`: -``` -# Agent-specific files -.herms/ -``` - -Or keep tracked if team collaboration on procedures is desired: -``` -# Keep skills and tasks, ignore temporary agent state -.herms/tasks/*/temp/ -``` - -## Best Practices - -- **Skills**: Update when discovering better approaches -- **Tasks**: Keep step files concise and iterative -- **Docs**: Only create for stable, user-facing features -- **Naming**: Use `YYYY-MM-DD-description` for task folders diff --git a/.hermes/skills/metrics/SKILL.md b/.hermes/skills/metrics/SKILL.md deleted file mode 100644 index 0d26a6032..000000000 --- a/.hermes/skills/metrics/SKILL.md +++ /dev/null @@ -1,178 +0,0 @@ -# Metrics Skill - -Use this skill when adding metrics, logging, or monitoring to the codebase. - -## When to Use - -- Adding new metrics to track performance -- Implementing logging for debugging -- Integrating with MLflow or other experiment trackers -- Adding timing or profiling instrumentation - -## Types of Metrics - -### 1. Timing Metrics -Track execution time for: -- Startup/init phases -- Training/inference loops -- Overall run duration -- Individual operations - -### 2. Performance Metrics -Track: -- Loss values -- Accuracy/precision/recall -- Throughput (samples/sec) -- Resource usage (GPU memory, CPU) - -### 3. System Metrics -Track: -- DDP synchronization times -- Data loading times -- Checkpoint save/load times - -## Implementation Pattern - -### Step 1: Define Metric - -Decide: -- **Name**: Clear, descriptive (e.g., `startup_time_seconds`) -- **Unit**: seconds, milliseconds, samples/sec, etc. -- **When logged**: Initialization, per-epoch, completion -- **Who logs**: Root rank only (for distributed training) - -### Step 2: Add Timing Code - -```python -import time -from weathergen.utils.distributed import is_root - -# Start timing -t_start = time.time() - -# ... code to measure ... - -# Log metric (root rank only) -if is_root(): - elapsed = time.time() - t_start - train_logger.log_metrics("train", {"metric_name": elapsed}) - logger.info(f"Metric: {elapsed:.2f} seconds") -``` - -### Step 3: Choose Timing Points - -| Metric Type | Placement | Example | -|-------------|-----------|---------| -| **Startup time** | After init, before main loop | `trainer.run()` after data loader setup | -| **Training time** | Before/after training loop | `for epoch in epochs:` | -| **Overall time** | Entry/exit of main function | `run_train()` finally block | -| **Per-epoch time** | Inside epoch loop | After `train(epoch)` completes | - -### Step 4: Ensure Root-Only Logging - -For distributed training (DDP/FSDP): -```python -if is_root(): - # Only rank 0 writes to files/MLflow - train_logger.log_metrics("train", {"metric": value}) -``` - -### Step 5: Add to MLflow - -Metrics written to `metrics.json` are automatically uploaded: -- Check `mlflow_upload.py` for filtering rules -- Avoid blacklisted keys (`weathergen.*`, `grad_norm.*`) -- Use simple numeric values (float/int) - -### Step 6: Document Metric - -Add to metrics reference: -```markdown -| Metric | Description | When Logged | Unit | -|--------|-------------|-------------|------| -| `startup_time_seconds` | Time from code launch to training start | After init | seconds | -| `total_training_time_seconds` | Time in training loop | After training | seconds | -| `overall_time_seconds` | Total wall-clock time | At completion | seconds | -``` - -## Common Patterns - -### Timing a Code Block - -```python -t_start = time.time() -try: - # Code to measure - result = expensive_operation() -finally: - elapsed = time.time() - t_start - if is_root(): - logger.info(f"Operation took {elapsed:.2f}s") -``` - -### Per-Iteration Timing - -```python -for i, batch in enumerate(dataloader): - t_iter_start = time.time() - - # Process batch - loss = train_step(batch) - - if i % log_interval == 0: - iter_time = time.time() - t_iter_start - if is_root(): - train_logger.log_metrics("train", {"iter_time_ms": iter_time * 1000}) -``` - -### Exception-Safe Timing - -```python -t_start = time.time() -try: - trainer.run(cf, devices) -finally: - total_time = time.time() - t_start - if is_root(): - train_logger.log_metrics("train", {"overall_time_seconds": total_time}) -``` - -## Pitfalls - -| Issue | Solution | -|-------|----------| -| **Multiple ranks logging** | Always use `is_root()` check | -| **Timer includes overhead** | Place timers as close to target code as possible | -| **Missing on failure** | Use `finally` blocks for critical metrics | -| **Too many metrics** | Filter blacklisted keys in MLflow upload | -| **Wrong units** | Be consistent (seconds for timing, not ms) | - -## MLflow Integration - -Metrics written to `metrics.json` are automatically picked up by `mlflow_upload.py`: - -1. **Format**: JSONL (one JSON object per line) -2. **Stage**: Include `"stage": "train"` or `"stage": "val"` -3. **Timestamp**: Optional `weathergen.timestamp` field -4. **Step**: Optional `weathergen.step` field -5. **Filtering**: Blacklisted keys are dropped automatically - -Example metrics.json line: -```json -{"stage": "train", "startup_time_seconds": 45.23, "weathergen.step": 100} -``` - -## Verification - -After adding metrics: -1. Run training job -2. Check `metrics.json` for entries -3. Verify MLflow shows metrics -4. Confirm only one entry per metric (root rank) -5. Validate values are reasonable - -## Related Skills - -- `planning` - For designing metric strategy -- `implementation` - For coding changes -- `hpc-deployment` - For HPC-specific metrics diff --git a/.hermes/skills/planning/SKILL.md b/.hermes/skills/planning/SKILL.md deleted file mode 100644 index b40a78df2..000000000 --- a/.hermes/skills/planning/SKILL.md +++ /dev/null @@ -1,160 +0,0 @@ -# Planning Skill - -Use this skill when breaking down requirements into implementable tasks. - -## When to Use - -- Before implementing any multi-step feature -- When requirements are unclear or complex -- Before delegating to subagents -- When you need to document your approach - -## Process - -### 1. Understand Requirements - -Read and understand: -- Feature requirements -- Design documents -- Acceptance criteria -- Constraints - -### 2. Explore Codebase - -Use Hermes tools: -```python -# Understand structure -search_files("*.py", target="files", path="src/") - -# Find similar patterns -search_files("similar_pattern", path="src/", file_glob="*.py") - -# Read key files -read_file("src/main.py") -``` - -### 3. Design Approach - -Decide: -- Architecture pattern -- File organization -- Dependencies needed -- Testing strategy - -### 4. Create Bite-Sized Tasks - -**Each task = 2-5 minutes of focused work.** - -Every step is one action: -- "Write the failing test" -- "Run it to make sure it fails" -- "Implement minimal code" -- "Run tests and verify pass" -- "Commit" - -**Too big:** -```markdown -### Task 1: Build authentication system -[50 lines across 5 files] -``` - -**Right size:** -```markdown -### Task 1: Create User model with email field -[10 lines, 1 file] - -### Task 2: Add password hash field -[8 lines, 1 file] -``` - -### 5. Document Task Structure - -Each task should include: - -```markdown -### Task N: [Descriptive Name] - -**Objective:** What this accomplishes (one sentence) - -**Files:** -- Create: `exact/path/to/new_file.py` -- Modify: `exact/path/to/existing.py:45-67` -- Test: `tests/path/to/test_file.py` - -**Step 1: Write failing test** -```python -def test_specific_behavior(): - result = function(input) - assert result == expected -``` - -**Step 2: Run test to verify failure** -Run: `pytest tests/path/test.py::test_specific_behavior -v` -Expected: FAIL - -**Step 3: Write minimal implementation** -```python -def function(input): - return expected -``` - -**Step 4: Run test to verify pass** -Run: `pytest tests/path/test.py::test_specific_behavior -v` -Expected: PASS - -**Step 5: Commit** -```bash -git add tests/path/test.py src/path/file.py -git commit -m "feat: add specific feature" -``` -``` - -### 6. Save Task Plan - -Create task folder: `tasks/YYYY-MM-DD-feature-name/` -- Save plan as `README.md` -- Create step files as work progresses - -## Principles - -### DRY (Don't Repeat Yourself) -Extract common patterns, don't copy-paste. - -### YAGNI (You Aren't Gonna Need It) -Implement only what's needed now, not "future flexibility." - -### TDD (Test-Driven Development) -Every code task should include: -1. Write failing test -2. Run to verify failure -3. Write minimal code -4. Run to verify pass - -### Frequent Commits -Commit after every task with clear messages. - -## Common Mistakes - -| Bad | Good | -|-----|------| -| "Add authentication" | "Create User model with email and password_hash" | -| "Add validation function" | "Add validation function" + complete code | -| "Test it works" | "Run `pytest tests/test_auth.py -v`, expected: 3 passed" | -| "Create the model file" | "Create: `src/models/user.py`" | - -## Execution - -After planning, offer: -> "Plan complete and saved to `tasks/YYYY-MM-DD-feature-name/`. Ready to implement task-by-task. Shall I proceed?" - -When implementing: -- Follow tasks sequentially -- Create step files documenting progress -- Update README with status -- Commit after each task - -## Related Skills - -- `implementation` - For coding tasks -- `testing` - For test writing -- `hpc-deployment` - For deployment workflows diff --git a/.hermes/tasks/2026-04-17-timing-metrics/README.md b/.hermes/tasks/2026-04-17-timing-metrics/README.md deleted file mode 100644 index 6200b2752..000000000 --- a/.hermes/tasks/2026-04-17-timing-metrics/README.md +++ /dev/null @@ -1,156 +0,0 @@ -# Timing Metrics Task - -**Status:** Completed - -**Created:** 2026-04-17 - -**Related Skills:** metrics, planning - -## Goal - -Add timing metrics to track startup time, training time, and overall execution time for the WeatherGenerator training pipeline. - -## Progress - -- [x] Step 1: Codebase analysis -- [x] Step 2: Design timing approach -- [x] Step 3: Implement timing in `run_train.py` -- [x] Step 4: Implement timing in `trainer.py` -- [x] Step 5: Verify metrics logging -- [x] Step 6: Document changes - -## Completed Steps - -### Step 1: Codebase Analysis - -**Objective:** Understand training entry points and metrics infrastructure - -**Files Reviewed:** -- `src/weathergen/run_train.py` - Main training entry point -- `src/weathergen/train/trainer.py` - Training loop logic -- `src/weathergen/utils/train_logger.py` - Metrics logging utility -- `hpc/mlflow_upload.py` - MLflow upload pipeline - -**Key Findings:** -- Training starts in `run_train()` → `Trainer.run()` -- Metrics logged via `train_logger.log_metrics()` -- MLflow upload filters blacklisted keys automatically -- Multi-node runs require `is_root()` checks - -### Step 2: Design Timing Approach - -**Objective:** Define where and how to add timing metrics - -**Decisions:** -1. **Three metrics:** - - `startup_time_seconds`: Code launch → training start - - `total_training_time_seconds`: Time in training loop - - `overall_time_seconds`: Total wall-clock time - -2. **Timing points:** - - `run_train()`: Overall time (entry/exit) - - `Trainer.run()`: Startup + training time - - Root rank only logging - -3. **Format:** JSONL compatible with existing MLflow pipeline - -### Step 3: Implement in `run_train.py` - -**Objective:** Add overall timing in main entry point - -**Changes:** -- Added `t_overall_start` at start of `run_train()` -- Added `t_overall_end` in `finally` block -- Logged `overall_time_seconds` metric - -**Files Modified:** -- `src/weathergen/run_train.py:23` - Added `t_overall_start` -- `src/weathergen/run_train.py:125` - Added timing in `run_continue()` -- `src/weathergen/run_train.py:145` - Added `finally` block with logging - -### Step 4: Implement in `trainer.py` - -**Objective:** Add startup and training time metrics - -**Changes:** -- Added `t_run_start` at start of `Trainer.run()` -- Added `t_training_start` before training loop -- Added `t_training_end` after training loop -- Logged `startup_time_seconds` and `total_training_time_seconds` - -**Files Modified:** -- `src/weathergen/train/trainer.py:100` - Added `t_run_start` -- `src/weathergen/train/trainer.py:150` - Added `t_training_start` -- `src/weathergen/train/trainer.py:230` - Added `t_training_end` -- `src/weathergen/train/trainer.py:235` - Added metric logging - -### Step 5: Verification - -**Objective:** Ensure metrics are logged correctly - -**Commands:** -```bash -# Run training with metrics -python -m weathergen.run_train --config config.yaml - -# Check metrics.json -cat logs/metrics.json | grep timing - -# Verify MLflow upload -python hpc/mlflow_upload.py --dry-run -``` - -**Expected Output:** -```json -{"stage": "train", "startup_time_seconds": 45.23, "weathergen.step": 0} -{"stage": "train", "total_training_time_seconds": 3600.12, "weathergen.step": 100} -{"stage": "train", "overall_time_seconds": 3650.45, "weathergen.step": 100} -``` - -### Step 6: Documentation - -**Objective:** Document implementation for future reference - -**Files Created:** -- `TIMING_METRICS_ANALYSIS.md` - Initial codebase analysis -- `TIMING_METRICS_IMPLEMENTATION.md` - Implementation details -- `.hermes/skills/metrics/SKILL.md` - Reusable metrics skill -- `.hermes/tasks/2026-04-17-timing-metrics/README.md` - This file - -## Implementation Summary - -**Total Changes:** 41 lines across 2 files - -**Modified Files:** -1. `src/weathergen/run_train.py` - Overall timing -2. `src/weathergen/train/trainer.py` - Startup + training timing - -**Git Commit:** -``` -feat: add timing metrics for startup, training, and overall time - -- Added overall_time_seconds in run_train() -- Added startup_time_seconds and total_training_time_seconds in Trainer.run() -- All metrics logged via root rank only -- Compatible with existing MLflow upload pipeline -``` - -## Lessons Learned - -1. **Root rank logging is critical** - Multi-node HPC runs would create file contention without `is_root()` checks -2. **Use `finally` blocks** - Ensures metrics are logged even on failure -3. **Keep metrics simple** - JSONL format works seamlessly with MLflow -4. **Document timing points** - Clear comments explain what each metric measures - -## Next Steps - -- Monitor metrics in MLflow dashboard -- Add per-epoch timing if needed -- Consider adding data loading time metric -- Track DDP synchronization overhead - -## Links - -- [Implementation PR](TODO) -- [Metrics Skill](../../../.hermes/skills/metrics/SKILL.md) -- [MLflow Upload Code](../../../hpc/mlflow_upload.py) diff --git a/.hermes/tasks/README.md b/.hermes/tasks/README.md deleted file mode 100644 index 5a658793d..000000000 --- a/.hermes/tasks/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Task Management Guidelines - -## Structure - -Each task gets its own folder: `tasks//` - -### Folder Contents - -``` -tasks/2026-04-17-timing-metrics/ -├── README.md # Task overview and current status -├── step-01-analysis.md # Initial codebase exploration -├── step-02-design.md # Design decisions -├── step-03-implementation.md # Code changes -├── step-04-testing.md # Verification steps -└── step-05-completion.md # Final summary and docs -``` - -## README.md Template - -```markdown -# [Task Name] - -**Status:** [Active | In Progress | Completed | Blocked] - -**Created:** YYYY-MM-DD - -**Related Skills:** [skill1, skill2] - -## Goal - -[One sentence describing what we're building] - -## Progress - -- [x] Step 1: Analysis -- [ ] Step 2: Design -- [ ] Step 3: Implementation -- [ ] Step 4: Testing -- [ ] Step 5: Documentation - -## Current Blockers - -[Any issues or decisions needed] - -## Links - -- [Implementation PR](link) -- [Design Doc](link) -- [Related Issues](link) -``` - -## Step Files - -Each step file documents: -1. **What was done** (1-2 sentence summary) -2. **Key decisions** (why we chose this approach) -3. **Code changes** (file paths, line numbers) -4. **Next steps** (what comes next) - -Keep step files concise. Update iteratively as work progresses. - -## Completion - -When task is complete: -1. Mark README status as "Completed" -2. Add final step file with summary -3. Move user-facing docs to `docs/` if applicable -4. Update relevant skills if new patterns discovered - -## Active Tasks - -- `2026-04-17-timing-metrics` - Add timing metrics to training pipeline From 09b6e822dde6a0ad0b7ba3c17d697eb7aa96eef2 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 16:37:58 +0200 Subject: [PATCH 06/76] Try duration metrics --- src/weathergen/run_train.py | 5 +++-- src/weathergen/train/trainer.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index de101c0f8..91d643d0b 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -21,6 +21,7 @@ import weathergen.common.config as config import weathergen.utils.cli as cli +from weathergen.utils.distributed import is_root from weathergen.common.logger import init_loggers from weathergen.train.trainer import Trainer @@ -122,7 +123,7 @@ def run_continue(args): Note: All model configurations are set in the function body. """ - t_overall_start = time.time() # Track overall run start time + t_overall_start = time.time() cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -174,7 +175,7 @@ def run_train(args): Note: All model configurations are set in the function body. """ - t_overall_start = time.time() # Track overall run start time + t_overall_start = time.time() cli_overwrite = config.from_cli_arglist(args.options) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 6bd28386f..ebe469766 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -233,8 +233,8 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): logger.info(f"Finished inference run with id: {cf.general.run_id}") def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): - t_run_start = time.time() # Track trainer.run() start time (for startup_time) - + t_run_start = time.time() + # general initalization self.init(cf, devices) cf = self.cf @@ -373,14 +373,14 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # run validation before training if requested self.validate_before_training() - # Log startup time (time from run() start to training loop start) + # Log startup time if is_root(): startup_time = time.time() - t_run_start self.train_logger.log_metrics("train", {"startup_time_seconds": startup_time}) logger.info(f"Startup time: {startup_time:.2f} seconds") # training loop - t_training_start = time.time() # Track start of actual training + t_training_start = time.time() for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") From da3c29ba67b30b96a77d41f58b632c22bddd7364 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 17 Apr 2026 17:53:39 +0200 Subject: [PATCH 07/76] Update metrics, store after each mini-epoch --- src/weathergen/train/trainer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ebe469766..c704f6815 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -402,7 +402,12 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # Log total training time if is_root(): total_training_time = time.time() - t_training_start - self.train_logger.log_metrics("train", {"total_training_time_seconds": total_training_time}) + total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + self.train_logger.log_metrics("train", { + "total_training_time_seconds": total_training_time, + "final_num_samples": total_samples, + "samples_per_second_total": total_samples / total_training_time if total_training_time > 0 else 0, + }) logger.info(f"Total training time: {total_training_time / 3600:.2f} hours") def validate_before_training(self): @@ -643,6 +648,17 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self._log_terminal(0, mini_epoch, VAL) self._log(VAL) + # Log elapsed training time and throughput metrics + # This ensures time is tracked even if job is killed mid-mini-epoch + if is_root(): + elapsed_time = time.time() - t_training_start + total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + self.train_logger.log_metrics("train", { + "elapsed_training_time_seconds": elapsed_time, + "num_samples": total_samples, + "samples_per_second_elapsed": total_samples / elapsed_time if elapsed_time > 0 else 0, + }) + # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() From fc9a111dbd89b61562bb7c72d926799f660a5ece Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 09:38:01 +0200 Subject: [PATCH 08/76] Refactor configs/streams --- config/{config_era5.yml => streams/era5_georing/era5.yml} | 0 config/{config_geos.yml => streams/era5_georing/geos.yml} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename config/{config_era5.yml => streams/era5_georing/era5.yml} (100%) rename config/{config_geos.yml => streams/era5_georing/geos.yml} (100%) diff --git a/config/config_era5.yml b/config/streams/era5_georing/era5.yml similarity index 100% rename from config/config_era5.yml rename to config/streams/era5_georing/era5.yml diff --git a/config/config_geos.yml b/config/streams/era5_georing/geos.yml similarity index 100% rename from config/config_geos.yml rename to config/streams/era5_georing/geos.yml From cfc4c62365cadad67eed455844a7dad269c3869c Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 12:07:28 +0200 Subject: [PATCH 09/76] Extract scaling data --- scripts/extract_scaling_data.py | 99 +++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 scripts/extract_scaling_data.py diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py new file mode 100644 index 000000000..378a8d98a --- /dev/null +++ b/scripts/extract_scaling_data.py @@ -0,0 +1,99 @@ +#!/usr/bin/env uv run python +"""Extract strong scaling data from WeatherGenerator runs. Outputs parquet with run_id, num_nodes, training_time, overall_time_seconds, loss_avg_mean.""" + +import argparse +import re +import sys +from pathlib import Path + +import polars as pl + + +def extract_num_nodes(err_log_path: Path) -> int | None: + if not err_log_path.exists(): + return None + try: + content = err_log_path.read_text() + match = re.search(r"Number of Nodes:\s*(\d+)", content) + return int(match.group(1)) if match else None + except Exception: + return None + + +def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: + metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" + if not metrics_path.exists(): + return None + try: + df = pl.read_ndjson(metrics_path) + if len(df) == 0: + return None + final_row = df.tail(1) + overall_time = final_row.get_column("overall_time_seconds").item() if "overall_time_seconds" in final_row.columns else None + if overall_time is None: + return None + startup_time = final_row.get_column("startup_time_seconds").item() if "startup_time_seconds" in final_row.columns else None + loss_avg = final_row.get_column("loss_avg_mean").item() if "loss_avg_mean" in final_row.columns else None + return { + "overall_time_seconds": overall_time, + "startup_time_seconds": startup_time, + "training_time": overall_time - startup_time if startup_time else None, + "loss_avg_mean": loss_avg, + } + except Exception: + return None + + +def main(): + parser = argparse.ArgumentParser(description="Extract strong scaling data from WeatherGenerator runs") + parser.add_argument("--run-ids", nargs="+", help="List of run-ids to process") + parser.add_argument("--run-id-file", type=Path, help="File containing run-ids (one per line)") + parser.add_argument("--logs-base-dir", type=Path, default=Path("/e/scratch/weatherai/logs"), help="Base directory for run logs") + parser.add_argument("--shared-work-dir", type=Path, default=Path("/e/scratch/weatherai/shared_work"), help="Base directory for shared work/results") + parser.add_argument("--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path") + + args = parser.parse_args() + + if args.run_ids and args.run_id_file: + sys.exit("Error: Cannot specify both --run-ids and --run-id-file") + elif args.run_ids: + run_ids = args.run_ids + elif args.run_id_file: + if not args.run_id_file.exists(): + sys.exit(f"Error: Run-id file not found: {args.run_id_file}") + run_ids = [line.strip() for line in args.run_id_file.read_text().splitlines() if line.strip()] + else: + sys.exit("Error: Must specify either --run-ids or --run-id-file") + + if not run_ids: + sys.exit("Error: No run-ids provided") + + results = [] + for run_id in run_ids: + log_pattern = args.logs_base_dir / run_id / "weathermen.*.err" + err_files = list(log_pattern.parent.glob("weathermen.*.err")) + num_nodes = extract_num_nodes(err_files[0]) if err_files else None + metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) + if metrics is None: + continue + row = { + "run_id": run_id, + "num_nodes": num_nodes, + "training_time": metrics.get("training_time"), + "overall_time_seconds": metrics["overall_time_seconds"], + "loss_avg_mean": metrics.get("loss_avg_mean"), + } + results.append(row) + + if not results: + sys.exit("No data extracted") + + df = pl.DataFrame(results) + if "num_nodes" in df.columns: + df = df.sort("num_nodes") + args.output.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(args.output) + + +if __name__ == "__main__": + main() From 82b503ab89d1c31ac9a3f1ee28b3fd8101b0ad2b Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 12:52:20 +0200 Subject: [PATCH 10/76] Script to generate scaling plots --- scripts/generate_scaling_plots.py | 119 ++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 scripts/generate_scaling_plots.py diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py new file mode 100644 index 000000000..86433237e --- /dev/null +++ b/scripts/generate_scaling_plots.py @@ -0,0 +1,119 @@ +#!/usr/bin/env uv run python +"""Generate strong scaling plots from parquet data. Single file with subplots per metric.""" + +import argparse +from pathlib import Path + +import polars as pl +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +def create_scaling_plots(df: pl.DataFrame, output_path: Path, metrics: list[str]): + """Create a single plot with subplots for each metric.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Count valid metrics + valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] + if not valid_metrics: + print("No valid metrics to plot") + return + + n_metrics = len(valid_metrics) + + fig = make_subplots( + rows=n_metrics, cols=1, + subplot_titles=valid_metrics, + vertical_spacing=0.1, + ) + + for idx, metric in enumerate(valid_metrics): + df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") + + # Add scatter trace with lines and text labels + fig.add_trace( + go.Scatter( + x=df_plot["num_nodes"], + y=df_plot[metric], + mode="lines+markers+text", + text=df_plot["run_id"], + textposition="top center", + name=metric, + showlegend=False, + marker=dict(size=10, color="steelblue"), + line=dict(width=2), + ), + row=idx + 1, col=1, + ) + + # Add optimal scaling reference line for training_time + if metric == "training_time" and "training_time" in df.columns: + # Find the 1-node training time + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + # Create optimal scaling line: t1 / n for each n + nodes = df_plot["num_nodes"].to_list() + optimal_y = [t1 / n for n in nodes] + fig.add_trace( + go.Scatter( + x=nodes, + y=optimal_y, + mode="lines", + name="Optimal scaling", + line=dict(width=1, color="red", dash="dash"), + showlegend=True, + ), + row=idx + 1, col=1, + ) + + fig.update_xaxes(title_text="Number of Nodes (log scale)", type="log", row=idx + 1, col=1) + fig.update_yaxes(title_text=metric, row=idx + 1, col=1) + + fig.update_layout( + height=400 * n_metrics, + title_text="Strong Scaling Analysis", + title_x=0.5, + template="plotly_white", + ) + + fig.write_html(output_path) + print(f"Saved: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Generate strong scaling plots from parquet data") + parser.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet file") + parser.add_argument("--output", type=Path, default=Path("scaling_plots/strong_scaling.html"), help="Output HTML file") + parser.add_argument("--metrics", nargs="+", default=["training_time", "overall_time_seconds", "loss_avg_mean"], help="Metrics to plot") + parser.add_argument("--generate-dummy", action="store_true", help="Generate dummy test data") + + args = parser.parse_args() + + if args.generate_dummy: + print("Generating dummy test data...") + dummy_data = { + "run_id": ["run_1node", "run_2node", "run_4node", "run_8node", "run_16node"], + "num_nodes": [1, 2, 4, 8, 16], + "training_time": [1000, 520, 270, 140, 75], + "overall_time_seconds": [1100, 580, 310, 165, 90], + "loss_avg_mean": [0.45, 0.44, 0.44, 0.43, 0.43], + } + df = pl.DataFrame(dummy_data) + args.input.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(args.input) + print(f"Created dummy data: {args.input}") + + if not args.input.exists(): + print(f"Error: Input file not found: {args.input}") + print("Use --generate-dummy to create test data") + return + + print(f"Loading data from: {args.input}") + df = pl.read_parquet(args.input) + print(f"Loaded {len(df)} rows") + + create_scaling_plots(df, args.output, args.metrics) + + +if __name__ == "__main__": + main() From 70053b1cb1ec6814c2266771cf55b3b2e44e55cf Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 12:52:30 +0200 Subject: [PATCH 11/76] Script update --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index ffeb81177..45659a118 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -153,8 +153,8 @@ training_config: start_date: 2014-01-01T00:00 end_date: 2022-12-31T00:00 - time_window_step: 06:00:00 - time_window_len: 06:00:00 + time_window_step: 01:00:00 + time_window_len: 01:00:00 learning_rate_scheduling : lr_start: 1e-6 From 0c2df975d39eed58862ce5d175f7206ba8ccfd25 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Mon, 20 Apr 2026 13:20:16 +0200 Subject: [PATCH 12/76] Repeat data in mini epoch --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index 45659a118..7c6c85220 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -132,7 +132,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False + repeat_data_in_mini_epoch : True # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -146,7 +146,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] - num_mini_epochs: 96 + num_mini_epochs: 4 samples_per_mini_epoch: 4096 shuffle: True From 2c79d28a240cd672f21b5972acbcc54c6e0fedd6 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 17:06:02 +0200 Subject: [PATCH 13/76] corrected time window length --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index 45659a118..8328a2363 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -132,7 +132,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False + repeat_data_in_mini_epoch : True # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -154,7 +154,7 @@ training_config: end_date: 2022-12-31T00:00 time_window_step: 01:00:00 - time_window_len: 01:00:00 + time_window_len: 06:00:00 learning_rate_scheduling : lr_start: 1e-6 From b5d70f6a72ae5f6558c85f0f7abcdc70271b2945 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Mon, 20 Apr 2026 19:21:13 +0200 Subject: [PATCH 14/76] Lower to 512 samples per mini epoch --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index fd775632a..ba0cf0d1b 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -146,8 +146,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] - num_mini_epochs: 4 - samples_per_mini_epoch: 4096 + num_mini_epochs: 1 + samples_per_mini_epoch: 512 shuffle: True start_date: 2014-01-01T00:00 From f46828c2ec62653d474706b707ba7e0052a1f3c9 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 20 Apr 2026 19:24:07 +0200 Subject: [PATCH 15/76] Updated extraction script --- scripts/extract_scaling_data.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 378a8d98a..3a6e28591 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -21,6 +21,12 @@ def extract_num_nodes(err_log_path: Path) -> int | None: def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: + """Extract metrics from NDJSON file with startup and training lines. + + Format: + - Line 1: startup_time_seconds + - Line 2+: loss_avg_mean, LossPhysical.loss_avg, etc. + """ metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" if not metrics_path.exists(): return None @@ -28,17 +34,30 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No df = pl.read_ndjson(metrics_path) if len(df) == 0: return None - final_row = df.tail(1) - overall_time = final_row.get_column("overall_time_seconds").item() if "overall_time_seconds" in final_row.columns else None + + # Extract startup_time from first row (startup line) + startup_time = None + if "startup_time_seconds" in df.columns: + startup_time = df.select(pl.col("startup_time_seconds").first()).item() + + # Extract loss_avg_mean from last training row + loss_avg_mean = None + if "loss_avg_mean" in df.columns: + loss_avg_mean = df.select(pl.col("loss_avg_mean").last()).item() + + # Extract overall_time from last row + overall_time = None + if "overall_time_seconds" in df.columns: + overall_time = df.select(pl.col("overall_time_seconds").last()).item() + if overall_time is None: return None - startup_time = final_row.get_column("startup_time_seconds").item() if "startup_time_seconds" in final_row.columns else None - loss_avg = final_row.get_column("loss_avg_mean").item() if "loss_avg_mean" in final_row.columns else None + return { "overall_time_seconds": overall_time, "startup_time_seconds": startup_time, "training_time": overall_time - startup_time if startup_time else None, - "loss_avg_mean": loss_avg, + "loss_avg_mean": loss_avg_mean, } except Exception: return None From 7cad6b5e56b827d07cf6cedf99e700b0a89940a4 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Tue, 21 Apr 2026 15:53:16 +0200 Subject: [PATCH 16/76] Log time more often --- src/weathergen/train/trainer.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index c704f6815..dc1726389 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -399,17 +399,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # log final model self.save_model(self.training_cfg.num_mini_epochs) - # Log total training time - if is_root(): - total_training_time = time.time() - t_training_start - total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) - self.train_logger.log_metrics("train", { - "total_training_time_seconds": total_training_time, - "final_num_samples": total_samples, - "samples_per_second_total": total_samples / total_training_time if total_training_time > 0 else 0, - }) - logger.info(f"Total training time: {total_training_time / 3600:.2f} hours") - def validate_before_training(self): """ Perform validation before training (eg. to check validation pipeline or data normalization) @@ -648,17 +637,6 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self._log_terminal(0, mini_epoch, VAL) self._log(VAL) - # Log elapsed training time and throughput metrics - # This ensures time is tracked even if job is killed mid-mini-epoch - if is_root(): - elapsed_time = time.time() - t_training_start - total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) - self.train_logger.log_metrics("train", { - "elapsed_training_time_seconds": elapsed_time, - "num_samples": total_samples, - "samples_per_second_elapsed": total_samples / elapsed_time if elapsed_time > 0 else 0, - }) - # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() @@ -758,9 +736,18 @@ def _log(self, stage: Stage): samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) if is_root(): + # Log elapsed training time and throughput metrics with every metric log + elapsed_time = time.time() - t_training_start + time_metrics = { + "elapsed_training_time_seconds": elapsed_time, + "num_samples": samples, + "samples_per_second_elapsed": samples / elapsed_time if elapsed_time > 0 else 0, + } + # plain logger if stage == VAL: self.train_logger.add_logs(stage, samples, losses_all, stddev_all) + self.train_logger.log_metrics("train", time_metrics) elif self.cf.general.istep >= 0: self.train_logger.add_logs( @@ -771,6 +758,7 @@ def _log(self, stage: Stage): avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), ) + self.train_logger.log_metrics("train", time_metrics) loss_calculator.loss_hist = [] loss_calculator.losses_unweighted_hist = [] From 30ac1021fb6873dff6a4aac30e47185494bd1022 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Tue, 21 Apr 2026 16:30:31 +0200 Subject: [PATCH 17/76] Fix training start scope --- src/weathergen/train/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index dc1726389..0a74fe2fd 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -380,7 +380,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): logger.info(f"Startup time: {startup_time:.2f} seconds") # training loop - t_training_start = time.time() + self.t_training_start = time.time() for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") @@ -737,7 +737,7 @@ def _log(self, stage: Stage): if is_root(): # Log elapsed training time and throughput metrics with every metric log - elapsed_time = time.time() - t_training_start + elapsed_time = time.time() - self.t_training_start time_metrics = { "elapsed_training_time_seconds": elapsed_time, "num_samples": samples, From 5e7f63ee48c632098fee4817d59cea8ab1be5ed4 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Wed, 22 Apr 2026 16:57:15 +0200 Subject: [PATCH 18/76] Minimal validation --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index ba0cf0d1b..fc8c4003a 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -202,7 +202,7 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: - samples_per_mini_epoch: 256 + samples_per_mini_epoch: 1 shuffle: True start_date: 2023-10-01T00:00 @@ -231,7 +231,7 @@ validation_config: # validation config; full validation config is merge of training and validation config test_config: - samples_per_mini_epoch: 256 + samples_per_mini_epoch: 1 shuffle: False start_date: 2023-10-01T00:00 From 2be95c6c87246c003e405f44b276374823e794ec Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Wed, 22 Apr 2026 17:12:46 +0200 Subject: [PATCH 19/76] Increase samples_per_mini_epoch to 1024 --- config/config_era5_georing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index fc8c4003a..f9e748ea7 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -147,7 +147,7 @@ training_config: training_mode: ["masking"] num_mini_epochs: 1 - samples_per_mini_epoch: 512 + samples_per_mini_epoch: 1024 shuffle: True start_date: 2014-01-01T00:00 From 93b203b25b58c6a0da1c41e022ae77dcf2e48f34 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 11:34:48 +0200 Subject: [PATCH 20/76] Final training duration and terminal/metric logging --- config/config_era5_georing.yml | 6 +++--- src/weathergen/train/trainer.py | 26 ++++++++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index f9e748ea7..82638d146 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -122,9 +122,9 @@ general: # logging frequency in the training loop (in number of batches) train_logging: - terminal: 10 - metrics: 20 - checkpoint: 250 + terminal: 16 + metrics: 16 + checkpoint: 256 log_grad_norms: False # parameters for data loading diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0a74fe2fd..60cd2f6a7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -399,6 +399,17 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # log final model self.save_model(self.training_cfg.num_mini_epochs) + # Log total training time + if is_root(): + total_training_time = time.time() - self.t_training_start + total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + self.train_logger.log_metrics("train", { + "total_training_time_seconds": total_training_time, + "final_num_samples": total_samples, + "samples_per_second_total": total_samples / total_training_time if total_training_time > 0 else 0, + }) + logger.info(f"Total training time: {total_training_time / 3600:.2f} hours") + def validate_before_training(self): """ Perform validation before training (eg. to check validation pipeline or data normalization) @@ -736,20 +747,19 @@ def _log(self, stage: Stage): samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) if is_root(): - # Log elapsed training time and throughput metrics with every metric log - elapsed_time = time.time() - self.t_training_start - time_metrics = { - "elapsed_training_time_seconds": elapsed_time, - "num_samples": samples, - "samples_per_second_elapsed": samples / elapsed_time if elapsed_time > 0 else 0, - } # plain logger if stage == VAL: self.train_logger.add_logs(stage, samples, losses_all, stddev_all) - self.train_logger.log_metrics("train", time_metrics) elif self.cf.general.istep >= 0: + # Log elapsed training time and throughput metrics with every metric log + elapsed_time = time.time() - self.t_training_start + time_metrics = { + "elapsed_training_time_seconds": elapsed_time, + "num_samples": samples, + "samples_per_second_elapsed": samples / elapsed_time if elapsed_time > 0 else 0, + } self.train_logger.add_logs( stage, samples, From 2b708e376b41fa477b31b4c5269ed8e7aa61d85a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 12:38:11 +0200 Subject: [PATCH 21/76] log metrics after mini-epoch --- src/weathergen/train/trainer.py | 38 ++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 60cd2f6a7..864606ba7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -386,6 +386,18 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") self.train(mini_epoch) + # Log training time after one epoch + if is_root(): + total_training_time = time.time() - self.t_training_start + total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + self.train_logger.log_metrics("train", { + "completed_mini_epoch": mini_epoch, + "elapsed_training_time_seconds": total_training_time, + "total_num_samples": total_samples, + "average_samples_per_second": total_samples / total_training_time if total_training_time > 0 else 0, + }) + logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time / 3600:.2f} hours") + logger.info( f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: validate." ) @@ -399,17 +411,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # log final model self.save_model(self.training_cfg.num_mini_epochs) - # Log total training time - if is_root(): - total_training_time = time.time() - self.t_training_start - total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) - self.train_logger.log_metrics("train", { - "total_training_time_seconds": total_training_time, - "final_num_samples": total_samples, - "samples_per_second_total": total_samples / total_training_time if total_training_time > 0 else 0, - }) - logger.info(f"Total training time: {total_training_time / 3600:.2f} hours") - + def validate_before_training(self): """ Perform validation before training (eg. to check validation pipeline or data normalization) @@ -755,11 +757,6 @@ def _log(self, stage: Stage): elif self.cf.general.istep >= 0: # Log elapsed training time and throughput metrics with every metric log elapsed_time = time.time() - self.t_training_start - time_metrics = { - "elapsed_training_time_seconds": elapsed_time, - "num_samples": samples, - "samples_per_second_elapsed": samples / elapsed_time if elapsed_time > 0 else 0, - } self.train_logger.add_logs( stage, samples, @@ -768,7 +765,14 @@ def _log(self, stage: Stage): avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), ) - self.train_logger.log_metrics("train", time_metrics) + self.train_logger.log_metrics( + "train", + { + "elapsed_training_time_seconds": elapsed_time, + "total_num_samples": samples, + "average_samples_per_second_": samples / elapsed_time if elapsed_time > 0 else 0, + } + ) loss_calculator.loss_hist = [] loss_calculator.losses_unweighted_hist = [] From 0d8407d4ee313a79826fbac23f46e8493363a251 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 12:45:43 +0200 Subject: [PATCH 22/76] Log metrics after mini-epoch, change schema --- src/weathergen/run_train.py | 8 ++++---- src/weathergen/train/trainer.py | 7 +++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 91d643d0b..391b220c4 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -175,7 +175,7 @@ def run_train(args): Note: All model configurations are set in the function body. """ - t_overall_start = time.time() + t_start = time.time() cli_overwrite = config.from_cli_arglist(args.options) @@ -203,7 +203,7 @@ def run_train(args): trainer = Trainer(cf.train_logging) try: - trainer.run(cf, devices) + trainer.run(cf, devices, t_start=t_start) except Exception: extype, value, tb = sys.exc_info() traceback.print_exc() @@ -212,8 +212,8 @@ def run_train(args): finally: # Log overall time (only on root rank) if is_root(): - t_overall_end = time.time() - overall_time = t_overall_end - t_overall_start + t_end = time.time() + overall_time = t_end - t_start trainer.train_logger.log_metrics( "train", { diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 864606ba7..a3fe0fad5 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -232,8 +232,7 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") - def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): - t_run_start = time.time() + def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: float | None = None): # general initalization self.init(cf, devices) @@ -374,8 +373,8 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.validate_before_training() # Log startup time - if is_root(): - startup_time = time.time() - t_run_start + if is_root() and t_start is not None: + startup_time = time.time() - t_start self.train_logger.log_metrics("train", {"startup_time_seconds": startup_time}) logger.info(f"Startup time: {startup_time:.2f} seconds") From 422fc60d7f553b21484663207f08d7f19b84a178 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 12:59:15 +0200 Subject: [PATCH 23/76] MEtric typo --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index a3fe0fad5..41f891ba2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -769,7 +769,7 @@ def _log(self, stage: Stage): { "elapsed_training_time_seconds": elapsed_time, "total_num_samples": samples, - "average_samples_per_second_": samples / elapsed_time if elapsed_time > 0 else 0, + "average_samples_per_second": samples / elapsed_time if elapsed_time > 0 else 0, } ) From f63cba92411fe2a9560d754c9359d6cf9e607450 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 13:19:48 +0200 Subject: [PATCH 24/76] Logging refactor --- src/weathergen/train/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 41f891ba2..584109632 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -388,14 +388,12 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl # Log training time after one epoch if is_root(): total_training_time = time.time() - self.t_training_start - total_samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) self.train_logger.log_metrics("train", { "completed_mini_epoch": mini_epoch, - "elapsed_training_time_seconds": total_training_time, - "total_num_samples": total_samples, - "average_samples_per_second": total_samples / total_training_time if total_training_time > 0 else 0, + "elapsed_time_mini_epoch": total_training_time, }) logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time / 3600:.2f} hours") + self._log(TRAIN) logger.info( f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: validate." From b596c14b0a72fe1496e77f16db15083e37c511b1 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 13:41:14 +0200 Subject: [PATCH 25/76] Update extraction script --- scripts/extract_scaling_data.py | 35 ++++++++++++--------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 3a6e28591..6425f1dbb 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -40,23 +40,25 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No if "startup_time_seconds" in df.columns: startup_time = df.select(pl.col("startup_time_seconds").first()).item() - # Extract loss_avg_mean from last training row + # Extract loss_avg_mean from last non-NaN training row loss_avg_mean = None if "loss_avg_mean" in df.columns: - loss_avg_mean = df.select(pl.col("loss_avg_mean").last()).item() + loss_avg_mean = df.select(pl.col("loss_avg_mean").drop_nulls().last()).item() - # Extract overall_time from last row + # Extract training for mini-epoch from last non-NaN row + overall_training_time = None + if "elapsed_time_mini_epoch" in df.columns: + overall_training_time = df.select(pl.col("elapsed_time_mini_epoch").drop_nulls().last()).item() + + # Extract overall_time from last non-NaN row overall_time = None if "overall_time_seconds" in df.columns: - overall_time = df.select(pl.col("overall_time_seconds").last()).item() - - if overall_time is None: - return None + overall_time = df.select(pl.col("overall_time_seconds").drop_nulls().last()).item() return { "overall_time_seconds": overall_time, "startup_time_seconds": startup_time, - "training_time": overall_time - startup_time if startup_time else None, + "training_time": overall_training_time, "loss_avg_mean": loss_avg_mean, } except Exception: @@ -66,31 +68,20 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No def main(): parser = argparse.ArgumentParser(description="Extract strong scaling data from WeatherGenerator runs") parser.add_argument("--run-ids", nargs="+", help="List of run-ids to process") - parser.add_argument("--run-id-file", type=Path, help="File containing run-ids (one per line)") parser.add_argument("--logs-base-dir", type=Path, default=Path("/e/scratch/weatherai/logs"), help="Base directory for run logs") parser.add_argument("--shared-work-dir", type=Path, default=Path("/e/scratch/weatherai/shared_work"), help="Base directory for shared work/results") parser.add_argument("--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path") args = parser.parse_args() - if args.run_ids and args.run_id_file: - sys.exit("Error: Cannot specify both --run-ids and --run-id-file") - elif args.run_ids: - run_ids = args.run_ids - elif args.run_id_file: - if not args.run_id_file.exists(): - sys.exit(f"Error: Run-id file not found: {args.run_id_file}") - run_ids = [line.strip() for line in args.run_id_file.read_text().splitlines() if line.strip()] - else: - sys.exit("Error: Must specify either --run-ids or --run-id-file") - + run_ids = args.run_ids if not run_ids: sys.exit("Error: No run-ids provided") results = [] for run_id in run_ids: - log_pattern = args.logs_base_dir / run_id / "weathermen.*.err" - err_files = list(log_pattern.parent.glob("weathermen.*.err")) + log_pattern = args.logs_base_dir / run_id / "weathergen.*.err" + err_files = list(log_pattern.parent.glob("weathergen.*.err")) num_nodes = extract_num_nodes(err_files[0]) if err_files else None metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: From 42ba646252523bf6e1ac211cc1f63d24466c6088 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 14:00:45 +0200 Subject: [PATCH 26/76] NNode extraction --- scripts/extract_scaling_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 6425f1dbb..804f369cc 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -14,7 +14,8 @@ def extract_num_nodes(err_log_path: Path) -> int | None: return None try: content = err_log_path.read_text() - match = re.search(r"Number of Nodes:\s*(\d+)", content) + # Case-insensitive match for "Number of Nodes:" with flexible whitespace + match = re.search(r"number\s+of\s+nodes\s*:\s*(\d+)", content, re.IGNORECASE) return int(match.group(1)) if match else None except Exception: return None @@ -80,8 +81,8 @@ def main(): results = [] for run_id in run_ids: - log_pattern = args.logs_base_dir / run_id / "weathergen.*.err" - err_files = list(log_pattern.parent.glob("weathergen.*.err")) + log_pattern = args.logs_base_dir / run_id / "weathergen.*.*.err" + err_files = list(log_pattern.parent.glob("weathergen.*.*.err")) num_nodes = extract_num_nodes(err_files[0]) if err_files else None metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: From c9fa64da73456268b48e6c22f9d1cb2f37262b5c Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 14:05:36 +0200 Subject: [PATCH 27/76] Logs path --- scripts/extract_scaling_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 804f369cc..d1f56dd37 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -69,7 +69,7 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No def main(): parser = argparse.ArgumentParser(description="Extract strong scaling data from WeatherGenerator runs") parser.add_argument("--run-ids", nargs="+", help="List of run-ids to process") - parser.add_argument("--logs-base-dir", type=Path, default=Path("/e/scratch/weatherai/logs"), help="Base directory for run logs") + parser.add_argument("--logs-base-dir", type=Path, default=Path("logs"), help="Base directory for run logs (default: logs relative to current dir)") parser.add_argument("--shared-work-dir", type=Path, default=Path("/e/scratch/weatherai/shared_work"), help="Base directory for shared work/results") parser.add_argument("--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path") @@ -81,8 +81,9 @@ def main(): results = [] for run_id in run_ids: - log_pattern = args.logs_base_dir / run_id / "weathergen.*.*.err" - err_files = list(log_pattern.parent.glob("weathergen.*.*.err")) + # Look for weathergen.*.err files (e.g., weathergen.part1.388004.err) + log_dir = args.logs_base_dir / run_id + err_files = list(log_dir.glob("weathergen.*.err")) if log_dir.exists() else [] num_nodes = extract_num_nodes(err_files[0]) if err_files else None metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: From ccfbc6477f69018c56e0a990307be91283382af8 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 14:58:00 +0200 Subject: [PATCH 28/76] Wait until all training complete and wait with validation until logs are done --- src/weathergen/train/trainer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 584109632..9d618a8aa 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -385,16 +385,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") self.train(mini_epoch) - # Log training time after one epoch - if is_root(): - total_training_time = time.time() - self.t_training_start - self.train_logger.log_metrics("train", { - "completed_mini_epoch": mini_epoch, - "elapsed_time_mini_epoch": total_training_time, - }) - logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time / 3600:.2f} hours") - self._log(TRAIN) - logger.info( f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: validate." ) @@ -561,6 +551,18 @@ def train(self, mini_epoch): self.dataset.advance() + torch.cuda.synchronize() + if is_root(): + total_training_time = time.time() - self.t_training_start + self.train_logger.log_metrics("train", { + "completed_mini_epoch": mini_epoch, + "elapsed_time_mini_epoch": total_training_time, + }) + logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time / 3600:.2f} hours") + self._log(TRAIN) + torch.cuda.synchronize() + + def validate(self, mini_epoch, mode_cfg, batch_size): """ Perform validation / test computation as specified by mode_cfg From c0f96b79ddd39ad01ced199993ea266eb5b811ef Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Thu, 23 Apr 2026 14:58:25 +0200 Subject: [PATCH 29/76] Log seconds rather than hours --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 584109632..45bb46a9d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -392,7 +392,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl "completed_mini_epoch": mini_epoch, "elapsed_time_mini_epoch": total_training_time, }) - logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time / 3600:.2f} hours") + logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") self._log(TRAIN) logger.info( From e6475e97b0254ea56c9345e143af7abcc01f8002 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 15:33:44 +0200 Subject: [PATCH 30/76] Measure dataset advancement time --- src/weathergen/train/trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b35895b75..7a49ce604 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -549,9 +549,7 @@ def train(self, mini_epoch): self.cf.general.istep += 1 - self.dataset.advance() - - torch.cuda.synchronize() + torch.distributed.barrier() if is_root(): total_training_time = time.time() - self.t_training_start self.train_logger.log_metrics("train", { @@ -560,8 +558,13 @@ def train(self, mini_epoch): }) logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") self._log(TRAIN) - torch.cuda.synchronize() - + + time_before_advance = time.time() + self.dataset.advance() + time_after_advance = time.time() + time_for_advance = time_after_advance - time_before_advance + if is_root(): + logger.info(f"Time to advance dataset after mini epoch {mini_epoch}: {time_for_advance} seconds") def validate(self, mini_epoch, mode_cfg, batch_size): """ From 6fd001f49428a4e6912c2205501d9dcc4d2641af Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 15:58:56 +0200 Subject: [PATCH 31/76] LR scheduler lower bounds --- src/weathergen/train/lr_scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..65e3fafda 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -53,9 +53,9 @@ def __init__( logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}") # ensure that steps_decay has a reasonable value if self.n_steps_decay < int(0.2 * lr_steps): - self.n_steps_warmup = int(0.1 * lr_steps) - self.n_steps_cooldown = int(0.05 * lr_steps) - self.n_steps_decay = lr_steps - self.n_steps_warmup - self.n_steps_cooldown + self.n_steps_warmup = max(1, int(0.1 * lr_steps)) + self.n_steps_cooldown = max(1, int(0.05 * lr_steps)) + self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown) s = ( "cf.lr_steps_warmup and cf.lr_steps_cooldown", f" were larger than cf.lr_steps={lr_steps}", From 313cec6c1d0be1e720379217bba66606226b28af Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 23 Apr 2026 16:20:21 +0200 Subject: [PATCH 32/76] At least two warmup steps --- src/weathergen/train/lr_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index 65e3fafda..9c8185fc2 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -53,7 +53,7 @@ def __init__( logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}") # ensure that steps_decay has a reasonable value if self.n_steps_decay < int(0.2 * lr_steps): - self.n_steps_warmup = max(1, int(0.1 * lr_steps)) + self.n_steps_warmup = max(2, int(0.1 * lr_steps)) self.n_steps_cooldown = max(1, int(0.05 * lr_steps)) self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown) s = ( From aa4d399b065986e3d1c245f09ae13f2ed2840365 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 09:38:06 +0200 Subject: [PATCH 33/76] Len per rank at least 1 to avoid zero division error --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 7a49ce604..45f40bead 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -355,7 +355,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: len_per_rank = ( - len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu) + max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu mini_epoch_base = int( self.cf.general.istep From 7956c522d92199d3da297f153de0ac065596e7d2 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 09:42:01 +0200 Subject: [PATCH 34/76] Write csv for easier viewing --- scripts/extract_scaling_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index d1f56dd37..e0fbae783 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -105,6 +105,7 @@ def main(): df = df.sort("num_nodes") args.output.parent.mkdir(parents=True, exist_ok=True) df.write_parquet(args.output) + df.write_csv(args.output.with_suffix(".csv")) if __name__ == "__main__": From cf659e1d4099fa20c9482569643eb2c004b4af52 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 09:52:57 +0200 Subject: [PATCH 35/76] Extraction and plotting --- scripts/extract_scaling_data.py | 8 +------- scripts/generate_scaling_plots.py | 11 +++++++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index e0fbae783..cff75c359 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -51,13 +51,7 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No if "elapsed_time_mini_epoch" in df.columns: overall_training_time = df.select(pl.col("elapsed_time_mini_epoch").drop_nulls().last()).item() - # Extract overall_time from last non-NaN row - overall_time = None - if "overall_time_seconds" in df.columns: - overall_time = df.select(pl.col("overall_time_seconds").drop_nulls().last()).item() - return { - "overall_time_seconds": overall_time, "startup_time_seconds": startup_time, "training_time": overall_training_time, "loss_avg_mean": loss_avg_mean, @@ -91,8 +85,8 @@ def main(): row = { "run_id": run_id, "num_nodes": num_nodes, + "startup_time_seconds": metrics.get("startup_time_seconds"), "training_time": metrics.get("training_time"), - "overall_time_seconds": metrics["overall_time_seconds"], "loss_avg_mean": metrics.get("loss_avg_mean"), } results.append(row) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index 86433237e..e06a23991 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -2,6 +2,7 @@ """Generate strong scaling plots from parquet data. Single file with subplots per metric.""" import argparse +import os from pathlib import Path import polars as pl @@ -71,7 +72,7 @@ def create_scaling_plots(df: pl.DataFrame, output_path: Path, metrics: list[str] fig.update_layout( height=400 * n_metrics, - title_text="Strong Scaling Analysis", + title_text="Scaling Analysis", title_x=0.5, template="plotly_white", ) @@ -81,9 +82,10 @@ def create_scaling_plots(df: pl.DataFrame, output_path: Path, metrics: list[str] def main(): - parser = argparse.ArgumentParser(description="Generate strong scaling plots from parquet data") + parser = argparse.ArgumentParser(description="Generate scaling plots from parquet data") parser.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet file") - parser.add_argument("--output", type=Path, default=Path("scaling_plots/strong_scaling.html"), help="Output HTML file") + parser.add_argument("--output_dir", type=Path, default=Path("scaling_plots"), help="Output directory for HTML files") + parser.add_argument("--output_file_name", type=Path, default=Path("scaling_plots.html"), help="Output HTML file name") parser.add_argument("--metrics", nargs="+", default=["training_time", "overall_time_seconds", "loss_avg_mean"], help="Metrics to plot") parser.add_argument("--generate-dummy", action="store_true", help="Generate dummy test data") @@ -112,7 +114,8 @@ def main(): df = pl.read_parquet(args.input) print(f"Loaded {len(df)} rows") - create_scaling_plots(df, args.output, args.metrics) + args.output_dir.mkdir(parents=True, exist_ok=True) + create_scaling_plots(df, os.path.join(args.output_dir, args.output_file_name), args.metrics) if __name__ == "__main__": From 177df79a7956e0888f7658090ec40fc640c364dc Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 09:56:36 +0200 Subject: [PATCH 36/76] Remove parent dir creation --- scripts/generate_scaling_plots.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index e06a23991..70e95a402 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -11,7 +11,6 @@ def create_scaling_plots(df: pl.DataFrame, output_path: Path, metrics: list[str]): """Create a single plot with subplots for each metric.""" - output_path.parent.mkdir(parents=True, exist_ok=True) # Count valid metrics valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] From 8a4bc56cac7a81a81c42e3a24e7d5225f41e701f Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 11:29:39 +0200 Subject: [PATCH 37/76] more detailed extraction script --- scripts/extract_scaling_data.py | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index cff75c359..3df08ce1e 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -60,6 +60,69 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No return None +def extract_detailed_metrics(run_id: str, shared_work_dir: Path, output_path: Path) -> int: + """Extract detailed metrics pairing timing rows with preceding loss rows. + + For each row containing elapsed_training_time_seconds, pair it with the + preceding row containing loss metrics. Returns the number of detailed entries extracted. + """ + metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" + if not metrics_path.exists(): + return 0 + + try: + df = pl.read_ndjson(metrics_path) + if len(df) == 0: + return 0 + + # Find rows with elapsed_training_time_seconds (timing rows) + timing_mask = pl.col("elapsed_training_time_seconds").is_not_null() + timing_indices = df.with_row_index().filter(timing_mask).get_column("index").to_list() + + if len(timing_indices) == 0: + return 0 + + # Get all row indices with loss data + loss_mask = pl.col("loss_avg_mean").is_not_null() + loss_rows_df = df.with_row_index().filter(loss_mask) + + detailed_records = [] + + for timing_idx in timing_indices: + # Find the last loss row before this timing row + loss_rows_before = loss_rows_df.filter(pl.col("index") < timing_idx) + + if len(loss_rows_before) == 0: + continue + + # Get the last loss row before timing + loss_row = loss_rows_before.sort("index").tail(1) + timing_row = df.with_row_index().filter(pl.col("index") == timing_idx).drop("index") + + # Drop index column from loss_row for merging + loss_row = loss_row.drop("index") + + # Merge loss and timing data + merged = loss_row.join(timing_row, how="cross") + detailed_records.append(merged) + + if len(detailed_records) == 0: + return 0 + + # Combine all records + detailed_df = pl.concat(detailed_records) + + # Write to output file + output_path.parent.mkdir(parents=True, exist_ok=True) + detailed_df.write_ndjson(output_path) + + return len(detailed_records) + + except Exception as e: + print(f"Error extracting detailed metrics for {run_id}: {e}") + return 0 + + def main(): parser = argparse.ArgumentParser(description="Extract strong scaling data from WeatherGenerator runs") parser.add_argument("--run-ids", nargs="+", help="List of run-ids to process") @@ -74,6 +137,7 @@ def main(): sys.exit("Error: No run-ids provided") results = [] + detailed_files_created = [] for run_id in run_ids: # Look for weathergen.*.err files (e.g., weathergen.part1.388004.err) log_dir = args.logs_base_dir / run_id @@ -90,6 +154,16 @@ def main(): "loss_avg_mean": metrics.get("loss_avg_mean"), } results.append(row) + + # Extract detailed metrics for this run + # Create output file with "detailed" suffix before the extension + output_stem = args.output.stem + output_suffix = args.output.suffix + detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") + count = extract_detailed_metrics(run_id, args.shared_work_dir, detailed_output) + if count > 0: + detailed_files_created.append((detailed_output, count)) + print(f"Extracted {count} detailed metric entries for {run_id}") if not results: sys.exit("No data extracted") @@ -100,6 +174,12 @@ def main(): args.output.parent.mkdir(parents=True, exist_ok=True) df.write_parquet(args.output) df.write_csv(args.output.with_suffix(".csv")) + + print(f"\nSummary:") + print(f" - Extracted {len(results)} run summaries to {args.output}") + if detailed_files_created: + total_detailed = sum(count for _, count in detailed_files_created) + print(f" - Extracted {total_detailed} detailed metric entries to {detailed_files_created[0][0]}") if __name__ == "__main__": From bca6d3d73b9b96837420a6194e63dcc021528c7f Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 16:03:37 +0200 Subject: [PATCH 38/76] Remove overall time logging --- src/weathergen/run_train.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 391b220c4..0f6a5aee9 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -21,7 +21,6 @@ import weathergen.common.config as config import weathergen.utils.cli as cli -from weathergen.utils.distributed import is_root from weathergen.common.logger import init_loggers from weathergen.train.trainer import Trainer @@ -123,7 +122,6 @@ def run_continue(args): Note: All model configurations are set in the function body. """ - t_overall_start = time.time() cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -155,18 +153,6 @@ def run_continue(args): traceback.print_exc() if cf.world_size == 1: pdb.post_mortem(tb) - finally: - # Log overall time (only on root rank) - if is_root(): - t_overall_end = time.time() - overall_time = t_overall_end - t_overall_start - trainer.train_logger.log_metrics( - "train", - { - "overall_time_seconds": overall_time, - }, - ) - logger.info(f"Training completed. Overall time: {overall_time / 3600:.2f} hours") def run_train(args): @@ -209,18 +195,6 @@ def run_train(args): traceback.print_exc() if cf.world_size == 1: pdb.post_mortem(tb) - finally: - # Log overall time (only on root rank) - if is_root(): - t_end = time.time() - overall_time = t_end - t_start - trainer.train_logger.log_metrics( - "train", - { - "overall_time_seconds": overall_time, - }, - ) - logger.info(f"Training completed. Overall time: {overall_time / 3600:.2f} hours") if __name__ == "__main__": From 94368118675239fdc8275b0192c1a773f6afcc6d Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 16:06:36 +0200 Subject: [PATCH 39/76] Cleanup trainer --- src/weathergen/train/trainer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 45f40bead..ac0d63577 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -354,6 +354,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl if self.world_size_original is None: mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: + # to avoid zero-division for small datasets len_per_rank = ( max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu @@ -398,7 +399,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl # log final model self.save_model(self.training_cfg.num_mini_epochs) - def validate_before_training(self): """ Perform validation before training (eg. to check validation pipeline or data normalization) @@ -559,12 +559,7 @@ def train(self, mini_epoch): logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") self._log(TRAIN) - time_before_advance = time.time() self.dataset.advance() - time_after_advance = time.time() - time_for_advance = time_after_advance - time_before_advance - if is_root(): - logger.info(f"Time to advance dataset after mini epoch {mini_epoch}: {time_for_advance} seconds") def validate(self, mini_epoch, mode_cfg, batch_size): """ @@ -751,7 +746,6 @@ def _log(self, stage: Stage): samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) if is_root(): - # plain logger if stage == VAL: self.train_logger.add_logs(stage, samples, losses_all, stddev_all) @@ -774,7 +768,7 @@ def _log(self, stage: Stage): "total_num_samples": samples, "average_samples_per_second": samples / elapsed_time if elapsed_time > 0 else 0, } - ) + ) loss_calculator.loss_hist = [] loss_calculator.losses_unweighted_hist = [] From 21c157522836643201ca2366ae65c8608fe7c44a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 16:56:31 +0200 Subject: [PATCH 40/76] Metrics extraction and plot generation scripts --- scripts/extract_scaling_data.py | 81 +++++--- scripts/generate_scaling_plots.py | 324 ++++++++++++++++++++++-------- 2 files changed, 287 insertions(+), 118 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 3df08ce1e..117507fec 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -60,27 +60,27 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No return None -def extract_detailed_metrics(run_id: str, shared_work_dir: Path, output_path: Path) -> int: +def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int | None = None) -> list: """Extract detailed metrics pairing timing rows with preceding loss rows. For each row containing elapsed_training_time_seconds, pair it with the - preceding row containing loss metrics. Returns the number of detailed entries extracted. + preceding row containing loss metrics. Returns a list of detailed record DataFrames. """ metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" if not metrics_path.exists(): - return 0 + return [] try: df = pl.read_ndjson(metrics_path) if len(df) == 0: - return 0 + return [] # Find rows with elapsed_training_time_seconds (timing rows) timing_mask = pl.col("elapsed_training_time_seconds").is_not_null() timing_indices = df.with_row_index().filter(timing_mask).get_column("index").to_list() if len(timing_indices) == 0: - return 0 + return [] # Get all row indices with loss data loss_mask = pl.col("loss_avg_mean").is_not_null() @@ -96,31 +96,36 @@ def extract_detailed_metrics(run_id: str, shared_work_dir: Path, output_path: Pa continue # Get the last loss row before timing - loss_row = loss_rows_before.sort("index").tail(1) + loss_row = loss_rows_before.sort("index").tail(1).drop("index") + + # Get the timing row timing_row = df.with_row_index().filter(pl.col("index") == timing_idx).drop("index") - # Drop index column from loss_row for merging - loss_row = loss_row.drop("index") + # Select only the columns we need + timing_cols = ["elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second"] + timing_available_cols = [c for c in timing_cols if c in timing_row.columns] + timing_row = timing_row.select(timing_available_cols) + + # Keep only loss_avg_mean from loss_row + loss_row = loss_row.select("loss_avg_mean") # Merge loss and timing data - merged = loss_row.join(timing_row, how="cross") + merged = loss_row.hstack(timing_row) + + # Add run_id and num_nodes + merged = merged.with_columns(pl.lit(run_id).alias("run_id")) + if num_nodes is not None: + merged = merged.with_columns(pl.lit(num_nodes).alias("num_nodes")) + detailed_records.append(merged) - if len(detailed_records) == 0: - return 0 - - # Combine all records - detailed_df = pl.concat(detailed_records) - - # Write to output file - output_path.parent.mkdir(parents=True, exist_ok=True) - detailed_df.write_ndjson(output_path) - - return len(detailed_records) + return detailed_records except Exception as e: print(f"Error extracting detailed metrics for {run_id}: {e}") - return 0 + import traceback + traceback.print_exc() + return [] def main(): @@ -137,7 +142,7 @@ def main(): sys.exit("Error: No run-ids provided") results = [] - detailed_files_created = [] + all_detailed_records = [] for run_id in run_ids: # Look for weathergen.*.err files (e.g., weathergen.part1.388004.err) log_dir = args.logs_base_dir / run_id @@ -156,14 +161,10 @@ def main(): results.append(row) # Extract detailed metrics for this run - # Create output file with "detailed" suffix before the extension - output_stem = args.output.stem - output_suffix = args.output.suffix - detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") - count = extract_detailed_metrics(run_id, args.shared_work_dir, detailed_output) - if count > 0: - detailed_files_created.append((detailed_output, count)) - print(f"Extracted {count} detailed metric entries for {run_id}") + detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) + if detailed_records: + all_detailed_records.extend(detailed_records) + print(f"Extracted {len(detailed_records)} detailed metric entries for {run_id}") if not results: sys.exit("No data extracted") @@ -175,11 +176,25 @@ def main(): df.write_parquet(args.output) df.write_csv(args.output.with_suffix(".csv")) + # Write detailed metrics if any were collected + if all_detailed_records: + detailed_df = pl.concat(all_detailed_records) + # Reorder columns for clarity + desired_cols = ["run_id", "num_nodes", "elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second", "loss_avg_mean"] + available_cols = [c for c in desired_cols if c in detailed_df.columns] + detailed_df = detailed_df.select(available_cols) + + output_stem = args.output.stem + output_suffix = args.output.suffix + detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") + + detailed_df.write_parquet(detailed_output) + detailed_df.write_csv(detailed_output.with_suffix(".csv")) + print(f"\nSummary:") print(f" - Extracted {len(results)} run summaries to {args.output}") - if detailed_files_created: - total_detailed = sum(count for _, count in detailed_files_created) - print(f" - Extracted {total_detailed} detailed metric entries to {detailed_files_created[0][0]}") + if all_detailed_records: + print(f" - Extracted {len(all_detailed_records)} detailed metric entries to {detailed_output}") if __name__ == "__main__": diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index 70e95a402..13984d1f7 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -1,120 +1,274 @@ #!/usr/bin/env uv run python -"""Generate strong scaling plots from parquet data. Single file with subplots per metric.""" +"""Generate scaling plots from parquet/ndjson data using matplotlib only. + +Two entrypoints: +- standard: plots run-level metrics vs num_nodes +- detailed: plots sample-level metrics vs total_num_samples +""" import argparse -import os from pathlib import Path +import matplotlib.pyplot as plt import polars as pl -import plotly.graph_objects as go -from plotly.subplots import make_subplots -def create_scaling_plots(df: pl.DataFrame, output_path: Path, metrics: list[str]): - """Create a single plot with subplots for each metric.""" +SCRIPT_DIR = Path(__file__).resolve().parent +VALID_IMAGE_SUFFIXES = {".png", ".pdf", ".svg", ".jpg", ".jpeg"} +PALETTE = [ + "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", + "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", +] + + +def resolve_input_path(path: Path) -> Path: + """Resolve relative input paths against cwd first, then the script directory.""" + if path.is_absolute(): + return path + + cwd_candidate = Path.cwd() / path + if cwd_candidate.exists(): + return cwd_candidate + + script_candidate = SCRIPT_DIR / path + if script_candidate.exists(): + return script_candidate + + return cwd_candidate + + +def resolve_output_path(path: Path) -> Path: + """Ensure the output path uses a supported image suffix.""" + if path.suffix.lower() in VALID_IMAGE_SUFFIXES: + return path + return path.with_suffix(".png") + + +def read_table(path: Path) -> pl.DataFrame: + """Read parquet or ndjson automatically.""" + try: + print("Read as parquet") + return pl.read_parquet(path) + except Exception: + print("Read as NDJSON") + return pl.read_ndjson(path) + + +def color_map_for_nodes(node_counts: list) -> dict: + return {node: PALETTE[i % len(PALETTE)] for i, node in enumerate(node_counts)} + + +def save_figure(fig: plt.Figure, output_path: Path) -> None: + output_path = resolve_output_path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {output_path}") + + +def plot_standard_scaling( + df: pl.DataFrame, + output_path: Path, + scaling_type: str, + metrics: list[str], + x_scale: str, + y_scale: str, +) -> None: + """Plot run-level scaling data vs num_nodes.""" + metric_labels = { + "training_time": "Training Time (seconds)", + "loss_avg_mean": "Average Loss", + } - # Count valid metrics valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] if not valid_metrics: print("No valid metrics to plot") return - n_metrics = len(valid_metrics) - - fig = make_subplots( - rows=n_metrics, cols=1, - subplot_titles=valid_metrics, - vertical_spacing=0.1, - ) + fig, axes = plt.subplots(len(valid_metrics), 1, figsize=(12, 6 * len(valid_metrics)), squeeze=False) for idx, metric in enumerate(valid_metrics): + ax = axes[idx][0] df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") + node_counts = df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] + colors = color_map_for_nodes(node_counts) - # Add scatter trace with lines and text labels - fig.add_trace( - go.Scatter( - x=df_plot["num_nodes"], - y=df_plot[metric], - mode="lines+markers+text", - text=df_plot["run_id"], - textposition="top center", - name=metric, - showlegend=False, - marker=dict(size=10, color="steelblue"), - line=dict(width=2), - ), - row=idx + 1, col=1, + ax.plot( + df_plot["num_nodes"], + df_plot[metric], + "o-", + color="steelblue", + markersize=8, ) - # Add optimal scaling reference line for training_time + for x, y, label in zip(df_plot["num_nodes"], df_plot[metric], df_plot["run_id"]): + ax.text(x, y, label, ha="center", va="bottom", fontsize=8) + if metric == "training_time" and "training_time" in df.columns: - # Find the 1-node training time one_node_data = df.filter(pl.col("num_nodes") == 1) if one_node_data.height > 0: t1 = one_node_data["training_time"].item() - # Create optimal scaling line: t1 / n for each n nodes = df_plot["num_nodes"].to_list() - optimal_y = [t1 / n for n in nodes] - fig.add_trace( - go.Scatter( - x=nodes, - y=optimal_y, - mode="lines", - name="Optimal scaling", - line=dict(width=1, color="red", dash="dash"), - showlegend=True, - ), - row=idx + 1, col=1, - ) - - fig.update_xaxes(title_text="Number of Nodes (log scale)", type="log", row=idx + 1, col=1) - fig.update_yaxes(title_text=metric, row=idx + 1, col=1) - - fig.update_layout( - height=400 * n_metrics, - title_text="Scaling Analysis", - title_x=0.5, - template="plotly_white", - ) - - fig.write_html(output_path) - print(f"Saved: {output_path}") + if scaling_type == "weak": + optimal_y = [t1 for _ in nodes] + elif scaling_type == "strong": + optimal_y = [t1 / n for n in nodes] + else: + raise ValueError(f"Invalid scaling type: {scaling_type}") + ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") + ax.legend() + + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_xlabel("Number of Nodes") + ax.set_ylabel(metric_labels.get(metric, metric)) + ax.set_title(metric) + ax.grid(True, alpha=0.3) + fig.suptitle("Scaling Analysis", fontsize=16) + plt.tight_layout() + save_figure(fig, output_path) -def main(): - parser = argparse.ArgumentParser(description="Generate scaling plots from parquet data") - parser.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet file") - parser.add_argument("--output_dir", type=Path, default=Path("scaling_plots"), help="Output directory for HTML files") - parser.add_argument("--output_file_name", type=Path, default=Path("scaling_plots.html"), help="Output HTML file name") - parser.add_argument("--metrics", nargs="+", default=["training_time", "overall_time_seconds", "loss_avg_mean"], help="Metrics to plot") - parser.add_argument("--generate-dummy", action="store_true", help="Generate dummy test data") +def plot_detailed_scaling( + df: pl.DataFrame, + output_path: Path, + x_scale: str, + y_scale: str, +) -> None: + """Plot sample-level detailed scaling data vs total_num_samples.""" + required_cols = ["total_num_samples", "elapsed_training_time_seconds", "loss_avg_mean", "num_nodes"] + if not all(col in df.columns for col in required_cols): + print("Detailed metrics not available in this dataset") + print(f"Available columns: {df.columns}") + return + + df_plot = df.filter( + pl.col("total_num_samples").is_not_null() + & (pl.col("total_num_samples") > 0) + & pl.col("elapsed_training_time_seconds").is_not_null() + & pl.col("loss_avg_mean").is_not_null() + & pl.col("num_nodes").is_not_null() + ).sort("num_nodes", "total_num_samples") + + if len(df_plot) == 0: + print("No valid data for detailed scaling plots") + return + + node_counts = sorted(df_plot["num_nodes"].unique().to_list()) + colors = color_map_for_nodes(node_counts) + + fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True) + + ax = axes[0] + for node_count in node_counts: + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort("total_num_samples") + ax.plot( + df_node["total_num_samples"], + df_node["elapsed_training_time_seconds"], + "o-", + color=colors[node_count], + markersize=6, + label=f"{node_count} nodes", + ) + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_ylabel("Elapsed Training Time (seconds)") + ax.set_title("Elapsed Training Time vs Samples") + ax.grid(True, alpha=0.3) + ax.legend(title="Node Count") + + ax = axes[1] + for node_count in node_counts: + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort("total_num_samples") + ax.plot( + df_node["total_num_samples"], + df_node["loss_avg_mean"], + "o-", + color=colors[node_count], + markersize=6, + label=f"{node_count} nodes", + ) + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_xlabel("Total Number of Samples") + ax.set_ylabel("Average Loss") + ax.set_title("Loss vs Samples") + ax.grid(True, alpha=0.3) + ax.legend(title="Node Count") + + fig.suptitle("Detailed Scaling Analysis", fontsize=16) + plt.tight_layout() + save_figure(fig, output_path) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Generate scaling plots from parquet or NDJSON data") + subparsers = parser.add_subparsers(dest="mode", required=True) + + standard = subparsers.add_parser("standard", help="Plot run-level scaling metrics vs num_nodes") + standard.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") + standard.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") + standard.add_argument("--output", type=Path, default=None, help="Output image path") + standard.add_argument("--metrics", nargs="+", default=["training_time", "loss_avg_mean"], help="Metrics to plot") + standard.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") + standard.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") + + detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") + detailed.add_argument("--input", type=Path, default=Path("scaling_data_detailed.parquet"), help="Input detailed parquet/ndjson file") + detailed.add_argument("--output", type=Path, default=None, help="Output image path") + detailed.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") + detailed.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") + + return parser + + +def main() -> None: + parser = build_parser() args = parser.parse_args() - if args.generate_dummy: - print("Generating dummy test data...") - dummy_data = { - "run_id": ["run_1node", "run_2node", "run_4node", "run_8node", "run_16node"], - "num_nodes": [1, 2, 4, 8, 16], - "training_time": [1000, 520, 270, 140, 75], - "overall_time_seconds": [1100, 580, 310, 165, 90], - "loss_avg_mean": [0.45, 0.44, 0.44, 0.43, 0.43], - } - df = pl.DataFrame(dummy_data) - args.input.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(args.input) - print(f"Created dummy data: {args.input}") - - if not args.input.exists(): - print(f"Error: Input file not found: {args.input}") - print("Use --generate-dummy to create test data") + if args.mode == "standard": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".png") + + print(f"Loading data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read input file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} rows") + plot_standard_scaling(df, output_path, args.type, args.metrics, args.x_scale, args.y_scale) + return + + if args.mode == "detailed": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".png") + + print(f"Loading detailed data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read detailed file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} detailed rows") + plot_detailed_scaling(df, output_path, args.x_scale, args.y_scale) return - print(f"Loading data from: {args.input}") - df = pl.read_parquet(args.input) - print(f"Loaded {len(df)} rows") + raise ValueError(f"Unknown mode: {args.mode}") - args.output_dir.mkdir(parents=True, exist_ok=True) - create_scaling_plots(df, os.path.join(args.output_dir, args.output_file_name), args.metrics) if __name__ == "__main__": From e67616aed556b63f400dc0e4a624ece00f9fd7e0 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 24 Apr 2026 17:10:43 +0200 Subject: [PATCH 41/76] Add efficiency factor in plot --- scripts/generate_scaling_plots.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index 13984d1f7..10c1a55ef 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -115,6 +115,24 @@ def plot_standard_scaling( else: raise ValueError(f"Invalid scaling type: {scaling_type}") ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") + + # Show per-point efficiency loss as a vertical line and factor label. + for x, y, y_opt in zip(nodes, df_plot[metric].to_list(), optimal_y): + if y_opt == 0: + continue + factor = y / y_opt + ax.vlines(x, y_opt, y, colors="gray", linestyles=":", linewidth=1, alpha=0.7) + y_mid = (y + y_opt) / 2 + ax.annotate( + f"{factor:.2f}x", + xy=(x, y_mid), + xytext=(4, 0), + textcoords="offset points", + fontsize=9, + fontweight="bold", + color="dimgray", + va="center", + ) ax.legend() ax.set_xscale(x_scale) From b1e4ea479caf16f62594535fc5cd09a92993336a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 27 Apr 2026 09:44:25 +0200 Subject: [PATCH 42/76] RM checkpoint and log metrics at last iteration --- src/weathergen/train/trainer.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ac0d63577..46d874c5c 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -544,20 +544,22 @@ def train(self, mini_epoch): self._log_collapse_metrics(TRAIN) # save model checkpoint (with designation _latest) - if bidx % self.train_logging.checkpoint == 0 and bidx > 0: - self.save_model(-1) + # if bidx % self.train_logging.checkpoint == 0 and bidx > 0: + # self.save_model(-1) self.cf.general.istep += 1 - torch.distributed.barrier() - if is_root(): - total_training_time = time.time() - self.t_training_start - self.train_logger.log_metrics("train", { - "completed_mini_epoch": mini_epoch, - "elapsed_time_mini_epoch": total_training_time, - }) - logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") - self._log(TRAIN) + # log metrics at last iteration (keep barrier for now) + if bidx == len(self.data_loader) - 1: + torch.distributed.barrier() + if is_root(): + total_training_time = time.time() - self.t_training_start + self.train_logger.log_metrics("train", { + "completed_mini_epoch": mini_epoch, + "elapsed_time_mini_epoch": total_training_time, + }) + logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") + self._log(TRAIN) self.dataset.advance() From c89fe20495ec5f5528733eeb465c07dfd1ade945 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Mon, 27 Apr 2026 09:56:58 +0200 Subject: [PATCH 43/76] Detailed metrics --- scripts/extract_scaling_data.py | 81 +++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 3df08ce1e..117507fec 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -60,27 +60,27 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No return None -def extract_detailed_metrics(run_id: str, shared_work_dir: Path, output_path: Path) -> int: +def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int | None = None) -> list: """Extract detailed metrics pairing timing rows with preceding loss rows. For each row containing elapsed_training_time_seconds, pair it with the - preceding row containing loss metrics. Returns the number of detailed entries extracted. + preceding row containing loss metrics. Returns a list of detailed record DataFrames. """ metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" if not metrics_path.exists(): - return 0 + return [] try: df = pl.read_ndjson(metrics_path) if len(df) == 0: - return 0 + return [] # Find rows with elapsed_training_time_seconds (timing rows) timing_mask = pl.col("elapsed_training_time_seconds").is_not_null() timing_indices = df.with_row_index().filter(timing_mask).get_column("index").to_list() if len(timing_indices) == 0: - return 0 + return [] # Get all row indices with loss data loss_mask = pl.col("loss_avg_mean").is_not_null() @@ -96,31 +96,36 @@ def extract_detailed_metrics(run_id: str, shared_work_dir: Path, output_path: Pa continue # Get the last loss row before timing - loss_row = loss_rows_before.sort("index").tail(1) + loss_row = loss_rows_before.sort("index").tail(1).drop("index") + + # Get the timing row timing_row = df.with_row_index().filter(pl.col("index") == timing_idx).drop("index") - # Drop index column from loss_row for merging - loss_row = loss_row.drop("index") + # Select only the columns we need + timing_cols = ["elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second"] + timing_available_cols = [c for c in timing_cols if c in timing_row.columns] + timing_row = timing_row.select(timing_available_cols) + + # Keep only loss_avg_mean from loss_row + loss_row = loss_row.select("loss_avg_mean") # Merge loss and timing data - merged = loss_row.join(timing_row, how="cross") + merged = loss_row.hstack(timing_row) + + # Add run_id and num_nodes + merged = merged.with_columns(pl.lit(run_id).alias("run_id")) + if num_nodes is not None: + merged = merged.with_columns(pl.lit(num_nodes).alias("num_nodes")) + detailed_records.append(merged) - if len(detailed_records) == 0: - return 0 - - # Combine all records - detailed_df = pl.concat(detailed_records) - - # Write to output file - output_path.parent.mkdir(parents=True, exist_ok=True) - detailed_df.write_ndjson(output_path) - - return len(detailed_records) + return detailed_records except Exception as e: print(f"Error extracting detailed metrics for {run_id}: {e}") - return 0 + import traceback + traceback.print_exc() + return [] def main(): @@ -137,7 +142,7 @@ def main(): sys.exit("Error: No run-ids provided") results = [] - detailed_files_created = [] + all_detailed_records = [] for run_id in run_ids: # Look for weathergen.*.err files (e.g., weathergen.part1.388004.err) log_dir = args.logs_base_dir / run_id @@ -156,14 +161,10 @@ def main(): results.append(row) # Extract detailed metrics for this run - # Create output file with "detailed" suffix before the extension - output_stem = args.output.stem - output_suffix = args.output.suffix - detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") - count = extract_detailed_metrics(run_id, args.shared_work_dir, detailed_output) - if count > 0: - detailed_files_created.append((detailed_output, count)) - print(f"Extracted {count} detailed metric entries for {run_id}") + detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) + if detailed_records: + all_detailed_records.extend(detailed_records) + print(f"Extracted {len(detailed_records)} detailed metric entries for {run_id}") if not results: sys.exit("No data extracted") @@ -175,11 +176,25 @@ def main(): df.write_parquet(args.output) df.write_csv(args.output.with_suffix(".csv")) + # Write detailed metrics if any were collected + if all_detailed_records: + detailed_df = pl.concat(all_detailed_records) + # Reorder columns for clarity + desired_cols = ["run_id", "num_nodes", "elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second", "loss_avg_mean"] + available_cols = [c for c in desired_cols if c in detailed_df.columns] + detailed_df = detailed_df.select(available_cols) + + output_stem = args.output.stem + output_suffix = args.output.suffix + detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") + + detailed_df.write_parquet(detailed_output) + detailed_df.write_csv(detailed_output.with_suffix(".csv")) + print(f"\nSummary:") print(f" - Extracted {len(results)} run summaries to {args.output}") - if detailed_files_created: - total_detailed = sum(count for _, count in detailed_files_created) - print(f" - Extracted {total_detailed} detailed metric entries to {detailed_files_created[0][0]}") + if all_detailed_records: + print(f" - Extracted {len(all_detailed_records)} detailed metric entries to {detailed_output}") if __name__ == "__main__": From b0bc6c2ce79b7cc7a926f373e44e2384c8c52082 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Mon, 27 Apr 2026 12:32:39 +0200 Subject: [PATCH 44/76] Remove barrier and extra logging on last batch --- src/weathergen/train/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 46d874c5c..cd81dee9b 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -537,7 +537,7 @@ def train(self, mini_epoch): ) self._log_terminal(bidx, mini_epoch, TRAIN) - if bidx % self.train_logging.metrics == 0: + if bidx % self.train_logging.metrics == 0 or bidx == len(self.data_loader) - 1: self._log(TRAIN) # Log collapse metrics if self.collapse_monitor.should_log(self.cf.general.istep): @@ -551,7 +551,7 @@ def train(self, mini_epoch): # log metrics at last iteration (keep barrier for now) if bidx == len(self.data_loader) - 1: - torch.distributed.barrier() + # torch.distributed.barrier() if is_root(): total_training_time = time.time() - self.t_training_start self.train_logger.log_metrics("train", { From 133ee4c9f4b007e1c482b3c4d85fd61dea55efa8 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 27 Apr 2026 18:16:28 +0200 Subject: [PATCH 45/76] trainer code cleanup Co-authored-by: Copilot --- src/weathergen/train/trainer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index cd81dee9b..f47837ef9 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -543,15 +543,11 @@ def train(self, mini_epoch): if self.collapse_monitor.should_log(self.cf.general.istep): self._log_collapse_metrics(TRAIN) - # save model checkpoint (with designation _latest) - # if bidx % self.train_logging.checkpoint == 0 and bidx > 0: - # self.save_model(-1) - self.cf.general.istep += 1 # log metrics at last iteration (keep barrier for now) if bidx == len(self.data_loader) - 1: - # torch.distributed.barrier() + torch.distributed.barrier() if is_root(): total_training_time = time.time() - self.t_training_start self.train_logger.log_metrics("train", { @@ -559,7 +555,6 @@ def train(self, mini_epoch): "elapsed_time_mini_epoch": total_training_time, }) logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") - self._log(TRAIN) self.dataset.advance() From ec665da3c9dd9272405418431570982865b4989b Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Mon, 27 Apr 2026 21:41:49 +0200 Subject: [PATCH 46/76] Lower bound beta2 in adam Co-authored-by: Copilot --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f47837ef9..c6d1acbf0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -313,7 +313,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: fl # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) + beta2 = max(0.9, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2)) eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( From 20883110fea854ff4c33e9c95ea1eba890e2d87c Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Tue, 28 Apr 2026 17:48:01 +0200 Subject: [PATCH 47/76] update script for scaling plots, loss as separate entry point --- scripts/generate_scaling_plots.py | 84 +++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index 10c1a55ef..d1dacec19 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -72,11 +72,13 @@ def plot_standard_scaling( metrics: list[str], x_scale: str, y_scale: str, + y_metric: str, ) -> None: """Plot run-level scaling data vs num_nodes.""" metric_labels = { "training_time": "Training Time (seconds)", "loss_avg_mean": "Average Loss", + "normalized_throughput": "Normalized Throughput (T1 / T)", } valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] @@ -92,32 +94,58 @@ def plot_standard_scaling( node_counts = df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] colors = color_map_for_nodes(node_counts) + # Handle normalized_throughput metric + if y_metric == "normalized_throughput" and metric == "training_time": + # Calculate normalized throughput: T1 / T + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + # Create a new dataframe with normalized throughput + df_plot = df_plot.with_columns( + (t1 / pl.col("training_time")).alias("normalized_throughput") + ) + plot_y = df_plot["normalized_throughput"] + else: + print("Warning: No 1-node data found for normalized throughput calculation") + continue + else: + plot_y = df_plot[metric] + ax.plot( df_plot["num_nodes"], - df_plot[metric], + plot_y, "o-", color="steelblue", markersize=8, ) - for x, y, label in zip(df_plot["num_nodes"], df_plot[metric], df_plot["run_id"]): + for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"]): ax.text(x, y, label, ha="center", va="bottom", fontsize=8) - if metric == "training_time" and "training_time" in df.columns: + if metric == "training_time" and y_metric == "time" and "training_time" in df.columns: one_node_data = df.filter(pl.col("num_nodes") == 1) if one_node_data.height > 0: t1 = one_node_data["training_time"].item() nodes = df_plot["num_nodes"].to_list() if scaling_type == "weak": - optimal_y = [t1 for _ in nodes] + if y_metric == "normalized_throughput": + # For normalized throughput, optimal is 1.0 (no speedup loss) + optimal_y = [1.0 for _ in nodes] + else: + optimal_y = [t1 for _ in nodes] elif scaling_type == "strong": - optimal_y = [t1 / n for n in nodes] + if y_metric == "normalized_throughput": + # For normalized throughput, optimal is n (linear speedup) + optimal_y = [float(n) for n in nodes] + else: + optimal_y = [t1 / n for n in nodes] else: raise ValueError(f"Invalid scaling type: {scaling_type}") ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") # Show per-point efficiency loss as a vertical line and factor label. - for x, y, y_opt in zip(nodes, df_plot[metric].to_list(), optimal_y): + # Use plot_y (normalized throughput if applicable) instead of df_plot[metric] + for x, y, y_opt in zip(nodes, plot_y.to_list(), optimal_y): if y_opt == 0: continue factor = y / y_opt @@ -139,8 +167,11 @@ def plot_standard_scaling( if y_scale == "log": ax.set_yscale("log") ax.set_xlabel("Number of Nodes") - ax.set_ylabel(metric_labels.get(metric, metric)) - ax.set_title(metric) + if y_metric == "normalized_throughput" and metric == "training_time": + ax.set_ylabel("Normalized Throughput (T1 / T)") + else: + ax.set_ylabel(metric_labels.get(metric, metric)) + ax.set_title(metric if y_metric != "normalized_throughput" or metric != "training_time" else "Normalized Throughput") ax.grid(True, alpha=0.3) fig.suptitle("Scaling Analysis", fontsize=16) @@ -230,9 +261,17 @@ def build_parser() -> argparse.ArgumentParser: standard.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") standard.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") standard.add_argument("--output", type=Path, default=None, help="Output image path") - standard.add_argument("--metrics", nargs="+", default=["training_time", "loss_avg_mean"], help="Metrics to plot") - standard.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") + standard.add_argument("--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale") standard.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") + standard.add_argument("--y-metric", choices=["time", "normalized_throughput"], default="normalized_throughput", help="Y-axis metric: 'time' for time-to-solution or 'normalized_throughput' for T1/T") + + # Subparser for loss-only plots (separate entry point) + loss_only = subparsers.add_parser("loss", help="Plot loss metrics vs num_nodes (separate from throughput)") + loss_only.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") + loss_only.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") + loss_only.add_argument("--output", type=Path, default=None, help="Output image path") + loss_only.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") + loss_only.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") detailed.add_argument("--input", type=Path, default=Path("scaling_data_detailed.parquet"), help="Input detailed parquet/ndjson file") @@ -263,7 +302,30 @@ def main() -> None: print(str(e)) return print(f"Loaded {len(df)} rows") - plot_standard_scaling(df, output_path, args.type, args.metrics, args.x_scale, args.y_scale) + # Standard mode: only plot training_time with normalized throughput or time + metrics_to_plot = ["training_time"] + plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric) + return + + if args.mode == "loss": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".loss.png") + + print(f"Loading data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read input file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} rows") + # Loss mode: only plot loss_avg_mean + metrics_to_plot = ["loss_avg_mean"] + plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, "time") return if args.mode == "detailed": From 5bd88d99bf654c5fe1693ccf2780552d34283f22 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Wed, 29 Apr 2026 09:12:25 +0200 Subject: [PATCH 48/76] specify nodes in scaling data script --- scripts/extract_scaling_data.py | 144 ++++++++++++++++---------------- 1 file changed, 70 insertions(+), 74 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index 117507fec..c98be77ee 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -6,7 +6,7 @@ import sys from pathlib import Path -import polars as pl +import pandas as pd def extract_num_nodes(err_log_path: Path) -> int | None: @@ -14,7 +14,6 @@ def extract_num_nodes(err_log_path: Path) -> int | None: return None try: content = err_log_path.read_text() - # Case-insensitive match for "Number of Nodes:" with flexible whitespace match = re.search(r"number\s+of\s+nodes\s*:\s*(\d+)", content, re.IGNORECASE) return int(match.group(1)) if match else None except Exception: @@ -23,7 +22,7 @@ def extract_num_nodes(err_log_path: Path) -> int | None: def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: """Extract metrics from NDJSON file with startup and training lines. - + Format: - Line 1: startup_time_seconds - Line 2+: loss_avg_mean, LossPhysical.loss_avg, etc. @@ -32,24 +31,27 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No if not metrics_path.exists(): return None try: - df = pl.read_ndjson(metrics_path) + df = pd.read_json(metrics_path, lines=True) if len(df) == 0: return None - + # Extract startup_time from first row (startup line) startup_time = None if "startup_time_seconds" in df.columns: - startup_time = df.select(pl.col("startup_time_seconds").first()).item() - + val = df["startup_time_seconds"].dropna() + startup_time = val.iloc[0] if len(val) > 0 else None + # Extract loss_avg_mean from last non-NaN training row loss_avg_mean = None if "loss_avg_mean" in df.columns: - loss_avg_mean = df.select(pl.col("loss_avg_mean").drop_nulls().last()).item() - - # Extract training for mini-epoch from last non-NaN row + val = df["loss_avg_mean"].dropna() + loss_avg_mean = val.iloc[-1] if len(val) > 0 else None + + # Extract training time for mini-epoch from last non-NaN row overall_training_time = None if "elapsed_time_mini_epoch" in df.columns: - overall_training_time = df.select(pl.col("elapsed_time_mini_epoch").drop_nulls().last()).item() + val = df["elapsed_time_mini_epoch"].dropna() + overall_training_time = val.iloc[-1] if len(val) > 0 else None return { "startup_time_seconds": startup_time, @@ -60,67 +62,61 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No return None -def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int | None = None) -> list: +def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int | None = None) -> list[pd.DataFrame]: """Extract detailed metrics pairing timing rows with preceding loss rows. - - For each row containing elapsed_training_time_seconds, pair it with the - preceding row containing loss metrics. Returns a list of detailed record DataFrames. + + For each row containing elapsed_training_time_seconds, pair it with the + preceding row containing loss metrics. Returns a list of DataFrames. """ metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" if not metrics_path.exists(): return [] - + try: - df = pl.read_ndjson(metrics_path) + df = pd.read_json(metrics_path, lines=True) if len(df) == 0: return [] - + # Find rows with elapsed_training_time_seconds (timing rows) - timing_mask = pl.col("elapsed_training_time_seconds").is_not_null() - timing_indices = df.with_row_index().filter(timing_mask).get_column("index").to_list() - - if len(timing_indices) == 0: + if "elapsed_training_time_seconds" not in df.columns: return [] - - # Get all row indices with loss data - loss_mask = pl.col("loss_avg_mean").is_not_null() - loss_rows_df = df.with_row_index().filter(loss_mask) - + timing_indices = df.index[df["elapsed_training_time_seconds"].notna()].tolist() + + if not timing_indices: + return [] + + # Find rows with loss data + if "loss_avg_mean" not in df.columns: + return [] + loss_indices = set(df.index[df["loss_avg_mean"].notna()].tolist()) + + timing_cols = ["elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second"] + detailed_records = [] - + for timing_idx in timing_indices: # Find the last loss row before this timing row - loss_rows_before = loss_rows_df.filter(pl.col("index") < timing_idx) - - if len(loss_rows_before) == 0: + loss_rows_before = [i for i in loss_indices if i < timing_idx] + if not loss_rows_before: continue - - # Get the last loss row before timing - loss_row = loss_rows_before.sort("index").tail(1).drop("index") - - # Get the timing row - timing_row = df.with_row_index().filter(pl.col("index") == timing_idx).drop("index") - - # Select only the columns we need - timing_cols = ["elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second"] - timing_available_cols = [c for c in timing_cols if c in timing_row.columns] - timing_row = timing_row.select(timing_available_cols) - - # Keep only loss_avg_mean from loss_row - loss_row = loss_row.select("loss_avg_mean") - - # Merge loss and timing data - merged = loss_row.hstack(timing_row) - - # Add run_id and num_nodes - merged = merged.with_columns(pl.lit(run_id).alias("run_id")) + + last_loss_idx = max(loss_rows_before) + + # Build record dict from loss row + timing row + record = {"run_id": run_id} if num_nodes is not None: - merged = merged.with_columns(pl.lit(num_nodes).alias("num_nodes")) - - detailed_records.append(merged) - + record["num_nodes"] = num_nodes + + record["loss_avg_mean"] = df.at[last_loss_idx, "loss_avg_mean"] + + for col in timing_cols: + if col in df.columns: + record[col] = df.at[timing_idx, col] + + detailed_records.append(pd.DataFrame([record])) + return detailed_records - + except Exception as e: print(f"Error extracting detailed metrics for {run_id}: {e}") import traceback @@ -143,14 +139,15 @@ def main(): results = [] all_detailed_records = [] + for run_id in run_ids: - # Look for weathergen.*.err files (e.g., weathergen.part1.388004.err) log_dir = args.logs_base_dir / run_id err_files = list(log_dir.glob("weathergen.*.err")) if log_dir.exists() else [] num_nodes = extract_num_nodes(err_files[0]) if err_files else None metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: continue + row = { "run_id": run_id, "num_nodes": num_nodes, @@ -159,8 +156,7 @@ def main(): "loss_avg_mean": metrics.get("loss_avg_mean"), } results.append(row) - - # Extract detailed metrics for this run + detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) if detailed_records: all_detailed_records.extend(detailed_records) @@ -169,28 +165,28 @@ def main(): if not results: sys.exit("No data extracted") - df = pl.DataFrame(results) + df = pd.DataFrame(results) if "num_nodes" in df.columns: - df = df.sort("num_nodes") + df = df.sort_values("num_nodes").reset_index(drop=True) + args.output.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(args.output) - df.write_csv(args.output.with_suffix(".csv")) - + df.to_parquet(args.output, index=False) + df.to_csv(args.output.with_suffix(".csv"), index=False) + # Write detailed metrics if any were collected if all_detailed_records: - detailed_df = pl.concat(all_detailed_records) - # Reorder columns for clarity + detailed_df = pd.concat(all_detailed_records, ignore_index=True) + desired_cols = ["run_id", "num_nodes", "elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second", "loss_avg_mean"] available_cols = [c for c in desired_cols if c in detailed_df.columns] - detailed_df = detailed_df.select(available_cols) - + detailed_df = detailed_df[available_cols] + output_stem = args.output.stem - output_suffix = args.output.suffix - detailed_output = args.output.with_name(f"{output_stem}_detailed{output_suffix}") - - detailed_df.write_parquet(detailed_output) - detailed_df.write_csv(detailed_output.with_suffix(".csv")) - + detailed_output = args.output.with_name(f"{output_stem}_detailed{args.output.suffix}") + + detailed_df.to_parquet(detailed_output, index=False) + detailed_df.to_csv(detailed_output.with_suffix(".csv"), index=False) + print(f"\nSummary:") print(f" - Extracted {len(results)} run summaries to {args.output}") if all_detailed_records: From 7e1ae1cf0d5984edcfa80f58588862bb7016a0f9 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Wed, 29 Apr 2026 17:13:49 +0200 Subject: [PATCH 49/76] Update extract scaling data --- scripts/extract_scaling_data.py | 117 +++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 17 deletions(-) diff --git a/scripts/extract_scaling_data.py b/scripts/extract_scaling_data.py index c98be77ee..42e69b092 100644 --- a/scripts/extract_scaling_data.py +++ b/scripts/extract_scaling_data.py @@ -9,15 +9,32 @@ import pandas as pd -def extract_num_nodes(err_log_path: Path) -> int | None: - if not err_log_path.exists(): - return None - try: - content = err_log_path.read_text() - match = re.search(r"number\s+of\s+nodes\s*:\s*(\d+)", content, re.IGNORECASE) - return int(match.group(1)) if match else None - except Exception: +def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: + """Extract num_nodes from output.*.txt file in the run directory. + + Looks for 'nNodes' pattern in output files. + """ + run_log_dir = logs_base_dir / run_id + if not run_log_dir.exists(): return None + + # Look for output.*.txt files + output_files = list(run_log_dir.glob("output.*.txt")) + if not output_files: + # Fallback to err files if no output files found + output_files = list(run_log_dir.glob("weathergen.*.err")) + + for output_file in output_files: + try: + content = output_file.read_text() + # Look for nNodes pattern: "nNodes 128" (space-separated, as in NCCL logs) + match = re.search(r"nNodes\s+(\d+)", content, re.IGNORECASE) + if match: + return int(match.group(1)) + except Exception: + continue + + return None def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: @@ -124,26 +141,92 @@ def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int return [] +def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: + """Parse run-ids argument which can be: + 1. A list of run-ids (old format): ["run1", "run2"] -> [(None, "run1"), (None, "run2")] + 2. A dict mapping num_nodes to run-ids (new format): "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] + + Returns list of (num_nodes, run_id) tuples. + """ + if len(run_ids_str) == 1: + # Check if it looks like a dict: "{key: value, ...}" + stripped = run_ids_str[0].strip() + if stripped.startswith("{") and stripped.endswith("}"): + # Parse as dict format: {num_nodes: run_id, ...} + import ast + try: + parsed = ast.literal_eval(stripped) + if isinstance(parsed, dict): + # Convert string keys to int if needed + result = [] + for k, v in parsed.items(): + key = int(k) if isinstance(k, str) and k.isdigit() else k + result.append((key, str(v))) + return result + except (ValueError, SyntaxError): + pass + + # Single run-id or comma-separated list + run_ids = [r.strip() for r in run_ids_str[0].split(",") if r.strip()] + return [(None, run_id) for run_id in run_ids] + + # Multiple arguments - treat as list of run-ids + return [(None, run_id) for run_id in run_ids_str] + + +def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: + """Extract num_nodes from output.*.txt file in the run directory. + + Looks for 'nNodes' pattern in output files. + """ + run_log_dir = logs_base_dir / run_id + if not run_log_dir.exists(): + return None + + # Look for output.*.txt files + output_files = list(run_log_dir.glob("output.*.txt")) + if not output_files: + # Fallback to err files if no output files found + output_files = list(run_log_dir.glob("weathergen.*.err")) + + for output_file in output_files: + try: + content = output_file.read_text() + # Look for nNodes pattern: "nNodes 128" (space-separated, as in NCCL logs) + match = re.search(r"nNodes\s+(\d+)", content, re.IGNORECASE) + if match: + return int(match.group(1)) + except Exception: + continue + + return None + + def main(): - parser = argparse.ArgumentParser(description="Extract strong scaling data from WeatherGenerator runs") - parser.add_argument("--run-ids", nargs="+", help="List of run-ids to process") + parser = argparse.ArgumentParser( + description="Extract strong scaling data from WeatherGenerator runs. " + "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict mapping num_nodes to run-ids " + "(--run-ids '{1: run1, 4: run2}'). If num_nodes is not provided in the dict, it will be extracted from output.*.txt files." + ) + parser.add_argument("--run-ids", nargs="+", help="Run-ids to process. Can be: (1) list: run1 run2 run3, or (2) dict: '{1: run1, 4: run2, 8: run3}'") parser.add_argument("--logs-base-dir", type=Path, default=Path("logs"), help="Base directory for run logs (default: logs relative to current dir)") parser.add_argument("--shared-work-dir", type=Path, default=Path("/e/scratch/weatherai/shared_work"), help="Base directory for shared work/results") parser.add_argument("--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path") args = parser.parse_args() - run_ids = args.run_ids - if not run_ids: + run_id_mapping = parse_run_ids(args.run_ids) + if not run_id_mapping: sys.exit("Error: No run-ids provided") results = [] all_detailed_records = [] - for run_id in run_ids: - log_dir = args.logs_base_dir / run_id - err_files = list(log_dir.glob("weathergen.*.err")) if log_dir.exists() else [] - num_nodes = extract_num_nodes(err_files[0]) if err_files else None + for num_nodes, run_id in run_id_mapping: + # If num_nodes not provided, extract from output.*.txt file + if num_nodes is None: + num_nodes = extract_num_nodes_from_output(run_id, args.logs_base_dir) + metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: continue @@ -160,7 +243,7 @@ def main(): detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) if detailed_records: all_detailed_records.extend(detailed_records) - print(f"Extracted {len(detailed_records)} detailed metric entries for {run_id}") + print(f"Extracted {len(detailed_records)} detailed metric entries for {run_id} ({num_nodes} nodes)") if not results: sys.exit("No data extracted") From b42432c6be7cd09e9eeb76332e1554f64a1ab9a2 Mon Sep 17 00:00:00 2001 From: florianscheidl Date: Thu, 30 Apr 2026 14:48:32 +0200 Subject: [PATCH 50/76] Add pyarrow --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 00103cb8c..429a6927f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ dependencies = [ "anemoi-datasets", "weathergen-common", "weathergen-evaluate", - "weathergen-readers-extra" + "weathergen-readers-extra", + "pyarrow>=23.0.1", ] From 6d5683b7434bfa72fbbb334c532f838fee566ab0 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Thu, 30 Apr 2026 17:51:09 +0200 Subject: [PATCH 51/76] Update script for scaling plots --- scripts/generate_scaling_plots.py | 166 ++++++++++++++++++++++++++---- 1 file changed, 144 insertions(+), 22 deletions(-) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index d1dacec19..35b5f3346 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -65,6 +65,95 @@ def save_figure(fig: plt.Figure, output_path: Path) -> None: print(f"Saved: {output_path}") +def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: bool = False) -> None: + """Generate a PNG table image with scaling metrics from the parquet file. + + Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) + """ + # Check if required columns exist + if "num_nodes" not in df.columns or "training_time" not in df.columns: + print("Warning: Required columns (num_nodes, training_time) not found in data") + return + + # Filter out rows with null values in required columns + df_filtered = df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + if len(df_filtered) == 0: + print("No valid data for scaling table") + return + + # Get the 1-node training time for ideal time calculation + one_node_data = df_filtered.filter(pl.col("num_nodes") == 1) + if one_node_data.height == 0: + print("Warning: No 1-node data found for ideal time calculation") + return + + t1 = one_node_data["training_time"].item() + + # Derive scaling type from input filename + input_name_lower = input_path.name.lower() + if "weak" in input_name_lower: + scaling_type = "Weak" + elif "strong" in input_name_lower: + scaling_type = "Strong" + else: + scaling_type = "Strong" # Default to strong + + # Build table data with proper formatting + has_run_id = "run_id" in df_filtered.columns + col_names = ["# Nodes", "Training Time (seconds)", "Ideal Time (seconds)", "Efficiency"] + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + + table_data = [] + for row in df_filtered.iter_rows(named=True): + num_nodes = row["num_nodes"] + training_time = row["training_time"] + + if num_nodes == 1: + ideal_time = "-" + efficiency = "-" + else: + if scaling_type == "Strong": + # Strong scaling: ideal time = t1 / num_nodes + ideal_val = t1 / num_nodes + efficiency_val = ideal_val / training_time + else: + # Weak scaling: ideal time = t1 (same work per node) + ideal_val = t1 + efficiency_val = min(1.0, t1 / training_time) + + ideal_time = f"{ideal_val:.2f}" + efficiency = f"{efficiency_val:.2f}" + + row_data = [] + if show_run_ids and has_run_id: + row_data.append(str(row.get("run_id", ""))) + row_data.extend([ + str(num_nodes), + f"{training_time:.2f}", + ideal_time, + efficiency + ]) + table_data.append(row_data) + + # Generate output filename: input_stem_table.csv + output_path = input_path.with_name(input_path.stem + "_table.csv") + + # Build DataFrame for CSV output + df_table_data = {} + for i, col in enumerate(col_names): + df_table_data[col] = [row[i] for row in table_data] + + df_table = pl.DataFrame(df_table_data) + + # Write to CSV + df_table.write_csv(output_path) + print(f"Saved scaling table: {output_path}") + + def plot_standard_scaling( df: pl.DataFrame, output_path: Path, @@ -73,12 +162,14 @@ def plot_standard_scaling( x_scale: str, y_scale: str, y_metric: str, + show_run_ids: bool = False, ) -> None: """Plot run-level scaling data vs num_nodes.""" metric_labels = { "training_time": "Training Time (seconds)", "loss_avg_mean": "Average Loss", - "normalized_throughput": "Normalized Throughput (T1 / T)", + "normalized_throughput": "Speedup", + "efficiency": "Scaling Efficiency", } valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] @@ -94,7 +185,7 @@ def plot_standard_scaling( node_counts = df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] colors = color_map_for_nodes(node_counts) - # Handle normalized_throughput metric + # Handle normalized_throughput and efficiency metrics if y_metric == "normalized_throughput" and metric == "training_time": # Calculate normalized throughput: T1 / T one_node_data = df.filter(pl.col("num_nodes") == 1) @@ -108,6 +199,25 @@ def plot_standard_scaling( else: print("Warning: No 1-node data found for normalized throughput calculation") continue + elif y_metric == "efficiency" and metric == "training_time": + # Calculate efficiency based on scaling type + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + if scaling_type == "strong": + # Strong scaling: efficiency = (t1 / num_nodes) / training_time + df_plot = df_plot.with_columns( + ((t1 / pl.col("num_nodes")) / pl.col("training_time")).alias("efficiency") + ) + else: + # Weak scaling: efficiency = min(1.0, t1 / training_time) + df_plot = df_plot.with_columns( + pl.min_horizontal(pl.lit(1.0), t1 / pl.col("training_time")).alias("efficiency") + ) + plot_y = df_plot["efficiency"] + else: + print("Warning: No 1-node data found for efficiency calculation") + continue else: plot_y = df_plot[metric] @@ -119,15 +229,19 @@ def plot_standard_scaling( markersize=8, ) - for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"]): - ax.text(x, y, label, ha="center", va="bottom", fontsize=8) + if show_run_ids: + for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"]): + ax.text(x, y, label, ha="center", va="bottom", fontsize=8) - if metric == "training_time" and y_metric == "time" and "training_time" in df.columns: + if metric == "training_time" and y_metric in ("time", "normalized_throughput", "efficiency") and "training_time" in df.columns: one_node_data = df.filter(pl.col("num_nodes") == 1) if one_node_data.height > 0: t1 = one_node_data["training_time"].item() nodes = df_plot["num_nodes"].to_list() - if scaling_type == "weak": + if y_metric == "efficiency": + # For efficiency, optimal is always 1.0 (100% efficiency) + optimal_y = [1.0 for _ in nodes] + elif scaling_type == "weak": if y_metric == "normalized_throughput": # For normalized throughput, optimal is 1.0 (no speedup loss) optimal_y = [1.0 for _ in nodes] @@ -152,11 +266,11 @@ def plot_standard_scaling( ax.vlines(x, y_opt, y, colors="gray", linestyles=":", linewidth=1, alpha=0.7) y_mid = (y + y_opt) / 2 ax.annotate( - f"{factor:.2f}x", + f"{factor:.2f}", xy=(x, y_mid), xytext=(4, 0), textcoords="offset points", - fontsize=9, + fontsize=14, fontweight="bold", color="dimgray", va="center", @@ -166,15 +280,16 @@ def plot_standard_scaling( ax.set_xscale(x_scale) if y_scale == "log": ax.set_yscale("log") - ax.set_xlabel("Number of Nodes") + ax.set_xlabel("Number of Nodes", fontsize=16) if y_metric == "normalized_throughput" and metric == "training_time": - ax.set_ylabel("Normalized Throughput (T1 / T)") + ax.set_ylabel("Speedup", fontsize=16) + elif y_metric == "efficiency" and metric == "training_time": + ax.set_ylabel("Scaling Efficiency", fontsize=16) else: - ax.set_ylabel(metric_labels.get(metric, metric)) - ax.set_title(metric if y_metric != "normalized_throughput" or metric != "training_time" else "Normalized Throughput") + ax.set_ylabel(metric_labels.get(metric, metric), fontsize=16) + ax.tick_params(axis='both', which='major', labelsize=14) ax.grid(True, alpha=0.3) - fig.suptitle("Scaling Analysis", fontsize=16) plt.tight_layout() save_figure(fig, output_path) @@ -223,8 +338,9 @@ def plot_detailed_scaling( ax.set_xscale(x_scale) if y_scale == "log": ax.set_yscale("log") - ax.set_ylabel("Elapsed Training Time (seconds)") - ax.set_title("Elapsed Training Time vs Samples") + ax.set_ylabel("Elapsed Training Time (seconds)", fontsize=16) + ax.set_title("Elapsed Training Time vs Samples", fontsize=16) + ax.tick_params(axis='both', which='major', labelsize=14) ax.grid(True, alpha=0.3) ax.legend(title="Node Count") @@ -242,9 +358,10 @@ def plot_detailed_scaling( ax.set_xscale(x_scale) if y_scale == "log": ax.set_yscale("log") - ax.set_xlabel("Total Number of Samples") - ax.set_ylabel("Average Loss") - ax.set_title("Loss vs Samples") + ax.set_xlabel("Total Number of Samples", fontsize=16) + ax.set_ylabel("Average Loss", fontsize=16) + ax.set_title("Loss vs Samples", fontsize=16) + ax.tick_params(axis='both', which='major', labelsize=14) ax.grid(True, alpha=0.3) ax.legend(title="Node Count") @@ -263,15 +380,16 @@ def build_parser() -> argparse.ArgumentParser: standard.add_argument("--output", type=Path, default=None, help="Output image path") standard.add_argument("--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale") standard.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") - standard.add_argument("--y-metric", choices=["time", "normalized_throughput"], default="normalized_throughput", help="Y-axis metric: 'time' for time-to-solution or 'normalized_throughput' for T1/T") + standard.add_argument("--y-metric", choices=["time", "normalized_throughput", "efficiency"], default="normalized_throughput", help="Y-axis metric: 'time' for time-to-solution, 'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency") + standard.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") - # Subparser for loss-only plots (separate entry point) loss_only = subparsers.add_parser("loss", help="Plot loss metrics vs num_nodes (separate from throughput)") loss_only.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") loss_only.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") loss_only.add_argument("--output", type=Path, default=None, help="Output image path") loss_only.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") loss_only.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") + loss_only.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") detailed.add_argument("--input", type=Path, default=Path("scaling_data_detailed.parquet"), help="Input detailed parquet/ndjson file") @@ -304,7 +422,9 @@ def main() -> None: print(f"Loaded {len(df)} rows") # Standard mode: only plot training_time with normalized throughput or time metrics_to_plot = ["training_time"] - plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric) + plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric, args.show_run_ids) + # Generate scaling table + generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids) return if args.mode == "loss": @@ -325,7 +445,9 @@ def main() -> None: print(f"Loaded {len(df)} rows") # Loss mode: only plot loss_avg_mean metrics_to_plot = ["loss_avg_mean"] - plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, "time") + plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, "time", args.show_run_ids) + # Generate scaling table + generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids) return if args.mode == "detailed": From 03962903fc64232537d9a527e1c3e81c5567104e Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 10:05:17 +0200 Subject: [PATCH 52/76] Update to generating scaling plots --- scripts/generate_scaling_plots.py | 379 +++++++++++++++++++++++++++--- 1 file changed, 340 insertions(+), 39 deletions(-) diff --git a/scripts/generate_scaling_plots.py b/scripts/generate_scaling_plots.py index 35b5f3346..89d1f1a39 100644 --- a/scripts/generate_scaling_plots.py +++ b/scripts/generate_scaling_plots.py @@ -4,6 +4,25 @@ Two entrypoints: - standard: plots run-level metrics vs num_nodes - detailed: plots sample-level metrics vs total_num_samples +- combined: generates a comparison table from separate strong and weak scaling input files + +Usage: + # Single scaling type (original behavior) + python generate_scaling_plots.py standard --type strong --input strong_data.parquet + + # Combined table from single file with both types + python generate_scaling_plots.py standard --type strong,weak --input data.parquet + + # Combined table from separate strong and weak input files (new) + python generate_scaling_plots.py combined \ + --strong-input strong_data.parquet \ + --weak-input weak_data.parquet + + # Loss plot + python generate_scaling_plots.py loss --type strong --input data.parquet + + # Detailed scaling plot + python generate_scaling_plots.py detailed --input detailed_data.parquet """ import argparse @@ -65,10 +84,11 @@ def save_figure(fig: plt.Figure, output_path: Path) -> None: print(f"Saved: {output_path}") -def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: bool = False) -> None: +def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: bool = False, scaling_types: list[str] = None) -> None: """Generate a PNG table image with scaling metrics from the parquet file. Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) + If scaling_types has multiple types, generates a combined table with columns per type. """ # Check if required columns exist if "num_nodes" not in df.columns or "training_time" not in df.columns: @@ -92,51 +112,96 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo t1 = one_node_data["training_time"].item() - # Derive scaling type from input filename - input_name_lower = input_path.name.lower() - if "weak" in input_name_lower: - scaling_type = "Weak" - elif "strong" in input_name_lower: - scaling_type = "Strong" - else: - scaling_type = "Strong" # Default to strong + # Determine scaling types to include + if scaling_types is None or len(scaling_types) == 0: + # Derive scaling type from input filename + input_name_lower = input_path.name.lower() + if "weak" in input_name_lower: + scaling_types = ["weak"] + elif "strong" in input_name_lower: + scaling_types = ["strong"] + else: + scaling_types = ["strong"] # Default to strong # Build table data with proper formatting has_run_id = "run_id" in df_filtered.columns - col_names = ["# Nodes", "Training Time (seconds)", "Ideal Time (seconds)", "Efficiency"] - if show_run_ids and has_run_id: - col_names.insert(0, "run_id") + + # Check if we're generating a combined table (multiple types) + is_combined = len(scaling_types) > 1 + + if is_combined: + # Combined table: columns per type + col_names = ["# Nodes"] + for stype in scaling_types: + col_names.extend([ + f"{stype.capitalize()} Training Time (seconds)", + f"{stype.capitalize()} Efficiency" + ]) + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + else: + # Single type table (original format) + scaling_type = scaling_types[0].capitalize() + col_names = ["# Nodes", "Training Time (seconds)", "Ideal Time (seconds)", "Efficiency"] + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") table_data = [] for row in df_filtered.iter_rows(named=True): num_nodes = row["num_nodes"] training_time = row["training_time"] - if num_nodes == 1: - ideal_time = "-" - efficiency = "-" + row_data = [] + if show_run_ids and has_run_id: + row_data.append(str(row.get("run_id", ""))) + + if is_combined: + # Combined table: add metrics for each type + row_data.append(str(num_nodes)) + for stype in scaling_types: + if num_nodes == 1: + efficiency = "-" + else: + if stype == "strong": + # Strong scaling: ideal time = t1 / num_nodes + ideal_val = t1 / num_nodes + efficiency_val = ideal_val / training_time + else: + # Weak scaling: ideal time = t1 (same work per node) + ideal_val = t1 + efficiency_val = min(1.0, t1 / training_time) + + efficiency = f"{efficiency_val:.2f}" + + row_data.extend([ + f"{training_time:.2f}", + efficiency + ]) else: - if scaling_type == "Strong": - # Strong scaling: ideal time = t1 / num_nodes - ideal_val = t1 / num_nodes - efficiency_val = ideal_val / training_time + # Single type table (original format) + scaling_type = scaling_types[0] + if num_nodes == 1: + ideal_time = "-" + efficiency = "-" else: - # Weak scaling: ideal time = t1 (same work per node) - ideal_val = t1 - efficiency_val = min(1.0, t1 / training_time) + if scaling_type == "strong": + # Strong scaling: ideal time = t1 / num_nodes + ideal_val = t1 / num_nodes + efficiency_val = ideal_val / training_time + else: + # Weak scaling: ideal time = t1 (same work per node) + ideal_val = t1 + efficiency_val = min(1.0, t1 / training_time) + + ideal_time = f"{ideal_val:.2f}" + efficiency = f"{efficiency_val:.2f}" - ideal_time = f"{ideal_val:.2f}" - efficiency = f"{efficiency_val:.2f}" + row_data.extend([ + f"{training_time:.2f}", + ideal_time, + efficiency + ]) - row_data = [] - if show_run_ids and has_run_id: - row_data.append(str(row.get("run_id", ""))) - row_data.extend([ - str(num_nodes), - f"{training_time:.2f}", - ideal_time, - efficiency - ]) table_data.append(row_data) # Generate output filename: input_stem_table.csv @@ -154,6 +219,171 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo print(f"Saved scaling table: {output_path}") +def generate_combined_scaling_table( + strong_df: pl.DataFrame, + weak_df: pl.DataFrame, + strong_path: Path, + weak_path: Path, + output_path: Path, + show_run_ids: bool = False +) -> None: + """Generate a combined table comparing strong and weak scaling from two separate input files. + + Rows: num_nodes + Columns: # Nodes, Strong Training Time, Strong Efficiency, Weak Training Time, Weak Efficiency + + Also generates a PNG visualization of the table. + """ + # Validate required columns + for name, df in [("strong", strong_df), ("weak", weak_df)]: + if "num_nodes" not in df.columns or "training_time" not in df.columns: + print(f"Warning: Required columns (num_nodes, training_time) not found in {name} data") + return + + # Filter and sort both datasets + strong_filtered = strong_df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + weak_filtered = weak_df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + if len(strong_filtered) == 0 or len(weak_filtered) == 0: + print("No valid data for combined scaling table") + return + + # Get 1-node training times for efficiency calculation + strong_one_node = strong_filtered.filter(pl.col("num_nodes") == 1) + weak_one_node = weak_filtered.filter(pl.col("num_nodes") == 1) + + if strong_one_node.height == 0 or weak_one_node.height == 0: + print("Warning: No 1-node data found for efficiency calculation") + return + + t1_strong = strong_one_node["training_time"].item() + t1_weak = weak_one_node["training_time"].item() + + # Check for run_id in either dataset + has_run_id = "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns + + # Build column names + col_names = ["# Nodes", "Strong Training Time (seconds)", "Strong Efficiency", + "Weak Training Time (seconds)", "Weak Efficiency"] + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + + # Get all unique num_nodes from both datasets + all_nodes = sorted(set(strong_filtered["num_nodes"].to_list()) | set(weak_filtered["num_nodes"].to_list())) + + # Create lookup dictionaries for easy access + strong_lookup = {row["num_nodes"]: row["training_time"] for row in strong_filtered.iter_rows(named=True)} + weak_lookup = {row["num_nodes"]: row["training_time"] for row in weak_filtered.iter_rows(named=True)} + strong_run_id_lookup = {row["num_nodes"]: row["run_id"] for row in strong_filtered.iter_rows(named=True)} if "run_id" in strong_filtered.columns else {} + weak_run_id_lookup = {row["num_nodes"]: row["run_id"] for row in weak_filtered.iter_rows(named=True)} if "run_id" in weak_filtered.columns else {} + + table_data = [] + for num_nodes in all_nodes: + row_data = [] + + # Get run_id if available + if show_run_ids and has_run_id: + run_id = str(strong_run_id_lookup.get(num_nodes, weak_run_id_lookup.get(num_nodes, ""))) + row_data.append(run_id) + + # Add num_nodes + row_data.append(str(num_nodes)) + + # Strong scaling metrics + if num_nodes in strong_lookup: + training_time_strong = strong_lookup[num_nodes] + if num_nodes == 1: + efficiency_strong = "-" + else: + ideal_strong = t1_strong / num_nodes + efficiency_strong = f"{ideal_strong / training_time_strong:.2f}" + row_data.extend([f"{training_time_strong:.2f}", efficiency_strong]) + else: + row_data.extend(["-", "-"]) + + # Weak scaling metrics + if num_nodes in weak_lookup: + training_time_weak = weak_lookup[num_nodes] + if num_nodes == 1: + efficiency_weak = "-" + else: + ideal_weak = t1_weak # Weak scaling: ideal is same as 1-node time + efficiency_weak = f"{min(1.0, ideal_weak / training_time_weak):.2f}" + row_data.extend([f"{training_time_weak:.2f}", efficiency_weak]) + else: + row_data.extend(["-", "-"]) + + table_data.append(row_data) + + # Ensure output path has .csv suffix + if output_path.suffix.lower() != ".csv": + output_path = output_path.with_suffix(".csv") + + # Build DataFrame for CSV output + df_table_data = {} + for i, col in enumerate(col_names): + df_table_data[col] = [row[i] for row in table_data] + + df_table = pl.DataFrame(df_table_data) + + # Write to CSV + df_table.write_csv(output_path) + print(f"Saved scaling table CSV: {output_path}") + + # Generate PNG visualization of the table + png_path = output_path.with_suffix(".png") + _save_table_as_image(table_data, col_names, png_path) + print(f"Saved scaling table PNG: {png_path}") + + +def _save_table_as_image(table_data: list, col_names: list, output_path: Path) -> None: + """Save table data as a PNG image using matplotlib. + + Automatically sizes the figure to fit all content. + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Calculate figure size based on content + num_cols = len(col_names) + num_rows = len(table_data) + 1 # +1 for header + + # Width: base + per-column width, Height: base + per-row height + fig_width = max(8, num_cols * 2.5) + fig_height = max(3, num_rows * 0.5) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + ax.axis('off') + + # Create table + table = ax.table( + cellText=table_data, + colLabels=col_names, + cellLoc='center', + loc='center', + colColours=['#2E5C8A'] * num_cols, + cellColours=[['#E8ECEF' if i % 2 == 0 else 'white' for _ in range(num_cols)] for i in range(len(table_data))] + ) + + # Style the table + table.auto_set_font_size(False) + table.set_fontsize(9) + table.auto_set_column_width(col=list(range(num_cols))) + + # Style header cells + for i in range(num_cols): + table[(0, i)].set_text_props(color='white', fontweight='bold') + + # Adjust layout and save + plt.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close(fig) + + def plot_standard_scaling( df: pl.DataFrame, output_path: Path, @@ -375,7 +605,7 @@ def build_parser() -> argparse.ArgumentParser: subparsers = parser.add_subparsers(dest="mode", required=True) standard = subparsers.add_parser("standard", help="Plot run-level scaling metrics vs num_nodes") - standard.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") + standard.add_argument("--type", required=True, help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table") standard.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") standard.add_argument("--output", type=Path, default=None, help="Output image path") standard.add_argument("--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale") @@ -384,13 +614,19 @@ def build_parser() -> argparse.ArgumentParser: standard.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") loss_only = subparsers.add_parser("loss", help="Plot loss metrics vs num_nodes (separate from throughput)") - loss_only.add_argument("--type", required=True, choices=["strong", "weak"], help="Scaling type") + loss_only.add_argument("--type", required=True, help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table") loss_only.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") loss_only.add_argument("--output", type=Path, default=None, help="Output image path") loss_only.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") loss_only.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") loss_only.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") + combined = subparsers.add_parser("combined", help="Generate combined table comparing strong and weak scaling from separate input files") + combined.add_argument("--strong-input", type=Path, required=True, help="Input parquet/ndjson file for strong scaling") + combined.add_argument("--weak-input", type=Path, required=True, help="Input parquet/ndjson file for weak scaling") + combined.add_argument("--output", type=Path, default=None, help="Output table path (CSV)") + combined.add_argument("--show-run-ids", action="store_true", help="Show run_id labels in the output table") + detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") detailed.add_argument("--input", type=Path, default=Path("scaling_data_detailed.parquet"), help="Input detailed parquet/ndjson file") detailed.add_argument("--output", type=Path, default=None, help="Output image path") @@ -420,11 +656,21 @@ def main() -> None: print(str(e)) return print(f"Loaded {len(df)} rows") + + # Parse scaling types from --type argument + scaling_types = [t.strip().lower() for t in args.type.split(",")] + for stype in scaling_types: + if stype not in ("strong", "weak"): + print(f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'") + return + # Standard mode: only plot training_time with normalized throughput or time metrics_to_plot = ["training_time"] - plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric, args.show_run_ids) + # Use the first type for plotting (or strong if combined) + plot_type = scaling_types[0] + plot_standard_scaling(df, output_path, plot_type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric, args.show_run_ids) # Generate scaling table - generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids) + generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types) return if args.mode == "loss": @@ -443,11 +689,21 @@ def main() -> None: print(str(e)) return print(f"Loaded {len(df)} rows") + + # Parse scaling types from --type argument + scaling_types = [t.strip().lower() for t in args.type.split(",")] + for stype in scaling_types: + if stype not in ("strong", "weak"): + print(f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'") + return + # Loss mode: only plot loss_avg_mean metrics_to_plot = ["loss_avg_mean"] - plot_standard_scaling(df, output_path, args.type, metrics_to_plot, args.x_scale, args.y_scale, "time", args.show_run_ids) + # Use the first type for plotting (or strong if combined) + plot_type = scaling_types[0] + plot_standard_scaling(df, output_path, plot_type, metrics_to_plot, args.x_scale, args.y_scale, "time", args.show_run_ids) # Generate scaling table - generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids) + generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types) return if args.mode == "detailed": @@ -469,6 +725,51 @@ def main() -> None: plot_detailed_scaling(df, output_path, args.x_scale, args.y_scale) return + if args.mode == "combined": + strong_path = resolve_input_path(args.strong_input) + weak_path = resolve_input_path(args.weak_input) + + if not strong_path.exists(): + print(f"Error: Strong scaling input file not found: {strong_path}") + return + if not weak_path.exists(): + print(f"Error: Weak scaling input file not found: {weak_path}") + return + + # Determine output path + if args.output: + output_path = args.output + if output_path.suffix.lower() not in VALID_IMAGE_SUFFIXES: + output_path = output_path.with_suffix(".csv") + else: + # Default output: strong_input_stem_combined_table.csv + output_path = strong_path.with_name(strong_path.stem + "_combined_table.csv") + + print(f"Loading strong scaling data from: {strong_path}") + try: + strong_df = read_table(strong_path) + except Exception as e: + print("Error: Could not read strong scaling input file") + print(str(e)) + return + print(f"Loaded {len(strong_df)} strong scaling rows") + + print(f"Loading weak scaling data from: {weak_path}") + try: + weak_df = read_table(weak_path) + except Exception as e: + print("Error: Could not read weak scaling input file") + print(str(e)) + return + print(f"Loaded {len(weak_df)} weak scaling rows") + + # Generate combined table + generate_combined_scaling_table( + strong_df, weak_df, strong_path, weak_path, output_path, + show_run_ids=args.show_run_ids + ) + return + raise ValueError(f"Unknown mode: {args.mode}") From 680026252c1623c7d6ede8ae92cc5fd94432dd6a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 10:52:52 +0200 Subject: [PATCH 53/76] Move scaling scripts to package --- packages/performance/README.md | 43 ++ packages/performance/__init__.py | 6 + packages/performance/pyproject.toml | 46 ++ .../performance/src/performance/__init__.py | 3 + .../src/performance}/extract_scaling_data.py | 79 +++- .../performance}/generate_scaling_plots.py | 396 ++++++++++++------ pyproject.toml | 8 +- 7 files changed, 423 insertions(+), 158 deletions(-) create mode 100644 packages/performance/README.md create mode 100644 packages/performance/__init__.py create mode 100644 packages/performance/pyproject.toml create mode 100644 packages/performance/src/performance/__init__.py rename {scripts => packages/performance/src/performance}/extract_scaling_data.py (86%) rename {scripts => packages/performance/src/performance}/generate_scaling_plots.py (76%) diff --git a/packages/performance/README.md b/packages/performance/README.md new file mode 100644 index 000000000..99dabfe4a --- /dev/null +++ b/packages/performance/README.md @@ -0,0 +1,43 @@ +# WeatherGenerator Performance Analysis Tools + +This package contains tools for extracting and analyzing scaling performance data from WeatherGenerator training runs. + +## Scripts + +### extract_scaling_data.py + +Extracts strong scaling metrics from WeatherGenerator training runs. + +```bash +extract_scaling_data --logs-dir /path/to/logs --work-dir /path/to/work +``` + +### generate_scaling_plots.py + +Generates scaling plots and tables from parquet/NDJSON data. + +```bash +# Standard mode (single type) +generate_scaling_plots standard --type strong --input data.parquet + +# Combined mode (separate files) +generate_scaling_plots combined \ + --strong-input strong.parquet \ + --weak-input weak.parquet + +# Combined mode (single file with both types) +generate_scaling_plots standard --type strong,weak --input data.parquet +``` + +## Installation + +This package is part of the WeatherGenerator workspace. To install: + +```bash +# In the root WeatherGenerator directory +uv sync --extra performance +``` + +The scripts will be available as console scripts: +- `extract_scaling_data` +- `generate_scaling_plots` diff --git a/packages/performance/__init__.py b/packages/performance/__init__.py new file mode 100644 index 000000000..98802f911 --- /dev/null +++ b/packages/performance/__init__.py @@ -0,0 +1,6 @@ +"""Performance analysis tools for WeatherGenerator. + +This package contains tools for extracting and analyzing scaling performance data. +""" + +__version__ = "0.1.0" diff --git a/packages/performance/pyproject.toml b/packages/performance/pyproject.toml new file mode 100644 index 000000000..2838552cd --- /dev/null +++ b/packages/performance/pyproject.toml @@ -0,0 +1,46 @@ +[project] +name = "weathergen-performance" +version = "0.1.0" +description = "Performance analysis tools for WeatherGenerator" +readme = "README.md" +authors = [ + { name = "WeatherGenerator collaboration" } +] + +requires-python = ">=3.12,<3.13" +dependencies = [ + "polars~=1.25.2", + "pandas~=2.2", + "matplotlib", + "pyarrow>=23.0.1", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/performance"] + +[project.scripts] +extract_scaling_data = "performance.extract_scaling_data:main" +generate_scaling_plots = "performance.generate_scaling_plots:main" + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", + "F", + "UP", + "B", + "SIM", + "I", +] +ignore = [ + "SIM108", + "N817", + "E731", + "N812", +] diff --git a/packages/performance/src/performance/__init__.py b/packages/performance/src/performance/__init__.py new file mode 100644 index 000000000..fd2921767 --- /dev/null +++ b/packages/performance/src/performance/__init__.py @@ -0,0 +1,3 @@ +"""WeatherGenerator performance analysis package.""" + +__all__ = ["extract_scaling_data", "generate_scaling_plots"] diff --git a/scripts/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py similarity index 86% rename from scripts/extract_scaling_data.py rename to packages/performance/src/performance/extract_scaling_data.py index 42e69b092..5994c6a62 100644 --- a/scripts/extract_scaling_data.py +++ b/packages/performance/src/performance/extract_scaling_data.py @@ -11,19 +11,19 @@ def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: """Extract num_nodes from output.*.txt file in the run directory. - + Looks for 'nNodes' pattern in output files. """ run_log_dir = logs_base_dir / run_id if not run_log_dir.exists(): return None - + # Look for output.*.txt files output_files = list(run_log_dir.glob("output.*.txt")) if not output_files: # Fallback to err files if no output files found output_files = list(run_log_dir.glob("weathergen.*.err")) - + for output_file in output_files: try: content = output_file.read_text() @@ -33,7 +33,7 @@ def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | Non return int(match.group(1)) except Exception: continue - + return None @@ -79,7 +79,9 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No return None -def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int | None = None) -> list[pd.DataFrame]: +def extract_detailed_metrics( + run_id: str, shared_work_dir: Path, num_nodes: int | None = None +) -> list[pd.DataFrame]: """Extract detailed metrics pairing timing rows with preceding loss rows. For each row containing elapsed_training_time_seconds, pair it with the @@ -107,7 +109,11 @@ def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int return [] loss_indices = set(df.index[df["loss_avg_mean"].notna()].tolist()) - timing_cols = ["elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second"] + timing_cols = [ + "elapsed_training_time_seconds", + "total_num_samples", + "average_samples_per_second", + ] detailed_records = [] @@ -137,6 +143,7 @@ def extract_detailed_metrics(run_id: str, shared_work_dir: Path, num_nodes: int except Exception as e: print(f"Error extracting detailed metrics for {run_id}: {e}") import traceback + traceback.print_exc() return [] @@ -145,7 +152,7 @@ def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: """Parse run-ids argument which can be: 1. A list of run-ids (old format): ["run1", "run2"] -> [(None, "run1"), (None, "run2")] 2. A dict mapping num_nodes to run-ids (new format): "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] - + Returns list of (num_nodes, run_id) tuples. """ if len(run_ids_str) == 1: @@ -154,6 +161,7 @@ def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: if stripped.startswith("{") and stripped.endswith("}"): # Parse as dict format: {num_nodes: run_id, ...} import ast + try: parsed = ast.literal_eval(stripped) if isinstance(parsed, dict): @@ -165,30 +173,30 @@ def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: return result except (ValueError, SyntaxError): pass - + # Single run-id or comma-separated list run_ids = [r.strip() for r in run_ids_str[0].split(",") if r.strip()] return [(None, run_id) for run_id in run_ids] - + # Multiple arguments - treat as list of run-ids return [(None, run_id) for run_id in run_ids_str] def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: """Extract num_nodes from output.*.txt file in the run directory. - + Looks for 'nNodes' pattern in output files. """ run_log_dir = logs_base_dir / run_id if not run_log_dir.exists(): return None - + # Look for output.*.txt files output_files = list(run_log_dir.glob("output.*.txt")) if not output_files: # Fallback to err files if no output files found output_files = list(run_log_dir.glob("weathergen.*.err")) - + for output_file in output_files: try: content = output_file.read_text() @@ -198,7 +206,7 @@ def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | Non return int(match.group(1)) except Exception: continue - + return None @@ -208,10 +216,26 @@ def main(): "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict mapping num_nodes to run-ids " "(--run-ids '{1: run1, 4: run2}'). If num_nodes is not provided in the dict, it will be extracted from output.*.txt files." ) - parser.add_argument("--run-ids", nargs="+", help="Run-ids to process. Can be: (1) list: run1 run2 run3, or (2) dict: '{1: run1, 4: run2, 8: run3}'") - parser.add_argument("--logs-base-dir", type=Path, default=Path("logs"), help="Base directory for run logs (default: logs relative to current dir)") - parser.add_argument("--shared-work-dir", type=Path, default=Path("/e/scratch/weatherai/shared_work"), help="Base directory for shared work/results") - parser.add_argument("--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path") + parser.add_argument( + "--run-ids", + nargs="+", + help="Run-ids to process. Can be: (1) list: run1 run2 run3, or (2) dict: '{1: run1, 4: run2, 8: run3}'", + ) + parser.add_argument( + "--logs-base-dir", + type=Path, + default=Path("logs"), + help="Base directory for run logs (default: logs relative to current dir)", + ) + parser.add_argument( + "--shared-work-dir", + type=Path, + default=Path("/e/scratch/weatherai/shared_work"), + help="Base directory for shared work/results", + ) + parser.add_argument( + "--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path" + ) args = parser.parse_args() @@ -226,7 +250,7 @@ def main(): # If num_nodes not provided, extract from output.*.txt file if num_nodes is None: num_nodes = extract_num_nodes_from_output(run_id, args.logs_base_dir) - + metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) if metrics is None: continue @@ -243,7 +267,9 @@ def main(): detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) if detailed_records: all_detailed_records.extend(detailed_records) - print(f"Extracted {len(detailed_records)} detailed metric entries for {run_id} ({num_nodes} nodes)") + print( + f"Extracted {len(detailed_records)} detailed metric entries for {run_id} ({num_nodes} nodes)" + ) if not results: sys.exit("No data extracted") @@ -260,7 +286,14 @@ def main(): if all_detailed_records: detailed_df = pd.concat(all_detailed_records, ignore_index=True) - desired_cols = ["run_id", "num_nodes", "elapsed_training_time_seconds", "total_num_samples", "average_samples_per_second", "loss_avg_mean"] + desired_cols = [ + "run_id", + "num_nodes", + "elapsed_training_time_seconds", + "total_num_samples", + "average_samples_per_second", + "loss_avg_mean", + ] available_cols = [c for c in desired_cols if c in detailed_df.columns] detailed_df = detailed_df[available_cols] @@ -270,10 +303,12 @@ def main(): detailed_df.to_parquet(detailed_output, index=False) detailed_df.to_csv(detailed_output.with_suffix(".csv"), index=False) - print(f"\nSummary:") + print("\nSummary:") print(f" - Extracted {len(results)} run summaries to {args.output}") if all_detailed_records: - print(f" - Extracted {len(all_detailed_records)} detailed metric entries to {detailed_output}") + print( + f" - Extracted {len(all_detailed_records)} detailed metric entries to {detailed_output}" + ) if __name__ == "__main__": diff --git a/scripts/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py similarity index 76% rename from scripts/generate_scaling_plots.py rename to packages/performance/src/performance/generate_scaling_plots.py index 89d1f1a39..4bad53e3a 100644 --- a/scripts/generate_scaling_plots.py +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -1,28 +1,28 @@ #!/usr/bin/env uv run python """Generate scaling plots from parquet/ndjson data using matplotlib only. -Two entrypoints: +Entry points: - standard: plots run-level metrics vs num_nodes - detailed: plots sample-level metrics vs total_num_samples - combined: generates a comparison table from separate strong and weak scaling input files Usage: # Single scaling type (original behavior) - python generate_scaling_plots.py standard --type strong --input strong_data.parquet + python -m performance.generate_scaling_plots standard --type strong --input strong_data.parquet # Combined table from single file with both types - python generate_scaling_plots.py standard --type strong,weak --input data.parquet + python -m performance.generate_scaling_plots standard --type strong,weak --input data.parquet # Combined table from separate strong and weak input files (new) - python generate_scaling_plots.py combined \ + python -m performance.generate_scaling_plots combined \ --strong-input strong_data.parquet \ --weak-input weak_data.parquet # Loss plot - python generate_scaling_plots.py loss --type strong --input data.parquet + python -m performance.generate_scaling_plots loss --type strong --input data.parquet # Detailed scaling plot - python generate_scaling_plots.py detailed --input detailed_data.parquet + python -m performance.generate_scaling_plots detailed --input detailed_data.parquet """ import argparse @@ -34,8 +34,16 @@ SCRIPT_DIR = Path(__file__).resolve().parent VALID_IMAGE_SUFFIXES = {".png", ".pdf", ".svg", ".jpg", ".jpeg"} PALETTE = [ - "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", - "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", ] @@ -84,9 +92,11 @@ def save_figure(fig: plt.Figure, output_path: Path) -> None: print(f"Saved: {output_path}") -def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: bool = False, scaling_types: list[str] = None) -> None: +def generate_scaling_table( + df: pl.DataFrame, input_path: Path, show_run_ids: bool = False, scaling_types: list[str] = None +) -> None: """Generate a PNG table image with scaling metrics from the parquet file. - + Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) If scaling_types has multiple types, generates a combined table with columns per type. """ @@ -94,24 +104,24 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo if "num_nodes" not in df.columns or "training_time" not in df.columns: print("Warning: Required columns (num_nodes, training_time) not found in data") return - + # Filter out rows with null values in required columns df_filtered = df.filter( pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() ).sort("num_nodes") - + if len(df_filtered) == 0: print("No valid data for scaling table") return - + # Get the 1-node training time for ideal time calculation one_node_data = df_filtered.filter(pl.col("num_nodes") == 1) if one_node_data.height == 0: print("Warning: No 1-node data found for ideal time calculation") return - + t1 = one_node_data["training_time"].item() - + # Determine scaling types to include if scaling_types is None or len(scaling_types) == 0: # Derive scaling type from input filename @@ -122,21 +132,23 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo scaling_types = ["strong"] else: scaling_types = ["strong"] # Default to strong - + # Build table data with proper formatting has_run_id = "run_id" in df_filtered.columns - + # Check if we're generating a combined table (multiple types) is_combined = len(scaling_types) > 1 - + if is_combined: # Combined table: columns per type col_names = ["# Nodes"] for stype in scaling_types: - col_names.extend([ - f"{stype.capitalize()} Training Time (seconds)", - f"{stype.capitalize()} Efficiency" - ]) + col_names.extend( + [ + f"{stype.capitalize()} Training Time (seconds)", + f"{stype.capitalize()} Efficiency", + ] + ) if show_run_ids and has_run_id: col_names.insert(0, "run_id") else: @@ -145,16 +157,16 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo col_names = ["# Nodes", "Training Time (seconds)", "Ideal Time (seconds)", "Efficiency"] if show_run_ids and has_run_id: col_names.insert(0, "run_id") - + table_data = [] for row in df_filtered.iter_rows(named=True): num_nodes = row["num_nodes"] training_time = row["training_time"] - + row_data = [] if show_run_ids and has_run_id: row_data.append(str(row.get("run_id", ""))) - + if is_combined: # Combined table: add metrics for each type row_data.append(str(num_nodes)) @@ -170,13 +182,10 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo # Weak scaling: ideal time = t1 (same work per node) ideal_val = t1 efficiency_val = min(1.0, t1 / training_time) - + efficiency = f"{efficiency_val:.2f}" - - row_data.extend([ - f"{training_time:.2f}", - efficiency - ]) + + row_data.extend([f"{training_time:.2f}", efficiency]) else: # Single type table (original format) scaling_type = scaling_types[0] @@ -192,28 +201,24 @@ def generate_scaling_table(df: pl.DataFrame, input_path: Path, show_run_ids: boo # Weak scaling: ideal time = t1 (same work per node) ideal_val = t1 efficiency_val = min(1.0, t1 / training_time) - + ideal_time = f"{ideal_val:.2f}" efficiency = f"{efficiency_val:.2f}" - - row_data.extend([ - f"{training_time:.2f}", - ideal_time, - efficiency - ]) - + + row_data.extend([f"{training_time:.2f}", ideal_time, efficiency]) + table_data.append(row_data) - + # Generate output filename: input_stem_table.csv output_path = input_path.with_name(input_path.stem + "_table.csv") - + # Build DataFrame for CSV output df_table_data = {} for i, col in enumerate(col_names): df_table_data[col] = [row[i] for row in table_data] - + df_table = pl.DataFrame(df_table_data) - + # Write to CSV df_table.write_csv(output_path) print(f"Saved scaling table: {output_path}") @@ -225,13 +230,13 @@ def generate_combined_scaling_table( strong_path: Path, weak_path: Path, output_path: Path, - show_run_ids: bool = False + show_run_ids: bool = False, ) -> None: """Generate a combined table comparing strong and weak scaling from two separate input files. - + Rows: num_nodes Columns: # Nodes, Strong Training Time, Strong Efficiency, Weak Training Time, Weak Efficiency - + Also generates a PNG visualization of the table. """ # Validate required columns @@ -239,61 +244,80 @@ def generate_combined_scaling_table( if "num_nodes" not in df.columns or "training_time" not in df.columns: print(f"Warning: Required columns (num_nodes, training_time) not found in {name} data") return - + # Filter and sort both datasets strong_filtered = strong_df.filter( pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() ).sort("num_nodes") - + weak_filtered = weak_df.filter( pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() ).sort("num_nodes") - + if len(strong_filtered) == 0 or len(weak_filtered) == 0: print("No valid data for combined scaling table") return - + # Get 1-node training times for efficiency calculation strong_one_node = strong_filtered.filter(pl.col("num_nodes") == 1) weak_one_node = weak_filtered.filter(pl.col("num_nodes") == 1) - + if strong_one_node.height == 0 or weak_one_node.height == 0: print("Warning: No 1-node data found for efficiency calculation") return - + t1_strong = strong_one_node["training_time"].item() t1_weak = weak_one_node["training_time"].item() - + # Check for run_id in either dataset has_run_id = "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns - + # Build column names - col_names = ["# Nodes", "Strong Training Time (seconds)", "Strong Efficiency", - "Weak Training Time (seconds)", "Weak Efficiency"] + col_names = [ + "# Nodes", + "Strong Training Time (seconds)", + "Strong Efficiency", + "Weak Training Time (seconds)", + "Weak Efficiency", + ] if show_run_ids and has_run_id: col_names.insert(0, "run_id") - + # Get all unique num_nodes from both datasets - all_nodes = sorted(set(strong_filtered["num_nodes"].to_list()) | set(weak_filtered["num_nodes"].to_list())) - + all_nodes = sorted( + set(strong_filtered["num_nodes"].to_list()) | set(weak_filtered["num_nodes"].to_list()) + ) + # Create lookup dictionaries for easy access - strong_lookup = {row["num_nodes"]: row["training_time"] for row in strong_filtered.iter_rows(named=True)} - weak_lookup = {row["num_nodes"]: row["training_time"] for row in weak_filtered.iter_rows(named=True)} - strong_run_id_lookup = {row["num_nodes"]: row["run_id"] for row in strong_filtered.iter_rows(named=True)} if "run_id" in strong_filtered.columns else {} - weak_run_id_lookup = {row["num_nodes"]: row["run_id"] for row in weak_filtered.iter_rows(named=True)} if "run_id" in weak_filtered.columns else {} - + strong_lookup = { + row["num_nodes"]: row["training_time"] for row in strong_filtered.iter_rows(named=True) + } + weak_lookup = { + row["num_nodes"]: row["training_time"] for row in weak_filtered.iter_rows(named=True) + } + strong_run_id_lookup = ( + {row["num_nodes"]: row["run_id"] for row in strong_filtered.iter_rows(named=True)} + if "run_id" in strong_filtered.columns + else {} + ) + weak_run_id_lookup = ( + {row["num_nodes"]: row["run_id"] for row in weak_filtered.iter_rows(named=True)} + if "run_id" in weak_filtered.columns + else {} + ) + table_data = [] for num_nodes in all_nodes: row_data = [] - + # Get run_id if available if show_run_ids and has_run_id: run_id = str(strong_run_id_lookup.get(num_nodes, weak_run_id_lookup.get(num_nodes, ""))) row_data.append(run_id) - + # Add num_nodes row_data.append(str(num_nodes)) - + # Strong scaling metrics if num_nodes in strong_lookup: training_time_strong = strong_lookup[num_nodes] @@ -305,7 +329,7 @@ def generate_combined_scaling_table( row_data.extend([f"{training_time_strong:.2f}", efficiency_strong]) else: row_data.extend(["-", "-"]) - + # Weak scaling metrics if num_nodes in weak_lookup: training_time_weak = weak_lookup[num_nodes] @@ -317,24 +341,24 @@ def generate_combined_scaling_table( row_data.extend([f"{training_time_weak:.2f}", efficiency_weak]) else: row_data.extend(["-", "-"]) - + table_data.append(row_data) - + # Ensure output path has .csv suffix if output_path.suffix.lower() != ".csv": output_path = output_path.with_suffix(".csv") - + # Build DataFrame for CSV output df_table_data = {} for i, col in enumerate(col_names): df_table_data[col] = [row[i] for row in table_data] - + df_table = pl.DataFrame(df_table_data) - + # Write to CSV df_table.write_csv(output_path) print(f"Saved scaling table CSV: {output_path}") - + # Generate PNG visualization of the table png_path = output_path.with_suffix(".png") _save_table_as_image(table_data, col_names, png_path) @@ -343,44 +367,47 @@ def generate_combined_scaling_table( def _save_table_as_image(table_data: list, col_names: list, output_path: Path) -> None: """Save table data as a PNG image using matplotlib. - + Automatically sizes the figure to fit all content. """ output_path.parent.mkdir(parents=True, exist_ok=True) - + # Calculate figure size based on content num_cols = len(col_names) num_rows = len(table_data) + 1 # +1 for header - + # Width: base + per-column width, Height: base + per-row height fig_width = max(8, num_cols * 2.5) fig_height = max(3, num_rows * 0.5) - + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - ax.axis('off') - + ax.axis("off") + # Create table table = ax.table( cellText=table_data, colLabels=col_names, - cellLoc='center', - loc='center', - colColours=['#2E5C8A'] * num_cols, - cellColours=[['#E8ECEF' if i % 2 == 0 else 'white' for _ in range(num_cols)] for i in range(len(table_data))] + cellLoc="center", + loc="center", + colColours=["#2E5C8A"] * num_cols, + cellColours=[ + ["#E8ECEF" if i % 2 == 0 else "white" for _ in range(num_cols)] + for i in range(len(table_data)) + ], ) - + # Style the table table.auto_set_font_size(False) table.set_fontsize(9) table.auto_set_column_width(col=list(range(num_cols))) - + # Style header cells for i in range(num_cols): - table[(0, i)].set_text_props(color='white', fontweight='bold') - + table[(0, i)].set_text_props(color="white", fontweight="bold") + # Adjust layout and save plt.tight_layout() - fig.savefig(output_path, dpi=150, bbox_inches='tight') + fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) @@ -402,17 +429,23 @@ def plot_standard_scaling( "efficiency": "Scaling Efficiency", } - valid_metrics = [m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0] + valid_metrics = [ + m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0 + ] if not valid_metrics: print("No valid metrics to plot") return - fig, axes = plt.subplots(len(valid_metrics), 1, figsize=(12, 6 * len(valid_metrics)), squeeze=False) + fig, axes = plt.subplots( + len(valid_metrics), 1, figsize=(12, 6 * len(valid_metrics)), squeeze=False + ) for idx, metric in enumerate(valid_metrics): ax = axes[idx][0] df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") - node_counts = df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] + node_counts = ( + df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] + ) colors = color_map_for_nodes(node_counts) # Handle normalized_throughput and efficiency metrics @@ -442,7 +475,9 @@ def plot_standard_scaling( else: # Weak scaling: efficiency = min(1.0, t1 / training_time) df_plot = df_plot.with_columns( - pl.min_horizontal(pl.lit(1.0), t1 / pl.col("training_time")).alias("efficiency") + pl.min_horizontal(pl.lit(1.0), t1 / pl.col("training_time")).alias( + "efficiency" + ) ) plot_y = df_plot["efficiency"] else: @@ -460,10 +495,14 @@ def plot_standard_scaling( ) if show_run_ids: - for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"]): + for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"], strict=False): ax.text(x, y, label, ha="center", va="bottom", fontsize=8) - if metric == "training_time" and y_metric in ("time", "normalized_throughput", "efficiency") and "training_time" in df.columns: + if ( + metric == "training_time" + and y_metric in ("time", "normalized_throughput", "efficiency") + and "training_time" in df.columns + ): one_node_data = df.filter(pl.col("num_nodes") == 1) if one_node_data.height > 0: t1 = one_node_data["training_time"].item() @@ -489,7 +528,7 @@ def plot_standard_scaling( # Show per-point efficiency loss as a vertical line and factor label. # Use plot_y (normalized throughput if applicable) instead of df_plot[metric] - for x, y, y_opt in zip(nodes, plot_y.to_list(), optimal_y): + for x, y, y_opt in zip(nodes, plot_y.to_list(), optimal_y, strict=False): if y_opt == 0: continue factor = y / y_opt @@ -517,7 +556,7 @@ def plot_standard_scaling( ax.set_ylabel("Scaling Efficiency", fontsize=16) else: ax.set_ylabel(metric_labels.get(metric, metric), fontsize=16) - ax.tick_params(axis='both', which='major', labelsize=14) + ax.tick_params(axis="both", which="major", labelsize=14) ax.grid(True, alpha=0.3) plt.tight_layout() @@ -531,7 +570,12 @@ def plot_detailed_scaling( y_scale: str, ) -> None: """Plot sample-level detailed scaling data vs total_num_samples.""" - required_cols = ["total_num_samples", "elapsed_training_time_seconds", "loss_avg_mean", "num_nodes"] + required_cols = [ + "total_num_samples", + "elapsed_training_time_seconds", + "loss_avg_mean", + "num_nodes", + ] if not all(col in df.columns for col in required_cols): print("Detailed metrics not available in this dataset") print(f"Available columns: {df.columns}") @@ -570,7 +614,7 @@ def plot_detailed_scaling( ax.set_yscale("log") ax.set_ylabel("Elapsed Training Time (seconds)", fontsize=16) ax.set_title("Elapsed Training Time vs Samples", fontsize=16) - ax.tick_params(axis='both', which='major', labelsize=14) + ax.tick_params(axis="both", which="major", labelsize=14) ax.grid(True, alpha=0.3) ax.legend(title="Node Count") @@ -591,7 +635,7 @@ def plot_detailed_scaling( ax.set_xlabel("Total Number of Samples", fontsize=16) ax.set_ylabel("Average Loss", fontsize=16) ax.set_title("Loss vs Samples", fontsize=16) - ax.tick_params(axis='both', which='major', labelsize=14) + ax.tick_params(axis="both", which="major", labelsize=14) ax.grid(True, alpha=0.3) ax.legend(title="Node Count") @@ -601,37 +645,95 @@ def plot_detailed_scaling( def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Generate scaling plots from parquet or NDJSON data") + parser = argparse.ArgumentParser( + description="Generate scaling plots from parquet or NDJSON data" + ) subparsers = parser.add_subparsers(dest="mode", required=True) standard = subparsers.add_parser("standard", help="Plot run-level scaling metrics vs num_nodes") - standard.add_argument("--type", required=True, help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table") - standard.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") + standard.add_argument( + "--type", + required=True, + help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", + ) + standard.add_argument( + "--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file" + ) standard.add_argument("--output", type=Path, default=None, help="Output image path") - standard.add_argument("--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale") - standard.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") - standard.add_argument("--y-metric", choices=["time", "normalized_throughput", "efficiency"], default="normalized_throughput", help="Y-axis metric: 'time' for time-to-solution, 'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency") - standard.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") - - loss_only = subparsers.add_parser("loss", help="Plot loss metrics vs num_nodes (separate from throughput)") - loss_only.add_argument("--type", required=True, help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table") - loss_only.add_argument("--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file") + standard.add_argument( + "--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale" + ) + standard.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) + standard.add_argument( + "--y-metric", + choices=["time", "normalized_throughput", "efficiency"], + default="normalized_throughput", + help="Y-axis metric: 'time' for time-to-solution, 'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency", + ) + standard.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels on the plot and in the output table", + ) + + loss_only = subparsers.add_parser( + "loss", help="Plot loss metrics vs num_nodes (separate from throughput)" + ) + loss_only.add_argument( + "--type", + required=True, + help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", + ) + loss_only.add_argument( + "--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file" + ) loss_only.add_argument("--output", type=Path, default=None, help="Output image path") - loss_only.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") - loss_only.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") - loss_only.add_argument("--show-run-ids", action="store_true", help="Show run_id labels on the plot and in the output table") + loss_only.add_argument( + "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" + ) + loss_only.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) + loss_only.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels on the plot and in the output table", + ) - combined = subparsers.add_parser("combined", help="Generate combined table comparing strong and weak scaling from separate input files") - combined.add_argument("--strong-input", type=Path, required=True, help="Input parquet/ndjson file for strong scaling") - combined.add_argument("--weak-input", type=Path, required=True, help="Input parquet/ndjson file for weak scaling") + combined = subparsers.add_parser( + "combined", + help="Generate combined table comparing strong and weak scaling from separate input files", + ) + combined.add_argument( + "--strong-input", + type=Path, + required=True, + help="Input parquet/ndjson file for strong scaling", + ) + combined.add_argument( + "--weak-input", type=Path, required=True, help="Input parquet/ndjson file for weak scaling" + ) combined.add_argument("--output", type=Path, default=None, help="Output table path (CSV)") - combined.add_argument("--show-run-ids", action="store_true", help="Show run_id labels in the output table") + combined.add_argument( + "--show-run-ids", action="store_true", help="Show run_id labels in the output table" + ) detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") - detailed.add_argument("--input", type=Path, default=Path("scaling_data_detailed.parquet"), help="Input detailed parquet/ndjson file") + detailed.add_argument( + "--input", + type=Path, + default=Path("scaling_data_detailed.parquet"), + help="Input detailed parquet/ndjson file", + ) detailed.add_argument("--output", type=Path, default=None, help="Output image path") - detailed.add_argument("--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale") - detailed.add_argument("--x-scale", choices=["linear", "log"], default="log", help="X-axis scale") + detailed.add_argument( + "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" + ) + detailed.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) return parser @@ -656,21 +758,34 @@ def main() -> None: print(str(e)) return print(f"Loaded {len(df)} rows") - + # Parse scaling types from --type argument scaling_types = [t.strip().lower() for t in args.type.split(",")] for stype in scaling_types: if stype not in ("strong", "weak"): - print(f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'") + print( + f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'" + ) return - + # Standard mode: only plot training_time with normalized throughput or time metrics_to_plot = ["training_time"] # Use the first type for plotting (or strong if combined) plot_type = scaling_types[0] - plot_standard_scaling(df, output_path, plot_type, metrics_to_plot, args.x_scale, args.y_scale, args.y_metric, args.show_run_ids) + plot_standard_scaling( + df, + output_path, + plot_type, + metrics_to_plot, + args.x_scale, + args.y_scale, + args.y_metric, + args.show_run_ids, + ) # Generate scaling table - generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types) + generate_scaling_table( + df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types + ) return if args.mode == "loss": @@ -689,21 +804,34 @@ def main() -> None: print(str(e)) return print(f"Loaded {len(df)} rows") - + # Parse scaling types from --type argument scaling_types = [t.strip().lower() for t in args.type.split(",")] for stype in scaling_types: if stype not in ("strong", "weak"): - print(f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'") + print( + f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'" + ) return - + # Loss mode: only plot loss_avg_mean metrics_to_plot = ["loss_avg_mean"] # Use the first type for plotting (or strong if combined) plot_type = scaling_types[0] - plot_standard_scaling(df, output_path, plot_type, metrics_to_plot, args.x_scale, args.y_scale, "time", args.show_run_ids) + plot_standard_scaling( + df, + output_path, + plot_type, + metrics_to_plot, + args.x_scale, + args.y_scale, + "time", + args.show_run_ids, + ) # Generate scaling table - generate_scaling_table(df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types) + generate_scaling_table( + df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types + ) return if args.mode == "detailed": @@ -728,7 +856,7 @@ def main() -> None: if args.mode == "combined": strong_path = resolve_input_path(args.strong_input) weak_path = resolve_input_path(args.weak_input) - + if not strong_path.exists(): print(f"Error: Strong scaling input file not found: {strong_path}") return @@ -765,14 +893,12 @@ def main() -> None: # Generate combined table generate_combined_scaling_table( - strong_df, weak_df, strong_path, weak_path, output_path, - show_run_ids=args.show_run_ids + strong_df, weak_df, strong_path, weak_path, output_path, show_run_ids=args.show_run_ids ) return raise ValueError(f"Unknown mode: {args.mode}") - if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 429a6927f..513fb988b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "weathergen-common", "weathergen-evaluate", "weathergen-readers-extra", - "pyarrow>=23.0.1", + ] @@ -76,6 +76,10 @@ dev = [ # aarch64: gpu [project.optional-dependencies] +performance = [ + "weathergen-performance", +] + cpu = [ 'torch==2.6.0', ] @@ -229,6 +233,7 @@ weathergen-common = { workspace = true } weathergen-evaluate = { workspace = true } weathergen-metrics = { workspace = true } weathergen-readers-extra = { workspace = true } +weathergen-performance = { workspace = true } flash-attn = [ @@ -273,5 +278,6 @@ members = [ "packages/readers_extra", # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. + "packages/performance", ] From c5af276320063b90bbf4467ba4de0a5873f01963 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:03:33 +0200 Subject: [PATCH 54/76] init refactor --- packages/performance/__init__.py | 6 ------ packages/performance/src/performance/__init__.py | 4 +--- 2 files changed, 1 insertion(+), 9 deletions(-) delete mode 100644 packages/performance/__init__.py diff --git a/packages/performance/__init__.py b/packages/performance/__init__.py deleted file mode 100644 index 98802f911..000000000 --- a/packages/performance/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Performance analysis tools for WeatherGenerator. - -This package contains tools for extracting and analyzing scaling performance data. -""" - -__version__ = "0.1.0" diff --git a/packages/performance/src/performance/__init__.py b/packages/performance/src/performance/__init__.py index fd2921767..7105b9746 100644 --- a/packages/performance/src/performance/__init__.py +++ b/packages/performance/src/performance/__init__.py @@ -1,3 +1 @@ -"""WeatherGenerator performance analysis package.""" - -__all__ = ["extract_scaling_data", "generate_scaling_plots"] +"""WeatherGenerator performance analysis tools.""" From e760c13712363a1978f68e2d69707488818cec71 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:07:50 +0200 Subject: [PATCH 55/76] Setup and linting --- packages/performance/pyproject.toml | 13 +- .../src/performance/extract_scaling_data.py | 13 +- .../src/performance/generate_scaling_plots.py | 134 ++++++++++++++---- src/weathergen/train/trainer.py | 28 ++-- 4 files changed, 132 insertions(+), 56 deletions(-) diff --git a/packages/performance/pyproject.toml b/packages/performance/pyproject.toml index 2838552cd..933323e3f 100644 --- a/packages/performance/pyproject.toml +++ b/packages/performance/pyproject.toml @@ -3,9 +3,6 @@ name = "weathergen-performance" version = "0.1.0" description = "Performance analysis tools for WeatherGenerator" readme = "README.md" -authors = [ - { name = "WeatherGenerator collaboration" } -] requires-python = ">=3.12,<3.13" dependencies = [ @@ -26,9 +23,6 @@ packages = ["src/performance"] extract_scaling_data = "performance.extract_scaling_data:main" generate_scaling_plots = "performance.generate_scaling_plots:main" -[tool.ruff] -line-length = 100 - [tool.ruff.lint] select = [ "E", @@ -37,10 +31,5 @@ select = [ "B", "SIM", "I", -] -ignore = [ - "SIM108", - "N817", - "E731", - "N812", + "N" ] diff --git a/packages/performance/src/performance/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py index 5994c6a62..a7e74bf1f 100644 --- a/packages/performance/src/performance/extract_scaling_data.py +++ b/packages/performance/src/performance/extract_scaling_data.py @@ -234,7 +234,10 @@ def main(): help="Base directory for shared work/results", ) parser.add_argument( - "--output", type=Path, default=Path("scaling_data.parquet"), help="Output parquet file path" + "--output", + type=Path, + default=Path("scaling_data.parquet"), + help="Output parquet file path", ) args = parser.parse_args() @@ -264,7 +267,9 @@ def main(): } results.append(row) - detailed_records = extract_detailed_metrics(run_id, args.shared_work_dir, num_nodes) + detailed_records = extract_detailed_metrics( + run_id, args.shared_work_dir, num_nodes + ) if detailed_records: all_detailed_records.extend(detailed_records) print( @@ -298,7 +303,9 @@ def main(): detailed_df = detailed_df[available_cols] output_stem = args.output.stem - detailed_output = args.output.with_name(f"{output_stem}_detailed{args.output.suffix}") + detailed_output = args.output.with_name( + f"{output_stem}_detailed{args.output.suffix}" + ) detailed_df.to_parquet(detailed_output, index=False) detailed_df.to_csv(detailed_output.with_suffix(".csv"), index=False) diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py index 4bad53e3a..195a10dfe 100644 --- a/packages/performance/src/performance/generate_scaling_plots.py +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -93,7 +93,10 @@ def save_figure(fig: plt.Figure, output_path: Path) -> None: def generate_scaling_table( - df: pl.DataFrame, input_path: Path, show_run_ids: bool = False, scaling_types: list[str] = None + df: pl.DataFrame, + input_path: Path, + show_run_ids: bool = False, + scaling_types: list[str] = None, ) -> None: """Generate a PNG table image with scaling metrics from the parquet file. @@ -154,7 +157,12 @@ def generate_scaling_table( else: # Single type table (original format) scaling_type = scaling_types[0].capitalize() - col_names = ["# Nodes", "Training Time (seconds)", "Ideal Time (seconds)", "Efficiency"] + col_names = [ + "# Nodes", + "Training Time (seconds)", + "Ideal Time (seconds)", + "Efficiency", + ] if show_run_ids and has_run_id: col_names.insert(0, "run_id") @@ -242,7 +250,9 @@ def generate_combined_scaling_table( # Validate required columns for name, df in [("strong", strong_df), ("weak", weak_df)]: if "num_nodes" not in df.columns or "training_time" not in df.columns: - print(f"Warning: Required columns (num_nodes, training_time) not found in {name} data") + print( + f"Warning: Required columns (num_nodes, training_time) not found in {name} data" + ) return # Filter and sort both datasets @@ -270,7 +280,9 @@ def generate_combined_scaling_table( t1_weak = weak_one_node["training_time"].item() # Check for run_id in either dataset - has_run_id = "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns + has_run_id = ( + "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns + ) # Build column names col_names = [ @@ -285,18 +297,24 @@ def generate_combined_scaling_table( # Get all unique num_nodes from both datasets all_nodes = sorted( - set(strong_filtered["num_nodes"].to_list()) | set(weak_filtered["num_nodes"].to_list()) + set(strong_filtered["num_nodes"].to_list()) + | set(weak_filtered["num_nodes"].to_list()) ) # Create lookup dictionaries for easy access strong_lookup = { - row["num_nodes"]: row["training_time"] for row in strong_filtered.iter_rows(named=True) + row["num_nodes"]: row["training_time"] + for row in strong_filtered.iter_rows(named=True) } weak_lookup = { - row["num_nodes"]: row["training_time"] for row in weak_filtered.iter_rows(named=True) + row["num_nodes"]: row["training_time"] + for row in weak_filtered.iter_rows(named=True) } strong_run_id_lookup = ( - {row["num_nodes"]: row["run_id"] for row in strong_filtered.iter_rows(named=True)} + { + row["num_nodes"]: row["run_id"] + for row in strong_filtered.iter_rows(named=True) + } if "run_id" in strong_filtered.columns else {} ) @@ -312,7 +330,11 @@ def generate_combined_scaling_table( # Get run_id if available if show_run_ids and has_run_id: - run_id = str(strong_run_id_lookup.get(num_nodes, weak_run_id_lookup.get(num_nodes, ""))) + run_id = str( + strong_run_id_lookup.get( + num_nodes, weak_run_id_lookup.get(num_nodes, "") + ) + ) row_data.append(run_id) # Add num_nodes @@ -430,7 +452,9 @@ def plot_standard_scaling( } valid_metrics = [ - m for m in metrics if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0 + m + for m in metrics + if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0 ] if not valid_metrics: print("No valid metrics to plot") @@ -444,7 +468,9 @@ def plot_standard_scaling( ax = axes[idx][0] df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") node_counts = ( - df_plot["num_nodes"].unique().to_list() if "num_nodes" in df_plot.columns else [] + df_plot["num_nodes"].unique().to_list() + if "num_nodes" in df_plot.columns + else [] ) colors = color_map_for_nodes(node_counts) @@ -460,7 +486,9 @@ def plot_standard_scaling( ) plot_y = df_plot["normalized_throughput"] else: - print("Warning: No 1-node data found for normalized throughput calculation") + print( + "Warning: No 1-node data found for normalized throughput calculation" + ) continue elif y_metric == "efficiency" and metric == "training_time": # Calculate efficiency based on scaling type @@ -470,14 +498,16 @@ def plot_standard_scaling( if scaling_type == "strong": # Strong scaling: efficiency = (t1 / num_nodes) / training_time df_plot = df_plot.with_columns( - ((t1 / pl.col("num_nodes")) / pl.col("training_time")).alias("efficiency") + ((t1 / pl.col("num_nodes")) / pl.col("training_time")).alias( + "efficiency" + ) ) else: # Weak scaling: efficiency = min(1.0, t1 / training_time) df_plot = df_plot.with_columns( - pl.min_horizontal(pl.lit(1.0), t1 / pl.col("training_time")).alias( - "efficiency" - ) + pl.min_horizontal( + pl.lit(1.0), t1 / pl.col("training_time") + ).alias("efficiency") ) plot_y = df_plot["efficiency"] else: @@ -495,7 +525,9 @@ def plot_standard_scaling( ) if show_run_ids: - for x, y, label in zip(df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"], strict=False): + for x, y, label in zip( + df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"], strict=False + ): ax.text(x, y, label, ha="center", va="bottom", fontsize=8) if ( @@ -528,11 +560,21 @@ def plot_standard_scaling( # Show per-point efficiency loss as a vertical line and factor label. # Use plot_y (normalized throughput if applicable) instead of df_plot[metric] - for x, y, y_opt in zip(nodes, plot_y.to_list(), optimal_y, strict=False): + for x, y, y_opt in zip( + nodes, plot_y.to_list(), optimal_y, strict=False + ): if y_opt == 0: continue factor = y / y_opt - ax.vlines(x, y_opt, y, colors="gray", linestyles=":", linewidth=1, alpha=0.7) + ax.vlines( + x, + y_opt, + y, + colors="gray", + linestyles=":", + linewidth=1, + alpha=0.7, + ) y_mid = (y + y_opt) / 2 ax.annotate( f"{factor:.2f}", @@ -600,7 +642,9 @@ def plot_detailed_scaling( ax = axes[0] for node_count in node_counts: - df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort("total_num_samples") + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( + "total_num_samples" + ) ax.plot( df_node["total_num_samples"], df_node["elapsed_training_time_seconds"], @@ -620,7 +664,9 @@ def plot_detailed_scaling( ax = axes[1] for node_count in node_counts: - df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort("total_num_samples") + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( + "total_num_samples" + ) ax.plot( df_node["total_num_samples"], df_node["loss_avg_mean"], @@ -650,14 +696,19 @@ def build_parser() -> argparse.ArgumentParser: ) subparsers = parser.add_subparsers(dest="mode", required=True) - standard = subparsers.add_parser("standard", help="Plot run-level scaling metrics vs num_nodes") + standard = subparsers.add_parser( + "standard", help="Plot run-level scaling metrics vs num_nodes" + ) standard.add_argument( "--type", required=True, help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", ) standard.add_argument( - "--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file" + "--input", + type=Path, + default=Path("scaling_data.parquet"), + help="Input parquet/ndjson file", ) standard.add_argument("--output", type=Path, default=None, help="Output image path") standard.add_argument( @@ -687,9 +738,14 @@ def build_parser() -> argparse.ArgumentParser: help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", ) loss_only.add_argument( - "--input", type=Path, default=Path("scaling_data.parquet"), help="Input parquet/ndjson file" + "--input", + type=Path, + default=Path("scaling_data.parquet"), + help="Input parquet/ndjson file", + ) + loss_only.add_argument( + "--output", type=Path, default=None, help="Output image path" ) - loss_only.add_argument("--output", type=Path, default=None, help="Output image path") loss_only.add_argument( "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" ) @@ -713,14 +769,23 @@ def build_parser() -> argparse.ArgumentParser: help="Input parquet/ndjson file for strong scaling", ) combined.add_argument( - "--weak-input", type=Path, required=True, help="Input parquet/ndjson file for weak scaling" + "--weak-input", + type=Path, + required=True, + help="Input parquet/ndjson file for weak scaling", ) - combined.add_argument("--output", type=Path, default=None, help="Output table path (CSV)") combined.add_argument( - "--show-run-ids", action="store_true", help="Show run_id labels in the output table" + "--output", type=Path, default=None, help="Output table path (CSV)" + ) + combined.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels in the output table", ) - detailed = subparsers.add_parser("detailed", help="Plot sample-level detailed scaling metrics") + detailed = subparsers.add_parser( + "detailed", help="Plot sample-level detailed scaling metrics" + ) detailed.add_argument( "--input", type=Path, @@ -871,7 +936,9 @@ def main() -> None: output_path = output_path.with_suffix(".csv") else: # Default output: strong_input_stem_combined_table.csv - output_path = strong_path.with_name(strong_path.stem + "_combined_table.csv") + output_path = strong_path.with_name( + strong_path.stem + "_combined_table.csv" + ) print(f"Loading strong scaling data from: {strong_path}") try: @@ -893,7 +960,12 @@ def main() -> None: # Generate combined table generate_combined_scaling_table( - strong_df, weak_df, strong_path, weak_path, output_path, show_run_ids=args.show_run_ids + strong_df, + weak_df, + strong_path, + weak_path, + output_path, + show_run_ids=args.show_run_ids, ) return diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ff81e2f8e..02b661f0f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -241,8 +241,9 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") - def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: float | None = None): - + def run( + self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: float | None = None + ): # general initalization self.init(cf, devices) cf = self.cf @@ -566,11 +567,16 @@ def train(self, mini_epoch): torch.distributed.barrier() if is_root(): total_training_time = time.time() - self.t_training_start - self.train_logger.log_metrics("train", { - "completed_mini_epoch": mini_epoch, - "elapsed_time_mini_epoch": total_training_time, - }) - logger.info(f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds") + self.train_logger.log_metrics( + "train", + { + "completed_mini_epoch": mini_epoch, + "elapsed_time_mini_epoch": total_training_time, + }, + ) + logger.info( + f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds" + ) self.dataset.advance() @@ -775,12 +781,14 @@ def _log(self, stage: Stage): lr=self.lr_scheduler.get_lr(), ) self.train_logger.log_metrics( - "train", + "train", { "elapsed_training_time_seconds": elapsed_time, "total_num_samples": samples, - "average_samples_per_second": samples / elapsed_time if elapsed_time > 0 else 0, - } + "average_samples_per_second": samples / elapsed_time + if elapsed_time > 0 + else 0, + }, ) loss_calculator.loss_hist = [] From 556106ebb6e24d9820a0924189842f3e912bddc3 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:17:15 +0200 Subject: [PATCH 56/76] Updated plot generation script --- .../src/performance/generate_scaling_plots.py | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py index 195a10dfe..de42c97f6 100644 --- a/packages/performance/src/performance/generate_scaling_plots.py +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -166,21 +166,24 @@ def generate_scaling_table( if show_run_ids and has_run_id: col_names.insert(0, "run_id") - table_data = [] + table_rows = [] for row in df_filtered.iter_rows(named=True): num_nodes = row["num_nodes"] training_time = row["training_time"] - row_data = [] - if show_run_ids and has_run_id: - row_data.append(str(row.get("run_id", ""))) - if is_combined: # Combined table: add metrics for each type - row_data.append(str(num_nodes)) + row_data = {} + if show_run_ids and has_run_id: + row_data["run_id"] = str(row.get("run_id", "")) + row_data["# Nodes"] = str(num_nodes) + for stype in scaling_types: + time_col = f"{stype.capitalize()} Training Time (seconds)" + eff_col = f"{stype.capitalize()} Efficiency" + row_data[time_col] = f"{training_time:.2f}" if num_nodes == 1: - efficiency = "-" + row_data[eff_col] = "-" else: if stype == "strong": # Strong scaling: ideal time = t1 / num_nodes @@ -191,15 +194,18 @@ def generate_scaling_table( ideal_val = t1 efficiency_val = min(1.0, t1 / training_time) - efficiency = f"{efficiency_val:.2f}" - - row_data.extend([f"{training_time:.2f}", efficiency]) + row_data[eff_col] = f"{efficiency_val:.2f}" else: # Single type table (original format) scaling_type = scaling_types[0] + row_data = {} + if show_run_ids and has_run_id: + row_data["run_id"] = str(row.get("run_id", "")) + row_data["# Nodes"] = str(num_nodes) + row_data["Training Time (seconds)"] = f"{training_time:.2f}" if num_nodes == 1: - ideal_time = "-" - efficiency = "-" + row_data["Ideal Time (seconds)"] = "-" + row_data["Efficiency"] = "-" else: if scaling_type == "strong": # Strong scaling: ideal time = t1 / num_nodes @@ -210,22 +216,18 @@ def generate_scaling_table( ideal_val = t1 efficiency_val = min(1.0, t1 / training_time) - ideal_time = f"{ideal_val:.2f}" - efficiency = f"{efficiency_val:.2f}" - - row_data.extend([f"{training_time:.2f}", ideal_time, efficiency]) + row_data["Ideal Time (seconds)"] = f"{ideal_val:.2f}" + row_data["Efficiency"] = f"{efficiency_val:.2f}" - table_data.append(row_data) + table_rows.append(row_data) # Generate output filename: input_stem_table.csv output_path = input_path.with_name(input_path.stem + "_table.csv") - # Build DataFrame for CSV output - df_table_data = {} - for i, col in enumerate(col_names): - df_table_data[col] = [row[i] for row in table_data] - - df_table = pl.DataFrame(df_table_data) + # Build DataFrame for CSV output from named columns. + df_table = pl.DataFrame( + [{col: row.get(col, "-") for col in col_names} for row in table_rows] + ) # Write to CSV df_table.write_csv(output_path) @@ -324,9 +326,9 @@ def generate_combined_scaling_table( else {} ) - table_data = [] + table_rows = [] for num_nodes in all_nodes: - row_data = [] + row_data = {} # Get run_id if available if show_run_ids and has_run_id: @@ -335,47 +337,47 @@ def generate_combined_scaling_table( num_nodes, weak_run_id_lookup.get(num_nodes, "") ) ) - row_data.append(run_id) + row_data["run_id"] = run_id # Add num_nodes - row_data.append(str(num_nodes)) + row_data["# Nodes"] = str(num_nodes) # Strong scaling metrics if num_nodes in strong_lookup: training_time_strong = strong_lookup[num_nodes] + row_data["Strong Training Time (seconds)"] = f"{training_time_strong:.2f}" if num_nodes == 1: - efficiency_strong = "-" + row_data["Strong Efficiency"] = "-" else: ideal_strong = t1_strong / num_nodes - efficiency_strong = f"{ideal_strong / training_time_strong:.2f}" - row_data.extend([f"{training_time_strong:.2f}", efficiency_strong]) + row_data["Strong Efficiency"] = f"{ideal_strong / training_time_strong:.2f}" else: - row_data.extend(["-", "-"]) + row_data["Strong Training Time (seconds)"] = "-" + row_data["Strong Efficiency"] = "-" # Weak scaling metrics if num_nodes in weak_lookup: training_time_weak = weak_lookup[num_nodes] + row_data["Weak Training Time (seconds)"] = f"{training_time_weak:.2f}" if num_nodes == 1: - efficiency_weak = "-" + row_data["Weak Efficiency"] = "-" else: ideal_weak = t1_weak # Weak scaling: ideal is same as 1-node time - efficiency_weak = f"{min(1.0, ideal_weak / training_time_weak):.2f}" - row_data.extend([f"{training_time_weak:.2f}", efficiency_weak]) + row_data["Weak Efficiency"] = f"{min(1.0, ideal_weak / training_time_weak):.2f}" else: - row_data.extend(["-", "-"]) + row_data["Weak Training Time (seconds)"] = "-" + row_data["Weak Efficiency"] = "-" - table_data.append(row_data) + table_rows.append(row_data) # Ensure output path has .csv suffix if output_path.suffix.lower() != ".csv": output_path = output_path.with_suffix(".csv") - # Build DataFrame for CSV output - df_table_data = {} - for i, col in enumerate(col_names): - df_table_data[col] = [row[i] for row in table_data] - - df_table = pl.DataFrame(df_table_data) + # Build DataFrame for CSV output from named columns. + df_table = pl.DataFrame( + [{col: row.get(col, "-") for col in col_names} for row in table_rows] + ) # Write to CSV df_table.write_csv(output_path) @@ -383,8 +385,11 @@ def generate_combined_scaling_table( # Generate PNG visualization of the table png_path = output_path.with_suffix(".png") - _save_table_as_image(table_data, col_names, png_path) - print(f"Saved scaling table PNG: {png_path}") + _save_table_as_image( + [[row.get(col, "-") for col in col_names] for row in table_rows], + col_names, + png_path, + ) def _save_table_as_image(table_data: list, col_names: list, output_path: Path) -> None: From b93813da3934b87cbd6b3d616811e0c9afc9c7e1 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:27:14 +0200 Subject: [PATCH 57/76] Update readme --- packages/performance/README.md | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/packages/performance/README.md b/packages/performance/README.md index 99dabfe4a..af1554d94 100644 --- a/packages/performance/README.md +++ b/packages/performance/README.md @@ -2,42 +2,42 @@ This package contains tools for extracting and analyzing scaling performance data from WeatherGenerator training runs. +## Installation + +Install the optional performance tools: + +```bash +uv sync --extra performance +``` + ## Scripts ### extract_scaling_data.py -Extracts strong scaling metrics from WeatherGenerator training runs. +Extracts scaling metrics from WeatherGenerator training runs and writes parquet output. ```bash -extract_scaling_data --logs-dir /path/to/logs --work-dir /path/to/work +extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet ``` ### generate_scaling_plots.py -Generates scaling plots and tables from parquet/NDJSON data. +Generates scaling plots and tables from parquet/NDJSON data using named columns from the input files. ```bash -# Standard mode (single type) -generate_scaling_plots standard --type strong --input data.parquet - -# Combined mode (separate files) -generate_scaling_plots combined \ - --strong-input strong.parquet \ - --weak-input weak.parquet - -# Combined mode (single file with both types) -generate_scaling_plots standard --type strong,weak --input data.parquet +generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log ``` -## Installation +## Suggested workflow -This package is part of the WeatherGenerator workspace. To install: +1. Extract the scaling data into a parquet file (on your HPC). +2. Copy the parquet file to your local machine. +3. Generate plots from the parquet file. + +Example: ```bash -# In the root WeatherGenerator directory -uv sync --extra performance +extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet +scp user@remote:/path/to/scaling.parquet . +generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log ``` - -The scripts will be available as console scripts: -- `extract_scaling_data` -- `generate_scaling_plots` From b2fe866cef9513706b97fc41e84b369ec29a309a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:27:50 +0200 Subject: [PATCH 58/76] Fewer diffs --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 513fb988b..24ea62a73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,7 @@ dependencies = [ "anemoi-datasets", "weathergen-common", "weathergen-evaluate", - "weathergen-readers-extra", - + "weathergen-readers-extra" ] From c5745148210e9d4e3d85b7633f960be9393d4241 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:28:10 +0200 Subject: [PATCH 59/76] no gitignore changes --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1ce3f78ee..9962276a0 100644 --- a/.gitignore +++ b/.gitignore @@ -228,4 +228,3 @@ output models results reports -.hermes/ From dad5462d6a670ccec2bcb57e841cb93ec3dde2a6 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:41:08 +0200 Subject: [PATCH 60/76] Refactor logging and move time for mini epoch logging outside loop --- src/weathergen/train/trainer.py | 47 +++++++++++----------------- src/weathergen/utils/train_logger.py | 11 ++++++- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 02b661f0f..e2c9d1b92 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -554,31 +554,31 @@ def train(self, mini_epoch): ) self._log_terminal(bidx, mini_epoch, TRAIN) - if bidx % self.train_logging.metrics == 0 or bidx == len(self.data_loader) - 1: + if bidx % self.train_logging.metrics == 0: self._log(TRAIN) # Log collapse metrics if self.collapse_monitor.should_log(self.cf.general.istep): self._log_collapse_metrics(TRAIN) - self.cf.general.istep += 1 - - # log metrics at last iteration (keep barrier for now) - if bidx == len(self.data_loader) - 1: - torch.distributed.barrier() - if is_root(): - total_training_time = time.time() - self.t_training_start - self.train_logger.log_metrics( - "train", - { - "completed_mini_epoch": mini_epoch, - "elapsed_time_mini_epoch": total_training_time, - }, - ) - logger.info( - f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds" - ) + # save model checkpoint (with designation _latest) + if bidx % self.train_logging.checkpoint == 0 and bidx > 0: + self.save_model(-1) + self.cf.general.istep += 1 self.dataset.advance() + + if is_root(): + total_training_time = time.time() - self.t_training_start + self.train_logger.log_metrics( + "train", + { + "completed_mini_epoch": mini_epoch, + "elapsed_time_mini_epoch": total_training_time, + }, + ) + logger.info( + f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds" + ) def validate(self, mini_epoch, mode_cfg, batch_size): """ @@ -779,16 +779,7 @@ def _log(self, stage: Stage): stddev_all, avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), - ) - self.train_logger.log_metrics( - "train", - { - "elapsed_training_time_seconds": elapsed_time, - "total_num_samples": samples, - "average_samples_per_second": samples / elapsed_time - if elapsed_time > 0 - else 0, - }, + elapsed_training_time_seconds=elapsed_time, ) loss_calculator.loss_hist = [] diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 5f1550e42..2c248f5c1 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -102,9 +102,10 @@ def add_logs( stddev_all: dict, avg_loss: list[float] = None, lr: float = None, + elapsed_training_time_seconds: float | None = None, ) -> None: """ - Log training or validation data + Log training or validation data. """ metrics: dict[str, float] = dict(num_samples=samples) @@ -112,6 +113,14 @@ def add_logs( metrics["loss_avg_mean"] = np.nanmean(avg_loss) metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) + if elapsed_training_time_seconds is not None: + metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds + metrics["total_num_samples"] = samples + metrics["average_samples_per_second"] = ( + samples / elapsed_training_time_seconds + if elapsed_training_time_seconds > 0 + else 0 + ) for key, value in losses_all.items(): metrics[key] = np.nanmean(value) From 4f11519929599fb3d0b1856ad935de6a7f11fcc0 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:42:53 +0200 Subject: [PATCH 61/76] Formatting and style fixes --- .../performance/src/performance/generate_scaling_plots.py | 8 ++++++-- src/weathergen/train/trainer.py | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py index de42c97f6..91e7b01ab 100644 --- a/packages/performance/src/performance/generate_scaling_plots.py +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -350,7 +350,9 @@ def generate_combined_scaling_table( row_data["Strong Efficiency"] = "-" else: ideal_strong = t1_strong / num_nodes - row_data["Strong Efficiency"] = f"{ideal_strong / training_time_strong:.2f}" + row_data["Strong Efficiency"] = ( + f"{ideal_strong / training_time_strong:.2f}" + ) else: row_data["Strong Training Time (seconds)"] = "-" row_data["Strong Efficiency"] = "-" @@ -363,7 +365,9 @@ def generate_combined_scaling_table( row_data["Weak Efficiency"] = "-" else: ideal_weak = t1_weak # Weak scaling: ideal is same as 1-node time - row_data["Weak Efficiency"] = f"{min(1.0, ideal_weak / training_time_weak):.2f}" + row_data["Weak Efficiency"] = ( + f"{min(1.0, ideal_weak / training_time_weak):.2f}" + ) else: row_data["Weak Training Time (seconds)"] = "-" row_data["Weak Efficiency"] = "-" diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e2c9d1b92..32fb82aeb 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -364,7 +364,6 @@ def run( if self.world_size_original is None: mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: - # to avoid zero-division for small datasets len_per_rank = ( max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu @@ -565,8 +564,9 @@ def train(self, mini_epoch): self.save_model(-1) self.cf.general.istep += 1 + self.dataset.advance() - + if is_root(): total_training_time = time.time() - self.t_training_start self.train_logger.log_metrics( @@ -770,7 +770,6 @@ def _log(self, stage: Stage): self.train_logger.add_logs(stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: - # Log elapsed training time and throughput metrics with every metric log elapsed_time = time.time() - self.t_training_start self.train_logger.add_logs( stage, From b02b38f60b6c0fdac475586fa910495cbd419eac Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:44:06 +0200 Subject: [PATCH 62/76] Update config --- config/config_era5_georing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml index 82638d146..7ef76620e 100644 --- a/config/config_era5_georing.yml +++ b/config/config_era5_georing.yml @@ -146,8 +146,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] - num_mini_epochs: 1 - samples_per_mini_epoch: 1024 + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 shuffle: True start_date: 2014-01-01T00:00 From 55d82190f76fb1cf380df5e81f87923abba55b1f Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:45:35 +0200 Subject: [PATCH 63/76] Avoid duplicate metrics --- src/weathergen/utils/train_logger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2c248f5c1..941605ff0 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -115,7 +115,6 @@ def add_logs( metrics["num_samples"] = int(samples) if elapsed_training_time_seconds is not None: metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds - metrics["total_num_samples"] = samples metrics["average_samples_per_second"] = ( samples / elapsed_training_time_seconds if elapsed_training_time_seconds > 0 From 904713d76e11480cc3026bf5a0196940bdd12a51 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 11:57:28 +0200 Subject: [PATCH 64/76] Fix lint issues --- .../src/performance/extract_scaling_data.py | 63 ++++++++---------- .../src/performance/generate_scaling_plots.py | 66 +++++++++++-------- 2 files changed, 66 insertions(+), 63 deletions(-) diff --git a/packages/performance/src/performance/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py index a7e74bf1f..370be4403 100644 --- a/packages/performance/src/performance/extract_scaling_data.py +++ b/packages/performance/src/performance/extract_scaling_data.py @@ -1,5 +1,10 @@ #!/usr/bin/env uv run python -"""Extract strong scaling data from WeatherGenerator runs. Outputs parquet with run_id, num_nodes, training_time, overall_time_seconds, loss_avg_mean.""" +"""Extract strong scaling data from WeatherGenerator runs. + +Outputs parquet with: +- run_id, num_nodes, training_time +- overall_time_seconds, loss_avg_mean +""" import argparse import re @@ -150,8 +155,11 @@ def extract_detailed_metrics( def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: """Parse run-ids argument which can be: - 1. A list of run-ids (old format): ["run1", "run2"] -> [(None, "run1"), (None, "run2")] - 2. A dict mapping num_nodes to run-ids (new format): "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] + + 1. A list of run-ids (old format): + ["run1", "run2"] -> [(None, "run1"), (None, "run2")] + 2. A dict mapping num_nodes to run-ids (new format): + "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] Returns list of (num_nodes, run_id) tuples. """ @@ -182,44 +190,23 @@ def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: return [(None, run_id) for run_id in run_ids_str] -def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: - """Extract num_nodes from output.*.txt file in the run directory. - - Looks for 'nNodes' pattern in output files. - """ - run_log_dir = logs_base_dir / run_id - if not run_log_dir.exists(): - return None - - # Look for output.*.txt files - output_files = list(run_log_dir.glob("output.*.txt")) - if not output_files: - # Fallback to err files if no output files found - output_files = list(run_log_dir.glob("weathergen.*.err")) - - for output_file in output_files: - try: - content = output_file.read_text() - # Look for nNodes pattern: "nNodes 128" (space-separated, as in NCCL logs) - match = re.search(r"nNodes\s+(\d+)", content, re.IGNORECASE) - if match: - return int(match.group(1)) - except Exception: - continue - - return None - - def main(): parser = argparse.ArgumentParser( - description="Extract strong scaling data from WeatherGenerator runs. " - "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict mapping num_nodes to run-ids " - "(--run-ids '{1: run1, 4: run2}'). If num_nodes is not provided in the dict, it will be extracted from output.*.txt files." + description=( + "Extract strong scaling data from WeatherGenerator runs. " + "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict " + "mapping num_nodes to run-ids (--run-ids '{1: run1, 4: run2}'). " + "If num_nodes is not provided in the dict, it will be extracted " + "from output.*.txt files." + ) ) parser.add_argument( "--run-ids", nargs="+", - help="Run-ids to process. Can be: (1) list: run1 run2 run3, or (2) dict: '{1: run1, 4: run2, 8: run3}'", + help=( + "Run-ids to process. Can be: (1) list: run1 run2 run3, or " + "(2) dict: '{1: run1, 4: run2, 8: run3}'" + ), ) parser.add_argument( "--logs-base-dir", @@ -273,7 +260,8 @@ def main(): if detailed_records: all_detailed_records.extend(detailed_records) print( - f"Extracted {len(detailed_records)} detailed metric entries for {run_id} ({num_nodes} nodes)" + f"Extracted {len(detailed_records)} detailed metric entries " + f"for {run_id} ({num_nodes} nodes)" ) if not results: @@ -314,7 +302,8 @@ def main(): print(f" - Extracted {len(results)} run summaries to {args.output}") if all_detailed_records: print( - f" - Extracted {len(all_detailed_records)} detailed metric entries to {detailed_output}" + f" - Extracted {len(all_detailed_records)} detailed metric entries " + f"to {detailed_output}" ) diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py index 91e7b01ab..b33073b82 100644 --- a/packages/performance/src/performance/generate_scaling_plots.py +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -4,25 +4,30 @@ Entry points: - standard: plots run-level metrics vs num_nodes - detailed: plots sample-level metrics vs total_num_samples -- combined: generates a comparison table from separate strong and weak scaling input files +- combined: generates a comparison table from separate strong and weak + scaling input files Usage: # Single scaling type (original behavior) - python -m performance.generate_scaling_plots standard --type strong --input strong_data.parquet - + python -m performance.generate_scaling_plots standard --type strong \ + --input strong_data.parquet + # Combined table from single file with both types - python -m performance.generate_scaling_plots standard --type strong,weak --input data.parquet - + python -m performance.generate_scaling_plots standard \ + --type strong,weak --input data.parquet + # Combined table from separate strong and weak input files (new) python -m performance.generate_scaling_plots combined \ --strong-input strong_data.parquet \ --weak-input weak_data.parquet - + # Loss plot - python -m performance.generate_scaling_plots loss --type strong --input data.parquet - + python -m performance.generate_scaling_plots loss --type strong \ + --input data.parquet + # Detailed scaling plot - python -m performance.generate_scaling_plots detailed --input detailed_data.parquet + python -m performance.generate_scaling_plots detailed \ + --input detailed_data.parquet """ import argparse @@ -101,7 +106,8 @@ def generate_scaling_table( """Generate a PNG table image with scaling metrics from the parquet file. Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) - If scaling_types has multiple types, generates a combined table with columns per type. + If scaling_types has multiple types, generates a combined table with + columns per type. """ # Check if required columns exist if "num_nodes" not in df.columns or "training_time" not in df.columns: @@ -242,10 +248,12 @@ def generate_combined_scaling_table( output_path: Path, show_run_ids: bool = False, ) -> None: - """Generate a combined table comparing strong and weak scaling from two separate input files. + """Generate a combined table comparing strong and weak scaling from two + separate input files. Rows: num_nodes - Columns: # Nodes, Strong Training Time, Strong Efficiency, Weak Training Time, Weak Efficiency + Columns: # Nodes, Strong Training Time, Strong Efficiency, + Weak Training Time, Weak Efficiency Also generates a PNG visualization of the table. """ @@ -253,7 +261,8 @@ def generate_combined_scaling_table( for name, df in [("strong", strong_df), ("weak", weak_df)]: if "num_nodes" not in df.columns or "training_time" not in df.columns: print( - f"Warning: Required columns (num_nodes, training_time) not found in {name} data" + f"Warning: Required columns (num_nodes, training_time) not found " + f"in {name} data" ) return @@ -476,12 +485,7 @@ def plot_standard_scaling( for idx, metric in enumerate(valid_metrics): ax = axes[idx][0] df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") - node_counts = ( - df_plot["num_nodes"].unique().to_list() - if "num_nodes" in df_plot.columns - else [] - ) - colors = color_map_for_nodes(node_counts) + # node_counts removed - was only used for colors which is also removed # Handle normalized_throughput and efficiency metrics if y_metric == "normalized_throughput" and metric == "training_time": @@ -496,7 +500,8 @@ def plot_standard_scaling( plot_y = df_plot["normalized_throughput"] else: print( - "Warning: No 1-node data found for normalized throughput calculation" + "Warning: No 1-node data found for normalized throughput " + "calculation" ) continue elif y_metric == "efficiency" and metric == "training_time": @@ -567,8 +572,9 @@ def plot_standard_scaling( raise ValueError(f"Invalid scaling type: {scaling_type}") ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") - # Show per-point efficiency loss as a vertical line and factor label. - # Use plot_y (normalized throughput if applicable) instead of df_plot[metric] + # Show per-point efficiency loss as a vertical line and factor + # label. Use plot_y (normalized throughput if applicable) instead + # of df_plot[metric] for x, y, y_opt in zip( nodes, plot_y.to_list(), optimal_y, strict=False ): @@ -730,7 +736,10 @@ def build_parser() -> argparse.ArgumentParser: "--y-metric", choices=["time", "normalized_throughput", "efficiency"], default="normalized_throughput", - help="Y-axis metric: 'time' for time-to-solution, 'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency", + help=( + "Y-axis metric: 'time' for time-to-solution, " + "'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency" + ), ) standard.add_argument( "--show-run-ids", @@ -769,7 +778,10 @@ def build_parser() -> argparse.ArgumentParser: combined = subparsers.add_parser( "combined", - help="Generate combined table comparing strong and weak scaling from separate input files", + help=( + "Generate combined table comparing strong and weak scaling " + "from separate input files" + ), ) combined.add_argument( "--strong-input", @@ -838,7 +850,8 @@ def main() -> None: for stype in scaling_types: if stype not in ("strong", "weak"): print( - f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'" + f"Error: Invalid scaling type '{stype}'. Use 'strong', " + "'weak', or 'strong,weak'" ) return @@ -884,7 +897,8 @@ def main() -> None: for stype in scaling_types: if stype not in ("strong", "weak"): print( - f"Error: Invalid scaling type '{stype}'. Use 'strong', 'weak', or 'strong,weak'" + f"Error: Invalid scaling type '{stype}'. Use 'strong', " + "'weak', or 'strong,weak'" ) return From 9ecd54400801056cfb209f63927b0aab2f51ad5c Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Mon, 4 May 2026 12:08:42 +0200 Subject: [PATCH 65/76] t_training in __init__ --- src/weathergen/train/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 32fb82aeb..bf14a4cb7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -86,6 +86,7 @@ def __init__(self, train_logging: Config): self.batch_size_test_per_gpu = -1 self.collapse_monitor: CollapseMonitor | None = None self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker() + self.t_training_start: float = 0 def get_batch_size_total(self, batch_size_per_gpu) -> int: """ From 9f02dc1e4122c9db522f51cbb0208356e5e5a971 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 8 May 2026 12:54:48 +0200 Subject: [PATCH 66/76] Renamed metric --- packages/performance/src/performance/extract_scaling_data.py | 4 ++-- src/weathergen/train/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/performance/src/performance/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py index 370be4403..7af725ef6 100644 --- a/packages/performance/src/performance/extract_scaling_data.py +++ b/packages/performance/src/performance/extract_scaling_data.py @@ -71,8 +71,8 @@ def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | No # Extract training time for mini-epoch from last non-NaN row overall_training_time = None - if "elapsed_time_mini_epoch" in df.columns: - val = df["elapsed_time_mini_epoch"].dropna() + if "training_time_after_mini_epoch_seconds" in df.columns: + val = df["training_time_after_mini_epoch_seconds"].dropna() overall_training_time = val.iloc[-1] if len(val) > 0 else None return { diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index bf14a4cb7..eda83335e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -574,7 +574,7 @@ def train(self, mini_epoch): "train", { "completed_mini_epoch": mini_epoch, - "elapsed_time_mini_epoch": total_training_time, + "training_time_after_mini_epoch_seconds": total_training_time, }, ) logger.info( From 0785e3b7270e72c0a7e968910ded71d2f48ca625 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 17:16:15 +0200 Subject: [PATCH 67/76] mv performance package --- packages/performance/README.md | 43 - packages/performance/pyproject.toml | 35 - .../performance/src/performance/__init__.py | 1 - .../src/performance/extract_scaling_data.py | 311 ------ .../src/performance/generate_scaling_plots.py | 999 ------------------ 5 files changed, 1389 deletions(-) delete mode 100644 packages/performance/README.md delete mode 100644 packages/performance/pyproject.toml delete mode 100644 packages/performance/src/performance/__init__.py delete mode 100644 packages/performance/src/performance/extract_scaling_data.py delete mode 100644 packages/performance/src/performance/generate_scaling_plots.py diff --git a/packages/performance/README.md b/packages/performance/README.md deleted file mode 100644 index af1554d94..000000000 --- a/packages/performance/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# WeatherGenerator Performance Analysis Tools - -This package contains tools for extracting and analyzing scaling performance data from WeatherGenerator training runs. - -## Installation - -Install the optional performance tools: - -```bash -uv sync --extra performance -``` - -## Scripts - -### extract_scaling_data.py - -Extracts scaling metrics from WeatherGenerator training runs and writes parquet output. - -```bash -extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet -``` - -### generate_scaling_plots.py - -Generates scaling plots and tables from parquet/NDJSON data using named columns from the input files. - -```bash -generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log -``` - -## Suggested workflow - -1. Extract the scaling data into a parquet file (on your HPC). -2. Copy the parquet file to your local machine. -3. Generate plots from the parquet file. - -Example: - -```bash -extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet -scp user@remote:/path/to/scaling.parquet . -generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log -``` diff --git a/packages/performance/pyproject.toml b/packages/performance/pyproject.toml deleted file mode 100644 index 933323e3f..000000000 --- a/packages/performance/pyproject.toml +++ /dev/null @@ -1,35 +0,0 @@ -[project] -name = "weathergen-performance" -version = "0.1.0" -description = "Performance analysis tools for WeatherGenerator" -readme = "README.md" - -requires-python = ">=3.12,<3.13" -dependencies = [ - "polars~=1.25.2", - "pandas~=2.2", - "matplotlib", - "pyarrow>=23.0.1", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src/performance"] - -[project.scripts] -extract_scaling_data = "performance.extract_scaling_data:main" -generate_scaling_plots = "performance.generate_scaling_plots:main" - -[tool.ruff.lint] -select = [ - "E", - "F", - "UP", - "B", - "SIM", - "I", - "N" -] diff --git a/packages/performance/src/performance/__init__.py b/packages/performance/src/performance/__init__.py deleted file mode 100644 index 7105b9746..000000000 --- a/packages/performance/src/performance/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""WeatherGenerator performance analysis tools.""" diff --git a/packages/performance/src/performance/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py deleted file mode 100644 index 7af725ef6..000000000 --- a/packages/performance/src/performance/extract_scaling_data.py +++ /dev/null @@ -1,311 +0,0 @@ -#!/usr/bin/env uv run python -"""Extract strong scaling data from WeatherGenerator runs. - -Outputs parquet with: -- run_id, num_nodes, training_time -- overall_time_seconds, loss_avg_mean -""" - -import argparse -import re -import sys -from pathlib import Path - -import pandas as pd - - -def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: - """Extract num_nodes from output.*.txt file in the run directory. - - Looks for 'nNodes' pattern in output files. - """ - run_log_dir = logs_base_dir / run_id - if not run_log_dir.exists(): - return None - - # Look for output.*.txt files - output_files = list(run_log_dir.glob("output.*.txt")) - if not output_files: - # Fallback to err files if no output files found - output_files = list(run_log_dir.glob("weathergen.*.err")) - - for output_file in output_files: - try: - content = output_file.read_text() - # Look for nNodes pattern: "nNodes 128" (space-separated, as in NCCL logs) - match = re.search(r"nNodes\s+(\d+)", content, re.IGNORECASE) - if match: - return int(match.group(1)) - except Exception: - continue - - return None - - -def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: - """Extract metrics from NDJSON file with startup and training lines. - - Format: - - Line 1: startup_time_seconds - - Line 2+: loss_avg_mean, LossPhysical.loss_avg, etc. - """ - metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" - if not metrics_path.exists(): - return None - try: - df = pd.read_json(metrics_path, lines=True) - if len(df) == 0: - return None - - # Extract startup_time from first row (startup line) - startup_time = None - if "startup_time_seconds" in df.columns: - val = df["startup_time_seconds"].dropna() - startup_time = val.iloc[0] if len(val) > 0 else None - - # Extract loss_avg_mean from last non-NaN training row - loss_avg_mean = None - if "loss_avg_mean" in df.columns: - val = df["loss_avg_mean"].dropna() - loss_avg_mean = val.iloc[-1] if len(val) > 0 else None - - # Extract training time for mini-epoch from last non-NaN row - overall_training_time = None - if "training_time_after_mini_epoch_seconds" in df.columns: - val = df["training_time_after_mini_epoch_seconds"].dropna() - overall_training_time = val.iloc[-1] if len(val) > 0 else None - - return { - "startup_time_seconds": startup_time, - "training_time": overall_training_time, - "loss_avg_mean": loss_avg_mean, - } - except Exception: - return None - - -def extract_detailed_metrics( - run_id: str, shared_work_dir: Path, num_nodes: int | None = None -) -> list[pd.DataFrame]: - """Extract detailed metrics pairing timing rows with preceding loss rows. - - For each row containing elapsed_training_time_seconds, pair it with the - preceding row containing loss metrics. Returns a list of DataFrames. - """ - metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" - if not metrics_path.exists(): - return [] - - try: - df = pd.read_json(metrics_path, lines=True) - if len(df) == 0: - return [] - - # Find rows with elapsed_training_time_seconds (timing rows) - if "elapsed_training_time_seconds" not in df.columns: - return [] - timing_indices = df.index[df["elapsed_training_time_seconds"].notna()].tolist() - - if not timing_indices: - return [] - - # Find rows with loss data - if "loss_avg_mean" not in df.columns: - return [] - loss_indices = set(df.index[df["loss_avg_mean"].notna()].tolist()) - - timing_cols = [ - "elapsed_training_time_seconds", - "total_num_samples", - "average_samples_per_second", - ] - - detailed_records = [] - - for timing_idx in timing_indices: - # Find the last loss row before this timing row - loss_rows_before = [i for i in loss_indices if i < timing_idx] - if not loss_rows_before: - continue - - last_loss_idx = max(loss_rows_before) - - # Build record dict from loss row + timing row - record = {"run_id": run_id} - if num_nodes is not None: - record["num_nodes"] = num_nodes - - record["loss_avg_mean"] = df.at[last_loss_idx, "loss_avg_mean"] - - for col in timing_cols: - if col in df.columns: - record[col] = df.at[timing_idx, col] - - detailed_records.append(pd.DataFrame([record])) - - return detailed_records - - except Exception as e: - print(f"Error extracting detailed metrics for {run_id}: {e}") - import traceback - - traceback.print_exc() - return [] - - -def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: - """Parse run-ids argument which can be: - - 1. A list of run-ids (old format): - ["run1", "run2"] -> [(None, "run1"), (None, "run2")] - 2. A dict mapping num_nodes to run-ids (new format): - "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] - - Returns list of (num_nodes, run_id) tuples. - """ - if len(run_ids_str) == 1: - # Check if it looks like a dict: "{key: value, ...}" - stripped = run_ids_str[0].strip() - if stripped.startswith("{") and stripped.endswith("}"): - # Parse as dict format: {num_nodes: run_id, ...} - import ast - - try: - parsed = ast.literal_eval(stripped) - if isinstance(parsed, dict): - # Convert string keys to int if needed - result = [] - for k, v in parsed.items(): - key = int(k) if isinstance(k, str) and k.isdigit() else k - result.append((key, str(v))) - return result - except (ValueError, SyntaxError): - pass - - # Single run-id or comma-separated list - run_ids = [r.strip() for r in run_ids_str[0].split(",") if r.strip()] - return [(None, run_id) for run_id in run_ids] - - # Multiple arguments - treat as list of run-ids - return [(None, run_id) for run_id in run_ids_str] - - -def main(): - parser = argparse.ArgumentParser( - description=( - "Extract strong scaling data from WeatherGenerator runs. " - "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict " - "mapping num_nodes to run-ids (--run-ids '{1: run1, 4: run2}'). " - "If num_nodes is not provided in the dict, it will be extracted " - "from output.*.txt files." - ) - ) - parser.add_argument( - "--run-ids", - nargs="+", - help=( - "Run-ids to process. Can be: (1) list: run1 run2 run3, or " - "(2) dict: '{1: run1, 4: run2, 8: run3}'" - ), - ) - parser.add_argument( - "--logs-base-dir", - type=Path, - default=Path("logs"), - help="Base directory for run logs (default: logs relative to current dir)", - ) - parser.add_argument( - "--shared-work-dir", - type=Path, - default=Path("/e/scratch/weatherai/shared_work"), - help="Base directory for shared work/results", - ) - parser.add_argument( - "--output", - type=Path, - default=Path("scaling_data.parquet"), - help="Output parquet file path", - ) - - args = parser.parse_args() - - run_id_mapping = parse_run_ids(args.run_ids) - if not run_id_mapping: - sys.exit("Error: No run-ids provided") - - results = [] - all_detailed_records = [] - - for num_nodes, run_id in run_id_mapping: - # If num_nodes not provided, extract from output.*.txt file - if num_nodes is None: - num_nodes = extract_num_nodes_from_output(run_id, args.logs_base_dir) - - metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) - if metrics is None: - continue - - row = { - "run_id": run_id, - "num_nodes": num_nodes, - "startup_time_seconds": metrics.get("startup_time_seconds"), - "training_time": metrics.get("training_time"), - "loss_avg_mean": metrics.get("loss_avg_mean"), - } - results.append(row) - - detailed_records = extract_detailed_metrics( - run_id, args.shared_work_dir, num_nodes - ) - if detailed_records: - all_detailed_records.extend(detailed_records) - print( - f"Extracted {len(detailed_records)} detailed metric entries " - f"for {run_id} ({num_nodes} nodes)" - ) - - if not results: - sys.exit("No data extracted") - - df = pd.DataFrame(results) - if "num_nodes" in df.columns: - df = df.sort_values("num_nodes").reset_index(drop=True) - - args.output.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(args.output, index=False) - df.to_csv(args.output.with_suffix(".csv"), index=False) - - # Write detailed metrics if any were collected - if all_detailed_records: - detailed_df = pd.concat(all_detailed_records, ignore_index=True) - - desired_cols = [ - "run_id", - "num_nodes", - "elapsed_training_time_seconds", - "total_num_samples", - "average_samples_per_second", - "loss_avg_mean", - ] - available_cols = [c for c in desired_cols if c in detailed_df.columns] - detailed_df = detailed_df[available_cols] - - output_stem = args.output.stem - detailed_output = args.output.with_name( - f"{output_stem}_detailed{args.output.suffix}" - ) - - detailed_df.to_parquet(detailed_output, index=False) - detailed_df.to_csv(detailed_output.with_suffix(".csv"), index=False) - - print("\nSummary:") - print(f" - Extracted {len(results)} run summaries to {args.output}") - if all_detailed_records: - print( - f" - Extracted {len(all_detailed_records)} detailed metric entries " - f"to {detailed_output}" - ) - - -if __name__ == "__main__": - main() diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py deleted file mode 100644 index b33073b82..000000000 --- a/packages/performance/src/performance/generate_scaling_plots.py +++ /dev/null @@ -1,999 +0,0 @@ -#!/usr/bin/env uv run python -"""Generate scaling plots from parquet/ndjson data using matplotlib only. - -Entry points: -- standard: plots run-level metrics vs num_nodes -- detailed: plots sample-level metrics vs total_num_samples -- combined: generates a comparison table from separate strong and weak - scaling input files - -Usage: - # Single scaling type (original behavior) - python -m performance.generate_scaling_plots standard --type strong \ - --input strong_data.parquet - - # Combined table from single file with both types - python -m performance.generate_scaling_plots standard \ - --type strong,weak --input data.parquet - - # Combined table from separate strong and weak input files (new) - python -m performance.generate_scaling_plots combined \ - --strong-input strong_data.parquet \ - --weak-input weak_data.parquet - - # Loss plot - python -m performance.generate_scaling_plots loss --type strong \ - --input data.parquet - - # Detailed scaling plot - python -m performance.generate_scaling_plots detailed \ - --input detailed_data.parquet -""" - -import argparse -from pathlib import Path - -import matplotlib.pyplot as plt -import polars as pl - -SCRIPT_DIR = Path(__file__).resolve().parent -VALID_IMAGE_SUFFIXES = {".png", ".pdf", ".svg", ".jpg", ".jpeg"} -PALETTE = [ - "#1f77b4", - "#ff7f0e", - "#2ca02c", - "#d62728", - "#9467bd", - "#8c564b", - "#e377c2", - "#7f7f7f", - "#bcbd22", - "#17becf", -] - - -def resolve_input_path(path: Path) -> Path: - """Resolve relative input paths against cwd first, then the script directory.""" - if path.is_absolute(): - return path - - cwd_candidate = Path.cwd() / path - if cwd_candidate.exists(): - return cwd_candidate - - script_candidate = SCRIPT_DIR / path - if script_candidate.exists(): - return script_candidate - - return cwd_candidate - - -def resolve_output_path(path: Path) -> Path: - """Ensure the output path uses a supported image suffix.""" - if path.suffix.lower() in VALID_IMAGE_SUFFIXES: - return path - return path.with_suffix(".png") - - -def read_table(path: Path) -> pl.DataFrame: - """Read parquet or ndjson automatically.""" - try: - print("Read as parquet") - return pl.read_parquet(path) - except Exception: - print("Read as NDJSON") - return pl.read_ndjson(path) - - -def color_map_for_nodes(node_counts: list) -> dict: - return {node: PALETTE[i % len(PALETTE)] for i, node in enumerate(node_counts)} - - -def save_figure(fig: plt.Figure, output_path: Path) -> None: - output_path = resolve_output_path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(output_path, dpi=150, bbox_inches="tight") - plt.close(fig) - print(f"Saved: {output_path}") - - -def generate_scaling_table( - df: pl.DataFrame, - input_path: Path, - show_run_ids: bool = False, - scaling_types: list[str] = None, -) -> None: - """Generate a PNG table image with scaling metrics from the parquet file. - - Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) - If scaling_types has multiple types, generates a combined table with - columns per type. - """ - # Check if required columns exist - if "num_nodes" not in df.columns or "training_time" not in df.columns: - print("Warning: Required columns (num_nodes, training_time) not found in data") - return - - # Filter out rows with null values in required columns - df_filtered = df.filter( - pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() - ).sort("num_nodes") - - if len(df_filtered) == 0: - print("No valid data for scaling table") - return - - # Get the 1-node training time for ideal time calculation - one_node_data = df_filtered.filter(pl.col("num_nodes") == 1) - if one_node_data.height == 0: - print("Warning: No 1-node data found for ideal time calculation") - return - - t1 = one_node_data["training_time"].item() - - # Determine scaling types to include - if scaling_types is None or len(scaling_types) == 0: - # Derive scaling type from input filename - input_name_lower = input_path.name.lower() - if "weak" in input_name_lower: - scaling_types = ["weak"] - elif "strong" in input_name_lower: - scaling_types = ["strong"] - else: - scaling_types = ["strong"] # Default to strong - - # Build table data with proper formatting - has_run_id = "run_id" in df_filtered.columns - - # Check if we're generating a combined table (multiple types) - is_combined = len(scaling_types) > 1 - - if is_combined: - # Combined table: columns per type - col_names = ["# Nodes"] - for stype in scaling_types: - col_names.extend( - [ - f"{stype.capitalize()} Training Time (seconds)", - f"{stype.capitalize()} Efficiency", - ] - ) - if show_run_ids and has_run_id: - col_names.insert(0, "run_id") - else: - # Single type table (original format) - scaling_type = scaling_types[0].capitalize() - col_names = [ - "# Nodes", - "Training Time (seconds)", - "Ideal Time (seconds)", - "Efficiency", - ] - if show_run_ids and has_run_id: - col_names.insert(0, "run_id") - - table_rows = [] - for row in df_filtered.iter_rows(named=True): - num_nodes = row["num_nodes"] - training_time = row["training_time"] - - if is_combined: - # Combined table: add metrics for each type - row_data = {} - if show_run_ids and has_run_id: - row_data["run_id"] = str(row.get("run_id", "")) - row_data["# Nodes"] = str(num_nodes) - - for stype in scaling_types: - time_col = f"{stype.capitalize()} Training Time (seconds)" - eff_col = f"{stype.capitalize()} Efficiency" - row_data[time_col] = f"{training_time:.2f}" - if num_nodes == 1: - row_data[eff_col] = "-" - else: - if stype == "strong": - # Strong scaling: ideal time = t1 / num_nodes - ideal_val = t1 / num_nodes - efficiency_val = ideal_val / training_time - else: - # Weak scaling: ideal time = t1 (same work per node) - ideal_val = t1 - efficiency_val = min(1.0, t1 / training_time) - - row_data[eff_col] = f"{efficiency_val:.2f}" - else: - # Single type table (original format) - scaling_type = scaling_types[0] - row_data = {} - if show_run_ids and has_run_id: - row_data["run_id"] = str(row.get("run_id", "")) - row_data["# Nodes"] = str(num_nodes) - row_data["Training Time (seconds)"] = f"{training_time:.2f}" - if num_nodes == 1: - row_data["Ideal Time (seconds)"] = "-" - row_data["Efficiency"] = "-" - else: - if scaling_type == "strong": - # Strong scaling: ideal time = t1 / num_nodes - ideal_val = t1 / num_nodes - efficiency_val = ideal_val / training_time - else: - # Weak scaling: ideal time = t1 (same work per node) - ideal_val = t1 - efficiency_val = min(1.0, t1 / training_time) - - row_data["Ideal Time (seconds)"] = f"{ideal_val:.2f}" - row_data["Efficiency"] = f"{efficiency_val:.2f}" - - table_rows.append(row_data) - - # Generate output filename: input_stem_table.csv - output_path = input_path.with_name(input_path.stem + "_table.csv") - - # Build DataFrame for CSV output from named columns. - df_table = pl.DataFrame( - [{col: row.get(col, "-") for col in col_names} for row in table_rows] - ) - - # Write to CSV - df_table.write_csv(output_path) - print(f"Saved scaling table: {output_path}") - - -def generate_combined_scaling_table( - strong_df: pl.DataFrame, - weak_df: pl.DataFrame, - strong_path: Path, - weak_path: Path, - output_path: Path, - show_run_ids: bool = False, -) -> None: - """Generate a combined table comparing strong and weak scaling from two - separate input files. - - Rows: num_nodes - Columns: # Nodes, Strong Training Time, Strong Efficiency, - Weak Training Time, Weak Efficiency - - Also generates a PNG visualization of the table. - """ - # Validate required columns - for name, df in [("strong", strong_df), ("weak", weak_df)]: - if "num_nodes" not in df.columns or "training_time" not in df.columns: - print( - f"Warning: Required columns (num_nodes, training_time) not found " - f"in {name} data" - ) - return - - # Filter and sort both datasets - strong_filtered = strong_df.filter( - pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() - ).sort("num_nodes") - - weak_filtered = weak_df.filter( - pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() - ).sort("num_nodes") - - if len(strong_filtered) == 0 or len(weak_filtered) == 0: - print("No valid data for combined scaling table") - return - - # Get 1-node training times for efficiency calculation - strong_one_node = strong_filtered.filter(pl.col("num_nodes") == 1) - weak_one_node = weak_filtered.filter(pl.col("num_nodes") == 1) - - if strong_one_node.height == 0 or weak_one_node.height == 0: - print("Warning: No 1-node data found for efficiency calculation") - return - - t1_strong = strong_one_node["training_time"].item() - t1_weak = weak_one_node["training_time"].item() - - # Check for run_id in either dataset - has_run_id = ( - "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns - ) - - # Build column names - col_names = [ - "# Nodes", - "Strong Training Time (seconds)", - "Strong Efficiency", - "Weak Training Time (seconds)", - "Weak Efficiency", - ] - if show_run_ids and has_run_id: - col_names.insert(0, "run_id") - - # Get all unique num_nodes from both datasets - all_nodes = sorted( - set(strong_filtered["num_nodes"].to_list()) - | set(weak_filtered["num_nodes"].to_list()) - ) - - # Create lookup dictionaries for easy access - strong_lookup = { - row["num_nodes"]: row["training_time"] - for row in strong_filtered.iter_rows(named=True) - } - weak_lookup = { - row["num_nodes"]: row["training_time"] - for row in weak_filtered.iter_rows(named=True) - } - strong_run_id_lookup = ( - { - row["num_nodes"]: row["run_id"] - for row in strong_filtered.iter_rows(named=True) - } - if "run_id" in strong_filtered.columns - else {} - ) - weak_run_id_lookup = ( - {row["num_nodes"]: row["run_id"] for row in weak_filtered.iter_rows(named=True)} - if "run_id" in weak_filtered.columns - else {} - ) - - table_rows = [] - for num_nodes in all_nodes: - row_data = {} - - # Get run_id if available - if show_run_ids and has_run_id: - run_id = str( - strong_run_id_lookup.get( - num_nodes, weak_run_id_lookup.get(num_nodes, "") - ) - ) - row_data["run_id"] = run_id - - # Add num_nodes - row_data["# Nodes"] = str(num_nodes) - - # Strong scaling metrics - if num_nodes in strong_lookup: - training_time_strong = strong_lookup[num_nodes] - row_data["Strong Training Time (seconds)"] = f"{training_time_strong:.2f}" - if num_nodes == 1: - row_data["Strong Efficiency"] = "-" - else: - ideal_strong = t1_strong / num_nodes - row_data["Strong Efficiency"] = ( - f"{ideal_strong / training_time_strong:.2f}" - ) - else: - row_data["Strong Training Time (seconds)"] = "-" - row_data["Strong Efficiency"] = "-" - - # Weak scaling metrics - if num_nodes in weak_lookup: - training_time_weak = weak_lookup[num_nodes] - row_data["Weak Training Time (seconds)"] = f"{training_time_weak:.2f}" - if num_nodes == 1: - row_data["Weak Efficiency"] = "-" - else: - ideal_weak = t1_weak # Weak scaling: ideal is same as 1-node time - row_data["Weak Efficiency"] = ( - f"{min(1.0, ideal_weak / training_time_weak):.2f}" - ) - else: - row_data["Weak Training Time (seconds)"] = "-" - row_data["Weak Efficiency"] = "-" - - table_rows.append(row_data) - - # Ensure output path has .csv suffix - if output_path.suffix.lower() != ".csv": - output_path = output_path.with_suffix(".csv") - - # Build DataFrame for CSV output from named columns. - df_table = pl.DataFrame( - [{col: row.get(col, "-") for col in col_names} for row in table_rows] - ) - - # Write to CSV - df_table.write_csv(output_path) - print(f"Saved scaling table CSV: {output_path}") - - # Generate PNG visualization of the table - png_path = output_path.with_suffix(".png") - _save_table_as_image( - [[row.get(col, "-") for col in col_names] for row in table_rows], - col_names, - png_path, - ) - - -def _save_table_as_image(table_data: list, col_names: list, output_path: Path) -> None: - """Save table data as a PNG image using matplotlib. - - Automatically sizes the figure to fit all content. - """ - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Calculate figure size based on content - num_cols = len(col_names) - num_rows = len(table_data) + 1 # +1 for header - - # Width: base + per-column width, Height: base + per-row height - fig_width = max(8, num_cols * 2.5) - fig_height = max(3, num_rows * 0.5) - - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - ax.axis("off") - - # Create table - table = ax.table( - cellText=table_data, - colLabels=col_names, - cellLoc="center", - loc="center", - colColours=["#2E5C8A"] * num_cols, - cellColours=[ - ["#E8ECEF" if i % 2 == 0 else "white" for _ in range(num_cols)] - for i in range(len(table_data)) - ], - ) - - # Style the table - table.auto_set_font_size(False) - table.set_fontsize(9) - table.auto_set_column_width(col=list(range(num_cols))) - - # Style header cells - for i in range(num_cols): - table[(0, i)].set_text_props(color="white", fontweight="bold") - - # Adjust layout and save - plt.tight_layout() - fig.savefig(output_path, dpi=150, bbox_inches="tight") - plt.close(fig) - - -def plot_standard_scaling( - df: pl.DataFrame, - output_path: Path, - scaling_type: str, - metrics: list[str], - x_scale: str, - y_scale: str, - y_metric: str, - show_run_ids: bool = False, -) -> None: - """Plot run-level scaling data vs num_nodes.""" - metric_labels = { - "training_time": "Training Time (seconds)", - "loss_avg_mean": "Average Loss", - "normalized_throughput": "Speedup", - "efficiency": "Scaling Efficiency", - } - - valid_metrics = [ - m - for m in metrics - if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0 - ] - if not valid_metrics: - print("No valid metrics to plot") - return - - fig, axes = plt.subplots( - len(valid_metrics), 1, figsize=(12, 6 * len(valid_metrics)), squeeze=False - ) - - for idx, metric in enumerate(valid_metrics): - ax = axes[idx][0] - df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") - # node_counts removed - was only used for colors which is also removed - - # Handle normalized_throughput and efficiency metrics - if y_metric == "normalized_throughput" and metric == "training_time": - # Calculate normalized throughput: T1 / T - one_node_data = df.filter(pl.col("num_nodes") == 1) - if one_node_data.height > 0: - t1 = one_node_data["training_time"].item() - # Create a new dataframe with normalized throughput - df_plot = df_plot.with_columns( - (t1 / pl.col("training_time")).alias("normalized_throughput") - ) - plot_y = df_plot["normalized_throughput"] - else: - print( - "Warning: No 1-node data found for normalized throughput " - "calculation" - ) - continue - elif y_metric == "efficiency" and metric == "training_time": - # Calculate efficiency based on scaling type - one_node_data = df.filter(pl.col("num_nodes") == 1) - if one_node_data.height > 0: - t1 = one_node_data["training_time"].item() - if scaling_type == "strong": - # Strong scaling: efficiency = (t1 / num_nodes) / training_time - df_plot = df_plot.with_columns( - ((t1 / pl.col("num_nodes")) / pl.col("training_time")).alias( - "efficiency" - ) - ) - else: - # Weak scaling: efficiency = min(1.0, t1 / training_time) - df_plot = df_plot.with_columns( - pl.min_horizontal( - pl.lit(1.0), t1 / pl.col("training_time") - ).alias("efficiency") - ) - plot_y = df_plot["efficiency"] - else: - print("Warning: No 1-node data found for efficiency calculation") - continue - else: - plot_y = df_plot[metric] - - ax.plot( - df_plot["num_nodes"], - plot_y, - "o-", - color="steelblue", - markersize=8, - ) - - if show_run_ids: - for x, y, label in zip( - df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"], strict=False - ): - ax.text(x, y, label, ha="center", va="bottom", fontsize=8) - - if ( - metric == "training_time" - and y_metric in ("time", "normalized_throughput", "efficiency") - and "training_time" in df.columns - ): - one_node_data = df.filter(pl.col("num_nodes") == 1) - if one_node_data.height > 0: - t1 = one_node_data["training_time"].item() - nodes = df_plot["num_nodes"].to_list() - if y_metric == "efficiency": - # For efficiency, optimal is always 1.0 (100% efficiency) - optimal_y = [1.0 for _ in nodes] - elif scaling_type == "weak": - if y_metric == "normalized_throughput": - # For normalized throughput, optimal is 1.0 (no speedup loss) - optimal_y = [1.0 for _ in nodes] - else: - optimal_y = [t1 for _ in nodes] - elif scaling_type == "strong": - if y_metric == "normalized_throughput": - # For normalized throughput, optimal is n (linear speedup) - optimal_y = [float(n) for n in nodes] - else: - optimal_y = [t1 / n for n in nodes] - else: - raise ValueError(f"Invalid scaling type: {scaling_type}") - ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") - - # Show per-point efficiency loss as a vertical line and factor - # label. Use plot_y (normalized throughput if applicable) instead - # of df_plot[metric] - for x, y, y_opt in zip( - nodes, plot_y.to_list(), optimal_y, strict=False - ): - if y_opt == 0: - continue - factor = y / y_opt - ax.vlines( - x, - y_opt, - y, - colors="gray", - linestyles=":", - linewidth=1, - alpha=0.7, - ) - y_mid = (y + y_opt) / 2 - ax.annotate( - f"{factor:.2f}", - xy=(x, y_mid), - xytext=(4, 0), - textcoords="offset points", - fontsize=14, - fontweight="bold", - color="dimgray", - va="center", - ) - ax.legend() - - ax.set_xscale(x_scale) - if y_scale == "log": - ax.set_yscale("log") - ax.set_xlabel("Number of Nodes", fontsize=16) - if y_metric == "normalized_throughput" and metric == "training_time": - ax.set_ylabel("Speedup", fontsize=16) - elif y_metric == "efficiency" and metric == "training_time": - ax.set_ylabel("Scaling Efficiency", fontsize=16) - else: - ax.set_ylabel(metric_labels.get(metric, metric), fontsize=16) - ax.tick_params(axis="both", which="major", labelsize=14) - ax.grid(True, alpha=0.3) - - plt.tight_layout() - save_figure(fig, output_path) - - -def plot_detailed_scaling( - df: pl.DataFrame, - output_path: Path, - x_scale: str, - y_scale: str, -) -> None: - """Plot sample-level detailed scaling data vs total_num_samples.""" - required_cols = [ - "total_num_samples", - "elapsed_training_time_seconds", - "loss_avg_mean", - "num_nodes", - ] - if not all(col in df.columns for col in required_cols): - print("Detailed metrics not available in this dataset") - print(f"Available columns: {df.columns}") - return - - df_plot = df.filter( - pl.col("total_num_samples").is_not_null() - & (pl.col("total_num_samples") > 0) - & pl.col("elapsed_training_time_seconds").is_not_null() - & pl.col("loss_avg_mean").is_not_null() - & pl.col("num_nodes").is_not_null() - ).sort("num_nodes", "total_num_samples") - - if len(df_plot) == 0: - print("No valid data for detailed scaling plots") - return - - node_counts = sorted(df_plot["num_nodes"].unique().to_list()) - colors = color_map_for_nodes(node_counts) - - fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True) - - ax = axes[0] - for node_count in node_counts: - df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( - "total_num_samples" - ) - ax.plot( - df_node["total_num_samples"], - df_node["elapsed_training_time_seconds"], - "o-", - color=colors[node_count], - markersize=6, - label=f"{node_count} nodes", - ) - ax.set_xscale(x_scale) - if y_scale == "log": - ax.set_yscale("log") - ax.set_ylabel("Elapsed Training Time (seconds)", fontsize=16) - ax.set_title("Elapsed Training Time vs Samples", fontsize=16) - ax.tick_params(axis="both", which="major", labelsize=14) - ax.grid(True, alpha=0.3) - ax.legend(title="Node Count") - - ax = axes[1] - for node_count in node_counts: - df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( - "total_num_samples" - ) - ax.plot( - df_node["total_num_samples"], - df_node["loss_avg_mean"], - "o-", - color=colors[node_count], - markersize=6, - label=f"{node_count} nodes", - ) - ax.set_xscale(x_scale) - if y_scale == "log": - ax.set_yscale("log") - ax.set_xlabel("Total Number of Samples", fontsize=16) - ax.set_ylabel("Average Loss", fontsize=16) - ax.set_title("Loss vs Samples", fontsize=16) - ax.tick_params(axis="both", which="major", labelsize=14) - ax.grid(True, alpha=0.3) - ax.legend(title="Node Count") - - fig.suptitle("Detailed Scaling Analysis", fontsize=16) - plt.tight_layout() - save_figure(fig, output_path) - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Generate scaling plots from parquet or NDJSON data" - ) - subparsers = parser.add_subparsers(dest="mode", required=True) - - standard = subparsers.add_parser( - "standard", help="Plot run-level scaling metrics vs num_nodes" - ) - standard.add_argument( - "--type", - required=True, - help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", - ) - standard.add_argument( - "--input", - type=Path, - default=Path("scaling_data.parquet"), - help="Input parquet/ndjson file", - ) - standard.add_argument("--output", type=Path, default=None, help="Output image path") - standard.add_argument( - "--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale" - ) - standard.add_argument( - "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" - ) - standard.add_argument( - "--y-metric", - choices=["time", "normalized_throughput", "efficiency"], - default="normalized_throughput", - help=( - "Y-axis metric: 'time' for time-to-solution, " - "'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency" - ), - ) - standard.add_argument( - "--show-run-ids", - action="store_true", - help="Show run_id labels on the plot and in the output table", - ) - - loss_only = subparsers.add_parser( - "loss", help="Plot loss metrics vs num_nodes (separate from throughput)" - ) - loss_only.add_argument( - "--type", - required=True, - help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", - ) - loss_only.add_argument( - "--input", - type=Path, - default=Path("scaling_data.parquet"), - help="Input parquet/ndjson file", - ) - loss_only.add_argument( - "--output", type=Path, default=None, help="Output image path" - ) - loss_only.add_argument( - "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" - ) - loss_only.add_argument( - "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" - ) - loss_only.add_argument( - "--show-run-ids", - action="store_true", - help="Show run_id labels on the plot and in the output table", - ) - - combined = subparsers.add_parser( - "combined", - help=( - "Generate combined table comparing strong and weak scaling " - "from separate input files" - ), - ) - combined.add_argument( - "--strong-input", - type=Path, - required=True, - help="Input parquet/ndjson file for strong scaling", - ) - combined.add_argument( - "--weak-input", - type=Path, - required=True, - help="Input parquet/ndjson file for weak scaling", - ) - combined.add_argument( - "--output", type=Path, default=None, help="Output table path (CSV)" - ) - combined.add_argument( - "--show-run-ids", - action="store_true", - help="Show run_id labels in the output table", - ) - - detailed = subparsers.add_parser( - "detailed", help="Plot sample-level detailed scaling metrics" - ) - detailed.add_argument( - "--input", - type=Path, - default=Path("scaling_data_detailed.parquet"), - help="Input detailed parquet/ndjson file", - ) - detailed.add_argument("--output", type=Path, default=None, help="Output image path") - detailed.add_argument( - "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" - ) - detailed.add_argument( - "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" - ) - - return parser - - -def main() -> None: - parser = build_parser() - args = parser.parse_args() - - if args.mode == "standard": - input_path = resolve_input_path(args.input) - if not input_path.exists(): - print(f"Error: Input file not found: {input_path}") - return - - output_path = args.output or input_path.with_suffix(".png") - - print(f"Loading data from: {input_path}") - try: - df = read_table(input_path) - except Exception as e: - print("Error: Could not read input file as parquet or NDJSON") - print(str(e)) - return - print(f"Loaded {len(df)} rows") - - # Parse scaling types from --type argument - scaling_types = [t.strip().lower() for t in args.type.split(",")] - for stype in scaling_types: - if stype not in ("strong", "weak"): - print( - f"Error: Invalid scaling type '{stype}'. Use 'strong', " - "'weak', or 'strong,weak'" - ) - return - - # Standard mode: only plot training_time with normalized throughput or time - metrics_to_plot = ["training_time"] - # Use the first type for plotting (or strong if combined) - plot_type = scaling_types[0] - plot_standard_scaling( - df, - output_path, - plot_type, - metrics_to_plot, - args.x_scale, - args.y_scale, - args.y_metric, - args.show_run_ids, - ) - # Generate scaling table - generate_scaling_table( - df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types - ) - return - - if args.mode == "loss": - input_path = resolve_input_path(args.input) - if not input_path.exists(): - print(f"Error: Input file not found: {input_path}") - return - - output_path = args.output or input_path.with_suffix(".loss.png") - - print(f"Loading data from: {input_path}") - try: - df = read_table(input_path) - except Exception as e: - print("Error: Could not read input file as parquet or NDJSON") - print(str(e)) - return - print(f"Loaded {len(df)} rows") - - # Parse scaling types from --type argument - scaling_types = [t.strip().lower() for t in args.type.split(",")] - for stype in scaling_types: - if stype not in ("strong", "weak"): - print( - f"Error: Invalid scaling type '{stype}'. Use 'strong', " - "'weak', or 'strong,weak'" - ) - return - - # Loss mode: only plot loss_avg_mean - metrics_to_plot = ["loss_avg_mean"] - # Use the first type for plotting (or strong if combined) - plot_type = scaling_types[0] - plot_standard_scaling( - df, - output_path, - plot_type, - metrics_to_plot, - args.x_scale, - args.y_scale, - "time", - args.show_run_ids, - ) - # Generate scaling table - generate_scaling_table( - df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types - ) - return - - if args.mode == "detailed": - input_path = resolve_input_path(args.input) - if not input_path.exists(): - print(f"Error: Input file not found: {input_path}") - return - - output_path = args.output or input_path.with_suffix(".png") - - print(f"Loading detailed data from: {input_path}") - try: - df = read_table(input_path) - except Exception as e: - print("Error: Could not read detailed file as parquet or NDJSON") - print(str(e)) - return - print(f"Loaded {len(df)} detailed rows") - plot_detailed_scaling(df, output_path, args.x_scale, args.y_scale) - return - - if args.mode == "combined": - strong_path = resolve_input_path(args.strong_input) - weak_path = resolve_input_path(args.weak_input) - - if not strong_path.exists(): - print(f"Error: Strong scaling input file not found: {strong_path}") - return - if not weak_path.exists(): - print(f"Error: Weak scaling input file not found: {weak_path}") - return - - # Determine output path - if args.output: - output_path = args.output - if output_path.suffix.lower() not in VALID_IMAGE_SUFFIXES: - output_path = output_path.with_suffix(".csv") - else: - # Default output: strong_input_stem_combined_table.csv - output_path = strong_path.with_name( - strong_path.stem + "_combined_table.csv" - ) - - print(f"Loading strong scaling data from: {strong_path}") - try: - strong_df = read_table(strong_path) - except Exception as e: - print("Error: Could not read strong scaling input file") - print(str(e)) - return - print(f"Loaded {len(strong_df)} strong scaling rows") - - print(f"Loading weak scaling data from: {weak_path}") - try: - weak_df = read_table(weak_path) - except Exception as e: - print("Error: Could not read weak scaling input file") - print(str(e)) - return - print(f"Loaded {len(weak_df)} weak scaling rows") - - # Generate combined table - generate_combined_scaling_table( - strong_df, - weak_df, - strong_path, - weak_path, - output_path, - show_run_ids=args.show_run_ids, - ) - return - - raise ValueError(f"Unknown mode: {args.mode}") - - -if __name__ == "__main__": - main() From e982776a12c1c5d0f2bb952a274882ed06ba16aa Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 17:22:20 +0200 Subject: [PATCH 68/76] Plot losses against elapsed training time via --x-axis flag - Replace --x_type ('step'/'reltime') with --x-axis column selector ('samples', 'elapsed_training_time') - Add x_axis param to plot_loss_avg (previously hardcoded num_samples) - Add friendly x-axis labels: 'elapsed training time [s]' when plotting against elapsed_training_time_seconds - plot_lr, plot_loss_avg, plot_loss_per_stream, plot_loss_per_run all now respect x_axis; xlabel is auto-derived from column name - Remove dead x_type parameter from plot_loss_per_stream --- src/weathergen/utils/plot_training.py | 60 +++++++++++++++++---------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 5678cbbc2..1f26d4fb9 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -250,7 +250,7 @@ def plot_lr( plot_dir : Path directory to save the plots x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") """ prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "y", "m"] @@ -284,7 +284,10 @@ def plot_lr( plt.yscale("log") plt.title("learning rate") plt.ylabel("lr") - plt.xlabel(x_axis) + x_label = ( + "elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -311,6 +314,7 @@ def plot_loss_avg( runs_data, runs_active, stage=TRAIN, + x_axis: str = "samples", x_scale_log=False, legend_outside: bool = False, legend_font_size: str = "x-small", @@ -322,10 +326,14 @@ def plot_loss_avg( _fig = plt.figure(figsize=(10, 7), dpi=PLOT_DPI_VALUE) + # x-axis label: "elapsed_training_time" -> "elapsed training time [s]", else "step" + x_label = "elapsed training time [s]" if "elapsed_training_time" in x_axis else "step" + legend_str = [] for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): run_data_stage = run_data.train if stage == TRAIN else run_data.val - x_vals = np.array(run_data_stage["num_samples"]) + x_col = next(filter(lambda c: x_axis in c, run_data_stage.columns)) + x_vals = np.array(run_data_stage[x_col]) y_vals = np.array(run_data_stage["loss_avg_mean"]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) @@ -347,7 +355,7 @@ def plot_loss_avg( plt.xscale("log") plt.title("average loss") plt.ylabel("loss") - plt.xlabel("step") + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -379,10 +387,9 @@ def plot_loss_per_stream( channels: list[str], forecast_steps: list[int], x_axis: str = "samples", - x_type: str = "step", + x_scale_log: bool = False, x_lim: list[float] | None = None, y_lim: list[float] | None = None, - x_scale_log: bool = False, legend_outside: bool = False, legend_font_size: str = "x-small", legend_num_columns: int = 3, @@ -408,9 +415,7 @@ def plot_loss_per_stream( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") - x_type : str - x-axis type (options: "step", "reltime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -525,7 +530,13 @@ def plot_loss_per_stream( title_loss = ".".join(title_col.split(".")[:-1]) plt.title(title_loss + " (" + ", ".join(modes) + ")") plt.ylabel(err) - plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") + # x-axis label: "elapsed_training_time" -> friendly name, else use column as-is + x_label = ( + "elapsed training time [s]" + if "elapsed_training_time" in x_axis + else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -596,7 +607,7 @@ def plot_loss_per_run( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -666,7 +677,10 @@ def plot_loss_per_run( plt.xscale("log") plt.grid(True, which="both", ls="-") plt.ylabel("loss") - plt.xlabel("samples") + x_label = ( + "elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -794,13 +808,13 @@ def plot_train(args=None): help="x-lim for per-stream plots", ) parser.add_argument( - "--x_type", + "--x-axis", "-x", - dest="x_type", - default="step", + dest="x_axis", + default="samples", type=str, - choices=["step", "reltime"], - help="Type of x-axis used in plots. Options: 'step' or 'reltime'", + choices=["samples", "elapsed_training_time"], + help="X-axis column for plots: 'samples' (default) or 'elapsed_training_time'", ) parser.add_argument( "--log-x", @@ -862,9 +876,7 @@ def plot_train(args=None): model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None out_dir = Path(args.output_dir) streams = list(args.streams) - x_types_valid = ["step"] # TODO: add "reltime" support when fix available - if args.x_type not in x_types_valid: - raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}") + x_axis = args.x_axis # Post-processing default logic for config from YAML-file if args.fd is None and args.fy is None: @@ -924,6 +936,7 @@ def plot_train(args=None): runs_data, runs_active, plot_dir=out_dir, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -937,6 +950,7 @@ def plot_train(args=None): runs_data, runs_active, stage=TRAIN, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -953,7 +967,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -972,7 +986,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -991,7 +1005,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, From 367454d22b749d4b11dfe606591be54b65478b7e Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 19:26:47 +0200 Subject: [PATCH 69/76] Fewer changes --- src/weathergen/utils/plot_training.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 1f26d4fb9..059ea57bf 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -250,7 +250,7 @@ def plot_lr( plot_dir : Path directory to save the plots x_axis : str - x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") + x-axis strings used in the column names (options: "samples", "dtime") """ prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "y", "m"] @@ -284,10 +284,7 @@ def plot_lr( plt.yscale("log") plt.title("learning rate") plt.ylabel("lr") - x_label = ( - "elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis - ) - plt.xlabel(x_label) + plt.xlabel(x_axis) plt.tight_layout() _add_legend( legend_str, From d0f851ced04b6e5ca07ee16d70f3feb5367e41dc Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:05:32 +0200 Subject: [PATCH 70/76] rm configs --- config/config_era5_georing.yml | 289 --------------------------- config/streams/era5_georing/era5.yml | 36 ---- config/streams/era5_georing/geos.yml | 77 ------- 3 files changed, 402 deletions(-) delete mode 100644 config/config_era5_georing.yml delete mode 100644 config/streams/era5_georing/era5.yml delete mode 100644 config/streams/era5_georing/geos.yml diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml deleted file mode 100644 index 7ef76620e..000000000 --- a/config/config_era5_georing.yml +++ /dev/null @@ -1,289 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -embed_orientation: "channels" -embed_unembed_mode: "block" -embed_dropout_rate: 0.1 - -ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 -ae_local_num_heads: 8 -ae_local_dropout_rate: 0.1 -ae_local_with_qk_lnorm: True - -ae_local_num_queries: 1 -ae_local_queries_per_cell: False -ae_adapter_num_heads: 16 -ae_adapter_embed: 128 -ae_adapter_with_qk_lnorm: True -ae_adapter_with_residual: True -ae_adapter_dropout_rate: 0.1 - -ae_global_dim_embed: 2048 -ae_global_num_blocks: 4 -ae_global_num_heads: 32 -ae_global_dropout_rate: 0.1 -ae_global_with_qk_lnorm: True -# TODO: switching to < 1 triggers triton-related issues. -# See https://github.com/ecmwf/WeatherGenerator/issues/1050 -ae_global_att_dense_rate: 1.0 -ae_global_block_factor: 64 -ae_global_mlp_hidden_factor: 2 -ae_global_trailing_layer_norm: False - -ae_aggregation_num_blocks: 0 -ae_aggregation_num_heads: 32 -ae_aggregation_dropout_rate: 0.1 -ae_aggregation_with_qk_lnorm: True -ae_aggregation_att_dense_rate: 1.0 -ae_aggregation_block_factor: 64 -ae_aggregation_mlp_hidden_factor: 2 - -decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear -pred_adapter_kv: False -pred_self_attention: True -pred_dyadic_dims: False -pred_mlp_adaln: True -num_class_tokens: 0 -num_register_tokens: 0 - -# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then -# one is training an auto-encoder -fe_num_blocks: 6 -fe_num_heads: 16 -fe_dropout_rate: 0.1 -fe_with_qk_lnorm: True -fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer -fe_impute_latent_noise_std: 0.0 # 1e-4 -# currently fixed to 1.0 (due to limitations with flex_attention and triton) -forecast_att_dense_rate: 1.0 - -healpix_level: 5 - -# Use 2D RoPE instead of traditional global positional encoding -# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) -# When False: uses traditional pe_global positional encoding -rope_2D: False - -with_mixed_precision: True -with_flash_attention: True -compile_model: False -with_fsdp: True -attention_dtype: bf16 -mixed_precision_dtype: bf16 -mlp_norm_eps: 1e-5 -norm_eps: 1e-4 - -latent_noise_kl_weight: 0.0 # 1e-5 -latent_noise_gamma: 2.0 -latent_noise_saturate_encodings: 5 -latent_noise_use_additive_noise: False -latent_noise_deterministic_latents: True - -freeze_modules: "" -load_chkpt: {} - -norm_type: "LayerNorm" -qk_norm_type: null # if null, defaults to norm_type - -##################################### - -streams_directory: "./config/streams/era5_georing/" -streams: ??? - -# type of zarr_store -zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore - -general: - - # mutable parameters - istep: 0 - rank: ??? - world_size: ??? - - # local_rank, - # with_ddp, - # data_path_*, - # model_path, - # run_path, - # path_shared_ - - multiprocessing_method: "fork" - - desc: "" - run_id: ??? - run_history: [] - -# logging frequency in the training loop (in number of batches) -train_logging: - terminal: 16 - metrics: 16 - checkpoint: 256 - log_grad_norms: False - -# parameters for data loading -data_loading : - - num_workers: 12 - rng_seed: ??? - repeat_data_in_mini_epoch : True - - # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with - # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. - # If this happens, you can disable the flag, but performance will drop on GH200. - memory_pinning: True - - -# config for training -training_config: - - # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking"] - - num_mini_epochs: 64 - samples_per_mini_epoch: 4096 - shuffle: True - - start_date: 2014-01-01T00:00 - end_date: 2022-12-31T00:00 - - time_window_step: 01:00:00 - time_window_len: 06:00:00 - - learning_rate_scheduling : - lr_start: 1e-6 - lr_max: 5e-5 - lr_final_decay: 1e-6 - lr_final: 0.0 - num_steps_warmup: 1024 - num_steps_cooldown: 512 - policy_warmup: "cosine" - policy_decay: "constant" - policy_cooldown: "linear" - parallel_scaling_policy: "sqrt" - - optimizer: - grad_clip: 1.0 - weight_decay: 0.1 - log_grad_norms: False - adamw : - # parameters are scaled by number of DDP workers - beta1 : 0.975 - beta2 : 0.9875 - eps : 2e-08 - - losses : { - "physical": { - type: LossPhysical, - loss_fcts: { "mse": { }, }, - }, - } - - model_input: { - "forecasting" : { - # masking strategy: "random", "healpix", "forecast" - masking_strategy: "forecast", - }, - } - - forecast : - time_step: 06:00:00 - num_steps: 3 - offset: 1 - policy: "fixed" - - -# validation config; full validation config is merge of training and validation config -validation_config: - - samples_per_mini_epoch: 1 - shuffle: True - - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T00:00 - - # whether to track the exponential moving average of weights for validation - validate_with_ema: - enabled : True - ema_ramp_up_ratio: 0.09 - ema_halflife_in_thousands: 1e-3 - - # parameters for validation samples that are written to disk - output : { - # number of samples that are written - num_samples: 0, - # write samples in normalized model space - normalized_samples: False, - # output streams to write; default all - streams: null, - } - - # run validation before training starts (mainly for model development) - validate_before_training: False - - -# validation config; full validation config is merge of training and validation config -test_config: - - samples_per_mini_epoch: 1 - shuffle: False - - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T00:00 - - # whether to track the exponential moving average of weights for validation - validate_with_ema: - enabled : True - ema_ramp_up_ratio: 0.09 - ema_halflife_in_thousands: 1e-3 - - # parameters for validation samples that are written to disk - output : { - # number of samples that are written - num_samples: 0, - # write samples in normalized model space - normalized_samples: False, - # output streams to write; default all - streams: null, - } - - # run validation before training starts (mainly for model development) - validate_before_training: False - - -# test config; full test config is merge of validation and test config -# test config is used by default when running inference - -# Tags for experiment tracking -# These tags will be logged in MLFlow along with completed runs for train, eval, val -# The tags are free-form, with the following rules: -# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries -# - tags should not duplicate existing config entries. -# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags -# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) -wgtags: - # The name of the organization of the person running the experiment. - # This may be autofilled in the future. Expected values are lowercase strings - # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" - org: null - # The Github issue corresponding to this run (number such as 1234) - # Github issues are the central point when running experiment and contain - # links to hedgedocs, code branches, pull requests etc. - # It is recommended to associate a run with a Github issue. - issue: null - # The name of the experiment. This is a distinctive codename for the experiment campaign being run. - # This is expected to be the primary tag for comparing experiments in MLFlow, along with the - # issue number. - # Expected values are lowercase strings with no spaces, just underscores: - # Examples: "rollout_ablation_grid" - exp: null - # *** Experiment-specific tags *** - # All extra tags (including lists, dictionaries, etc.) are treated - # as strings by mlflow, so treat all extra tags as simple string key: value pairs. - grid: null diff --git a/config/streams/era5_georing/era5.yml b/config/streams/era5_georing/era5.yml deleted file mode 100644 index 186b8d862..000000000 --- a/config/streams/era5_georing/era5.yml +++ /dev/null @@ -1,36 +0,0 @@ -# (C) Copyright 2024 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -ERA5 : - type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] - stream_id : 0 - source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] - target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] - loss_weight : 1. - location_weight : cosine_latitude - token_size : 8 - tokenize_spacetime : True - max_num_targets: -1 - embed : - net : transformer - num_tokens : 1 - num_heads : 8 - dim_embed : 256 - num_blocks : 2 - embed_target_coords : - net : linear - dim_embed : 512 - target_readout : - num_layers : 2 - num_heads : 4 - # sampling_rate : 0.2 - pred_head : - ens_size : 1 - num_layers : 1 diff --git a/config/streams/era5_georing/geos.yml b/config/streams/era5_georing/geos.yml deleted file mode 100644 index e6cc89442..000000000 --- a/config/streams/era5_georing/geos.yml +++ /dev/null @@ -1,77 +0,0 @@ - -METEOSAT_SEVIRI : - type : obs - stream_id : 2 - # filenames : ['observations-od-ai-0001-201311-202505-msg-combined-seviri-o256-v1.zarr'] - filenames : ['observations-file-2014-2024-seviri-h512-v5.zarr'] - loss_weight : 1.0 - token_size : 128 - tokenize_spacetime : True - max_num_targets: 65536 - embed : - net : transformer - num_tokens : 1 - num_heads : 2 - dim_embed : 256 - num_blocks : 2 - embed_target_coords : - net : linear - dim_embed : 512 - target_readout : - num_layers : 2 - num_heads : 4 - pred_head : - ens_size : 1 - num_layers : 1 - - -GOES_ABI : - type : obs - stream_id : 3 - # filenames : ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] - filenames : ['observations-file-2017-2024-abi-goes16-IR-h512-v2.zarr'] - loss_weight : 1.0 - token_size : 128 - tokenize_spacetime : True - max_num_targets: 65536 - embed : - net : transformer - num_tokens : 1 - num_heads : 2 - dim_embed : 256 - num_blocks : 2 - embed_target_coords : - net : linear - dim_embed : 512 - target_readout : - num_layers : 2 - num_heads : 4 - pred_head : - ens_size : 1 - num_layers : 1 - - -HIMAWARI_AHI : - type : obs - stream_id : 4 - # filenames : ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr'] - filenames : ['observations-file-2015-2022-himawari8-IR-h512-v1.zarr'] - loss_weight : 1.0 - token_size : 128 - tokenize_spacetime : True - max_num_targets: 65536 - embed : - net : transformer - num_tokens : 1 - num_heads : 2 - dim_embed : 256 - num_blocks : 2 - embed_target_coords : - net : linear - dim_embed : 512 - target_readout : - num_layers : 2 - num_heads : 4 - pred_head : - ens_size : 1 - num_layers : 1 From 2f25ee8bbdc4a1ef559c62f38c341b08bc398126 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:05:42 +0200 Subject: [PATCH 71/76] Remove startup time --- src/weathergen/train/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f3a23544d..ff33bc6ce 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -381,12 +381,6 @@ def run( # run validation before training if requested self.validate_before_training() - # Log startup time - if is_root() and t_start is not None: - startup_time = time.time() - t_start - self.train_logger.log_metrics("train", {"startup_time_seconds": startup_time}) - logger.info(f"Startup time: {startup_time:.2f} seconds") - # training loop self.t_training_start = time.time() From 198b542e0f222685b221d21d6b258be34d143231 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:11:54 +0200 Subject: [PATCH 72/76] Remove startup time --- src/weathergen/run_train.py | 4 +--- src/weathergen/train/trainer.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 0f6a5aee9..af9037a1e 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -161,8 +161,6 @@ def run_train(args): Note: All model configurations are set in the function body. """ - t_start = time.time() - cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -189,7 +187,7 @@ def run_train(args): trainer = Trainer(cf.train_logging) try: - trainer.run(cf, devices, t_start=t_start) + trainer.run(cf, devices) except Exception: extype, value, tb = sys.exc_info() traceback.print_exc() diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ff33bc6ce..1ea3e1571 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -243,7 +243,7 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): logger.info(f"Finished inference run with id: {cf.general.run_id}") def run( - self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: float | None = None + self, cf, devices, run_id_contd=None, mini_epoch_contd=None ): # general initalization self.init(cf, devices) From 0de866068d6172b57298479c2919cf90b0d164f8 Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:14:57 +0200 Subject: [PATCH 73/76] Formatting and removed time per epoch --- src/weathergen/run_train.py | 1 + src/weathergen/train/trainer.py | 17 +---------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index af9037a1e..7995b5864 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -161,6 +161,7 @@ def run_train(args): Note: All model configurations are set in the function body. """ + cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 1ea3e1571..a31dd57a7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -242,9 +242,7 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") - def run( - self, cf, devices, run_id_contd=None, mini_epoch_contd=None - ): + def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # general initalization self.init(cf, devices) cf = self.cf @@ -565,19 +563,6 @@ def train(self, mini_epoch): self.dataset.advance() - if is_root(): - total_training_time = time.time() - self.t_training_start - self.train_logger.log_metrics( - "train", - { - "completed_mini_epoch": mini_epoch, - "training_time_after_mini_epoch_seconds": total_training_time, - }, - ) - logger.info( - f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds" - ) - def validate(self, mini_epoch, mode_cfg, batch_size): """ Perform validation / test computation as specified by mode_cfg From 39f80751f1b8793bc14e8a31ec9aa1766898891a Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:16:53 +0200 Subject: [PATCH 74/76] Undo pyproject change --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 24ea62a73..03c9dbb31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,10 +75,6 @@ dev = [ # aarch64: gpu [project.optional-dependencies] -performance = [ - "weathergen-performance", -] - cpu = [ 'torch==2.6.0', ] From 61ad9cfe2488393a25330b842dce7440f3cbf89d Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:20:39 +0200 Subject: [PATCH 75/76] ploting changes wip --- src/weathergen/utils/plot_training.py | 53 ++++++++++++++++----------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 5678cbbc2..059ea57bf 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -311,6 +311,7 @@ def plot_loss_avg( runs_data, runs_active, stage=TRAIN, + x_axis: str = "samples", x_scale_log=False, legend_outside: bool = False, legend_font_size: str = "x-small", @@ -322,10 +323,14 @@ def plot_loss_avg( _fig = plt.figure(figsize=(10, 7), dpi=PLOT_DPI_VALUE) + # x-axis label: "elapsed_training_time" -> "elapsed training time [s]", else "step" + x_label = "elapsed training time [s]" if "elapsed_training_time" in x_axis else "step" + legend_str = [] for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): run_data_stage = run_data.train if stage == TRAIN else run_data.val - x_vals = np.array(run_data_stage["num_samples"]) + x_col = next(filter(lambda c: x_axis in c, run_data_stage.columns)) + x_vals = np.array(run_data_stage[x_col]) y_vals = np.array(run_data_stage["loss_avg_mean"]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) @@ -347,7 +352,7 @@ def plot_loss_avg( plt.xscale("log") plt.title("average loss") plt.ylabel("loss") - plt.xlabel("step") + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -379,10 +384,9 @@ def plot_loss_per_stream( channels: list[str], forecast_steps: list[int], x_axis: str = "samples", - x_type: str = "step", + x_scale_log: bool = False, x_lim: list[float] | None = None, y_lim: list[float] | None = None, - x_scale_log: bool = False, legend_outside: bool = False, legend_font_size: str = "x-small", legend_num_columns: int = 3, @@ -408,9 +412,7 @@ def plot_loss_per_stream( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") - x_type : str - x-axis type (options: "step", "reltime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -525,7 +527,13 @@ def plot_loss_per_stream( title_loss = ".".join(title_col.split(".")[:-1]) plt.title(title_loss + " (" + ", ".join(modes) + ")") plt.ylabel(err) - plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") + # x-axis label: "elapsed_training_time" -> friendly name, else use column as-is + x_label = ( + "elapsed training time [s]" + if "elapsed_training_time" in x_axis + else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -596,7 +604,7 @@ def plot_loss_per_run( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -666,7 +674,10 @@ def plot_loss_per_run( plt.xscale("log") plt.grid(True, which="both", ls="-") plt.ylabel("loss") - plt.xlabel("samples") + x_label = ( + "elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -794,13 +805,13 @@ def plot_train(args=None): help="x-lim for per-stream plots", ) parser.add_argument( - "--x_type", + "--x-axis", "-x", - dest="x_type", - default="step", + dest="x_axis", + default="samples", type=str, - choices=["step", "reltime"], - help="Type of x-axis used in plots. Options: 'step' or 'reltime'", + choices=["samples", "elapsed_training_time"], + help="X-axis column for plots: 'samples' (default) or 'elapsed_training_time'", ) parser.add_argument( "--log-x", @@ -862,9 +873,7 @@ def plot_train(args=None): model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None out_dir = Path(args.output_dir) streams = list(args.streams) - x_types_valid = ["step"] # TODO: add "reltime" support when fix available - if args.x_type not in x_types_valid: - raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}") + x_axis = args.x_axis # Post-processing default logic for config from YAML-file if args.fd is None and args.fy is None: @@ -924,6 +933,7 @@ def plot_train(args=None): runs_data, runs_active, plot_dir=out_dir, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -937,6 +947,7 @@ def plot_train(args=None): runs_data, runs_active, stage=TRAIN, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -953,7 +964,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -972,7 +983,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -991,7 +1002,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, From e099f8487926d0b71c5da946321ae9831ba2a5fb Mon Sep 17 00:00:00 2001 From: Florian Scheidl Date: Fri, 12 Jun 2026 21:21:22 +0200 Subject: [PATCH 76/76] Undo pyproject changes --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03c9dbb31..00103cb8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,7 +228,6 @@ weathergen-common = { workspace = true } weathergen-evaluate = { workspace = true } weathergen-metrics = { workspace = true } weathergen-readers-extra = { workspace = true } -weathergen-performance = { workspace = true } flash-attn = [ @@ -273,6 +272,5 @@ members = [ "packages/readers_extra", # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. - "packages/performance", ]