diff --git a/.agents/skills/commit-conventions/SKILL.md b/.agents/skills/commit-conventions/SKILL.md index 65894854a8..2e356537b9 100644 --- a/.agents/skills/commit-conventions/SKILL.md +++ b/.agents/skills/commit-conventions/SKILL.md @@ -115,7 +115,7 @@ delegation instead of hardcoded model names. Key changes: - Create .opencode/command/ with review-pr, create-pr -- Replace Opus/Sonnet/Haiku with deep/unspecified-high/quick +- Replace hardcoded model routing with platform-native review routing - Add expert subagent consultation patterns ``` diff --git a/.agents/skills/review-pr/MIGRATION.md b/.agents/skills/review-pr/MIGRATION.md new file mode 100644 index 0000000000..473a24376f --- /dev/null +++ b/.agents/skills/review-pr/MIGRATION.md @@ -0,0 +1,347 @@ +# `/review-pr` Refactor & Update Plan (Codex / OpenCode / Claude) + +Last updated: 2026-03-31 Owner: AI workflow maintainers Scope: +`.agents/skills/review-pr/*`, `.opencode/{command,data}/review-pr*`, +`.claude/{commands,data}/review-pr*` + +______________________________________________________________________ + +## 1. Goals and Non-Goals + +### Goals + +1. Rebuild `/review-pr` taxonomy and templates for better coverage of new AReaL modules. +1. Keep template count manageable while preserving depth for high-risk PRs. +1. Eliminate cross-platform drift through a single semantic source of truth. +1. Reduce false positives by prioritizing path-based triggers. + +### Non-Goals + +1. Changing `/review-pr` into a file-mutating workflow (must remain read-only). +1. Adding one-off niche templates for every micro-feature. +1. Rewriting all platform command orchestration logic in one iteration. + +______________________________________________________________________ + +## 2. Current State Summary (Problem Statement) + +### What works + +1. Existing workflow phases are clear (context -> analysis -> planning -> delegated + review -> summary). +1. Platform-specific command wrappers already exist and are functional. +1. Risk-first review behavior (CRITICAL/HIGH/MEDIUM/LOW) is already enforced. + +### What is broken or stale + +1. Domain/signal matrix lags behind repository evolution (`tree_attn`, `vllm_ext`, + service stack, `archon_weight_sync`). +1. Taxonomy is overly fragmented in some areas and missing in others. +1. Three-platform data files are near-duplicates and easy to drift. +1. Trigger rules rely heavily on generic keywords in places where path triggers are + safer. + +______________________________________________________________________ + +## 3. Target Information Architecture + +Use a **two-layer taxonomy**: + +- **L1 Domain (12 domains)**: template-level planning unit. +- **L2 Signals (2-9 per domain)**: detection and checklist specialization unit. + +This avoids both extremes: too many tiny templates or too coarse single-pass review. + +### 3.1 L1 Domains and L2 Signals + +1. **Distributed Runtime** + - L2: `archon_core`, `archon_parallel`, `process_group`, `fsdp_core`, + `megatron_core`, `collectives`, `mesh_dtensor`, `activation_checkpointing`, + `weight_sync` +1. **Model Compute & Attention** + - L2: `tree_attn`, `sdpa_varlen`, `sp_cp_attention_mask`, `triton_kernel`, + `archon_model_family`, `archon_attention_stack`, `archon_moe_modeling` +1. **Inference Backend & Serving** + - L2: `vllm_ext`, `remote_inference_backend`, `request_lifecycle` +1. **Service Orchestration** + - L2: `service_routing_dataflow`, `session_consistency` +1. **Workflow & Trainer Contract** + - L2: `workflow_engine_boundary`, `dataset_surface`, `async_contract`, + `weight_version_contract` +1. **API & Config Compatibility** + - L2: `dataclass_schema`, `cli_compat`, `backward_compat`, + `project_dependency_config` +1. **Numerics & Tensor Semantics** + - L2: `shape_dtype`, `numerical_stability`, `reward_surface`, `compile_dynamo`, + `mixed_precision_fp8` +1. **Checkpoint & Recovery** + - L2: `dcp_consistency`, `optimizer_state`, `resume_compat` +1. **Launcher & Infrastructure** + - L2: `launcher_resource_match`, `scheduler_contract`, `rpc_transport`, + `runtime_image_config` +1. **Low-Risk Hygiene** + - L2: `tests_docs_config`, `logging_import_security`, `project_docs_metadata` +1. **Harness & Agent Infrastructure** + - L2: `skill_definition`, `platform_command_data`, `agent_registry_config` +1. **CI/CD & Release Automation** + - L2: `github_workflow_jobs`, `runner_provisioning`, `release_delivery` + +______________________________________________________________________ + +## 4. Detection Strategy (Path-first) + +### 4.1 Trigger priority + +1. **Path trigger (primary)** +1. **Keyword trigger (secondary, only to refine existing path hits)** +1. **Linkage trigger (auto-add dependent checks)** + +### 4.2 Canonical path mapping (v1) + +1. `areal/models/tree_attn/**` -> Model Compute & Attention (HIGH/MEDIUM by touched + files) +1. `areal/engine/vllm_ext/**` -> Inference Backend & Serving (HIGH) +1. `areal/experimental/models/archon/**` + `areal/experimental/engine/archon_engine.py` + \+ `areal/experimental/engine/archon_checkpoint.py` -> Distributed Runtime + Model + Compute & Attention (CRITICAL/HIGH) +1. `areal/experimental/agent_service/**` -> Service Orchestration (HIGH) +1. `areal/experimental/inference_service/**` -> Service Orchestration (HIGH) +1. `areal/experimental/engine/archon_weight_sync.py` -> Distributed Runtime (CRITICAL) +1. `areal/infra/rpc/**` -> Launcher & Infrastructure + Distributed Runtime (HIGH) +1. `areal/workflow/**`, `areal/trainer/**` -> Workflow & Trainer Contract (HIGH/MEDIUM) +1. `areal/api/**` -> API & Config Compatibility (MEDIUM) +1. `areal/utils/{saver.py,recover.py,async_checkpoint.py}` + `*checkpoint*.py` -> + Checkpoint & Recovery (CRITICAL/HIGH) +1. `areal/experimental/models/archon/activation_checkpoint.py` -> Distributed Runtime + (MEDIUM/HIGH) +1. `areal/experimental/models/archon/compile.py` -> Numerics & Tensor Semantics + + Distributed Runtime (MEDIUM) +1. `pyproject.toml`, `uv.lock` -> API & Config Compatibility (MEDIUM) +1. `Dockerfile`, `.dockerignore` -> Launcher & Infrastructure (HIGH) +1. `.agents/**`, `.claude/**`, `.opencode/**`, `.codex/**`, `AGENTS.md`, `CLAUDE.md` -> + Harness & Agent Infrastructure (MEDIUM/HIGH) +1. `.github/workflows/**` -> CI/CD & Release Automation (HIGH/CRITICAL) +1. `docs/build_all.sh`, `docs/generate_cli_docs.py`, `.github/PULL_REQUEST_TEMPLATE.md`, + `README.md`, `CONTRIBUTING.md` -> Low-Risk Hygiene (LOW/MEDIUM) + +### 4.3 Must-not-regress coverage (from current review-pr) + +The migration must preserve the existing high-risk framework coverage already present +today: + +1. `areal/experimental/models/archon/**` + `areal/experimental/engine/archon_engine.py` + - `areal/experimental/engine/archon_checkpoint.py` -> Distributed Runtime + Model + Compute & Attention +1. `areal/engine/fsdp_utils/**` + `areal/engine/fsdp_engine.py` -> Distributed Runtime +1. `areal/engine/megatron_engine.py` + `areal/engine/megatron_utils/**` -> Distributed + Runtime +1. `areal/trainer/**` -> Workflow & Trainer Contract +1. `areal/reward/**` -> Numerics & Tensor Semantics + Workflow & Trainer Contract +1. `areal/dataset/**` -> Workflow & Trainer Contract + API & Config Compatibility + +### 4.4 Noise control rules + +1. No repo-wide standalone triggering for `current_platform`, `RTensor`, or `fp8`. +1. These keywords only refine severity/checklists after domain has been path-selected. +1. Cap task fan-out: max one primary template per domain, plus one general logic pass. + +______________________________________________________________________ + +## 5. Template Strategy + +Maintain **12 domain templates** (one per L1), each with signal-specific checklists, +plus **1 universal logic pass**. + +### Template inventory (v1) + +1. Distributed Runtime Review +1. Model Compute & Attention Review +1. Inference Backend & Serving Review +1. Service Orchestration Review +1. Workflow & Trainer Contract Review +1. API & Config Compatibility Review +1. Numerics & Tensor Semantics Review +1. Checkpoint & Recovery Review +1. Launcher & Infrastructure Review +1. Low-Risk Hygiene Review +1. Harness & Agent Infrastructure Review +1. CI/CD & Release Automation Review + +### Mandatory universal pass + +- Always run one lightweight **General Logic & Boundary** pass for non-doc PRs. + +______________________________________________________________________ + +## 6. Severity Mapping + +1. **CRITICAL**: distributed invariants, checkpoint correctness, core weight sync + safety. +1. **HIGH**: service orchestration, inference backend runtime, trainer/workflow + contracts. +1. **MEDIUM**: API compatibility, numerics/tensor semantics in bounded scope. +1. **LOW**: docs/tests/config-only hygiene. + +Rule: domain default severity can be escalated by L2 signal combinations (e.g., +`mesh_dtensor` + `weight_sync`). + +______________________________________________________________________ + +## 7. Cross-Domain Linkage Rules (Auto-appended checks) + +1. `tree_attn` -> also append Numerics & Tensor Semantics checks. +1. `archon_core` or `archon_parallel` -> also append Model Compute & Attention checks. +1. `archon_model_family` or `archon_moe_modeling` -> also append Numerics & Tensor + Semantics checks. +1. `reward_surface` -> also append Workflow & Trainer Contract checks. +1. `compile_dynamo` -> also append Distributed Runtime checks. +1. `vllm_ext` -> also append Launcher & Infrastructure checks. +1. Service Orchestration changes -> also append Workflow & Trainer async-contract + checks. +1. `archon_weight_sync` -> also append DTensor + process-group + checkpoint interaction + checks. +1. RPC transport changes -> also append Distributed Runtime synchronization checks. +1. `mixed_precision_fp8` + Distributed Runtime -> also append mesh + weight-sync + compatibility checks. +1. `runtime_image_config` -> also append Inference Backend & Serving checks. +1. `project_dependency_config` -> also append API & Config Compatibility checks. +1. `github_workflow_jobs` or `release_delivery` -> also append Launcher & Infrastructure + checks. +1. `skill_definition` or `platform_command_data` -> also append Low-Risk Hygiene checks. + +______________________________________________________________________ + +## 8. Three-Platform Synchronization Model + +## 8.1 Source-of-truth + +Semantic content is authored only in: + +1. `.agents/skills/review-pr/references/review-pr-domains-and-signals.md` +1. `.agents/skills/review-pr/references/review-pr-templates.md` + +Canonical semantic scope includes: + +1. taxonomy (L1/L2), path/linkage rules, and severity rules +1. checklist bodies and change-analysis vocabulary + +Wrapper-specific scope (non-canonical) includes: + +1. command syntax and frontmatter fields +1. shell snippets, import/include syntax, and runtime-specific routing policies +1. OpenCode-only and Claude-only execution options + +## 8.2 Derived outputs + +Generated/derived files: + +1. `.opencode/data/review-pr-domains-and-signals.md` +1. `.opencode/data/review-pr-templates.md` +1. `.claude/data/review-pr-domains-and-signals.md` +1. `.claude/data/review-pr-templates.md` + +## 8.3 Mechanical sync (definition) + +"Mechanical sync" means deterministic conversion, not manual copy/paste: + +1. Read canonical `.agents` files. +1. Emit OpenCode and Claude data copies with the same generic review-depth vocabulary + (`comprehensive`, `targeted`, `basic`). +1. Keep all runtime routing and platform execution choices in the wrapper command files + only. +1. Preserve section order and checklist content exactly. + +______________________________________________________________________ + +## 9. Implementation Plan + +### Phase A (Foundation) + +1. Replace the old review-pr taxonomy with L1/L2 domain/signal references in canonical + `.agents` references. +1. Rebuild template file into 12 domain templates + universal logic pass. +1. Add linkage rules and fan-out caps. + +### Phase B (Platform sync) + +1. Add sync script at `.agents/skills/review-pr/sync_review_pr_refs.py`. +1. Regenerate `.opencode/data/*` and `.claude/data/*`. +1. Add CI check: fail if derived files differ from regeneration output. + +### Phase C (Command layer alignment) + +1. Update `.opencode/command/review-pr.md` wording to reflect new domains/signals. +1. Update `.claude/commands/review-pr.md` wording similarly. +1. Keep orchestration differences platform-specific (task categories vs model routing). + +### Phase D (Validation) + +1. Run **classification lane** with `/review-pr --quick` against representative PRs + from: + - `tree_attn` + - `vllm_ext` + - `agent_service/inference_service` + - `archon_weight_sync` +1. Measure in classification lane: + - expected detected domains/signals + - expected severity + - false positive rate + - missing high-risk findings +1. Run **delegation lane** with full `/review-pr` (non-quick) on the same fixtures. +1. Measure in delegation lane: + - false positive rate + - missing high-risk findings + - number of spawned review tasks per PR + +______________________________________________________________________ + +## 10. Plan Review (Critical self-review) + +### Strengths + +1. Improves coverage for newly introduced high-risk modules. +1. Shrinks maintenance overhead by moving to domain templates. +1. Prevents cross-platform drift via deterministic derivation. + +### Risks + +1. Over-broad domains can dilute checklist quality. +1. First migration may temporarily shift severity distributions. +1. Command docs can lag behind taxonomy if sync discipline is not enforced. + +### Mitigations + +1. Keep L2 signals explicit and path-anchored. +1. Use golden PR cases for pre/post comparison. +1. Add CI consistency gate for derived platform data files. +1. Require wrapper wording updates in the same migration PR as taxonomy changes. + +______________________________________________________________________ + +## 11. Acceptance Criteria + +1. Canonical references express all 12 domains + L2 signals and linkage rules. +1. Derived OpenCode/Claude data files regenerate with zero manual edits. +1. No mixed old/new label vocabulary remains after regeneration. +1. Representative PRs trigger expected domains: + - `tree_attn` PR -> Model Compute & Attention (+ Numerics linkage) + - `vllm_ext` PR -> Inference Backend & Serving (+ Launcher linkage) + - service-stack PR -> Service Orchestration (+ Workflow linkage) + - `archon_weight_sync` PR -> Distributed Runtime (CRITICAL) +1. Legacy high-risk coverage remains intact (Archon/FSDP/Megatron/Reward/Dataset). +1. Task fan-out remains bounded (no uncontrolled template explosion). + +______________________________________________________________________ + +## 12. Out of Scope for v1 (defer) + +1. Fully unifying command orchestration syntax across OpenCode and Claude. +1. Creating standalone domain types for every backend keyword. +1. Automatically posting review comments to GitHub. + +______________________________________________________________________ + +## 13. Recommended Next Action + +Land **Phase A + B + C** in one migration PR (taxonomy + sync tooling + wrapper +alignment), then run **Phase D** validation immediately using fixed fixtures. diff --git a/.agents/skills/review-pr/SKILL.md b/.agents/skills/review-pr/SKILL.md index aad0d2a642..671823676d 100644 --- a/.agents/skills/review-pr/SKILL.md +++ b/.agents/skills/review-pr/SKILL.md @@ -22,7 +22,7 @@ PR. ## Reference Files -- `references/review-pr-change-types.md` +- `references/review-pr-domains-and-signals.md` - `references/review-pr-templates.md` ## Workflow @@ -36,10 +36,10 @@ PR. ### Phase 2: Change analysis -1. Classify changed files using `references/review-pr-change-types.md`. +1. Classify changed files using `references/review-pr-domains-and-signals.md`. 1. Determine the highest overall risk level: `CRITICAL`, `HIGH`, `MEDIUM`, or `LOW`. 1. Build a `CHANGE_ANALYSIS_REPORT` that lists: - - detected change types + - detected domains/signals - risk level - affected files - related frameworks @@ -96,7 +96,8 @@ Use this structure: ```markdown CHANGE_ANALYSIS_REPORT: -- detected_types: [...] +- detected_domains: [...] +- detected_signals: [...] - risk_level: ... - affected_files: [...] - related_frameworks: [...] diff --git a/.agents/skills/review-pr/references/review-pr-change-types.md b/.agents/skills/review-pr/references/review-pr-change-types.md deleted file mode 100644 index 81bb2d165b..0000000000 --- a/.agents/skills/review-pr/references/review-pr-change-types.md +++ /dev/null @@ -1,149 +0,0 @@ -# PR Review: Change Type Detection Reference - -This file contains the change type detection tables for PR review. Referenced by: -`.agents/skills/review-pr/SKILL.md` - -______________________________________________________________________ - -## CRITICAL Level (Requires `deep` category) - -| Change Type | File Path Pattern | Code Pattern | -| ---------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------- | -| **ARCHON_CORE** | `areal/experimental/models/archon/` | - | -| **ARCHON_PARALLEL** | `parallel_dims.py` | `ArchonParallelDims`, `_build_mesh`, `DeviceMesh` | -| **ARCHON_MOE** | `archon/moe/` | `router`, `grouped_experts`, `TokenReorderer`, `grouped_mm` | -| **ARCHON_PARALLELIZE** | `qwen*/infra/parallelize.py` | `apply_moe_ep_tp`, `apply_tp`, `apply_cp` | -| **ARCHON_ENGINE** | `areal/experimental/engine/archon_engine.py` | `ArchonEngine` | -| **FSDP_CORE** | `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` | `FSDP`, `FullyShardedDataParallel`, `fully_shard` | -| **MEGATRON_CORE** | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine` | -| **DCP_CHECKPOINT** | - | `DCP`, `DistributedCheckpoint`, `dcp.save`, `dcp.load` | - -## HIGH Level (Recommend `deep` category) - -| Change Type | File Path Pattern | Code Pattern | -| --------------------- | ----------------- | -------------------------------------------------------------------------------- | -| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` | -| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` | -| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` | -| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` | -| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` | -| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` | -| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` | -| **TRAINER_CORE** | `areal/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` | - -## MEDIUM Level (Use `unspecified-high` category) - -| Change Type | File Path Pattern | Code Pattern | -| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ | -| **TENSOR_OPS** | - | `.view(`, `.reshape(`, `dtype=`, `.detach()`, `no_grad`, `.contiguous()` | -| **NUMERICAL** | - | `log(`, `softmax`, `cross_entropy`, `eps=`, `.clamp(`, `nan`, `inf` | -| **WORKFLOW_ENGINE** | `areal/workflow/`, `areal/engine/` | `arun_episode`, `agenerate`, `RolloutWorkflow` | -| **API_CONFIG** | `areal/api/` | `@dataclass`, `__post_init__`, `field(` | -| **COMPILE** | - | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | -| **ACTIVATION_CKPT** | `activation_checkpoint.py` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | -| **CHECKPOINT_RECOVERY** | `areal/utils/saver.py`, `areal/utils/recover.py`, `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/async_checkpoint.py` | `state_dict`, `load_state_dict`, `checkpoint`, `AsyncCheckpointManager` | -| **REWARD** | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | -| **DATASET** | `areal/dataset/` | `get_*_dataset`, `DataLoader`, `IterableDataset` | -| **LAUNCHER_SCHEDULER** | `areal/infra/launcher/`, `areal/infra/scheduler/`, `areal/infra/rpc/` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` | -| **ATTENTION** | `attention/`, `attention/sdpa.py`, `attention/varlen.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` | - -## LOW Level (Use `quick` category) - -| Change Type | File Path Pattern | Code Pattern | -| --------------- | ---------------------------- | ------------ | -| **TESTS** | `tests/`, `*_test.py` | - | -| **DOCS** | `docs/`, `*.md` | - | -| **CONFIG_ONLY** | `*.yaml`, `*.json`, `*.toml` | - | - -______________________________________________________________________ - -## Framework-Specific Risk Identification - -### Archon Risks (When ARCHON\_\* types detected) - -- **Device mesh dimension mismatch**: mesh dimension names don't correspond to placement -- **EP constraint violation**: `ep_size` must divide `num_experts`, and - `dp_shard * cp * (tp if etp==1 else 1) % ep == 0` -- **ETP configuration error**: `etp` must be 1 or equal to `tp` -- **Token alignment error**: `grouped_mm` requires token count aligned to 8/16/32 -- **All-to-All split/combine mismatch**: dispatch and combine split configs inconsistent -- **DTensor/Local tensor conversion missing**: need `.to_local()` or - `DTensor.from_local()` -- **torch.compile dynamic shape marking missing**: missing `mark_dynamic` calls -- **AC application order error**: must be after TP/CP, before FSDP -- **Ulysses SP configuration**: CP uses Ulysses implementation, not Ring Attention -- **dp_shard_mod_ep mesh usage**: MoE experts must use `dp_shard_mod_ep` mesh for FSDP - -### FSDP Risks (When FSDP\_\* types detected) - -- **Shard/reshard timing error**: premature or delayed sharding operations -- **EP mesh interaction issue**: should use `dp_shard_mod_ep` not `dp_shard` for MoE -- **Gradient divide factor calculation**: incorrect relationship with world size -- **State dict save/load inconsistency**: mixing sharded vs full modes -- **Optimizer state handling**: aggregation and distribution of sharded state -- **DCP compatibility**: ensure DCP save/load works with FSDP2 - -### Megatron Risks (When MEGATRON\_\* types detected) - -- **Pipeline stage splitting error**: unbalanced layer distribution -- **Micro-batch scheduling issues**: pipeline bubble handling -- **Weight sharding and sync**: tied weights handling -- **AC interaction**: checkpointing under pipeline parallelism - -### DCP/Checkpoint Risks (When DCP_CHECKPOINT or CHECKPOINT_RECOVERY detected) - -- **Distributed checkpoint consistency**: all ranks must participate in save/load -- **State dict key mismatch**: keys must match between save and load -- **Optimizer state compatibility**: ensure optimizer state is correctly - sharded/gathered -- **Version compatibility**: old checkpoints should load in new code -- **Storage backend compatibility**: ensure storage backend (filesystem, S3, etc.) is - compatible - -______________________________________________________________________ - -## Risk Linkage Rules - -| Detected Change | Auto-Linked Review | -| --------------------------- | ------------------------------------------------------ | -| EP changes | FSDP interaction check, dp_shard_mod_ep mesh check | -| ETP changes | TP + EP combination check, mesh dimension check | -| Megatron changes | Pipeline + AC check | -| Distributed comm changes | Process group + sync check | -| SEQUENCE_PARALLEL changes | TP combination + Attention mask check, Ulysses check | -| CHECKPOINT_RECOVERY changes | FSDP state dict check, DCP compatibility check | -| DCP_CHECKPOINT changes | FSDP2 integration check, distributed consistency check | -| COMPILE changes | Performance regression + FSDP/TP interaction check | -| REWARD changes | Workflow interaction check, AsyncRewardWrapper check | -| LAUNCHER_SCHEDULER changes | Resource config + parallel strategy match check | -| TRAINER_CORE changes | Engine lifecycle + workflow integration check | -| ARCHON_ENGINE changes | DCP checkpoint + parallel dims check | - -______________________________________________________________________ - -## Core Framework Paths (Requires `deep` category) - -**Archon Core**: - -- `areal/experimental/models/archon/` (entire directory) -- `areal/experimental/engine/archon_engine.py` -- `areal/experimental/engine/archon_checkpoint.py` - -**FSDP Core**: - -- `areal/engine/fsdp_utils/` -- `areal/engine/fsdp_engine.py` - -**Megatron Core**: - -- `areal/engine/megatron_engine.py` -- `areal/engine/megatron_utils/megatron.py` -- `areal/engine/megatron_utils/checkpointer.py` - -**Trainer Core**: - -- `areal/trainer/` - -**Training Engine Core** (excludes FSDP/Megatron which have their own categories): - -- `areal/engine/` (except `fsdp_engine.py`, `megatron_engine.py`) diff --git a/.agents/skills/review-pr/references/review-pr-domains-and-signals.md b/.agents/skills/review-pr/references/review-pr-domains-and-signals.md new file mode 100644 index 0000000000..a835419ef9 --- /dev/null +++ b/.agents/skills/review-pr/references/review-pr-domains-and-signals.md @@ -0,0 +1,253 @@ +# PR Review: Domain & Signal Detection Reference + +This file contains the canonical change-domain and signal detection tables for PR +review. Referenced by: `.agents/skills/review-pr/SKILL.md` + +______________________________________________________________________ + +## Severity-to-Review-Depth Mapping + +- **CRITICAL**: use `comprehensive` review depth +- **HIGH**: use `comprehensive` review depth +- **MEDIUM**: use `targeted` review depth +- **LOW**: use `basic` review depth + +______________________________________________________________________ + +## L1 Domains and L2 Signals + +## Domain 1: Distributed Runtime (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `archon_core` | `areal/experimental/engine/archon_engine.py`, `areal/experimental/engine/archon_checkpoint.py` | `ArchonEngine`, `ArchonCheckpointManager`, `archon` | +| `archon_parallel` | `areal/experimental/models/archon/`, `parallel_dims.py`, `parallelize.py` | `ArchonParallelDims`, `_build_mesh`, `apply_moe_ep_tp`, `apply_tp`, `apply_cp`, `ExpertTensorParallel`, `etp`, `parallelize_module`, `ColwiseParallel`, `RowwiseParallel` | +| `process_group` | `areal/engine/fsdp_utils/`, `areal/engine/megatron_utils/`, `areal/experimental/engine/` | `new_group`, `ProcessGroup`, `dist.get_rank(` | +| `fsdp_core` | `areal/engine/fsdp_engine.py`, `areal/engine/fsdp_utils/` | `FSDP`, `fully_shard`, `FullyShardedDataParallel` | +| `megatron_core` | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine`, `pipeline`, `micro-batch` | +| `collectives` | `areal/engine/`, `areal/infra/rpc/` | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `broadcast`, `barrier` | +| `mesh_dtensor` | `areal/experimental/models/archon/`, `areal/engine/fsdp_utils/` | `DeviceMesh`, `DTensor`, `Shard(`, `Replicate(`, `distribute_tensor` | +| `activation_checkpointing` | `areal/experimental/models/archon/activation_checkpoint.py`, `areal/models/`, `areal/engine/` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | +| `weight_sync` | `areal/experimental/engine/archon_weight_sync.py`, `areal/api/engine_api.py`, `areal/engine/` | `WeightUpdateMeta`, `set_version`, `update_weights` | + +## Domain 2: Model Compute & Attention (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| `tree_attn` | `areal/models/tree_attn/` | `TreeAttention`, `tree_attn`, `TreeNode`, `tree` | +| `sdpa_varlen` | `attention/sdpa.py`, `attention/varlen.py`, `areal/models/tree_attn/` | `sdpa`, `flash_attn`, `varlen`, `causal_mask` | +| `sp_cp_attention_mask` | `areal/models/tree_attn/`, `areal/experimental/models/archon/attention/` | `SequenceParallel`, `context_parallel`, `mask` | +| `triton_kernel` | `areal/models/tree_attn/triton_kernel.py` | `triton`, `kernel`, `autotune` | +| `archon_model_family` | `areal/experimental/models/archon/model_spec.py`, `areal/experimental/models/archon/qwen*/` | `ModelSpec`, `register_model_spec`, `supported_model_types`, `state_dict_adapter`, `rope` | +| `archon_attention_stack` | `areal/experimental/models/archon/attention/`, `areal/experimental/models/archon/ulysses.py` | `ulysses_slice_inputs`, `ulysses_gather_output`, `gather_seq_scatter_heads`, `sdpa`, `varlen` | +| `archon_moe_modeling` | `areal/experimental/models/archon/moe/`, `areal/experimental/models/archon/expert_parallel.py`, `areal/experimental/models/archon/moe_weight_converter.py` | `TokenChoiceTopKRouter`, `RouterGateLinear`, `GroupedExperts`, `MoEWeightConverter`, `expert_parallel` | + +## Domain 3: Inference Backend & Serving (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | -------------------------------------------------------------- | ---------------------------------------------------------------- | +| `vllm_ext` | `areal/engine/vllm_ext/` | `areal_vllm_server`, `vllm_worker_extension`, `pause_generation` | +| `remote_inference_backend` | `areal/engine/vllm_remote.py`, `areal/engine/sglang_remote.py` | `vllm`, `sglang`, `OpenAI`, `request`, `response` | +| `request_lifecycle` | `areal/engine/`, `areal/infra/launcher/` | `enqueue`, `dequeue`, `cancel`, `timeout` | + +## Domain 4: Service Orchestration (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `service_routing_dataflow` | `areal/experimental/agent_service/gateway/`, `areal/experimental/agent_service/router/`, `areal/experimental/inference_service/data_proxy/`, `areal/experimental/inference_service/controller/` | `route`, `gateway`, `router`, `DataProxy`, `controller`, `batch` | +| `session_consistency` | `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `session`, `affinity`, `history`, `state` | + +## Domain 5: Workflow & Trainer Contract (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------- | --------------------------------------------------- | +| `workflow_engine_boundary` | `areal/workflow/`, `areal/trainer/`, `areal/engine/` | `RolloutWorkflow`, `arun_episode`, `agenerate` | +| `dataset_surface` | `areal/dataset/` | `DataLoader`, `IterableDataset`, `get_*_dataset` | +| `async_contract` | `areal/workflow/`, `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `async def`, `await`, `aiofiles`, `asyncio` | +| `weight_version_contract` | `areal/api/engine_api.py`, `areal/workflow/`, `areal/trainer/` | `WeightUpdateMeta`, `set_version`, `weight version` | + +## Domain 6: API & Config Compatibility (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------------- | ------------------------------------- | -------------------------------------------------------------------------------- | +| `dataclass_schema` | `areal/api/` | `@dataclass`, `field(`, `__post_init__` | +| `cli_compat` | `areal/api/cli_args.py` | `Literal`, `help`, `default` | +| `backward_compat` | `areal/api/`, `areal/infra/launcher/` | `deprecated`, `compat`, `version` | +| `project_dependency_config` | `pyproject.toml`, `uv.lock` | `requires-python`, `dependencies`, `optional-dependencies`, `build-system`, `uv` | + +## Domain 7: Numerics & Tensor Semantics (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------- | +| `shape_dtype` | `areal/engine/`, `areal/models/`, `areal/trainer/` | `.view(`, `.reshape(`, `dtype=`, `.contiguous(` | +| `numerical_stability` | `areal/engine/`, `areal/reward/`, `areal/utils/functional/` | `log(`, `softmax`, `eps=`, `.clamp(`, `nan`, `inf` | +| `reward_surface` | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | +| `compile_dynamo` | `areal/experimental/models/archon/compile.py`, `areal/models/`, `areal/engine/` | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | +| `mixed_precision_fp8` | `areal/engine/megatron_utils/fp8/`, `areal/experimental/models/archon/` | `fp8`, `bf16`, `fp16`, `mixed precision` | + +## Domain 8: Checkpoint & Recovery (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------- | ------------------------------------------------------------------- | ----------------------------------------------- | +| `dcp_consistency` | `areal/utils/async_checkpoint.py`, `areal/engine/**/checkpoint*.py` | `dcp.save`, `dcp.load`, `DistributedCheckpoint` | +| `optimizer_state` | `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/saver.py` | `optimizer state`, `state_dict` | +| `resume_compat` | `areal/utils/recover.py`, `areal/utils/saver.py` | `resume`, `load_state_dict`, `migration` | + +## Domain 9: Launcher & Infrastructure (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ---------------------------------------------------------------------- | --------------------------------------------------------- | +| `launcher_resource_match` | `areal/infra/launcher/` | `LaunchConfig`, `RayLauncher`, `SlurmLauncher` | +| `scheduler_contract` | `areal/infra/scheduler/`, `areal/scheduler/` | `Scheduler`, `placement`, `resource` | +| `rpc_transport` | `areal/infra/rpc/`, `areal/experimental/inference_service/data_proxy/` | `RTensor`, `serialize`, `rpc`, `fetch` | +| `runtime_image_config` | `Dockerfile`, `.dockerignore` | `FROM`, `ARG`, `RUN`, `ENV`, `COPY`, `uv sync`, `VARIANT` | + +## Domain 10: Low-Risk Hygiene (LOW) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ | +| `tests_docs_config` | `tests/`, `docs/`, `*.md`, `*.yaml`, `*.json`, `*.toml` | - | +| `logging_import_security` | `areal/`, `examples/` | `getLogger`, `print(`, `import *`, `api_key`, `token`, `password` | +| `project_docs_metadata` | `docs/build_all.sh`, `docs/generate_cli_docs.py`, `docs/en/`, `docs/zh/`, `README.md`, `CONTRIBUTING.md`, `.github/PULL_REQUEST_TEMPLATE.md`, `.github/ISSUE_TEMPLATE/` | `jupyter-book`, `generate_cli_docs`, `build_all`, `_build`, `checklist`, `template`, `contributing`, `usage` | + +## Domain 11: Harness & Agent Infrastructure (MEDIUM/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------------- | ----------------------------------------------------------------------------- | ------------------------------------------------------------ | +| `skill_definition` | `.agents/skills/**/SKILL.md`, `.agents/skills/**/references/` | `description:`, `## Workflow`, `## Reference Files`, `skill` | +| `platform_command_data` | `.claude/commands/`, `.claude/data/`, `.opencode/command/`, `.opencode/data/` | `@.`, `/review-pr`, `/create-pr`, `data/`, `task(` | +| `agent_registry_config` | `.codex/config.toml`, `.codex/agents/`, `AGENTS.md`, `CLAUDE.md` | `agents`, `skills`, `registry`, `subagent`, `config.toml` | + +## Domain 12: CI/CD & Release Automation (HIGH/CRITICAL) + +| L2 Signal | File Path Pattern | Code Pattern | +| ---------------------- | -------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------- | +| `github_workflow_jobs` | `.github/workflows/*.yml` | `jobs:`, `runs-on:`, `needs:`, `if:`, `workflow_dispatch` | +| `runner_provisioning` | `.github/workflows/bake-gcp-image.yml`, `.github/workflows/runner-heartbeat.yml` | `gcp`, `runner`, `image`, `heartbeat` | +| `release_delivery` | `.github/workflows/build-docker-image.yml`, `.github/workflows/tag-release-image.yml`, `.github/workflows/deploy-docs.yml` | `docker`, `tag`, `release`, `pages`, `publish` | + +______________________________________________________________________ + +## Must-Not-Regress Core Coverage + +The refactor must preserve these existing review surfaces: + +- Archon core: `areal/experimental/models/archon/`, + `areal/experimental/engine/archon_engine.py` +- FSDP core: `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` +- Megatron core: `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` +- Reward: `areal/reward/` +- Dataset: `areal/dataset/` +- Trainer: `areal/trainer/` +- Harness: `.agents/`, `.claude/`, `.opencode/`, `.codex/` +- CI/CD and release: `.github/workflows/`, `Dockerfile`, `pyproject.toml` + +______________________________________________________________________ + +## Cross-Domain Linkage Rules + +| Detected Signal | Auto-Linked Review | +| ---------------------------------------------- | --------------------------------------------------- | +| `archon_core` or `archon_parallel` | Model Compute & Attention checks | +| `archon_model_family` or `archon_moe_modeling` | Numerics & Tensor Semantics checks | +| `tree_attn` | Numerics & Tensor Semantics checks | +| `reward_surface` | Workflow & Trainer Contract checks | +| `compile_dynamo` | Distributed Runtime checks | +| `vllm_ext` | Launcher & Infrastructure checks | +| `service_routing_dataflow` | Workflow & Trainer async-contract checks | +| `weight_sync` | DTensor/process-group/checkpoint interaction checks | +| `rpc_transport` | Distributed Runtime synchronization checks | +| `mixed_precision_fp8` + Distributed Runtime | mesh + weight-sync compatibility checks | +| `runtime_image_config` | Inference Backend & Serving checks | +| `project_dependency_config` | API & Config Compatibility checks | +| `github_workflow_jobs` or `release_delivery` | Launcher & Infrastructure checks | +| `skill_definition` or `platform_command_data` | Low-Risk Hygiene checks | + +______________________________________________________________________ + +## Risk Identification Guidance + +### Distributed Runtime Risks + +- Archon mesh construction or parallel-dims mismatch +- EP/TP/CP application order errors in Archon parallelization +- Activation checkpoint placement violating TP/CP/FSDP ordering assumptions +- Archon engine lifecycle drift around distributed setup and checkpoint boundaries +- Collective call order mismatch across ranks +- Wrong process-group scope in rank-sensitive logic +- Mesh dimension mismatch and invalid DTensor placement +- Weight version drift between rollout and training workers + +### Model Compute & Attention Risks + +- Attention mask inconsistency under TP/SP/CP paths +- Tree attention index/routing mismatch +- Archon model-family registration or per-family wiring drift +- Archon MoE router/expert behavior diverging from weight-conversion expectations +- Archon Ulysses slicing/gather semantics mismatching attention layout assumptions +- Kernel assumptions violating dtype/shape invariants +- Sequence packing alignment errors + +### Service Orchestration Risks + +- Session affinity or history drift across gateway/router/data proxy +- Async message handling holes and dropped tasks +- Controller/worker lifecycle desynchronization + +### Inference Backend & Serving Risks + +- Request lifecycle inconsistencies (enqueue/cancel/timeout) +- Worker state transitions leaving requests stranded +- Backend extension hooks drifting from runtime expectations + +### Workflow & Trainer Contract Risks + +- Workflow-engine contract drift across async boundaries +- Weight version handshake mismatch between rollout and train +- Trainer lifecycle transition inconsistencies + +### API & Config Compatibility Risks + +- Breaking config/schema changes without migration path +- Dataclass or CLI default changes altering behavior silently +- Missing validation for newly introduced fields +- Dependency or build-system pin changes breaking supported environments + +### Numerics & Tensor Semantics Risks + +- Silent shape/dtype mismatch under distributed paths +- Unstable numerical operations in loss/reward logic +- torch.compile or dynamo guard changes breaking graph assumptions +- Mixed-precision interaction regressions + +### Checkpoint & Recovery Risks + +- Partial-rank checkpoint participation +- Incompatible state key evolution +- Resume path breaking optimizer/model synchronization + +### Launcher & Infrastructure Risks + +- Resource assignment mismatching parallel strategy assumptions +- RPC transport metadata loss (shape/dtype/device) +- Startup/shutdown ordering races across processes +- Runtime image or build-arg drift from supported inference/training variants + +### Low-Risk Hygiene Risks + +- Docs/config drift from actual runtime behavior +- Logging or import hygiene regressions +- Sensitive data exposure in logs or config +- Documentation/build scripts or project templates drifting from actual workflow + +### Harness & Agent Infrastructure Risks + +- Skill and command docs drifting from actual platform behavior +- Cross-platform data files falling out of sync with canonical references +- Agent registry/config changes breaking expert routing or command discovery + +### CI/CD & Release Automation Risks + +- Workflow trigger or job dependency changes skipping required validation +- Runner provisioning drift causing flaky or non-reproducible CI +- Release or docs deployment jobs publishing the wrong artifacts diff --git a/.agents/skills/review-pr/references/review-pr-templates.md b/.agents/skills/review-pr/references/review-pr-templates.md index 1ca4717e7e..f33a7a76f5 100644 --- a/.agents/skills/review-pr/references/review-pr-templates.md +++ b/.agents/skills/review-pr/references/review-pr-templates.md @@ -1,424 +1,320 @@ -# PR Review: Task Templates Reference +# PR Review: Domain Templates Reference -This file contains the review task templates for PR review. Referenced by: +This file contains canonical domain templates for PR review. Referenced by: `.agents/skills/review-pr/SKILL.md` ______________________________________________________________________ -## Framework-Specific Review Task Templates +## Template Selection Rules -### Archon Tasks \[deep\] +1. Select templates by detected L1 domains and L2 signals. +1. Use at most one primary template per domain. +1. Always include **General Logic & Boundary** for non-doc/config-only PRs. +1. Apply cross-domain linkage checks from `review-pr-domains-and-signals.md`. -**Task: Archon EP/ETP Strategy Correctness Review** +______________________________________________________________________ -``` -Checklist: -- ExpertParallel, TensorParallel, ExpertTensorParallel placement implementation -- Placement dimension matching with mesh dimensions -- Placement list length in _partition_fn -- all_to_all communication autograd compatibility -- ReordererSequenceParallel token index conversion -``` +## Universal Template -**Task: ArchonParallelDims Configuration Validation** +### General Logic & Boundary ``` +Applicable: Any non-doc/config-only change Checklist: -- ETP constraint: etp=1 (TP borrowed by EP) vs etp=tp (independent TP) logic -- Mesh construction: _build_mesh_with_ep() dimension order and names -- EP/TP/CP combination validity verification -- dp_shard * cp * (tp if etp==1 else 1) % ep == 0 constraint +- Boundary condition correctness (empty inputs, singleton, max-size) +- Conditional logic correctness (branch inversion, short-circuit mistakes) +- Error-path behavior (exceptions propagated with actionable context) +- Return-value consistency across code paths +- No newly introduced hidden behavior changes ``` -**Task: MoE Layer Implementation Correctness** - -``` -Checklist: -- TokenReorderer and router separation correctness -- grouped_mm token alignment (8/16/32) -- Expert weight 3D tensor sharding -- Load balancing loss calculation -``` - -**Task: Model Parallelization Application Order** - -``` -Checklist: -- apply_moe_ep_tp() strategy selection logic -- FSDP wrap order (EP -> TP -> AC -> FSDP) -- torch.compile dynamic shape marking -- Explicit prefetching configuration -``` - -### FSDP Tasks \[deep / unspecified-high\] - -**Task: FSDP Core Correctness \[deep\]** - -``` -Checklist: -- Shard/reshard operation timing and correctness -- ShardedTensor and DTensor conversion -- Mixed precision (param_dtype vs reduce_dtype) -``` +______________________________________________________________________ -**Task: FSDP Interaction with Other Parallel Strategies \[deep\]** +## Domain 1 Template: Distributed Runtime Review \[comprehensive\] ``` +Applicable signals: archon_core, archon_parallel, process_group, fsdp_core, megatron_core, collectives, mesh_dtensor, activation_checkpointing, weight_sync Checklist: -- FSDP must be applied after TP/CP/EP -- Use dp_shard_mod_ep mesh in EP scenarios -- Gradient divide factor relationship with world size +- Archon engine and checkpoint lifecycle remain aligned with distributed runtime assumptions +- FSDP and Megatron engine invariants still match process-group, sharding, and pipeline assumptions +- Archon parallel-dims and mesh construction still match downstream placement logic +- Process-group creation/usage/cleanup is rank-consistent +- Collective operations are called by all required ranks in consistent order +- DeviceMesh dimensions and DTensor placements are correct for each path +- Activation checkpoint placement remains compatible with parallel and sharding order requirements +- Local/global tensor conversion boundaries are explicit and correct +- Weight version propagation and update ordering are deterministic +- No debug-only barriers left in hot path ``` -**Task: FSDP State Management \[unspecified-high\]** +## Domain 2 Template: Model Compute & Attention Review \[comprehensive\] ``` +Applicable signals: tree_attn, sdpa_varlen, sp_cp_attention_mask, triton_kernel, archon_model_family, archon_attention_stack, archon_moe_modeling Checklist: -- state_dict save/load sharded vs full mode -- Optimizer state sharding and aggregation -- Checkpoint compatibility +- Attention mask semantics preserved under TP/SP/CP +- Archon model-family registration and per-family wiring remain internally consistent +- Archon attention/Ulysses slicing and gather paths preserve layout assumptions +- Archon MoE router, grouped experts, and weight-conversion interfaces remain aligned +- Tree attention index/order invariants are maintained +- Kernel assumptions on dtype/shape/contiguity are satisfied +- No silent behavior change in sequence packing/unpacking +- Tensor layouts remain compatible with downstream modules ``` -### Megatron Tasks \[deep\] - -**Task: Pipeline Parallelism Correctness** +## Domain 3 Template: Inference Backend & Serving Review \[comprehensive\] ``` +Applicable signals: vllm_ext, remote_inference_backend, request_lifecycle Checklist: -- Stage splitting correctness and balance -- Micro-batch scheduling -- Pipeline flush and bubble handling +- Request lifecycle (enqueue, execution, cancellation, timeout) is coherent +- Worker state transitions are safe under concurrency +- Backend-specific extension points stay API-compatible +- Error handling does not strand in-flight requests +- Versioning/weight-update interactions are explicit and safe ``` -**Task: Megatron Model Sharding** +## Domain 4 Template: Service Orchestration Review \[comprehensive\] ``` +Applicable signals: service_routing_dataflow, session_consistency Checklist: -- Weight sharding and synchronization -- Tied weights handling -- Embedding/output layer parallel strategy +- Gateway/router/data-proxy routing rules are deterministic +- Session affinity and history consistency are preserved +- Controller/worker coordination has no lost-update window +- Async boundaries avoid blocking operations in critical paths +- Failure/retry behavior does not duplicate or drop work ``` -### DCP/Checkpoint Tasks \[deep\] - -**Task: Distributed Checkpoint Correctness** +## Domain 5 Template: Workflow & Trainer Contract Review \[comprehensive\] ``` +Applicable signals: workflow_engine_boundary, dataset_surface, async_contract, weight_version_contract Checklist: -- All ranks participate in DCP save/load operations -- State dict keys match between save and load -- No tensor shape/dtype mismatches -- Storage backend compatibility (filesystem, S3) -- Checkpoint versioning and migration +- RolloutWorkflow and Engine interfaces remain contract-compatible +- Dataset/output structure still matches workflow and trainer consumption expectations +- Async flow uses await consistently and avoids sync I/O in async paths +- Weight update/version handshake is preserved end-to-end +- Trainer lifecycle transitions are valid for all execution branches +- Call ordering assumptions across trainer/workflow/engine are unchanged or justified ``` -**Task: FSDP2 + DCP Integration** +## Domain 6 Template: API & Config Compatibility Review \[targeted\] ``` +Applicable signals: dataclass_schema, cli_compat, backward_compat, project_dependency_config Checklist: -- FSDP2 state dict options (full vs sharded) -- Optimizer state handling with DCP -- Async checkpointing correctness -- Checkpoint resumption logic +- Public API signature and default value changes are intentional and compatible +- Dataclass validation remains complete and informative +- CLI options preserve expected compatibility semantics +- New fields include safe defaults or explicit migration handling +- Breaking changes are documented and scoped +- Dependency and build-system changes remain compatible with supported environments ``` -### Trainer Tasks \[deep\] - -**Task: Trainer Core Logic** +## Domain 7 Template: Numerics & Tensor Semantics Review \[targeted\] ``` +Applicable signals: shape_dtype, numerical_stability, reward_surface, compile_dynamo, mixed_precision_fp8 Checklist: -- PPOTrainer/SFTTrainer initialization correctness -- Workflow registration and invocation -- Engine lifecycle management -- Distributed training coordination +- Tensor shape/dtype transitions are explicit and internally consistent +- Numerical stability is protected (log/division/softmax/clamp paths) +- Reward-side numerical behavior remains consistent with workflow consumption expectations +- torch.compile / dynamo assumptions still hold for dynamic shapes and distributed execution +- Mixed-precision behavior is correct for forward + backward + reduce paths +- In-place and view/reshape operations do not corrupt gradient flow +- Device placement and dtype combinations remain legal across code paths ``` -______________________________________________________________________ - -## General Review Task Templates - -### Logic and Boundary Conditions \[deep\] +## Domain 8 Template: Checkpoint & Recovery Review \[comprehensive\] ``` -Applicable: Any non-doc/config changes +Applicable signals: dcp_consistency, optimizer_state, resume_compat Checklist: -- Conditional logic errors (if/else inversion, boundary condition omission, short-circuit issues) -- Loop errors (off-by-one, infinite loops, early exit, iterator invalidation) -- Missing null/None/empty list handling -- Type mismatch or implicit type conversion issues -- Improper exception handling (swallowing exceptions, wrong exception type, return in finally) -- Return value errors (wrong type, missing return, inconsistent multi-path returns) -- Boolean expression errors (De Morgan's law violation, precedence errors) +- Save/load requires and enforces all-rank participation where needed +- State dict naming/structure is stable or migration-safe +- Optimizer state sharding/gather behavior is consistent +- Resume path restores model + optimizer + version state coherently +- Async checkpoint behavior preserves ordering and durability assumptions ``` -### Concurrency and Async \[deep\] +## Domain 9 Template: Launcher & Infrastructure Review \[targeted\] ``` -Applicable: ASYNC_CONCURRENT type detected +Applicable signals: launcher_resource_match, scheduler_contract, rpc_transport, runtime_image_config Checklist: -- Race conditions -- Deadlock risks (inconsistent lock ordering, nested locks) -- Non-thread-safe access to shared state -- Missing await in async code -- Blocking calls in async functions (should use executor) -- Resource leaks (file handles, network connections, GPU memory not released) -- State inconsistency (dirty state after partial update failure) -- Improper context manager usage -- Signal handling and graceful shutdown issues +- Resource assignment matches declared parallel strategy assumptions +- Scheduler decisions preserve required placement/affinity constraints +- RPC serialization/deserialization keeps shape/dtype/device semantics +- Transport retries/timeouts do not violate idempotency expectations +- Cross-process startup/shutdown ordering is robust +- Runtime image and build configuration remain aligned with supported variants ``` -### Tensor Shape and Data Type \[deep\] +## Domain 10 Template: Low-Risk Hygiene Review \[basic\] ``` -Applicable: TENSOR_OPS type detected with complex tensor operations +Applicable signals: tests_docs_config, logging_import_security, project_docs_metadata Checklist: -- Tensor shape mismatch (dimension errors, broadcast errors) -- Batch dimension handling errors (missing batch dim, wrong dimension order) -- Sequence length and padding handling (missing mask, padding token in computation) -- Index out of bounds risk (dynamic indexing, negative indexing) -- dtype mismatch (fp16/fp32/bf16 mixing, integer overflow) -- Device placement errors (tensor on wrong device, CPU/GPU mixed operations) -- Gradient-related issues (missing detach, missing no_grad context, gradient accumulation errors) -- view/reshape contiguity requirements -- In-place operation effects on gradient computation +- Tests/docs/config edits are internally consistent and non-misleading +- Logging follows project conventions and avoids sensitive leakage +- No wildcard imports or obvious dependency hygiene regressions +- No accidental secrets/keys/tokens introduced +- Docs build scripts and project templates stay aligned with real contributor workflow ``` -### Numerical Stability \[unspecified-high\] +## Domain 11 Template: Harness & Agent Infrastructure Review \[targeted\] ``` -Applicable: NUMERICAL type detected +Applicable signals: skill_definition, platform_command_data, agent_registry_config Checklist: -- Numerical precision issues (floating point precision loss, accumulated errors) -- Numerical stability (log(0), division by zero, exp overflow, softmax stability) -- Numerical issues in loss function computation -- Gradient vanishing/exploding risks -- Scaling issues in mixed precision training +- Canonical skills and derived platform data remain structurally aligned +- Command docs still point to the correct data files and execution model +- Agent registry/config changes preserve command discovery and expert routing +- Cross-platform mirrors are regenerated after canonical changes ``` -### Tensor Parallel (TP) Correctness \[deep\] +## Domain 12 Template: CI/CD & Release Automation Review \[comprehensive\] ``` -Applicable: TENSOR_PARALLEL or DISTRIBUTED_COMM type detected +Applicable signals: github_workflow_jobs, runner_provisioning, release_delivery Checklist: -- Missing or misplaced all-reduce -- Missing or misplaced all-gather -- Reduce handling after weight sharding (column/row sharding) -- Input Replicate / output Partial DTensor semantics -- scatter/gather correctness in Sequence Parallel (SP) -- TP group communication correctness +- Workflow triggers, job dependencies, and permissions still enforce required validation +- Runner/image provisioning remains reproducible and compatible with job expectations +- Release, docker, and docs deployment jobs publish the intended artifacts only +- CI changes do not silently skip tests, formatting, or release gates ``` -### Communication and Synchronization \[unspecified-high\] +______________________________________________________________________ -``` -Applicable: DISTRIBUTED_COMM type detected -Checklist: -- Process group usage errors -- Device mesh configuration errors -- Improper barrier placement -- Unnecessary synchronization operations (GPU-CPU sync) -- Collective communication order dependencies -``` +## Signal-Specific Add-On Checklists -### API Compatibility \[unspecified-high\] +Use these only when corresponding L2 signals are detected. -``` -Applicable: API_CONFIG type detected -Checklist: -- Function signature changes (parameter add/delete/rename/reorder) -- Return type changes -- Default value changes causing behavior changes -- Breaking changes to public APIs -- Deprecated API usage -- Class/module rename or move -``` - -### Configuration and Parameter Validation \[unspecified-high\] +### `tree_attn` Add-On \[comprehensive\] ``` -Applicable: API_CONFIG type detected with dataclass -Checklist: -- New config items missing validation (__post_init__ validation) -- Unreasonable config default values -- Missing parameter range checks -- Unhandled dependencies between config items -- Hydra/CLI compatibility issues -- Backward compatibility of env vars/config files -- Incorrect dataclass field types +- Node/edge indexing is deterministic and shape-safe +- Tree traversal order matches attention mask semantics +- FSDP/Megatron/Archon variant modules remain behaviorally aligned ``` -### Workflow and Engine Interaction \[unspecified-high\] +### `vllm_ext` Add-On \[comprehensive\] ``` -Applicable: WORKFLOW_ENGINE type detected -Checklist: -- RolloutWorkflow.arun_episode async correctness -- InferenceEngine.agenerate call patterns -- Weight version management (set_version/update_weights/WeightUpdateMeta) -- Tensor output format ([batch, seq_len, ...] convention) -- concat_padded_tensors usage correctness -- AsyncRewardWrapper wrapping requirements +- Server and worker extension hooks still match upstream expectations +- Request pause/resume/cancel semantics remain coherent +- Integration-specific monkey-patches are scoped and guarded ``` -### Activation Checkpointing (AC) \[unspecified-high\] +### `archon_model_family` Add-On \[comprehensive\] ``` -Applicable: ACTIVATION_CKPT type detected -Checklist: -- AC application order (must after TP/CP, before FSDP) -- Selective AC op registration correctness -- AC config validation logic -- Compatibility with torch.compile +- ModelSpec registration stays unique and complete for supported model types +- Per-family model/args/spec/state-adapter wiring remains consistent +- Pipelining hooks and model-part boundaries stay compatible with runtime assumptions ``` -### Performance Regression Risk \[unspecified-high\] +### `archon_moe_modeling` Add-On \[comprehensive\] ``` -Applicable: Any non-doc changes, especially TENSOR_OPS, DISTRIBUTED_COMM -Checklist: -- Unnecessary GPU-CPU sync (.item(), .tolist(), printing tensors) -- Memory allocation pattern changes (potential OOM) -- Communication volume increase -- Computational complexity changes -- torch.compile compatibility breakage -- Unnecessary tensor copies +- Router top-k, gating dtype, and expert grouping semantics remain coherent +- GroupedExperts layout and token reordering assumptions still match expert execution +- MoE weight conversion paths stay consistent with runtime sharding expectations ``` -### Context-Aware Review \[unspecified-high\] +### `service_routing_dataflow` Add-On \[comprehensive\] ``` -Applicable: Any code changes -Checklist: -- Read git blame and history of modified code -- Check for accidental rollback of previous fixes -- Check for breaking previously established patterns or conventions -- Check if changes violate code comments -- Check for violations of TODO/FIXME constraints -- Check for ignored NOTE/WARNING comments +- Route selection and fallback ordering are deterministic +- Data proxy transformations preserve payload integrity +- Session-key partitioning logic is collision-safe ``` -### Sequence Parallel (SP/CP) Correctness \[deep\] +### `remote_inference_backend` Add-On \[targeted\] ``` -Applicable: sequence_parallel, context_parallel, SP, CP -Checklist: -- scatter/gather operation correctness -- Attention mask handling under SP -- Position encoding sharding -- KV cache handling under CP -- Combination correctness with TP +- Remote backend request/response semantics remain consistent across supported engines +- Backend-specific transport options do not change lifecycle expectations silently +- Shared request payload assumptions remain compatible across remote backends ``` -### Checkpoint and Recovery \[unspecified-high\] +### `weight_sync` Add-On \[comprehensive\] ``` -Applicable: areal/utils/saver.py, areal/utils/recover.py, state_dict, checkpoint -Checklist: -- Checkpoint save/load completeness -- Distributed checkpoint consistency -- Version compatibility (can old checkpoints load) -- Recovery logic correctness -- Optimizer state handling +- Versioned updates are monotonic and race-safe +- Broadcast/all-gather points are aligned with consumer expectations +- Local caching behavior cannot serve stale weights indefinitely ``` -### Reward Function Correctness \[unspecified-high\] +### `activation_checkpointing` Add-On \[targeted\] ``` -Applicable: areal/reward/ directory -Checklist: -- Reward function signature matches (prompt, completions, prompt_ids, completion_ids, **data) -- Deterministic computation (same input produces same output) -- Blocking calls wrapped with AsyncRewardWrapper -- Numerical range reasonableness -- Edge case handling (empty input, abnormal answers) +- Checkpoint wrappers are applied in a parallelism-safe order +- Selective checkpoint policies still cover the intended modules only +- Activation recompute paths do not break sharding or sequence-parallel assumptions ``` -### Dataset Loader Correctness \[unspecified-high\] +### `reward_surface` Add-On \[targeted\] ``` -Applicable: areal/dataset/ directory -Checklist: -- Data format validation (messages, answer, image_path fields) -- Tokenizer compatibility -- max_length truncation logic -- Distributed sampling correctness -- Memory efficiency (avoid loading all data at once) +- AsyncRewardWrapper-facing reward interfaces remain contract-compatible +- Reward outputs keep expected shape, dtype, and per-sample semantics +- Workflow assumptions about reward timing and batching remain valid ``` -### Launcher and Scheduler Configuration \[unspecified-high\] +### `compile_dynamo` Add-On \[targeted\] ``` -Applicable: areal/infra/launcher/, areal/infra/scheduler/, areal/infra/rpc/ directories -Checklist: -- Resource config reasonableness (GPU count, memory) -- Process group config matches parallel strategy -- Environment variable passing correctness -- Container/image config compatibility -- Slurm/Ray specific configurations +- torch.compile and dynamo guards still tolerate expected dynamic-shape inputs +- fullgraph and mark_dynamic choices remain compatible with distributed execution paths +- Compile-specific changes do not silently alter runtime fallback behavior ``` -### torch.compile Compatibility \[unspecified-high\] +### `rpc_transport` Add-On \[targeted\] ``` -Applicable: COMPILE type detected or hot path code modified -Checklist: -- Dynamic shape mark_dynamic marking -- Graph break risks (Python control flow, data-dependent branches) -- Unsupported operations (some in-place ops) -- fullgraph=True compatibility -- Interaction with FSDP/TP +- RTensor conversion is reversible and metadata-complete +- Batch fetch/request framing preserves ordering and boundaries +- Retry logic does not replay non-idempotent actions incorrectly ``` -### Documentation Format Check \[quick\] +### `runtime_image_config` Add-On \[targeted\] ``` -Applicable: DOCS type detected -Checklist: -- Markdown format correctness -- Internal link validity -- Code example correctness +- Docker base image and build args still match supported backend variants +- Layer ordering preserves expected cache and dependency behavior +- Image contents remain aligned with runtime assumptions documented in the repo ``` -### Test Coverage Check \[quick\] +### `project_dependency_config` Add-On \[targeted\] ``` -Applicable: TESTS type detected -Checklist: -- Test cases cover main paths -- Boundary condition tests -- Error handling tests +- Python/version constraints and extras remain internally consistent +- Lockfile changes match the intended dependency update scope +- Build backend/tooling changes do not break install or publish workflows ``` -### Logging and Metrics \[quick\] +### `github_workflow_jobs` Add-On \[comprehensive\] ``` -Applicable: logging, stats_tracker, StatsLogger -Checklist: -- Use areal.utils.logging.getLogger not print -- Structured metrics sent via stats_tracker -- Reasonable log levels (no DEBUG on hot paths) -- Sensitive info not logged +- Workflow triggers and job graph still run required validation paths +- Required secrets/permissions are scoped correctly +- Matrix or conditional changes do not silently skip critical jobs ``` -### Import and Dependencies \[quick\] +### `project_docs_metadata` Add-On \[basic\] ``` -Applicable: Any Python file changes -Checklist: -- Avoid wildcard imports (from x import *) -- Correct third-party vs internal import grouping -- Heavy optional deps inside functions -- Circular import risks +- Docs build entrypoints and contributor-facing metadata remain mutually consistent +- Public templates and contributor instructions still match the actual workflow +- Build/preview guidance still points to the supported commands ``` -### Security and Sensitive Information \[quick\] +### `skill_definition` / `platform_command_data` Add-On \[targeted\] ``` -Applicable: Config files, environment variables, API calls -Checklist: -- No hardcoded keys/tokens/passwords -- Sensitive info not committed to repo -- API endpoints configurable -- Error messages don't leak sensitive details +- Canonical and derived review-pr data files stay in sync after edits +- Command/import paths remain correct after file moves or renames +- Wrapper-specific routing stays out of canonical reference files ``` diff --git a/.agents/skills/review-pr/sync_review_pr_refs.py b/.agents/skills/review-pr/sync_review_pr_refs.py new file mode 100644 index 0000000000..9f85d1539e --- /dev/null +++ b/.agents/skills/review-pr/sync_review_pr_refs.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +"""Sync review-pr reference data from canonical .agents files. + +Canonical source: + - .agents/skills/review-pr/references/review-pr-domains-and-signals.md + - .agents/skills/review-pr/references/review-pr-templates.md + +Derived targets: + - .opencode/data/review-pr-domains-and-signals.md + - .opencode/data/review-pr-templates.md + - .claude/data/review-pr-domains-and-signals.md + - .claude/data/review-pr-templates.md + +Usage: + python .agents/skills/review-pr/sync_review_pr_refs.py --write + python .agents/skills/review-pr/sync_review_pr_refs.py --check +""" + +from __future__ import annotations + +import difflib +import sys +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path + +Transform = Callable[[str], str] + + +@dataclass(frozen=True) +class SyncSpec: + src: Path + dst: Path + transform: Transform + + +def transform_for_opencode(text: str) -> str: + return text.replace( + "`.agents/skills/review-pr/SKILL.md`", "`.opencode/command/review-pr.md`" + ) + + +def transform_for_claude(text: str) -> str: + out = text + out = out.replace( + "`.agents/skills/review-pr/SKILL.md`", "`.claude/commands/review-pr.md`" + ) + return out + + +def find_repo_root(start: Path) -> Path: + cur = start.resolve() + for parent in [cur, *cur.parents]: + if (parent / ".git").exists(): + return parent + raise RuntimeError("Unable to locate repository root (missing .git)") + + +def normalized(text: str) -> str: + norm = text.replace("\r\n", "\n").replace("\r", "\n") + if not norm.endswith("\n"): + norm += "\n" + return norm + + +def sync_one(spec: SyncSpec, check_only: bool) -> tuple[bool, str]: + try: + source_text = normalized(spec.src.read_text(encoding="utf-8")) + except OSError as exc: + raise RuntimeError(f"Cannot read canonical source: {spec.src}") from exc + expected = normalized(spec.transform(source_text)) + + existing = "" + if spec.dst.exists(): + try: + existing = normalized(spec.dst.read_text(encoding="utf-8")) + except OSError as exc: + raise RuntimeError(f"Cannot read sync target: {spec.dst}") from exc + + if existing == expected: + return False, "" + + if check_only: + diff = "".join( + difflib.unified_diff( + existing.splitlines(keepends=True), + expected.splitlines(keepends=True), + fromfile=str(spec.dst), + tofile=f"{spec.dst} (expected)", + ) + ) + return True, diff + + spec.dst.parent.mkdir(parents=True, exist_ok=True) + try: + _ = spec.dst.write_text(expected, encoding="utf-8") + except OSError as exc: + raise RuntimeError(f"Cannot write sync target: {spec.dst}") from exc + return True, "" + + +def build_specs(repo_root: Path) -> list[SyncSpec]: + canonical_dir = repo_root / ".agents/skills/review-pr/references" + domains_and_signals = canonical_dir / "review-pr-domains-and-signals.md" + templates = canonical_dir / "review-pr-templates.md" + + return [ + SyncSpec( + domains_and_signals, + repo_root / ".opencode/data/review-pr-domains-and-signals.md", + transform_for_opencode, + ), + SyncSpec( + templates, + repo_root / ".opencode/data/review-pr-templates.md", + transform_for_opencode, + ), + SyncSpec( + domains_and_signals, + repo_root / ".claude/data/review-pr-domains-and-signals.md", + transform_for_claude, + ), + SyncSpec( + templates, + repo_root / ".claude/data/review-pr-templates.md", + transform_for_claude, + ), + ] + + +def parse_mode(argv: list[str]) -> str: + if "-h" in argv or "--help" in argv: + print("usage: sync_review_pr_refs.py [--write | --check]") + print() + print("Sync /review-pr reference files across platforms") + print() + print("options:") + print(" --write Write derived files") + print(" --check Check derived files are up to date") + raise SystemExit(0) + + modes = [arg for arg in argv if arg in {"--write", "--check"}] + if len(modes) != 1: + print( + "error: exactly one mode is required: --write or --check", file=sys.stderr + ) + raise SystemExit(2) + return modes[0] + + +def main() -> int: + mode = parse_mode(sys.argv[1:]) + check_only = mode == "--check" + write_mode = mode == "--write" + repo_root = find_repo_root(Path(__file__)) + specs = build_specs(repo_root) + + changed_any = False + diffs: list[str] = [] + + for spec in specs: + changed, diff = sync_one(spec, check_only=check_only) + changed_any = changed_any or changed + if changed and check_only and diff: + diffs.append(diff) + if changed and write_mode: + print(f"updated: {spec.dst}") + if not changed and write_mode: + print(f"up-to-date: {spec.dst}") + + if check_only: + if changed_any: + print( + "/review-pr reference files are out of sync. Run with --write.", + file=sys.stderr, + ) + for d in diffs: + print(d, file=sys.stderr) + return 1 + print("/review-pr reference files are in sync.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.claude/commands/review-pr.md b/.claude/commands/review-pr.md index 50d8737a64..9cf88d0ff9 100644 --- a/.claude/commands/review-pr.md +++ b/.claude/commands/review-pr.md @@ -1,12 +1,12 @@ --- name: review-pr -description: Intelligent PR code review with dynamic agent allocation based on change types +description: Intelligent PR code review with dynamic agent allocation based on domains and signals allowed-tools: Read, Grep, Glob, Bash, Task --- -@.claude/data/review-pr-change-types.md @.claude/data/review-pr-templates.md +@.claude/data/review-pr-domains-and-signals.md @.claude/data/review-pr-templates.md # PR Code Review (Dynamic Agent Allocation) @@ -33,7 +33,7 @@ targeted review tasks based on PR changes. Phase 1: Deep PR Analysis [Haiku + Sonnet] |- 1.0 PR Status Check [Haiku] |- 1.1 Get PR Summary [Haiku] - +- 1.2-1.4 Change Type Detection [Sonnet] ++- 1.2-1.4 Domain/Signal Detection [Sonnet] | Phase 2: Dynamic Agent Planning [Sonnet] | @@ -66,32 +66,33 @@ Check if PR should be reviewed: Get basic PR info: title, description, modified files, change summary. -### 1.2 Change Type Detection \[Sonnet\] +### 1.2 Domain & Signal Detection \[Sonnet\] -Analyze each file change, detecting change types by risk level. +Analyze each file change, detecting L1 domains and L2 signals by risk level. -**Reference**: See `review-pr-change-types.md` for complete detection tables: +**Reference**: See `review-pr-domains-and-signals.md` for complete domain tables: -- CRITICAL level types (Archon, FSDP, Megatron, DCP) -- HIGH level types (distributed comm, DTensor, MoE, TP/EP/CP) -- MEDIUM level types (tensor ops, workflow, API, compile) -- LOW level types (tests, docs, config) +- L1 domains (Distributed Runtime, Model Compute & Attention, Inference Backend & + Serving, etc.) +- L2 signals per domain +- cross-domain linkage rules -### 1.3 Framework-Specific Risk Identification +### 1.3 Domain-Specific Risk Identification -Based on detected types, identify corresponding risks. +Based on detected domains/signals, identify corresponding risks and linked checks. -**Reference**: See `review-pr-change-types.md` for risk lists per framework. +**Reference**: See `review-pr-domains-and-signals.md` for risk lists per domain. ### 1.4 Output Change Analysis Report ``` CHANGE_ANALYSIS_REPORT: -- detected_types: [ARCHON_PARALLEL, EP_ETP, FSDP_CORE, ...] +- detected_domains: [Distributed Runtime, Model Compute & Attention, ...] +- detected_signals: [weight_sync, tree_attn, ...] - risk_level: CRITICAL | HIGH | MEDIUM | LOW - affected_files: [file1.py, file2.py, ...] - identified_risks: [risk1, risk2, ...] -- related_frameworks: [archon, fsdp, megatron, ...] +- related_frameworks: [archon, fsdp, megatron, vllm, service-stack, ...] ``` ______________________________________________________________________ @@ -102,28 +103,31 @@ ______________________________________________________________________ 1. **Generate tasks by risk area**: Each high-risk area gets a dedicated task 1. **Merge related changes**: Interdependent changes can be merged -1. **Model selection**: CRITICAL/HIGH -> Opus, MEDIUM -> Sonnet, LOW -> Haiku +1. **Review depth selection**: CRITICAL/HIGH -> `comprehensive`, MEDIUM -> `targeted`, + LOW -> `basic` +1. **Model routing**: `comprehensive` -> Opus, `targeted` -> Sonnet, `basic` -> Haiku 1. **Minimum coverage**: Even simple changes get at least 1 basic review task ### 2.2 Task Template Selection -Based on detected change types, select appropriate review task templates. +Based on detected domains/signals, select appropriate review task templates. **Reference**: See `review-pr-templates.md` for complete task templates: -- Framework-specific tasks (Archon, FSDP, Megatron, DCP, Trainer) -- General tasks (Logic, Concurrency, Tensor, Numerical, TP, etc.) +- Domain templates (Distributed Runtime, Model Compute & Attention, Inference Backend & + Serving, etc.) +- Universal + signal-specific add-on templates ### 2.3 Output Review Task List ``` GENERATED_REVIEW_TASKS: -1. [Opus] Task Name - - Reason: XXX change type detected +1. [comprehensive -> Opus] Task Name + - Reason: XXX domain/signal detected - Checklist: [...] - Focus files: [...] -2. [Sonnet] Task Name +2. [targeted -> Sonnet] Task Name - Reason: ... ... ``` @@ -155,13 +159,13 @@ findings: suggestion: "Fix suggestion" ``` -### 3.3 Review Depth by Model +### 3.3 Review Depth Mapping -| Model | Requirements | -| ---------- | -------------------------------------------------------------------------- | -| **Opus** | Complete context, cross-file traces, verify parallel strategy interactions | -| **Sonnet** | Changed code + direct callers/callees, type signature consistency | -| **Haiku** | Format and basic correctness only | +| Review Depth | Model | Requirements | +| ----------------- | ------ | -------------------------------------------------------------------------- | +| **comprehensive** | Opus | Complete context, cross-file traces, verify parallel strategy interactions | +| **targeted** | Sonnet | Changed code + direct callers/callees, type signature consistency | +| **basic** | Haiku | Format and basic correctness only | ______________________________________________________________________ @@ -184,7 +188,8 @@ ______________________________________________________________________ ## PR Overview - **Title**: PR title -- **Detected Change Types**: [...] +- **Detected Domains**: [...] +- **Detected Signals**: [...] - **Risk Level**: CRITICAL | HIGH | MEDIUM | LOW - **Generated Review Tasks**: N @@ -225,13 +230,13 @@ ______________________________________________________________________ ## Dynamic Generation Examples -| PR Type | Detected Types | Generated Tasks | -| -------------- | ------------------------------------- | --------------- | -| Docs only | \[DOCS\] | 1 Haiku | -| Config only | \[CONFIG_ONLY\] | 1-2 Haiku | -| Single bug fix | \[TENSOR_OPS\] | 2-4 Sonnet | -| Archon core | \[ARCHON\_\*, EP_ETP, DTENSOR\] | 4-8 Opus | -| Cross-domain | \[WORKFLOW_ENGINE, FSDP_CORE, TESTS\] | 5-10 mixed | +| PR Type | Detected Domains/Signals | Generated Tasks | +| -------------- | --------------------------------------------------- | ------------------------------- | +| Docs only | \[Low-Risk Hygiene / tests_docs_config\] | 1 basic -> Haiku | +| Config only | \[API & Config Compatibility / dataclass_schema\] | 1-2 basic/targeted | +| Single bug fix | \[Numerics & Tensor Semantics / shape_dtype\] | 2-4 targeted -> Sonnet | +| Archon core | \[Distributed Runtime / mesh_dtensor, weight_sync\] | 4-8 comprehensive -> Opus | +| Cross-domain | \[Workflow & Trainer + Distributed + Hygiene\] | 5-10 mixed review depths/models | ______________________________________________________________________ @@ -262,26 +267,24 @@ ______________________________________________________________________ Location: .claude/commands/review-pr.md Invocation: /review-pr Related files: - - .claude/data/review-pr-change-types.md: Change type detection tables +- .claude/data/review-pr-domains-and-signals.md: Domain and signal detection tables - .claude/data/review-pr-templates.md: Review task templates ## Structure - Main file (this): workflow and phases, @imports data files -- data/review-pr-change-types.md: detection tables +- data/review-pr-domains-and-signals.md: domain and signal detection tables - data/review-pr-templates.md: task templates ## How to Update -### Adding New Change Types -Edit .claude/data/review-pr-change-types.md: -1. Add to appropriate level table (CRITICAL/HIGH/MEDIUM/LOW) -2. Add framework risks if applicable +### Adding New Domains or Signals +Edit `.agents/skills/review-pr/references/review-pr-domains-and-signals.md`, then regenerate +the derived data files with `python3 .agents/skills/review-pr/sync_review_pr_refs.py --write`. ### Adding New Task Templates -Edit .claude/data/review-pr-templates.md: -1. Add to framework-specific or general section -2. Include checklist +Edit `.agents/skills/review-pr/references/review-pr-templates.md`, then regenerate the +derived data files with `python3 .agents/skills/review-pr/sync_review_pr_refs.py --write`. ### Adjusting Model Selection Modify "Model Configuration" table in this file. diff --git a/.claude/data/review-pr-change-types.md b/.claude/data/review-pr-change-types.md deleted file mode 100644 index 0fd0584072..0000000000 --- a/.claude/data/review-pr-change-types.md +++ /dev/null @@ -1,149 +0,0 @@ -# PR Review: Change Type Detection Reference - -This file contains the change type detection tables for PR review. Referenced by: -`.claude/commands/review-pr.md` - -______________________________________________________________________ - -## CRITICAL Level (Must use Opus) - -| Change Type | File Path Pattern | Code Pattern | -| ---------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------- | -| **ARCHON_CORE** | `areal/experimental/models/archon/` | - | -| **ARCHON_PARALLEL** | `parallel_dims.py` | `ArchonParallelDims`, `_build_mesh`, `DeviceMesh` | -| **ARCHON_MOE** | `archon/moe/` | `router`, `grouped_experts`, `TokenReorderer`, `grouped_mm` | -| **ARCHON_PARALLELIZE** | `qwen*/infra/parallelize.py` | `apply_moe_ep_tp`, `apply_tp`, `apply_cp` | -| **ARCHON_ENGINE** | `areal/experimental/engine/archon_engine.py` | `ArchonEngine` | -| **FSDP_CORE** | `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` | `FSDP`, `FullyShardedDataParallel`, `fully_shard` | -| **MEGATRON_CORE** | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine` | -| **DCP_CHECKPOINT** | - | `DCP`, `DistributedCheckpoint`, `dcp.save`, `dcp.load` | - -## HIGH Level (Recommend Opus) - -| Change Type | File Path Pattern | Code Pattern | -| --------------------- | ----------------- | -------------------------------------------------------------------------------- | -| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` | -| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` | -| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` | -| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` | -| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` | -| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` | -| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` | -| **TRAINER_CORE** | `areal/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` | - -## MEDIUM Level (Use Sonnet) - -| Change Type | File Path Pattern | Code Pattern | -| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ | -| **TENSOR_OPS** | - | `.view(`, `.reshape(`, `dtype=`, `.detach()`, `no_grad`, `.contiguous()` | -| **NUMERICAL** | - | `log(`, `softmax`, `cross_entropy`, `eps=`, `.clamp(`, `nan`, `inf` | -| **WORKFLOW_ENGINE** | `areal/workflow/`, `areal/engine/` | `arun_episode`, `agenerate`, `RolloutWorkflow` | -| **API_CONFIG** | `areal/api/` | `@dataclass`, `__post_init__`, `field(` | -| **COMPILE** | - | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | -| **ACTIVATION_CKPT** | `activation_checkpoint.py` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | -| **CHECKPOINT_RECOVERY** | `areal/utils/saver.py`, `areal/utils/recover.py`, `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/async_checkpoint.py` | `state_dict`, `load_state_dict`, `checkpoint`, `AsyncCheckpointManager` | -| **REWARD** | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | -| **DATASET** | `areal/dataset/` | `get_*_dataset`, `DataLoader`, `IterableDataset` | -| **LAUNCHER_SCHEDULER** | `areal/infra/launcher/`, `areal/infra/scheduler/`, `areal/infra/rpc/` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` | -| **ATTENTION** | `attention/`, `attention/sdpa.py`, `attention/varlen.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` | - -## LOW Level (Use Haiku) - -| Change Type | File Path Pattern | Code Pattern | -| --------------- | ---------------------------- | ------------ | -| **TESTS** | `tests/`, `*_test.py` | - | -| **DOCS** | `docs/`, `*.md` | - | -| **CONFIG_ONLY** | `*.yaml`, `*.json`, `*.toml` | - | - -______________________________________________________________________ - -## Framework-Specific Risk Identification - -### Archon Risks (When ARCHON\_\* types detected) - -- **Device mesh dimension mismatch**: mesh dimension names don't correspond to placement -- **EP constraint violation**: `ep_size` must divide `num_experts`, and - `dp_shard * cp * (tp if etp==1 else 1) % ep == 0` -- **ETP configuration error**: `etp` must be 1 or equal to `tp` -- **Token alignment error**: `grouped_mm` requires token count aligned to 8/16/32 -- **All-to-All split/combine mismatch**: dispatch and combine split configs inconsistent -- **DTensor/Local tensor conversion missing**: need `.to_local()` or - `DTensor.from_local()` -- **torch.compile dynamic shape marking missing**: missing `mark_dynamic` calls -- **AC application order error**: must be after TP/CP, before FSDP -- **Ulysses SP configuration**: CP uses Ulysses implementation, not Ring Attention -- **dp_shard_mod_ep mesh usage**: MoE experts must use `dp_shard_mod_ep` mesh for FSDP - -### FSDP Risks (When FSDP\_\* types detected) - -- **Shard/reshard timing error**: premature or delayed sharding operations -- **EP mesh interaction issue**: should use `dp_shard_mod_ep` not `dp_shard` for MoE -- **Gradient divide factor calculation**: incorrect relationship with world size -- **State dict save/load inconsistency**: mixing sharded vs full modes -- **Optimizer state handling**: aggregation and distribution of sharded state -- **DCP compatibility**: ensure DCP save/load works with FSDP2 - -### Megatron Risks (When MEGATRON\_\* types detected) - -- **Pipeline stage splitting error**: unbalanced layer distribution -- **Micro-batch scheduling issues**: pipeline bubble handling -- **Weight sharding and sync**: tied weights handling -- **AC interaction**: checkpointing under pipeline parallelism - -### DCP/Checkpoint Risks (When DCP_CHECKPOINT or CHECKPOINT_RECOVERY detected) - -- **Distributed checkpoint consistency**: all ranks must participate in save/load -- **State dict key mismatch**: keys must match between save and load -- **Optimizer state compatibility**: ensure optimizer state is correctly - sharded/gathered -- **Version compatibility**: old checkpoints should load in new code -- **Storage backend compatibility**: ensure storage backend (filesystem, S3, etc.) is - compatible - -______________________________________________________________________ - -## Risk Linkage Rules - -| Detected Change | Auto-Linked Review | -| --------------------------- | ------------------------------------------------------ | -| EP changes | FSDP interaction check, dp_shard_mod_ep mesh check | -| ETP changes | TP + EP combination check, mesh dimension check | -| Megatron changes | Pipeline + AC check | -| Distributed comm changes | Process group + sync check | -| SEQUENCE_PARALLEL changes | TP combination + Attention mask check, Ulysses check | -| CHECKPOINT_RECOVERY changes | FSDP state dict check, DCP compatibility check | -| DCP_CHECKPOINT changes | FSDP2 integration check, distributed consistency check | -| COMPILE changes | Performance regression + FSDP/TP interaction check | -| REWARD changes | Workflow interaction check, AsyncRewardWrapper check | -| LAUNCHER_SCHEDULER changes | Resource config + parallel strategy match check | -| TRAINER_CORE changes | Engine lifecycle + workflow integration check | -| ARCHON_ENGINE changes | DCP checkpoint + parallel dims check | - -______________________________________________________________________ - -## Core Framework Paths (Must Use Opus) - -**Archon Core**: - -- `areal/experimental/models/archon/` (entire directory) -- `areal/experimental/engine/archon_engine.py` -- `areal/experimental/engine/archon_checkpoint.py` - -**FSDP Core**: - -- `areal/engine/fsdp_utils/` -- `areal/engine/fsdp_engine.py` - -**Megatron Core**: - -- `areal/engine/megatron_engine.py` -- `areal/engine/megatron_utils/megatron.py` -- `areal/engine/megatron_utils/checkpointer.py` - -**Trainer Core**: - -- `areal/trainer/` - -**Training Engine Core** (excludes FSDP/Megatron which have their own categories): - -- `areal/engine/` (except `fsdp_engine.py`, `megatron_engine.py`) diff --git a/.claude/data/review-pr-domains-and-signals.md b/.claude/data/review-pr-domains-and-signals.md new file mode 100644 index 0000000000..7587d6dda9 --- /dev/null +++ b/.claude/data/review-pr-domains-and-signals.md @@ -0,0 +1,253 @@ +# PR Review: Domain & Signal Detection Reference + +This file contains the canonical change-domain and signal detection tables for PR +review. Referenced by: `.claude/commands/review-pr.md` + +______________________________________________________________________ + +## Severity-to-Review-Depth Mapping + +- **CRITICAL**: use `comprehensive` review depth +- **HIGH**: use `comprehensive` review depth +- **MEDIUM**: use `targeted` review depth +- **LOW**: use `basic` review depth + +______________________________________________________________________ + +## L1 Domains and L2 Signals + +## Domain 1: Distributed Runtime (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `archon_core` | `areal/experimental/engine/archon_engine.py`, `areal/experimental/engine/archon_checkpoint.py` | `ArchonEngine`, `ArchonCheckpointManager`, `archon` | +| `archon_parallel` | `areal/experimental/models/archon/`, `parallel_dims.py`, `parallelize.py` | `ArchonParallelDims`, `_build_mesh`, `apply_moe_ep_tp`, `apply_tp`, `apply_cp`, `ExpertTensorParallel`, `etp`, `parallelize_module`, `ColwiseParallel`, `RowwiseParallel` | +| `process_group` | `areal/engine/fsdp_utils/`, `areal/engine/megatron_utils/`, `areal/experimental/engine/` | `new_group`, `ProcessGroup`, `dist.get_rank(` | +| `fsdp_core` | `areal/engine/fsdp_engine.py`, `areal/engine/fsdp_utils/` | `FSDP`, `fully_shard`, `FullyShardedDataParallel` | +| `megatron_core` | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine`, `pipeline`, `micro-batch` | +| `collectives` | `areal/engine/`, `areal/infra/rpc/` | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `broadcast`, `barrier` | +| `mesh_dtensor` | `areal/experimental/models/archon/`, `areal/engine/fsdp_utils/` | `DeviceMesh`, `DTensor`, `Shard(`, `Replicate(`, `distribute_tensor` | +| `activation_checkpointing` | `areal/experimental/models/archon/activation_checkpoint.py`, `areal/models/`, `areal/engine/` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | +| `weight_sync` | `areal/experimental/engine/archon_weight_sync.py`, `areal/api/engine_api.py`, `areal/engine/` | `WeightUpdateMeta`, `set_version`, `update_weights` | + +## Domain 2: Model Compute & Attention (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| `tree_attn` | `areal/models/tree_attn/` | `TreeAttention`, `tree_attn`, `TreeNode`, `tree` | +| `sdpa_varlen` | `attention/sdpa.py`, `attention/varlen.py`, `areal/models/tree_attn/` | `sdpa`, `flash_attn`, `varlen`, `causal_mask` | +| `sp_cp_attention_mask` | `areal/models/tree_attn/`, `areal/experimental/models/archon/attention/` | `SequenceParallel`, `context_parallel`, `mask` | +| `triton_kernel` | `areal/models/tree_attn/triton_kernel.py` | `triton`, `kernel`, `autotune` | +| `archon_model_family` | `areal/experimental/models/archon/model_spec.py`, `areal/experimental/models/archon/qwen*/` | `ModelSpec`, `register_model_spec`, `supported_model_types`, `state_dict_adapter`, `rope` | +| `archon_attention_stack` | `areal/experimental/models/archon/attention/`, `areal/experimental/models/archon/ulysses.py` | `ulysses_slice_inputs`, `ulysses_gather_output`, `gather_seq_scatter_heads`, `sdpa`, `varlen` | +| `archon_moe_modeling` | `areal/experimental/models/archon/moe/`, `areal/experimental/models/archon/expert_parallel.py`, `areal/experimental/models/archon/moe_weight_converter.py` | `TokenChoiceTopKRouter`, `RouterGateLinear`, `GroupedExperts`, `MoEWeightConverter`, `expert_parallel` | + +## Domain 3: Inference Backend & Serving (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | -------------------------------------------------------------- | ---------------------------------------------------------------- | +| `vllm_ext` | `areal/engine/vllm_ext/` | `areal_vllm_server`, `vllm_worker_extension`, `pause_generation` | +| `remote_inference_backend` | `areal/engine/vllm_remote.py`, `areal/engine/sglang_remote.py` | `vllm`, `sglang`, `OpenAI`, `request`, `response` | +| `request_lifecycle` | `areal/engine/`, `areal/infra/launcher/` | `enqueue`, `dequeue`, `cancel`, `timeout` | + +## Domain 4: Service Orchestration (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `service_routing_dataflow` | `areal/experimental/agent_service/gateway/`, `areal/experimental/agent_service/router/`, `areal/experimental/inference_service/data_proxy/`, `areal/experimental/inference_service/controller/` | `route`, `gateway`, `router`, `DataProxy`, `controller`, `batch` | +| `session_consistency` | `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `session`, `affinity`, `history`, `state` | + +## Domain 5: Workflow & Trainer Contract (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------- | --------------------------------------------------- | +| `workflow_engine_boundary` | `areal/workflow/`, `areal/trainer/`, `areal/engine/` | `RolloutWorkflow`, `arun_episode`, `agenerate` | +| `dataset_surface` | `areal/dataset/` | `DataLoader`, `IterableDataset`, `get_*_dataset` | +| `async_contract` | `areal/workflow/`, `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `async def`, `await`, `aiofiles`, `asyncio` | +| `weight_version_contract` | `areal/api/engine_api.py`, `areal/workflow/`, `areal/trainer/` | `WeightUpdateMeta`, `set_version`, `weight version` | + +## Domain 6: API & Config Compatibility (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------------- | ------------------------------------- | -------------------------------------------------------------------------------- | +| `dataclass_schema` | `areal/api/` | `@dataclass`, `field(`, `__post_init__` | +| `cli_compat` | `areal/api/cli_args.py` | `Literal`, `help`, `default` | +| `backward_compat` | `areal/api/`, `areal/infra/launcher/` | `deprecated`, `compat`, `version` | +| `project_dependency_config` | `pyproject.toml`, `uv.lock` | `requires-python`, `dependencies`, `optional-dependencies`, `build-system`, `uv` | + +## Domain 7: Numerics & Tensor Semantics (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------- | +| `shape_dtype` | `areal/engine/`, `areal/models/`, `areal/trainer/` | `.view(`, `.reshape(`, `dtype=`, `.contiguous(` | +| `numerical_stability` | `areal/engine/`, `areal/reward/`, `areal/utils/functional/` | `log(`, `softmax`, `eps=`, `.clamp(`, `nan`, `inf` | +| `reward_surface` | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | +| `compile_dynamo` | `areal/experimental/models/archon/compile.py`, `areal/models/`, `areal/engine/` | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | +| `mixed_precision_fp8` | `areal/engine/megatron_utils/fp8/`, `areal/experimental/models/archon/` | `fp8`, `bf16`, `fp16`, `mixed precision` | + +## Domain 8: Checkpoint & Recovery (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------- | ------------------------------------------------------------------- | ----------------------------------------------- | +| `dcp_consistency` | `areal/utils/async_checkpoint.py`, `areal/engine/**/checkpoint*.py` | `dcp.save`, `dcp.load`, `DistributedCheckpoint` | +| `optimizer_state` | `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/saver.py` | `optimizer state`, `state_dict` | +| `resume_compat` | `areal/utils/recover.py`, `areal/utils/saver.py` | `resume`, `load_state_dict`, `migration` | + +## Domain 9: Launcher & Infrastructure (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ---------------------------------------------------------------------- | --------------------------------------------------------- | +| `launcher_resource_match` | `areal/infra/launcher/` | `LaunchConfig`, `RayLauncher`, `SlurmLauncher` | +| `scheduler_contract` | `areal/infra/scheduler/`, `areal/scheduler/` | `Scheduler`, `placement`, `resource` | +| `rpc_transport` | `areal/infra/rpc/`, `areal/experimental/inference_service/data_proxy/` | `RTensor`, `serialize`, `rpc`, `fetch` | +| `runtime_image_config` | `Dockerfile`, `.dockerignore` | `FROM`, `ARG`, `RUN`, `ENV`, `COPY`, `uv sync`, `VARIANT` | + +## Domain 10: Low-Risk Hygiene (LOW) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ | +| `tests_docs_config` | `tests/`, `docs/`, `*.md`, `*.yaml`, `*.json`, `*.toml` | - | +| `logging_import_security` | `areal/`, `examples/` | `getLogger`, `print(`, `import *`, `api_key`, `token`, `password` | +| `project_docs_metadata` | `docs/build_all.sh`, `docs/generate_cli_docs.py`, `docs/en/`, `docs/zh/`, `README.md`, `CONTRIBUTING.md`, `.github/PULL_REQUEST_TEMPLATE.md`, `.github/ISSUE_TEMPLATE/` | `jupyter-book`, `generate_cli_docs`, `build_all`, `_build`, `checklist`, `template`, `contributing`, `usage` | + +## Domain 11: Harness & Agent Infrastructure (MEDIUM/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------------- | ----------------------------------------------------------------------------- | ------------------------------------------------------------ | +| `skill_definition` | `.agents/skills/**/SKILL.md`, `.agents/skills/**/references/` | `description:`, `## Workflow`, `## Reference Files`, `skill` | +| `platform_command_data` | `.claude/commands/`, `.claude/data/`, `.opencode/command/`, `.opencode/data/` | `@.`, `/review-pr`, `/create-pr`, `data/`, `task(` | +| `agent_registry_config` | `.codex/config.toml`, `.codex/agents/`, `AGENTS.md`, `CLAUDE.md` | `agents`, `skills`, `registry`, `subagent`, `config.toml` | + +## Domain 12: CI/CD & Release Automation (HIGH/CRITICAL) + +| L2 Signal | File Path Pattern | Code Pattern | +| ---------------------- | -------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------- | +| `github_workflow_jobs` | `.github/workflows/*.yml` | `jobs:`, `runs-on:`, `needs:`, `if:`, `workflow_dispatch` | +| `runner_provisioning` | `.github/workflows/bake-gcp-image.yml`, `.github/workflows/runner-heartbeat.yml` | `gcp`, `runner`, `image`, `heartbeat` | +| `release_delivery` | `.github/workflows/build-docker-image.yml`, `.github/workflows/tag-release-image.yml`, `.github/workflows/deploy-docs.yml` | `docker`, `tag`, `release`, `pages`, `publish` | + +______________________________________________________________________ + +## Must-Not-Regress Core Coverage + +The refactor must preserve these existing review surfaces: + +- Archon core: `areal/experimental/models/archon/`, + `areal/experimental/engine/archon_engine.py` +- FSDP core: `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` +- Megatron core: `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` +- Reward: `areal/reward/` +- Dataset: `areal/dataset/` +- Trainer: `areal/trainer/` +- Harness: `.agents/`, `.claude/`, `.opencode/`, `.codex/` +- CI/CD and release: `.github/workflows/`, `Dockerfile`, `pyproject.toml` + +______________________________________________________________________ + +## Cross-Domain Linkage Rules + +| Detected Signal | Auto-Linked Review | +| ---------------------------------------------- | --------------------------------------------------- | +| `archon_core` or `archon_parallel` | Model Compute & Attention checks | +| `archon_model_family` or `archon_moe_modeling` | Numerics & Tensor Semantics checks | +| `tree_attn` | Numerics & Tensor Semantics checks | +| `reward_surface` | Workflow & Trainer Contract checks | +| `compile_dynamo` | Distributed Runtime checks | +| `vllm_ext` | Launcher & Infrastructure checks | +| `service_routing_dataflow` | Workflow & Trainer async-contract checks | +| `weight_sync` | DTensor/process-group/checkpoint interaction checks | +| `rpc_transport` | Distributed Runtime synchronization checks | +| `mixed_precision_fp8` + Distributed Runtime | mesh + weight-sync compatibility checks | +| `runtime_image_config` | Inference Backend & Serving checks | +| `project_dependency_config` | API & Config Compatibility checks | +| `github_workflow_jobs` or `release_delivery` | Launcher & Infrastructure checks | +| `skill_definition` or `platform_command_data` | Low-Risk Hygiene checks | + +______________________________________________________________________ + +## Risk Identification Guidance + +### Distributed Runtime Risks + +- Archon mesh construction or parallel-dims mismatch +- EP/TP/CP application order errors in Archon parallelization +- Activation checkpoint placement violating TP/CP/FSDP ordering assumptions +- Archon engine lifecycle drift around distributed setup and checkpoint boundaries +- Collective call order mismatch across ranks +- Wrong process-group scope in rank-sensitive logic +- Mesh dimension mismatch and invalid DTensor placement +- Weight version drift between rollout and training workers + +### Model Compute & Attention Risks + +- Attention mask inconsistency under TP/SP/CP paths +- Tree attention index/routing mismatch +- Archon model-family registration or per-family wiring drift +- Archon MoE router/expert behavior diverging from weight-conversion expectations +- Archon Ulysses slicing/gather semantics mismatching attention layout assumptions +- Kernel assumptions violating dtype/shape invariants +- Sequence packing alignment errors + +### Service Orchestration Risks + +- Session affinity or history drift across gateway/router/data proxy +- Async message handling holes and dropped tasks +- Controller/worker lifecycle desynchronization + +### Inference Backend & Serving Risks + +- Request lifecycle inconsistencies (enqueue/cancel/timeout) +- Worker state transitions leaving requests stranded +- Backend extension hooks drifting from runtime expectations + +### Workflow & Trainer Contract Risks + +- Workflow-engine contract drift across async boundaries +- Weight version handshake mismatch between rollout and train +- Trainer lifecycle transition inconsistencies + +### API & Config Compatibility Risks + +- Breaking config/schema changes without migration path +- Dataclass or CLI default changes altering behavior silently +- Missing validation for newly introduced fields +- Dependency or build-system pin changes breaking supported environments + +### Numerics & Tensor Semantics Risks + +- Silent shape/dtype mismatch under distributed paths +- Unstable numerical operations in loss/reward logic +- torch.compile or dynamo guard changes breaking graph assumptions +- Mixed-precision interaction regressions + +### Checkpoint & Recovery Risks + +- Partial-rank checkpoint participation +- Incompatible state key evolution +- Resume path breaking optimizer/model synchronization + +### Launcher & Infrastructure Risks + +- Resource assignment mismatching parallel strategy assumptions +- RPC transport metadata loss (shape/dtype/device) +- Startup/shutdown ordering races across processes +- Runtime image or build-arg drift from supported inference/training variants + +### Low-Risk Hygiene Risks + +- Docs/config drift from actual runtime behavior +- Logging or import hygiene regressions +- Sensitive data exposure in logs or config +- Documentation/build scripts or project templates drifting from actual workflow + +### Harness & Agent Infrastructure Risks + +- Skill and command docs drifting from actual platform behavior +- Cross-platform data files falling out of sync with canonical references +- Agent registry/config changes breaking expert routing or command discovery + +### CI/CD & Release Automation Risks + +- Workflow trigger or job dependency changes skipping required validation +- Runner provisioning drift causing flaky or non-reproducible CI +- Release or docs deployment jobs publishing the wrong artifacts diff --git a/.claude/data/review-pr-templates.md b/.claude/data/review-pr-templates.md index d3aa11455b..2a45defca9 100644 --- a/.claude/data/review-pr-templates.md +++ b/.claude/data/review-pr-templates.md @@ -1,424 +1,320 @@ -# PR Review: Task Templates Reference +# PR Review: Domain Templates Reference -This file contains the review task templates for PR review. Referenced by: +This file contains canonical domain templates for PR review. Referenced by: `.claude/commands/review-pr.md` ______________________________________________________________________ -## Framework-Specific Review Task Templates +## Template Selection Rules -### Archon Tasks \[Opus\] +1. Select templates by detected L1 domains and L2 signals. +1. Use at most one primary template per domain. +1. Always include **General Logic & Boundary** for non-doc/config-only PRs. +1. Apply cross-domain linkage checks from `review-pr-domains-and-signals.md`. -**Task: Archon EP/ETP Strategy Correctness Review** +______________________________________________________________________ -``` -Checklist: -- ExpertParallel, TensorParallel, ExpertTensorParallel placement implementation -- Placement dimension matching with mesh dimensions -- Placement list length in _partition_fn -- all_to_all communication autograd compatibility -- ReordererSequenceParallel token index conversion -``` +## Universal Template -**Task: ArchonParallelDims Configuration Validation** +### General Logic & Boundary ``` +Applicable: Any non-doc/config-only change Checklist: -- ETP constraint: etp=1 (TP borrowed by EP) vs etp=tp (independent TP) logic -- Mesh construction: _build_mesh_with_ep() dimension order and names -- EP/TP/CP combination validity verification -- dp_shard * cp * (tp if etp==1 else 1) % ep == 0 constraint +- Boundary condition correctness (empty inputs, singleton, max-size) +- Conditional logic correctness (branch inversion, short-circuit mistakes) +- Error-path behavior (exceptions propagated with actionable context) +- Return-value consistency across code paths +- No newly introduced hidden behavior changes ``` -**Task: MoE Layer Implementation Correctness** - -``` -Checklist: -- TokenReorderer and router separation correctness -- grouped_mm token alignment (8/16/32) -- Expert weight 3D tensor sharding -- Load balancing loss calculation -``` - -**Task: Model Parallelization Application Order** - -``` -Checklist: -- apply_moe_ep_tp() strategy selection logic -- FSDP wrap order (EP -> TP -> AC -> FSDP) -- torch.compile dynamic shape marking -- Explicit prefetching configuration -``` - -### FSDP Tasks \[Opus/Sonnet\] - -**Task: FSDP Core Correctness \[Opus\]** - -``` -Checklist: -- Shard/reshard operation timing and correctness -- ShardedTensor and DTensor conversion -- Mixed precision (param_dtype vs reduce_dtype) -``` +______________________________________________________________________ -**Task: FSDP Interaction with Other Parallel Strategies \[Opus\]** +## Domain 1 Template: Distributed Runtime Review \[comprehensive\] ``` +Applicable signals: archon_core, archon_parallel, process_group, fsdp_core, megatron_core, collectives, mesh_dtensor, activation_checkpointing, weight_sync Checklist: -- FSDP must be applied after TP/CP/EP -- Use dp_shard_mod_ep mesh in EP scenarios -- Gradient divide factor relationship with world size +- Archon engine and checkpoint lifecycle remain aligned with distributed runtime assumptions +- FSDP and Megatron engine invariants still match process-group, sharding, and pipeline assumptions +- Archon parallel-dims and mesh construction still match downstream placement logic +- Process-group creation/usage/cleanup is rank-consistent +- Collective operations are called by all required ranks in consistent order +- DeviceMesh dimensions and DTensor placements are correct for each path +- Activation checkpoint placement remains compatible with parallel and sharding order requirements +- Local/global tensor conversion boundaries are explicit and correct +- Weight version propagation and update ordering are deterministic +- No debug-only barriers left in hot path ``` -**Task: FSDP State Management \[Sonnet\]** +## Domain 2 Template: Model Compute & Attention Review \[comprehensive\] ``` +Applicable signals: tree_attn, sdpa_varlen, sp_cp_attention_mask, triton_kernel, archon_model_family, archon_attention_stack, archon_moe_modeling Checklist: -- state_dict save/load sharded vs full mode -- Optimizer state sharding and aggregation -- Checkpoint compatibility +- Attention mask semantics preserved under TP/SP/CP +- Archon model-family registration and per-family wiring remain internally consistent +- Archon attention/Ulysses slicing and gather paths preserve layout assumptions +- Archon MoE router, grouped experts, and weight-conversion interfaces remain aligned +- Tree attention index/order invariants are maintained +- Kernel assumptions on dtype/shape/contiguity are satisfied +- No silent behavior change in sequence packing/unpacking +- Tensor layouts remain compatible with downstream modules ``` -### Megatron Tasks \[Opus\] - -**Task: Pipeline Parallelism Correctness** +## Domain 3 Template: Inference Backend & Serving Review \[comprehensive\] ``` +Applicable signals: vllm_ext, remote_inference_backend, request_lifecycle Checklist: -- Stage splitting correctness and balance -- Micro-batch scheduling -- Pipeline flush and bubble handling +- Request lifecycle (enqueue, execution, cancellation, timeout) is coherent +- Worker state transitions are safe under concurrency +- Backend-specific extension points stay API-compatible +- Error handling does not strand in-flight requests +- Versioning/weight-update interactions are explicit and safe ``` -**Task: Megatron Model Sharding** +## Domain 4 Template: Service Orchestration Review \[comprehensive\] ``` +Applicable signals: service_routing_dataflow, session_consistency Checklist: -- Weight sharding and synchronization -- Tied weights handling -- Embedding/output layer parallel strategy +- Gateway/router/data-proxy routing rules are deterministic +- Session affinity and history consistency are preserved +- Controller/worker coordination has no lost-update window +- Async boundaries avoid blocking operations in critical paths +- Failure/retry behavior does not duplicate or drop work ``` -### DCP/Checkpoint Tasks \[Opus\] - -**Task: Distributed Checkpoint Correctness** +## Domain 5 Template: Workflow & Trainer Contract Review \[comprehensive\] ``` +Applicable signals: workflow_engine_boundary, dataset_surface, async_contract, weight_version_contract Checklist: -- All ranks participate in DCP save/load operations -- State dict keys match between save and load -- No tensor shape/dtype mismatches -- Storage backend compatibility (filesystem, S3) -- Checkpoint versioning and migration +- RolloutWorkflow and Engine interfaces remain contract-compatible +- Dataset/output structure still matches workflow and trainer consumption expectations +- Async flow uses await consistently and avoids sync I/O in async paths +- Weight update/version handshake is preserved end-to-end +- Trainer lifecycle transitions are valid for all execution branches +- Call ordering assumptions across trainer/workflow/engine are unchanged or justified ``` -**Task: FSDP2 + DCP Integration** +## Domain 6 Template: API & Config Compatibility Review \[targeted\] ``` +Applicable signals: dataclass_schema, cli_compat, backward_compat, project_dependency_config Checklist: -- FSDP2 state dict options (full vs sharded) -- Optimizer state handling with DCP -- Async checkpointing correctness -- Checkpoint resumption logic +- Public API signature and default value changes are intentional and compatible +- Dataclass validation remains complete and informative +- CLI options preserve expected compatibility semantics +- New fields include safe defaults or explicit migration handling +- Breaking changes are documented and scoped +- Dependency and build-system changes remain compatible with supported environments ``` -### Trainer Tasks \[Opus\] - -**Task: Trainer Core Logic** +## Domain 7 Template: Numerics & Tensor Semantics Review \[targeted\] ``` +Applicable signals: shape_dtype, numerical_stability, reward_surface, compile_dynamo, mixed_precision_fp8 Checklist: -- PPOTrainer/SFTTrainer initialization correctness -- Workflow registration and invocation -- Engine lifecycle management -- Distributed training coordination +- Tensor shape/dtype transitions are explicit and internally consistent +- Numerical stability is protected (log/division/softmax/clamp paths) +- Reward-side numerical behavior remains consistent with workflow consumption expectations +- torch.compile / dynamo assumptions still hold for dynamic shapes and distributed execution +- Mixed-precision behavior is correct for forward + backward + reduce paths +- In-place and view/reshape operations do not corrupt gradient flow +- Device placement and dtype combinations remain legal across code paths ``` -______________________________________________________________________ - -## General Review Task Templates - -### Logic and Boundary Conditions \[Opus\] +## Domain 8 Template: Checkpoint & Recovery Review \[comprehensive\] ``` -Applicable: Any non-doc/config changes +Applicable signals: dcp_consistency, optimizer_state, resume_compat Checklist: -- Conditional logic errors (if/else inversion, boundary condition omission, short-circuit issues) -- Loop errors (off-by-one, infinite loops, early exit, iterator invalidation) -- Missing null/None/empty list handling -- Type mismatch or implicit type conversion issues -- Improper exception handling (swallowing exceptions, wrong exception type, return in finally) -- Return value errors (wrong type, missing return, inconsistent multi-path returns) -- Boolean expression errors (De Morgan's law violation, precedence errors) +- Save/load requires and enforces all-rank participation where needed +- State dict naming/structure is stable or migration-safe +- Optimizer state sharding/gather behavior is consistent +- Resume path restores model + optimizer + version state coherently +- Async checkpoint behavior preserves ordering and durability assumptions ``` -### Concurrency and Async \[Opus\] +## Domain 9 Template: Launcher & Infrastructure Review \[targeted\] ``` -Applicable: ASYNC_CONCURRENT type detected +Applicable signals: launcher_resource_match, scheduler_contract, rpc_transport, runtime_image_config Checklist: -- Race conditions -- Deadlock risks (inconsistent lock ordering, nested locks) -- Non-thread-safe access to shared state -- Missing await in async code -- Blocking calls in async functions (should use executor) -- Resource leaks (file handles, network connections, GPU memory not released) -- State inconsistency (dirty state after partial update failure) -- Improper context manager usage -- Signal handling and graceful shutdown issues +- Resource assignment matches declared parallel strategy assumptions +- Scheduler decisions preserve required placement/affinity constraints +- RPC serialization/deserialization keeps shape/dtype/device semantics +- Transport retries/timeouts do not violate idempotency expectations +- Cross-process startup/shutdown ordering is robust +- Runtime image and build configuration remain aligned with supported variants ``` -### Tensor Shape and Data Type \[Opus\] +## Domain 10 Template: Low-Risk Hygiene Review \[basic\] ``` -Applicable: TENSOR_OPS type detected with complex tensor operations +Applicable signals: tests_docs_config, logging_import_security, project_docs_metadata Checklist: -- Tensor shape mismatch (dimension errors, broadcast errors) -- Batch dimension handling errors (missing batch dim, wrong dimension order) -- Sequence length and padding handling (missing mask, padding token in computation) -- Index out of bounds risk (dynamic indexing, negative indexing) -- dtype mismatch (fp16/fp32/bf16 mixing, integer overflow) -- Device placement errors (tensor on wrong device, CPU/GPU mixed operations) -- Gradient-related issues (missing detach, missing no_grad context, gradient accumulation errors) -- view/reshape contiguity requirements -- In-place operation effects on gradient computation +- Tests/docs/config edits are internally consistent and non-misleading +- Logging follows project conventions and avoids sensitive leakage +- No wildcard imports or obvious dependency hygiene regressions +- No accidental secrets/keys/tokens introduced +- Docs build scripts and project templates stay aligned with real contributor workflow ``` -### Numerical Stability \[Sonnet\] +## Domain 11 Template: Harness & Agent Infrastructure Review \[targeted\] ``` -Applicable: NUMERICAL type detected +Applicable signals: skill_definition, platform_command_data, agent_registry_config Checklist: -- Numerical precision issues (floating point precision loss, accumulated errors) -- Numerical stability (log(0), division by zero, exp overflow, softmax stability) -- Numerical issues in loss function computation -- Gradient vanishing/exploding risks -- Scaling issues in mixed precision training +- Canonical skills and derived platform data remain structurally aligned +- Command docs still point to the correct data files and execution model +- Agent registry/config changes preserve command discovery and expert routing +- Cross-platform mirrors are regenerated after canonical changes ``` -### Tensor Parallel (TP) Correctness \[Opus\] +## Domain 12 Template: CI/CD & Release Automation Review \[comprehensive\] ``` -Applicable: TENSOR_PARALLEL or DISTRIBUTED_COMM type detected +Applicable signals: github_workflow_jobs, runner_provisioning, release_delivery Checklist: -- Missing or misplaced all-reduce -- Missing or misplaced all-gather -- Reduce handling after weight sharding (column/row sharding) -- Input Replicate / output Partial DTensor semantics -- scatter/gather correctness in Sequence Parallel (SP) -- TP group communication correctness +- Workflow triggers, job dependencies, and permissions still enforce required validation +- Runner/image provisioning remains reproducible and compatible with job expectations +- Release, docker, and docs deployment jobs publish the intended artifacts only +- CI changes do not silently skip tests, formatting, or release gates ``` -### Communication and Synchronization \[Sonnet\] +______________________________________________________________________ -``` -Applicable: DISTRIBUTED_COMM type detected -Checklist: -- Process group usage errors -- Device mesh configuration errors -- Improper barrier placement -- Unnecessary synchronization operations (GPU-CPU sync) -- Collective communication order dependencies -``` +## Signal-Specific Add-On Checklists -### API Compatibility \[Sonnet\] +Use these only when corresponding L2 signals are detected. -``` -Applicable: API_CONFIG type detected -Checklist: -- Function signature changes (parameter add/delete/rename/reorder) -- Return type changes -- Default value changes causing behavior changes -- Breaking changes to public APIs -- Deprecated API usage -- Class/module rename or move -``` - -### Configuration and Parameter Validation \[Sonnet\] +### `tree_attn` Add-On \[comprehensive\] ``` -Applicable: API_CONFIG type detected with dataclass -Checklist: -- New config items missing validation (__post_init__ validation) -- Unreasonable config default values -- Missing parameter range checks -- Unhandled dependencies between config items -- Hydra/CLI compatibility issues -- Backward compatibility of env vars/config files -- Incorrect dataclass field types +- Node/edge indexing is deterministic and shape-safe +- Tree traversal order matches attention mask semantics +- FSDP/Megatron/Archon variant modules remain behaviorally aligned ``` -### Workflow and Engine Interaction \[Sonnet\] +### `vllm_ext` Add-On \[comprehensive\] ``` -Applicable: WORKFLOW_ENGINE type detected -Checklist: -- RolloutWorkflow.arun_episode async correctness -- InferenceEngine.agenerate call patterns -- Weight version management (set_version/update_weights/WeightUpdateMeta) -- Tensor output format ([batch, seq_len, ...] convention) -- concat_padded_tensors usage correctness -- AsyncRewardWrapper wrapping requirements +- Server and worker extension hooks still match upstream expectations +- Request pause/resume/cancel semantics remain coherent +- Integration-specific monkey-patches are scoped and guarded ``` -### Activation Checkpointing (AC) \[Sonnet\] +### `archon_model_family` Add-On \[comprehensive\] ``` -Applicable: ACTIVATION_CKPT type detected -Checklist: -- AC application order (must after TP/CP, before FSDP) -- Selective AC op registration correctness -- AC config validation logic -- Compatibility with torch.compile +- ModelSpec registration stays unique and complete for supported model types +- Per-family model/args/spec/state-adapter wiring remains consistent +- Pipelining hooks and model-part boundaries stay compatible with runtime assumptions ``` -### Performance Regression Risk \[Sonnet\] +### `archon_moe_modeling` Add-On \[comprehensive\] ``` -Applicable: Any non-doc changes, especially TENSOR_OPS, DISTRIBUTED_COMM -Checklist: -- Unnecessary GPU-CPU sync (.item(), .tolist(), printing tensors) -- Memory allocation pattern changes (potential OOM) -- Communication volume increase -- Computational complexity changes -- torch.compile compatibility breakage -- Unnecessary tensor copies +- Router top-k, gating dtype, and expert grouping semantics remain coherent +- GroupedExperts layout and token reordering assumptions still match expert execution +- MoE weight conversion paths stay consistent with runtime sharding expectations ``` -### Context-Aware Review \[Sonnet\] +### `service_routing_dataflow` Add-On \[comprehensive\] ``` -Applicable: Any code changes -Checklist: -- Read git blame and history of modified code -- Check for accidental rollback of previous fixes -- Check for breaking previously established patterns or conventions -- Check if changes violate code comments -- Check for violations of TODO/FIXME constraints -- Check for ignored NOTE/WARNING comments +- Route selection and fallback ordering are deterministic +- Data proxy transformations preserve payload integrity +- Session-key partitioning logic is collision-safe ``` -### Sequence Parallel (SP/CP) Correctness \[Opus\] +### `remote_inference_backend` Add-On \[targeted\] ``` -Applicable: sequence_parallel, context_parallel, SP, CP -Checklist: -- scatter/gather operation correctness -- Attention mask handling under SP -- Position encoding sharding -- KV cache handling under CP -- Combination correctness with TP +- Remote backend request/response semantics remain consistent across supported engines +- Backend-specific transport options do not change lifecycle expectations silently +- Shared request payload assumptions remain compatible across remote backends ``` -### Checkpoint and Recovery \[Sonnet\] +### `weight_sync` Add-On \[comprehensive\] ``` -Applicable: areal/utils/saver.py, areal/utils/recover.py, state_dict, checkpoint -Checklist: -- Checkpoint save/load completeness -- Distributed checkpoint consistency -- Version compatibility (can old checkpoints load) -- Recovery logic correctness -- Optimizer state handling +- Versioned updates are monotonic and race-safe +- Broadcast/all-gather points are aligned with consumer expectations +- Local caching behavior cannot serve stale weights indefinitely ``` -### Reward Function Correctness \[Sonnet\] +### `activation_checkpointing` Add-On \[targeted\] ``` -Applicable: areal/reward/ directory -Checklist: -- Reward function signature matches (prompt, completions, prompt_ids, completion_ids, **data) -- Deterministic computation (same input produces same output) -- Blocking calls wrapped with AsyncRewardWrapper -- Numerical range reasonableness -- Edge case handling (empty input, abnormal answers) +- Checkpoint wrappers are applied in a parallelism-safe order +- Selective checkpoint policies still cover the intended modules only +- Activation recompute paths do not break sharding or sequence-parallel assumptions ``` -### Dataset Loader Correctness \[Sonnet\] +### `reward_surface` Add-On \[targeted\] ``` -Applicable: areal/dataset/ directory -Checklist: -- Data format validation (messages, answer, image_path fields) -- Tokenizer compatibility -- max_length truncation logic -- Distributed sampling correctness -- Memory efficiency (avoid loading all data at once) +- AsyncRewardWrapper-facing reward interfaces remain contract-compatible +- Reward outputs keep expected shape, dtype, and per-sample semantics +- Workflow assumptions about reward timing and batching remain valid ``` -### Launcher and Scheduler Configuration \[Sonnet\] +### `compile_dynamo` Add-On \[targeted\] ``` -Applicable: areal/infra/launcher/, areal/infra/scheduler/, areal/infra/rpc/ directories -Checklist: -- Resource config reasonableness (GPU count, memory) -- Process group config matches parallel strategy -- Environment variable passing correctness -- Container/image config compatibility -- Slurm/Ray specific configurations +- torch.compile and dynamo guards still tolerate expected dynamic-shape inputs +- fullgraph and mark_dynamic choices remain compatible with distributed execution paths +- Compile-specific changes do not silently alter runtime fallback behavior ``` -### torch.compile Compatibility \[Sonnet\] +### `rpc_transport` Add-On \[targeted\] ``` -Applicable: COMPILE type detected or hot path code modified -Checklist: -- Dynamic shape mark_dynamic marking -- Graph break risks (Python control flow, data-dependent branches) -- Unsupported operations (some in-place ops) -- fullgraph=True compatibility -- Interaction with FSDP/TP +- RTensor conversion is reversible and metadata-complete +- Batch fetch/request framing preserves ordering and boundaries +- Retry logic does not replay non-idempotent actions incorrectly ``` -### Documentation Format Check \[Haiku\] +### `runtime_image_config` Add-On \[targeted\] ``` -Applicable: DOCS type detected -Checklist: -- Markdown format correctness -- Internal link validity -- Code example correctness +- Docker base image and build args still match supported backend variants +- Layer ordering preserves expected cache and dependency behavior +- Image contents remain aligned with runtime assumptions documented in the repo ``` -### Test Coverage Check \[Haiku\] +### `project_dependency_config` Add-On \[targeted\] ``` -Applicable: TESTS type detected -Checklist: -- Test cases cover main paths -- Boundary condition tests -- Error handling tests +- Python/version constraints and extras remain internally consistent +- Lockfile changes match the intended dependency update scope +- Build backend/tooling changes do not break install or publish workflows ``` -### Logging and Metrics \[Haiku\] +### `github_workflow_jobs` Add-On \[comprehensive\] ``` -Applicable: logging, stats_tracker, StatsLogger -Checklist: -- Use areal.utils.logging.getLogger not print -- Structured metrics sent via stats_tracker -- Reasonable log levels (no DEBUG on hot paths) -- Sensitive info not logged +- Workflow triggers and job graph still run required validation paths +- Required secrets/permissions are scoped correctly +- Matrix or conditional changes do not silently skip critical jobs ``` -### Import and Dependencies \[Haiku\] +### `project_docs_metadata` Add-On \[basic\] ``` -Applicable: Any Python file changes -Checklist: -- Avoid wildcard imports (from x import *) -- Correct third-party vs internal import grouping -- Heavy optional deps inside functions -- Circular import risks +- Docs build entrypoints and contributor-facing metadata remain mutually consistent +- Public templates and contributor instructions still match the actual workflow +- Build/preview guidance still points to the supported commands ``` -### Security and Sensitive Information \[Haiku\] +### `skill_definition` / `platform_command_data` Add-On \[targeted\] ``` -Applicable: Config files, environment variables, API calls -Checklist: -- No hardcoded keys/tokens/passwords -- Sensitive info not committed to repo -- API endpoints configurable -- Error messages don't leak sensitive details +- Canonical and derived review-pr data files stay in sync after edits +- Command/import paths remain correct after file moves or renames +- Wrapper-specific routing stays out of canonical reference files ``` diff --git a/.claude/skills/commit-conventions/SKILL.md b/.claude/skills/commit-conventions/SKILL.md index 851d07b992..52f1acd1bf 100644 --- a/.claude/skills/commit-conventions/SKILL.md +++ b/.claude/skills/commit-conventions/SKILL.md @@ -115,7 +115,7 @@ delegation instead of hardcoded model names. Key changes: - Create .opencode/command/ with review-pr, create-pr -- Replace Opus/Sonnet/Haiku with deep/unspecified-high/quick +- Replace hardcoded model routing with platform-native review routing - Add expert subagent consultation patterns ``` diff --git a/.opencode/command/review-pr.md b/.opencode/command/review-pr.md index 9a1a24eab3..2a787f4dd7 100644 --- a/.opencode/command/review-pr.md +++ b/.opencode/command/review-pr.md @@ -1,5 +1,5 @@ --- -description: Intelligent PR code review with dynamic agent allocation based on change types +description: Intelligent PR code review with dynamic agent allocation based on domains and signals --- # PR Code Review (Dynamic Agent Allocation) @@ -29,7 +29,7 @@ Current branch: !`git branch --show-current` The following data files contain detection tables and task templates (auto-included): -@.opencode/data/review-pr-change-types.md @.opencode/data/review-pr-templates.md +@.opencode/data/review-pr-domains-and-signals.md @.opencode/data/review-pr-templates.md ## Arguments @@ -53,7 +53,7 @@ The following data files contain detection tables and task templates (auto-inclu Phase 1: Deep PR Analysis |- 1.0 PR Status Check [quick] |- 1.1 Get PR Summary [quick] - +- 1.2-1.4 Change Type Detection [direct analysis] ++- 1.2-1.4 Domain/Signal Detection [direct analysis] | Phase 2: Dynamic Agent Planning [direct analysis] | @@ -69,7 +69,15 @@ OpenCode uses `task()` with categories for delegating review work. **If `--quick` is set**: stop after Phase 1 and output `CHANGE_ANALYSIS_REPORT` only (do NOT delegate review tasks). -Otherwise, map risk levels to categories: +Otherwise, map canonical review depths to OpenCode categories: + +| Review Depth | OpenCode Routing | +| --------------- | ------------------------------------------------------------ | +| `comprehensive` | `deep` (and add `ultrabrain` in parallel for CRITICAL cases) | +| `targeted` | `unspecified-high` | +| `basic` | `quick` | + +Then map risk levels to review depths and categories: | Risk Level | Category | | ------------ | ---------------------------------------------------------------------------------------------- | @@ -122,34 +130,35 @@ Check if PR should be reviewed: Get basic PR info: title, description, modified files, change summary. -### 1.2 Change Type Detection +### 1.2 Domain & Signal Detection -Analyze each file change, detecting change types by risk level. +Analyze each file change, detecting L1 domains and L2 signals by risk level. -**Reference**: See `.opencode/data/review-pr-change-types.md` for complete detection +**Reference**: See `.opencode/data/review-pr-domains-and-signals.md` for complete domain tables: -- CRITICAL level types (Archon, FSDP, Megatron, DCP) -- HIGH level types (distributed comm, DTensor, MoE, TP/EP/CP) -- MEDIUM level types (tensor ops, workflow, API, compile) -- LOW level types (tests, docs, config) +- L1 domains (Distributed Runtime, Model Compute & Attention, Inference Backend & + Serving, etc.) +- L2 signals per domain +- cross-domain linkage rules -### 1.3 Framework-Specific Risk Identification +### 1.3 Domain-Specific Risk Identification -Based on detected types, identify corresponding risks. +Based on detected domains/signals, identify corresponding risks and linked checks. -**Reference**: See `.opencode/data/review-pr-change-types.md` for risk lists per +**Reference**: See `.opencode/data/review-pr-domains-and-signals.md` for risk lists per framework. ### 1.4 Output Change Analysis Report ``` CHANGE_ANALYSIS_REPORT: -- detected_types: [ARCHON_PARALLEL, EP_ETP, FSDP_CORE, ...] +- detected_domains: [Distributed Runtime, Model Compute & Attention, ...] +- detected_signals: [weight_sync, tree_attn, ...] - risk_level: CRITICAL | HIGH | MEDIUM | LOW - affected_files: [file1.py, file2.py, ...] - identified_risks: [risk1, risk2, ...] -- related_frameworks: [archon, fsdp, megatron, ...] +- related_frameworks: [archon, fsdp, megatron, vllm, service-stack, ...] ``` ______________________________________________________________________ @@ -160,8 +169,10 @@ ______________________________________________________________________ 1. **Generate tasks by risk area**: Each high-risk area gets a dedicated task 1. **Merge related changes**: Interdependent changes can be merged -1. **Category selection**: CRITICAL/HIGH -> `deep`, MEDIUM -> `unspecified-high`, LOW -> - `quick` +1. **Review depth selection**: CRITICAL/HIGH -> `comprehensive`, MEDIUM -> `targeted`, + LOW -> `basic` +1. **Category routing**: `comprehensive` -> `deep`, `targeted` -> `unspecified-high`, + `basic` -> `quick` 1. **Minimum coverage**: Even simple changes get at least 1 basic review task 1. **Skill loading**: Include relevant skills for framework-specific reviews (e.g., `add-archon-model` for Archon changes, `debug-distributed` for distributed code) @@ -170,25 +181,26 @@ ______________________________________________________________________ ### 2.2 Task Template Selection -Based on detected change types, select appropriate review task templates. +Based on detected domains/signals, select appropriate review task templates. **Reference**: See `.opencode/data/review-pr-templates.md` for complete task templates: -- Framework-specific tasks (Archon, FSDP, Megatron, DCP, Trainer) -- General tasks (Logic, Concurrency, Tensor, Numerical, TP, etc.) +- Domain templates (Distributed Runtime, Model Compute & Attention, Inference Backend & + Serving, etc.) +- Universal + signal-specific add-on templates ### 2.3 Output Review Task List ``` GENERATED_REVIEW_TASKS: -1. [deep] Task Name - - Reason: XXX change type detected - - Skills: [skill1, skill2] // or [] if none - - Expert: archon-expert // or none - - Checklist: [...] - - Focus files: [...] - -2. [unspecified-high] Task Name +1. [comprehensive -> deep] Task Name + - Reason: XXX domain/signal detected + - Skills: [skill1, skill2] // or [] if none + - Expert: archon-expert // or none + - Checklist: [...] + - Focus files: [...] + +2. [targeted -> unspecified-high] Task Name - Reason: ... - Skills: [] ... @@ -208,7 +220,7 @@ ______________________________________________________________________ ### 3.2 Delegation Template -For each review task from Phase 2, delegate as: +For each review task from Phase 2, first map review depth to category, then delegate as: ``` task( @@ -251,13 +263,13 @@ task( ) ``` -### 3.3 Review Depth by Category +### 3.3 Review Depth Mapping -| Category | Requirements | -| -------------------- | -------------------------------------------------------------------------- | -| **deep** | Complete context, cross-file traces, verify parallel strategy interactions | -| **unspecified-high** | Changed code + direct callers/callees, type signature consistency | -| **quick** | Format and basic correctness only | +| Review Depth | Category | Requirements | +| ----------------- | ------------------ | -------------------------------------------------------------------------- | +| **comprehensive** | `deep` | Complete context, cross-file traces, verify parallel strategy interactions | +| **targeted** | `unspecified-high` | Changed code + direct callers/callees, type signature consistency | +| **basic** | `quick` | Format and basic correctness only | ______________________________________________________________________ @@ -280,7 +292,8 @@ ______________________________________________________________________ ## PR Overview - **Title**: PR title -- **Detected Change Types**: [...] +- **Detected Domains**: [...] +- **Detected Signals**: [...] - **Risk Level**: CRITICAL | HIGH | MEDIUM | LOW - **Generated Review Tasks**: N @@ -321,13 +334,13 @@ ______________________________________________________________________ ## Dynamic Generation Examples -| PR Type | Detected Types | Generated Tasks | -| -------------- | ------------------------------------- | --------------------------- | -| Docs only | \[DOCS\] | 1 quick | -| Config only | \[CONFIG_ONLY\] | 1-2 quick | -| Single bug fix | \[TENSOR_OPS\] | 2-4 unspecified-high | -| Archon core | \[ARCHON\_\*, EP_ETP, DTENSOR\] | 4-8 deep + expert subagents | -| Cross-domain | \[WORKFLOW_ENGINE, FSDP_CORE, TESTS\] | 5-10 mixed categories | +| PR Type | Detected Domains/Signals | Generated Tasks | +| -------------- | --------------------------------------------------- | -------------------------------------------- | +| Docs only | \[Low-Risk Hygiene / tests_docs_config\] | 1 basic -> quick | +| Config only | \[API & Config Compatibility / dataclass_schema\] | 1-2 targeted/basic | +| Single bug fix | \[Numerics & Tensor Semantics / shape_dtype\] | 2-4 targeted | +| Archon core | \[Distributed Runtime / mesh_dtensor, weight_sync\] | 4-8 comprehensive -> deep + expert subagents | +| Cross-domain | \[Workflow & Trainer + Distributed + Hygiene\] | 5-10 mixed review depths and categories | ______________________________________________________________________ @@ -359,12 +372,12 @@ Location: .opencode/command/review-pr.md Invocation: /review-pr Related files: - - .opencode/data/review-pr-change-types.md: Change type detection tables +- .opencode/data/review-pr-domains-and-signals.md: Domain and signal detection tables - .opencode/data/review-pr-templates.md: Review task templates ## Differences from Claude Code version -1. Model names (Opus/Sonnet/Haiku) -> task() categories (deep/unspecified-high/quick) +1. Claude model routing -> OpenCode task() categories via generic review depths 2. @import syntax removed -> uses @ file references for auto-inclusion 3. allowed-tools frontmatter removed 4. Added subtask: true to run as subtask (not pollute main context) @@ -375,17 +388,13 @@ Related files: ## How to Update -### Adding New Change Types -Edit .opencode/data/review-pr-change-types.md: -1. Add to appropriate level table (CRITICAL/HIGH/MEDIUM/LOW) -2. Add framework risks if applicable -3. Also update .claude/data/review-pr-change-types.md to keep in sync +### Adding New Domains or Signals +Edit `.agents/skills/review-pr/references/review-pr-domains-and-signals.md`, then regenerate +the derived data files with `python3 .agents/skills/review-pr/sync_review_pr_refs.py --write`. ### Adding New Task Templates -Edit .opencode/data/review-pr-templates.md: -1. Add to framework-specific or general section -2. Include checklist and category assignment -3. Also update .claude/data/review-pr-templates.md to keep in sync +Edit `.agents/skills/review-pr/references/review-pr-templates.md`, then regenerate the +derived data files with `python3 .agents/skills/review-pr/sync_review_pr_refs.py --write`. ### Adjusting Category Selection Modify "Delegation Strategy" table in this file. diff --git a/.opencode/data/review-pr-change-types.md b/.opencode/data/review-pr-change-types.md deleted file mode 100644 index 11f02b4471..0000000000 --- a/.opencode/data/review-pr-change-types.md +++ /dev/null @@ -1,149 +0,0 @@ -# PR Review: Change Type Detection Reference - -This file contains the change type detection tables for PR review. Referenced by: -`.opencode/command/review-pr.md` - -______________________________________________________________________ - -## CRITICAL Level (Requires `deep` category) - -| Change Type | File Path Pattern | Code Pattern | -| ---------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------- | -| **ARCHON_CORE** | `areal/experimental/models/archon/` | - | -| **ARCHON_PARALLEL** | `parallel_dims.py` | `ArchonParallelDims`, `_build_mesh`, `DeviceMesh` | -| **ARCHON_MOE** | `archon/moe/` | `router`, `grouped_experts`, `TokenReorderer`, `grouped_mm` | -| **ARCHON_PARALLELIZE** | `qwen*/infra/parallelize.py` | `apply_moe_ep_tp`, `apply_tp`, `apply_cp` | -| **ARCHON_ENGINE** | `areal/experimental/engine/archon_engine.py` | `ArchonEngine` | -| **FSDP_CORE** | `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` | `FSDP`, `FullyShardedDataParallel`, `fully_shard` | -| **MEGATRON_CORE** | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine` | -| **DCP_CHECKPOINT** | - | `DCP`, `DistributedCheckpoint`, `dcp.save`, `dcp.load` | - -## HIGH Level (Recommend `deep` category) - -| Change Type | File Path Pattern | Code Pattern | -| --------------------- | ----------------- | -------------------------------------------------------------------------------- | -| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` | -| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` | -| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` | -| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` | -| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` | -| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` | -| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` | -| **TRAINER_CORE** | `areal/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` | - -## MEDIUM Level (Use `unspecified-high` category) - -| Change Type | File Path Pattern | Code Pattern | -| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ | -| **TENSOR_OPS** | - | `.view(`, `.reshape(`, `dtype=`, `.detach()`, `no_grad`, `.contiguous()` | -| **NUMERICAL** | - | `log(`, `softmax`, `cross_entropy`, `eps=`, `.clamp(`, `nan`, `inf` | -| **WORKFLOW_ENGINE** | `areal/workflow/`, `areal/engine/` | `arun_episode`, `agenerate`, `RolloutWorkflow` | -| **API_CONFIG** | `areal/api/` | `@dataclass`, `__post_init__`, `field(` | -| **COMPILE** | - | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | -| **ACTIVATION_CKPT** | `activation_checkpoint.py` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | -| **CHECKPOINT_RECOVERY** | `areal/utils/saver.py`, `areal/utils/recover.py`, `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/async_checkpoint.py` | `state_dict`, `load_state_dict`, `checkpoint`, `AsyncCheckpointManager` | -| **REWARD** | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | -| **DATASET** | `areal/dataset/` | `get_*_dataset`, `DataLoader`, `IterableDataset` | -| **LAUNCHER_SCHEDULER** | `areal/infra/launcher/`, `areal/infra/scheduler/`, `areal/infra/rpc/` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` | -| **ATTENTION** | `attention/`, `attention/sdpa.py`, `attention/varlen.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` | - -## LOW Level (Use `quick` category) - -| Change Type | File Path Pattern | Code Pattern | -| --------------- | ---------------------------- | ------------ | -| **TESTS** | `tests/`, `*_test.py` | - | -| **DOCS** | `docs/`, `*.md` | - | -| **CONFIG_ONLY** | `*.yaml`, `*.json`, `*.toml` | - | - -______________________________________________________________________ - -## Framework-Specific Risk Identification - -### Archon Risks (When ARCHON\_\* types detected) - -- **Device mesh dimension mismatch**: mesh dimension names don't correspond to placement -- **EP constraint violation**: `ep_size` must divide `num_experts`, and - `dp_shard * cp * (tp if etp==1 else 1) % ep == 0` -- **ETP configuration error**: `etp` must be 1 or equal to `tp` -- **Token alignment error**: `grouped_mm` requires token count aligned to 8/16/32 -- **All-to-All split/combine mismatch**: dispatch and combine split configs inconsistent -- **DTensor/Local tensor conversion missing**: need `.to_local()` or - `DTensor.from_local()` -- **torch.compile dynamic shape marking missing**: missing `mark_dynamic` calls -- **AC application order error**: must be after TP/CP, before FSDP -- **Ulysses SP configuration**: CP uses Ulysses implementation, not Ring Attention -- **dp_shard_mod_ep mesh usage**: MoE experts must use `dp_shard_mod_ep` mesh for FSDP - -### FSDP Risks (When FSDP\_\* types detected) - -- **Shard/reshard timing error**: premature or delayed sharding operations -- **EP mesh interaction issue**: should use `dp_shard_mod_ep` not `dp_shard` for MoE -- **Gradient divide factor calculation**: incorrect relationship with world size -- **State dict save/load inconsistency**: mixing sharded vs full modes -- **Optimizer state handling**: aggregation and distribution of sharded state -- **DCP compatibility**: ensure DCP save/load works with FSDP2 - -### Megatron Risks (When MEGATRON\_\* types detected) - -- **Pipeline stage splitting error**: unbalanced layer distribution -- **Micro-batch scheduling issues**: pipeline bubble handling -- **Weight sharding and sync**: tied weights handling -- **AC interaction**: checkpointing under pipeline parallelism - -### DCP/Checkpoint Risks (When DCP_CHECKPOINT or CHECKPOINT_RECOVERY detected) - -- **Distributed checkpoint consistency**: all ranks must participate in save/load -- **State dict key mismatch**: keys must match between save and load -- **Optimizer state compatibility**: ensure optimizer state is correctly - sharded/gathered -- **Version compatibility**: old checkpoints should load in new code -- **Storage backend compatibility**: ensure storage backend (filesystem, S3, etc.) is - compatible - -______________________________________________________________________ - -## Risk Linkage Rules - -| Detected Change | Auto-Linked Review | -| --------------------------- | ------------------------------------------------------ | -| EP changes | FSDP interaction check, dp_shard_mod_ep mesh check | -| ETP changes | TP + EP combination check, mesh dimension check | -| Megatron changes | Pipeline + AC check | -| Distributed comm changes | Process group + sync check | -| SEQUENCE_PARALLEL changes | TP combination + Attention mask check, Ulysses check | -| CHECKPOINT_RECOVERY changes | FSDP state dict check, DCP compatibility check | -| DCP_CHECKPOINT changes | FSDP2 integration check, distributed consistency check | -| COMPILE changes | Performance regression + FSDP/TP interaction check | -| REWARD changes | Workflow interaction check, AsyncRewardWrapper check | -| LAUNCHER_SCHEDULER changes | Resource config + parallel strategy match check | -| TRAINER_CORE changes | Engine lifecycle + workflow integration check | -| ARCHON_ENGINE changes | DCP checkpoint + parallel dims check | - -______________________________________________________________________ - -## Core Framework Paths (Requires `deep` category) - -**Archon Core**: - -- `areal/experimental/models/archon/` (entire directory) -- `areal/experimental/engine/archon_engine.py` -- `areal/experimental/engine/archon_checkpoint.py` - -**FSDP Core**: - -- `areal/engine/fsdp_utils/` -- `areal/engine/fsdp_engine.py` - -**Megatron Core**: - -- `areal/engine/megatron_engine.py` -- `areal/engine/megatron_utils/megatron.py` -- `areal/engine/megatron_utils/checkpointer.py` - -**Trainer Core**: - -- `areal/trainer/` - -**Training Engine Core** (excludes FSDP/Megatron which have their own categories): - -- `areal/engine/` (except `fsdp_engine.py`, `megatron_engine.py`) diff --git a/.opencode/data/review-pr-domains-and-signals.md b/.opencode/data/review-pr-domains-and-signals.md new file mode 100644 index 0000000000..21c65a15af --- /dev/null +++ b/.opencode/data/review-pr-domains-and-signals.md @@ -0,0 +1,253 @@ +# PR Review: Domain & Signal Detection Reference + +This file contains the canonical change-domain and signal detection tables for PR +review. Referenced by: `.opencode/command/review-pr.md` + +______________________________________________________________________ + +## Severity-to-Review-Depth Mapping + +- **CRITICAL**: use `comprehensive` review depth +- **HIGH**: use `comprehensive` review depth +- **MEDIUM**: use `targeted` review depth +- **LOW**: use `basic` review depth + +______________________________________________________________________ + +## L1 Domains and L2 Signals + +## Domain 1: Distributed Runtime (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `archon_core` | `areal/experimental/engine/archon_engine.py`, `areal/experimental/engine/archon_checkpoint.py` | `ArchonEngine`, `ArchonCheckpointManager`, `archon` | +| `archon_parallel` | `areal/experimental/models/archon/`, `parallel_dims.py`, `parallelize.py` | `ArchonParallelDims`, `_build_mesh`, `apply_moe_ep_tp`, `apply_tp`, `apply_cp`, `ExpertTensorParallel`, `etp`, `parallelize_module`, `ColwiseParallel`, `RowwiseParallel` | +| `process_group` | `areal/engine/fsdp_utils/`, `areal/engine/megatron_utils/`, `areal/experimental/engine/` | `new_group`, `ProcessGroup`, `dist.get_rank(` | +| `fsdp_core` | `areal/engine/fsdp_engine.py`, `areal/engine/fsdp_utils/` | `FSDP`, `fully_shard`, `FullyShardedDataParallel` | +| `megatron_core` | `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` | `MegatronEngine`, `pipeline`, `micro-batch` | +| `collectives` | `areal/engine/`, `areal/infra/rpc/` | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `broadcast`, `barrier` | +| `mesh_dtensor` | `areal/experimental/models/archon/`, `areal/engine/fsdp_utils/` | `DeviceMesh`, `DTensor`, `Shard(`, `Replicate(`, `distribute_tensor` | +| `activation_checkpointing` | `areal/experimental/models/archon/activation_checkpoint.py`, `areal/models/`, `areal/engine/` | `activation_checkpoint`, `checkpoint_wrapper`, `selective_checkpoint` | +| `weight_sync` | `areal/experimental/engine/archon_weight_sync.py`, `areal/api/engine_api.py`, `areal/engine/` | `WeightUpdateMeta`, `set_version`, `update_weights` | + +## Domain 2: Model Compute & Attention (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| `tree_attn` | `areal/models/tree_attn/` | `TreeAttention`, `tree_attn`, `TreeNode`, `tree` | +| `sdpa_varlen` | `attention/sdpa.py`, `attention/varlen.py`, `areal/models/tree_attn/` | `sdpa`, `flash_attn`, `varlen`, `causal_mask` | +| `sp_cp_attention_mask` | `areal/models/tree_attn/`, `areal/experimental/models/archon/attention/` | `SequenceParallel`, `context_parallel`, `mask` | +| `triton_kernel` | `areal/models/tree_attn/triton_kernel.py` | `triton`, `kernel`, `autotune` | +| `archon_model_family` | `areal/experimental/models/archon/model_spec.py`, `areal/experimental/models/archon/qwen*/` | `ModelSpec`, `register_model_spec`, `supported_model_types`, `state_dict_adapter`, `rope` | +| `archon_attention_stack` | `areal/experimental/models/archon/attention/`, `areal/experimental/models/archon/ulysses.py` | `ulysses_slice_inputs`, `ulysses_gather_output`, `gather_seq_scatter_heads`, `sdpa`, `varlen` | +| `archon_moe_modeling` | `areal/experimental/models/archon/moe/`, `areal/experimental/models/archon/expert_parallel.py`, `areal/experimental/models/archon/moe_weight_converter.py` | `TokenChoiceTopKRouter`, `RouterGateLinear`, `GroupedExperts`, `MoEWeightConverter`, `expert_parallel` | + +## Domain 3: Inference Backend & Serving (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | -------------------------------------------------------------- | ---------------------------------------------------------------- | +| `vllm_ext` | `areal/engine/vllm_ext/` | `areal_vllm_server`, `vllm_worker_extension`, `pause_generation` | +| `remote_inference_backend` | `areal/engine/vllm_remote.py`, `areal/engine/sglang_remote.py` | `vllm`, `sglang`, `OpenAI`, `request`, `response` | +| `request_lifecycle` | `areal/engine/`, `areal/infra/launcher/` | `enqueue`, `dequeue`, `cancel`, `timeout` | + +## Domain 4: Service Orchestration (HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `service_routing_dataflow` | `areal/experimental/agent_service/gateway/`, `areal/experimental/agent_service/router/`, `areal/experimental/inference_service/data_proxy/`, `areal/experimental/inference_service/controller/` | `route`, `gateway`, `router`, `DataProxy`, `controller`, `batch` | +| `session_consistency` | `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `session`, `affinity`, `history`, `state` | + +## Domain 5: Workflow & Trainer Contract (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| -------------------------- | ----------------------------------------------------------------------------------------------- | --------------------------------------------------- | +| `workflow_engine_boundary` | `areal/workflow/`, `areal/trainer/`, `areal/engine/` | `RolloutWorkflow`, `arun_episode`, `agenerate` | +| `dataset_surface` | `areal/dataset/` | `DataLoader`, `IterableDataset`, `get_*_dataset` | +| `async_contract` | `areal/workflow/`, `areal/experimental/agent_service/`, `areal/experimental/inference_service/` | `async def`, `await`, `aiofiles`, `asyncio` | +| `weight_version_contract` | `areal/api/engine_api.py`, `areal/workflow/`, `areal/trainer/` | `WeightUpdateMeta`, `set_version`, `weight version` | + +## Domain 6: API & Config Compatibility (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------------- | ------------------------------------- | -------------------------------------------------------------------------------- | +| `dataclass_schema` | `areal/api/` | `@dataclass`, `field(`, `__post_init__` | +| `cli_compat` | `areal/api/cli_args.py` | `Literal`, `help`, `default` | +| `backward_compat` | `areal/api/`, `areal/infra/launcher/` | `deprecated`, `compat`, `version` | +| `project_dependency_config` | `pyproject.toml`, `uv.lock` | `requires-python`, `dependencies`, `optional-dependencies`, `build-system`, `uv` | + +## Domain 7: Numerics & Tensor Semantics (MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| --------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------- | +| `shape_dtype` | `areal/engine/`, `areal/models/`, `areal/trainer/` | `.view(`, `.reshape(`, `dtype=`, `.contiguous(` | +| `numerical_stability` | `areal/engine/`, `areal/reward/`, `areal/utils/functional/` | `log(`, `softmax`, `eps=`, `.clamp(`, `nan`, `inf` | +| `reward_surface` | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` | +| `compile_dynamo` | `areal/experimental/models/archon/compile.py`, `areal/models/`, `areal/engine/` | `torch.compile`, `_dynamo`, `mark_dynamic`, `fullgraph` | +| `mixed_precision_fp8` | `areal/engine/megatron_utils/fp8/`, `areal/experimental/models/archon/` | `fp8`, `bf16`, `fp16`, `mixed precision` | + +## Domain 8: Checkpoint & Recovery (CRITICAL/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------- | ------------------------------------------------------------------- | ----------------------------------------------- | +| `dcp_consistency` | `areal/utils/async_checkpoint.py`, `areal/engine/**/checkpoint*.py` | `dcp.save`, `dcp.load`, `DistributedCheckpoint` | +| `optimizer_state` | `areal/engine/fsdp_utils/checkpoint.py`, `areal/utils/saver.py` | `optimizer state`, `state_dict` | +| `resume_compat` | `areal/utils/recover.py`, `areal/utils/saver.py` | `resume`, `load_state_dict`, `migration` | + +## Domain 9: Launcher & Infrastructure (HIGH/MEDIUM) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ---------------------------------------------------------------------- | --------------------------------------------------------- | +| `launcher_resource_match` | `areal/infra/launcher/` | `LaunchConfig`, `RayLauncher`, `SlurmLauncher` | +| `scheduler_contract` | `areal/infra/scheduler/`, `areal/scheduler/` | `Scheduler`, `placement`, `resource` | +| `rpc_transport` | `areal/infra/rpc/`, `areal/experimental/inference_service/data_proxy/` | `RTensor`, `serialize`, `rpc`, `fetch` | +| `runtime_image_config` | `Dockerfile`, `.dockerignore` | `FROM`, `ARG`, `RUN`, `ENV`, `COPY`, `uv sync`, `VARIANT` | + +## Domain 10: Low-Risk Hygiene (LOW) + +| L2 Signal | File Path Pattern | Code Pattern | +| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ | +| `tests_docs_config` | `tests/`, `docs/`, `*.md`, `*.yaml`, `*.json`, `*.toml` | - | +| `logging_import_security` | `areal/`, `examples/` | `getLogger`, `print(`, `import *`, `api_key`, `token`, `password` | +| `project_docs_metadata` | `docs/build_all.sh`, `docs/generate_cli_docs.py`, `docs/en/`, `docs/zh/`, `README.md`, `CONTRIBUTING.md`, `.github/PULL_REQUEST_TEMPLATE.md`, `.github/ISSUE_TEMPLATE/` | `jupyter-book`, `generate_cli_docs`, `build_all`, `_build`, `checklist`, `template`, `contributing`, `usage` | + +## Domain 11: Harness & Agent Infrastructure (MEDIUM/HIGH) + +| L2 Signal | File Path Pattern | Code Pattern | +| ----------------------- | ----------------------------------------------------------------------------- | ------------------------------------------------------------ | +| `skill_definition` | `.agents/skills/**/SKILL.md`, `.agents/skills/**/references/` | `description:`, `## Workflow`, `## Reference Files`, `skill` | +| `platform_command_data` | `.claude/commands/`, `.claude/data/`, `.opencode/command/`, `.opencode/data/` | `@.`, `/review-pr`, `/create-pr`, `data/`, `task(` | +| `agent_registry_config` | `.codex/config.toml`, `.codex/agents/`, `AGENTS.md`, `CLAUDE.md` | `agents`, `skills`, `registry`, `subagent`, `config.toml` | + +## Domain 12: CI/CD & Release Automation (HIGH/CRITICAL) + +| L2 Signal | File Path Pattern | Code Pattern | +| ---------------------- | -------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------- | +| `github_workflow_jobs` | `.github/workflows/*.yml` | `jobs:`, `runs-on:`, `needs:`, `if:`, `workflow_dispatch` | +| `runner_provisioning` | `.github/workflows/bake-gcp-image.yml`, `.github/workflows/runner-heartbeat.yml` | `gcp`, `runner`, `image`, `heartbeat` | +| `release_delivery` | `.github/workflows/build-docker-image.yml`, `.github/workflows/tag-release-image.yml`, `.github/workflows/deploy-docs.yml` | `docker`, `tag`, `release`, `pages`, `publish` | + +______________________________________________________________________ + +## Must-Not-Regress Core Coverage + +The refactor must preserve these existing review surfaces: + +- Archon core: `areal/experimental/models/archon/`, + `areal/experimental/engine/archon_engine.py` +- FSDP core: `areal/engine/fsdp_utils/`, `areal/engine/fsdp_engine.py` +- Megatron core: `areal/engine/megatron_engine.py`, `areal/engine/megatron_utils/` +- Reward: `areal/reward/` +- Dataset: `areal/dataset/` +- Trainer: `areal/trainer/` +- Harness: `.agents/`, `.claude/`, `.opencode/`, `.codex/` +- CI/CD and release: `.github/workflows/`, `Dockerfile`, `pyproject.toml` + +______________________________________________________________________ + +## Cross-Domain Linkage Rules + +| Detected Signal | Auto-Linked Review | +| ---------------------------------------------- | --------------------------------------------------- | +| `archon_core` or `archon_parallel` | Model Compute & Attention checks | +| `archon_model_family` or `archon_moe_modeling` | Numerics & Tensor Semantics checks | +| `tree_attn` | Numerics & Tensor Semantics checks | +| `reward_surface` | Workflow & Trainer Contract checks | +| `compile_dynamo` | Distributed Runtime checks | +| `vllm_ext` | Launcher & Infrastructure checks | +| `service_routing_dataflow` | Workflow & Trainer async-contract checks | +| `weight_sync` | DTensor/process-group/checkpoint interaction checks | +| `rpc_transport` | Distributed Runtime synchronization checks | +| `mixed_precision_fp8` + Distributed Runtime | mesh + weight-sync compatibility checks | +| `runtime_image_config` | Inference Backend & Serving checks | +| `project_dependency_config` | API & Config Compatibility checks | +| `github_workflow_jobs` or `release_delivery` | Launcher & Infrastructure checks | +| `skill_definition` or `platform_command_data` | Low-Risk Hygiene checks | + +______________________________________________________________________ + +## Risk Identification Guidance + +### Distributed Runtime Risks + +- Archon mesh construction or parallel-dims mismatch +- EP/TP/CP application order errors in Archon parallelization +- Activation checkpoint placement violating TP/CP/FSDP ordering assumptions +- Archon engine lifecycle drift around distributed setup and checkpoint boundaries +- Collective call order mismatch across ranks +- Wrong process-group scope in rank-sensitive logic +- Mesh dimension mismatch and invalid DTensor placement +- Weight version drift between rollout and training workers + +### Model Compute & Attention Risks + +- Attention mask inconsistency under TP/SP/CP paths +- Tree attention index/routing mismatch +- Archon model-family registration or per-family wiring drift +- Archon MoE router/expert behavior diverging from weight-conversion expectations +- Archon Ulysses slicing/gather semantics mismatching attention layout assumptions +- Kernel assumptions violating dtype/shape invariants +- Sequence packing alignment errors + +### Service Orchestration Risks + +- Session affinity or history drift across gateway/router/data proxy +- Async message handling holes and dropped tasks +- Controller/worker lifecycle desynchronization + +### Inference Backend & Serving Risks + +- Request lifecycle inconsistencies (enqueue/cancel/timeout) +- Worker state transitions leaving requests stranded +- Backend extension hooks drifting from runtime expectations + +### Workflow & Trainer Contract Risks + +- Workflow-engine contract drift across async boundaries +- Weight version handshake mismatch between rollout and train +- Trainer lifecycle transition inconsistencies + +### API & Config Compatibility Risks + +- Breaking config/schema changes without migration path +- Dataclass or CLI default changes altering behavior silently +- Missing validation for newly introduced fields +- Dependency or build-system pin changes breaking supported environments + +### Numerics & Tensor Semantics Risks + +- Silent shape/dtype mismatch under distributed paths +- Unstable numerical operations in loss/reward logic +- torch.compile or dynamo guard changes breaking graph assumptions +- Mixed-precision interaction regressions + +### Checkpoint & Recovery Risks + +- Partial-rank checkpoint participation +- Incompatible state key evolution +- Resume path breaking optimizer/model synchronization + +### Launcher & Infrastructure Risks + +- Resource assignment mismatching parallel strategy assumptions +- RPC transport metadata loss (shape/dtype/device) +- Startup/shutdown ordering races across processes +- Runtime image or build-arg drift from supported inference/training variants + +### Low-Risk Hygiene Risks + +- Docs/config drift from actual runtime behavior +- Logging or import hygiene regressions +- Sensitive data exposure in logs or config +- Documentation/build scripts or project templates drifting from actual workflow + +### Harness & Agent Infrastructure Risks + +- Skill and command docs drifting from actual platform behavior +- Cross-platform data files falling out of sync with canonical references +- Agent registry/config changes breaking expert routing or command discovery + +### CI/CD & Release Automation Risks + +- Workflow trigger or job dependency changes skipping required validation +- Runner provisioning drift causing flaky or non-reproducible CI +- Release or docs deployment jobs publishing the wrong artifacts diff --git a/.opencode/data/review-pr-templates.md b/.opencode/data/review-pr-templates.md index cf19d48af8..7f2f3a9887 100644 --- a/.opencode/data/review-pr-templates.md +++ b/.opencode/data/review-pr-templates.md @@ -1,424 +1,320 @@ -# PR Review: Task Templates Reference +# PR Review: Domain Templates Reference -This file contains the review task templates for PR review. Referenced by: +This file contains canonical domain templates for PR review. Referenced by: `.opencode/command/review-pr.md` ______________________________________________________________________ -## Framework-Specific Review Task Templates +## Template Selection Rules -### Archon Tasks \[deep\] +1. Select templates by detected L1 domains and L2 signals. +1. Use at most one primary template per domain. +1. Always include **General Logic & Boundary** for non-doc/config-only PRs. +1. Apply cross-domain linkage checks from `review-pr-domains-and-signals.md`. -**Task: Archon EP/ETP Strategy Correctness Review** +______________________________________________________________________ -``` -Checklist: -- ExpertParallel, TensorParallel, ExpertTensorParallel placement implementation -- Placement dimension matching with mesh dimensions -- Placement list length in _partition_fn -- all_to_all communication autograd compatibility -- ReordererSequenceParallel token index conversion -``` +## Universal Template -**Task: ArchonParallelDims Configuration Validation** +### General Logic & Boundary ``` +Applicable: Any non-doc/config-only change Checklist: -- ETP constraint: etp=1 (TP borrowed by EP) vs etp=tp (independent TP) logic -- Mesh construction: _build_mesh_with_ep() dimension order and names -- EP/TP/CP combination validity verification -- dp_shard * cp * (tp if etp==1 else 1) % ep == 0 constraint +- Boundary condition correctness (empty inputs, singleton, max-size) +- Conditional logic correctness (branch inversion, short-circuit mistakes) +- Error-path behavior (exceptions propagated with actionable context) +- Return-value consistency across code paths +- No newly introduced hidden behavior changes ``` -**Task: MoE Layer Implementation Correctness** - -``` -Checklist: -- TokenReorderer and router separation correctness -- grouped_mm token alignment (8/16/32) -- Expert weight 3D tensor sharding -- Load balancing loss calculation -``` - -**Task: Model Parallelization Application Order** - -``` -Checklist: -- apply_moe_ep_tp() strategy selection logic -- FSDP wrap order (EP -> TP -> AC -> FSDP) -- torch.compile dynamic shape marking -- Explicit prefetching configuration -``` - -### FSDP Tasks \[deep / unspecified-high\] - -**Task: FSDP Core Correctness \[deep\]** - -``` -Checklist: -- Shard/reshard operation timing and correctness -- ShardedTensor and DTensor conversion -- Mixed precision (param_dtype vs reduce_dtype) -``` +______________________________________________________________________ -**Task: FSDP Interaction with Other Parallel Strategies \[deep\]** +## Domain 1 Template: Distributed Runtime Review \[comprehensive\] ``` +Applicable signals: archon_core, archon_parallel, process_group, fsdp_core, megatron_core, collectives, mesh_dtensor, activation_checkpointing, weight_sync Checklist: -- FSDP must be applied after TP/CP/EP -- Use dp_shard_mod_ep mesh in EP scenarios -- Gradient divide factor relationship with world size +- Archon engine and checkpoint lifecycle remain aligned with distributed runtime assumptions +- FSDP and Megatron engine invariants still match process-group, sharding, and pipeline assumptions +- Archon parallel-dims and mesh construction still match downstream placement logic +- Process-group creation/usage/cleanup is rank-consistent +- Collective operations are called by all required ranks in consistent order +- DeviceMesh dimensions and DTensor placements are correct for each path +- Activation checkpoint placement remains compatible with parallel and sharding order requirements +- Local/global tensor conversion boundaries are explicit and correct +- Weight version propagation and update ordering are deterministic +- No debug-only barriers left in hot path ``` -**Task: FSDP State Management \[unspecified-high\]** +## Domain 2 Template: Model Compute & Attention Review \[comprehensive\] ``` +Applicable signals: tree_attn, sdpa_varlen, sp_cp_attention_mask, triton_kernel, archon_model_family, archon_attention_stack, archon_moe_modeling Checklist: -- state_dict save/load sharded vs full mode -- Optimizer state sharding and aggregation -- Checkpoint compatibility +- Attention mask semantics preserved under TP/SP/CP +- Archon model-family registration and per-family wiring remain internally consistent +- Archon attention/Ulysses slicing and gather paths preserve layout assumptions +- Archon MoE router, grouped experts, and weight-conversion interfaces remain aligned +- Tree attention index/order invariants are maintained +- Kernel assumptions on dtype/shape/contiguity are satisfied +- No silent behavior change in sequence packing/unpacking +- Tensor layouts remain compatible with downstream modules ``` -### Megatron Tasks \[deep\] - -**Task: Pipeline Parallelism Correctness** +## Domain 3 Template: Inference Backend & Serving Review \[comprehensive\] ``` +Applicable signals: vllm_ext, remote_inference_backend, request_lifecycle Checklist: -- Stage splitting correctness and balance -- Micro-batch scheduling -- Pipeline flush and bubble handling +- Request lifecycle (enqueue, execution, cancellation, timeout) is coherent +- Worker state transitions are safe under concurrency +- Backend-specific extension points stay API-compatible +- Error handling does not strand in-flight requests +- Versioning/weight-update interactions are explicit and safe ``` -**Task: Megatron Model Sharding** +## Domain 4 Template: Service Orchestration Review \[comprehensive\] ``` +Applicable signals: service_routing_dataflow, session_consistency Checklist: -- Weight sharding and synchronization -- Tied weights handling -- Embedding/output layer parallel strategy +- Gateway/router/data-proxy routing rules are deterministic +- Session affinity and history consistency are preserved +- Controller/worker coordination has no lost-update window +- Async boundaries avoid blocking operations in critical paths +- Failure/retry behavior does not duplicate or drop work ``` -### DCP/Checkpoint Tasks \[deep\] - -**Task: Distributed Checkpoint Correctness** +## Domain 5 Template: Workflow & Trainer Contract Review \[comprehensive\] ``` +Applicable signals: workflow_engine_boundary, dataset_surface, async_contract, weight_version_contract Checklist: -- All ranks participate in DCP save/load operations -- State dict keys match between save and load -- No tensor shape/dtype mismatches -- Storage backend compatibility (filesystem, S3) -- Checkpoint versioning and migration +- RolloutWorkflow and Engine interfaces remain contract-compatible +- Dataset/output structure still matches workflow and trainer consumption expectations +- Async flow uses await consistently and avoids sync I/O in async paths +- Weight update/version handshake is preserved end-to-end +- Trainer lifecycle transitions are valid for all execution branches +- Call ordering assumptions across trainer/workflow/engine are unchanged or justified ``` -**Task: FSDP2 + DCP Integration** +## Domain 6 Template: API & Config Compatibility Review \[targeted\] ``` +Applicable signals: dataclass_schema, cli_compat, backward_compat, project_dependency_config Checklist: -- FSDP2 state dict options (full vs sharded) -- Optimizer state handling with DCP -- Async checkpointing correctness -- Checkpoint resumption logic +- Public API signature and default value changes are intentional and compatible +- Dataclass validation remains complete and informative +- CLI options preserve expected compatibility semantics +- New fields include safe defaults or explicit migration handling +- Breaking changes are documented and scoped +- Dependency and build-system changes remain compatible with supported environments ``` -### Trainer Tasks \[deep\] - -**Task: Trainer Core Logic** +## Domain 7 Template: Numerics & Tensor Semantics Review \[targeted\] ``` +Applicable signals: shape_dtype, numerical_stability, reward_surface, compile_dynamo, mixed_precision_fp8 Checklist: -- PPOTrainer/SFTTrainer initialization correctness -- Workflow registration and invocation -- Engine lifecycle management -- Distributed training coordination +- Tensor shape/dtype transitions are explicit and internally consistent +- Numerical stability is protected (log/division/softmax/clamp paths) +- Reward-side numerical behavior remains consistent with workflow consumption expectations +- torch.compile / dynamo assumptions still hold for dynamic shapes and distributed execution +- Mixed-precision behavior is correct for forward + backward + reduce paths +- In-place and view/reshape operations do not corrupt gradient flow +- Device placement and dtype combinations remain legal across code paths ``` -______________________________________________________________________ - -## General Review Task Templates - -### Logic and Boundary Conditions \[deep\] +## Domain 8 Template: Checkpoint & Recovery Review \[comprehensive\] ``` -Applicable: Any non-doc/config changes +Applicable signals: dcp_consistency, optimizer_state, resume_compat Checklist: -- Conditional logic errors (if/else inversion, boundary condition omission, short-circuit issues) -- Loop errors (off-by-one, infinite loops, early exit, iterator invalidation) -- Missing null/None/empty list handling -- Type mismatch or implicit type conversion issues -- Improper exception handling (swallowing exceptions, wrong exception type, return in finally) -- Return value errors (wrong type, missing return, inconsistent multi-path returns) -- Boolean expression errors (De Morgan's law violation, precedence errors) +- Save/load requires and enforces all-rank participation where needed +- State dict naming/structure is stable or migration-safe +- Optimizer state sharding/gather behavior is consistent +- Resume path restores model + optimizer + version state coherently +- Async checkpoint behavior preserves ordering and durability assumptions ``` -### Concurrency and Async \[deep\] +## Domain 9 Template: Launcher & Infrastructure Review \[targeted\] ``` -Applicable: ASYNC_CONCURRENT type detected +Applicable signals: launcher_resource_match, scheduler_contract, rpc_transport, runtime_image_config Checklist: -- Race conditions -- Deadlock risks (inconsistent lock ordering, nested locks) -- Non-thread-safe access to shared state -- Missing await in async code -- Blocking calls in async functions (should use executor) -- Resource leaks (file handles, network connections, GPU memory not released) -- State inconsistency (dirty state after partial update failure) -- Improper context manager usage -- Signal handling and graceful shutdown issues +- Resource assignment matches declared parallel strategy assumptions +- Scheduler decisions preserve required placement/affinity constraints +- RPC serialization/deserialization keeps shape/dtype/device semantics +- Transport retries/timeouts do not violate idempotency expectations +- Cross-process startup/shutdown ordering is robust +- Runtime image and build configuration remain aligned with supported variants ``` -### Tensor Shape and Data Type \[deep\] +## Domain 10 Template: Low-Risk Hygiene Review \[basic\] ``` -Applicable: TENSOR_OPS type detected with complex tensor operations +Applicable signals: tests_docs_config, logging_import_security, project_docs_metadata Checklist: -- Tensor shape mismatch (dimension errors, broadcast errors) -- Batch dimension handling errors (missing batch dim, wrong dimension order) -- Sequence length and padding handling (missing mask, padding token in computation) -- Index out of bounds risk (dynamic indexing, negative indexing) -- dtype mismatch (fp16/fp32/bf16 mixing, integer overflow) -- Device placement errors (tensor on wrong device, CPU/GPU mixed operations) -- Gradient-related issues (missing detach, missing no_grad context, gradient accumulation errors) -- view/reshape contiguity requirements -- In-place operation effects on gradient computation +- Tests/docs/config edits are internally consistent and non-misleading +- Logging follows project conventions and avoids sensitive leakage +- No wildcard imports or obvious dependency hygiene regressions +- No accidental secrets/keys/tokens introduced +- Docs build scripts and project templates stay aligned with real contributor workflow ``` -### Numerical Stability \[unspecified-high\] +## Domain 11 Template: Harness & Agent Infrastructure Review \[targeted\] ``` -Applicable: NUMERICAL type detected +Applicable signals: skill_definition, platform_command_data, agent_registry_config Checklist: -- Numerical precision issues (floating point precision loss, accumulated errors) -- Numerical stability (log(0), division by zero, exp overflow, softmax stability) -- Numerical issues in loss function computation -- Gradient vanishing/exploding risks -- Scaling issues in mixed precision training +- Canonical skills and derived platform data remain structurally aligned +- Command docs still point to the correct data files and execution model +- Agent registry/config changes preserve command discovery and expert routing +- Cross-platform mirrors are regenerated after canonical changes ``` -### Tensor Parallel (TP) Correctness \[deep\] +## Domain 12 Template: CI/CD & Release Automation Review \[comprehensive\] ``` -Applicable: TENSOR_PARALLEL or DISTRIBUTED_COMM type detected +Applicable signals: github_workflow_jobs, runner_provisioning, release_delivery Checklist: -- Missing or misplaced all-reduce -- Missing or misplaced all-gather -- Reduce handling after weight sharding (column/row sharding) -- Input Replicate / output Partial DTensor semantics -- scatter/gather correctness in Sequence Parallel (SP) -- TP group communication correctness +- Workflow triggers, job dependencies, and permissions still enforce required validation +- Runner/image provisioning remains reproducible and compatible with job expectations +- Release, docker, and docs deployment jobs publish the intended artifacts only +- CI changes do not silently skip tests, formatting, or release gates ``` -### Communication and Synchronization \[unspecified-high\] +______________________________________________________________________ -``` -Applicable: DISTRIBUTED_COMM type detected -Checklist: -- Process group usage errors -- Device mesh configuration errors -- Improper barrier placement -- Unnecessary synchronization operations (GPU-CPU sync) -- Collective communication order dependencies -``` +## Signal-Specific Add-On Checklists -### API Compatibility \[unspecified-high\] +Use these only when corresponding L2 signals are detected. -``` -Applicable: API_CONFIG type detected -Checklist: -- Function signature changes (parameter add/delete/rename/reorder) -- Return type changes -- Default value changes causing behavior changes -- Breaking changes to public APIs -- Deprecated API usage -- Class/module rename or move -``` - -### Configuration and Parameter Validation \[unspecified-high\] +### `tree_attn` Add-On \[comprehensive\] ``` -Applicable: API_CONFIG type detected with dataclass -Checklist: -- New config items missing validation (__post_init__ validation) -- Unreasonable config default values -- Missing parameter range checks -- Unhandled dependencies between config items -- Hydra/CLI compatibility issues -- Backward compatibility of env vars/config files -- Incorrect dataclass field types +- Node/edge indexing is deterministic and shape-safe +- Tree traversal order matches attention mask semantics +- FSDP/Megatron/Archon variant modules remain behaviorally aligned ``` -### Workflow and Engine Interaction \[unspecified-high\] +### `vllm_ext` Add-On \[comprehensive\] ``` -Applicable: WORKFLOW_ENGINE type detected -Checklist: -- RolloutWorkflow.arun_episode async correctness -- InferenceEngine.agenerate call patterns -- Weight version management (set_version/update_weights/WeightUpdateMeta) -- Tensor output format ([batch, seq_len, ...] convention) -- concat_padded_tensors usage correctness -- AsyncRewardWrapper wrapping requirements +- Server and worker extension hooks still match upstream expectations +- Request pause/resume/cancel semantics remain coherent +- Integration-specific monkey-patches are scoped and guarded ``` -### Activation Checkpointing (AC) \[unspecified-high\] +### `archon_model_family` Add-On \[comprehensive\] ``` -Applicable: ACTIVATION_CKPT type detected -Checklist: -- AC application order (must after TP/CP, before FSDP) -- Selective AC op registration correctness -- AC config validation logic -- Compatibility with torch.compile +- ModelSpec registration stays unique and complete for supported model types +- Per-family model/args/spec/state-adapter wiring remains consistent +- Pipelining hooks and model-part boundaries stay compatible with runtime assumptions ``` -### Performance Regression Risk \[unspecified-high\] +### `archon_moe_modeling` Add-On \[comprehensive\] ``` -Applicable: Any non-doc changes, especially TENSOR_OPS, DISTRIBUTED_COMM -Checklist: -- Unnecessary GPU-CPU sync (.item(), .tolist(), printing tensors) -- Memory allocation pattern changes (potential OOM) -- Communication volume increase -- Computational complexity changes -- torch.compile compatibility breakage -- Unnecessary tensor copies +- Router top-k, gating dtype, and expert grouping semantics remain coherent +- GroupedExperts layout and token reordering assumptions still match expert execution +- MoE weight conversion paths stay consistent with runtime sharding expectations ``` -### Context-Aware Review \[unspecified-high\] +### `service_routing_dataflow` Add-On \[comprehensive\] ``` -Applicable: Any code changes -Checklist: -- Read git blame and history of modified code -- Check for accidental rollback of previous fixes -- Check for breaking previously established patterns or conventions -- Check if changes violate code comments -- Check for violations of TODO/FIXME constraints -- Check for ignored NOTE/WARNING comments +- Route selection and fallback ordering are deterministic +- Data proxy transformations preserve payload integrity +- Session-key partitioning logic is collision-safe ``` -### Sequence Parallel (SP/CP) Correctness \[deep\] +### `remote_inference_backend` Add-On \[targeted\] ``` -Applicable: sequence_parallel, context_parallel, SP, CP -Checklist: -- scatter/gather operation correctness -- Attention mask handling under SP -- Position encoding sharding -- KV cache handling under CP -- Combination correctness with TP +- Remote backend request/response semantics remain consistent across supported engines +- Backend-specific transport options do not change lifecycle expectations silently +- Shared request payload assumptions remain compatible across remote backends ``` -### Checkpoint and Recovery \[unspecified-high\] +### `weight_sync` Add-On \[comprehensive\] ``` -Applicable: areal/utils/saver.py, areal/utils/recover.py, state_dict, checkpoint -Checklist: -- Checkpoint save/load completeness -- Distributed checkpoint consistency -- Version compatibility (can old checkpoints load) -- Recovery logic correctness -- Optimizer state handling +- Versioned updates are monotonic and race-safe +- Broadcast/all-gather points are aligned with consumer expectations +- Local caching behavior cannot serve stale weights indefinitely ``` -### Reward Function Correctness \[unspecified-high\] +### `activation_checkpointing` Add-On \[targeted\] ``` -Applicable: areal/reward/ directory -Checklist: -- Reward function signature matches (prompt, completions, prompt_ids, completion_ids, **data) -- Deterministic computation (same input produces same output) -- Blocking calls wrapped with AsyncRewardWrapper -- Numerical range reasonableness -- Edge case handling (empty input, abnormal answers) +- Checkpoint wrappers are applied in a parallelism-safe order +- Selective checkpoint policies still cover the intended modules only +- Activation recompute paths do not break sharding or sequence-parallel assumptions ``` -### Dataset Loader Correctness \[unspecified-high\] +### `reward_surface` Add-On \[targeted\] ``` -Applicable: areal/dataset/ directory -Checklist: -- Data format validation (messages, answer, image_path fields) -- Tokenizer compatibility -- max_length truncation logic -- Distributed sampling correctness -- Memory efficiency (avoid loading all data at once) +- AsyncRewardWrapper-facing reward interfaces remain contract-compatible +- Reward outputs keep expected shape, dtype, and per-sample semantics +- Workflow assumptions about reward timing and batching remain valid ``` -### Launcher and Scheduler Configuration \[unspecified-high\] +### `compile_dynamo` Add-On \[targeted\] ``` -Applicable: areal/infra/launcher/, areal/infra/scheduler/, areal/infra/rpc/ directories -Checklist: -- Resource config reasonableness (GPU count, memory) -- Process group config matches parallel strategy -- Environment variable passing correctness -- Container/image config compatibility -- Slurm/Ray specific configurations +- torch.compile and dynamo guards still tolerate expected dynamic-shape inputs +- fullgraph and mark_dynamic choices remain compatible with distributed execution paths +- Compile-specific changes do not silently alter runtime fallback behavior ``` -### torch.compile Compatibility \[unspecified-high\] +### `rpc_transport` Add-On \[targeted\] ``` -Applicable: COMPILE type detected or hot path code modified -Checklist: -- Dynamic shape mark_dynamic marking -- Graph break risks (Python control flow, data-dependent branches) -- Unsupported operations (some in-place ops) -- fullgraph=True compatibility -- Interaction with FSDP/TP +- RTensor conversion is reversible and metadata-complete +- Batch fetch/request framing preserves ordering and boundaries +- Retry logic does not replay non-idempotent actions incorrectly ``` -### Documentation Format Check \[quick\] +### `runtime_image_config` Add-On \[targeted\] ``` -Applicable: DOCS type detected -Checklist: -- Markdown format correctness -- Internal link validity -- Code example correctness +- Docker base image and build args still match supported backend variants +- Layer ordering preserves expected cache and dependency behavior +- Image contents remain aligned with runtime assumptions documented in the repo ``` -### Test Coverage Check \[quick\] +### `project_dependency_config` Add-On \[targeted\] ``` -Applicable: TESTS type detected -Checklist: -- Test cases cover main paths -- Boundary condition tests -- Error handling tests +- Python/version constraints and extras remain internally consistent +- Lockfile changes match the intended dependency update scope +- Build backend/tooling changes do not break install or publish workflows ``` -### Logging and Metrics \[quick\] +### `github_workflow_jobs` Add-On \[comprehensive\] ``` -Applicable: logging, stats_tracker, StatsLogger -Checklist: -- Use areal.utils.logging.getLogger not print -- Structured metrics sent via stats_tracker -- Reasonable log levels (no DEBUG on hot paths) -- Sensitive info not logged +- Workflow triggers and job graph still run required validation paths +- Required secrets/permissions are scoped correctly +- Matrix or conditional changes do not silently skip critical jobs ``` -### Import and Dependencies \[quick\] +### `project_docs_metadata` Add-On \[basic\] ``` -Applicable: Any Python file changes -Checklist: -- Avoid wildcard imports (from x import *) -- Correct third-party vs internal import grouping -- Heavy optional deps inside functions -- Circular import risks +- Docs build entrypoints and contributor-facing metadata remain mutually consistent +- Public templates and contributor instructions still match the actual workflow +- Build/preview guidance still points to the supported commands ``` -### Security and Sensitive Information \[quick\] +### `skill_definition` / `platform_command_data` Add-On \[targeted\] ``` -Applicable: Config files, environment variables, API calls -Checklist: -- No hardcoded keys/tokens/passwords -- Sensitive info not committed to repo -- API endpoints configurable -- Error messages don't leak sensitive details +- Canonical and derived review-pr data files stay in sync after edits +- Command/import paths remain correct after file moves or renames +- Wrapper-specific routing stays out of canonical reference files ``` diff --git a/.opencode/skills/commit-conventions/SKILL.md b/.opencode/skills/commit-conventions/SKILL.md index 0cd141a900..cdc62dc98e 100644 --- a/.opencode/skills/commit-conventions/SKILL.md +++ b/.opencode/skills/commit-conventions/SKILL.md @@ -115,7 +115,7 @@ delegation instead of hardcoded model names. Key changes: - Create .opencode/command/ with review-pr, create-pr -- Replace Opus/Sonnet/Haiku with deep/unspecified-high/quick +- Replace hardcoded model routing with platform-native review routing - Add expert subagent consultation patterns ``` diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index a03921ba16..7701dab081 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -2006,6 +2006,36 @@ class TensorBoardConfig: path: str | None = None +@dataclass +class TrackioConfig: + """Configuration for Trackio experiment tracking (Hugging Face). + + Trackio is a lightweight, local-first experiment tracking library + with a wandb-compatible API. Dashboards can be viewed locally or + deployed to Hugging Face Spaces. + + See: https://github.com/gradio-app/trackio + """ + + mode: str = "disabled" + """Tracking mode. One of "disabled", "online", or "local".""" + project: str | None = None + """Project name. Defaults to experiment_name if not set.""" + name: str | None = None + """Run name. Defaults to trial_name if not set.""" + space_id: str | None = None + """HF Space ID for remote dashboard deployment (e.g. "user/my-space"). + When set, metrics are also pushed to the specified Hugging Face Space.""" + + def __post_init__(self): + """Validate Trackio configuration.""" + valid_modes = {"disabled", "online", "local"} + if self.mode not in valid_modes: + raise ValueError( + f"Invalid trackio mode: '{self.mode}'. Must be one of {valid_modes}." + ) + + @dataclass class StatsLoggerConfig: """Configuration for experiment statistics logging and tracking services.""" @@ -2025,6 +2055,10 @@ class StatsLoggerConfig: default_factory=TensorBoardConfig, metadata={"help": "TensorBoard configuration. Only 'path' field required."}, ) + trackio: TrackioConfig = field( + default_factory=TrackioConfig, + metadata={"help": "Trackio configuration (Hugging Face experiment tracking)."}, + ) @dataclass diff --git a/areal/experimental/inference_service/guard/__main__.py b/areal/experimental/inference_service/guard/__main__.py index 69b1921029..ca9fdeaa6d 100644 --- a/areal/experimental/inference_service/guard/__main__.py +++ b/areal/experimental/inference_service/guard/__main__.py @@ -1,110 +1,27 @@ -"""CLI entrypoint: python -m areal.experimental.inference_service.guard""" +"""CLI entrypoint: ``python -m areal.experimental.inference_service.guard``""" from __future__ import annotations -import argparse -import os -import signal - -from werkzeug.serving import make_server - -from areal.api.cli_args import NameResolveConfig -from areal.experimental.inference_service.guard import app as guard_app -from areal.experimental.inference_service.guard.app import app as flask_app -from areal.utils import logging, name_resolve, names -from areal.utils.network import gethostip - -logger = logging.getLogger("RPCGuard") +from areal.experimental.inference_service.guard.app import ( + _state, + app, +) +from areal.infra.rpc.guard.app import ( + configure_state_from_args, + make_base_parser, + run_server, +) def main(): - """Main entry point for the RPCGuard service.""" - parser = argparse.ArgumentParser( - description="AReaL RPCGuard — HTTP gateway for coordinating forked workers" - ) - parser.add_argument( - "--port", - type=int, - default=0, - help="Port to serve on (default: 0 = auto-assign)", - ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" + parser = make_base_parser( + description=("AReaL RPCGuard — HTTP gateway for coordinating forked workers") ) - # name_resolve config - parser.add_argument("--experiment-name", type=str, required=True) - parser.add_argument("--trial-name", type=str, required=True) - parser.add_argument("--role", type=str, required=True) - parser.add_argument("--worker-index", type=int, default=-1) - parser.add_argument("--name-resolve-type", type=str, default="nfs") - parser.add_argument( - "--nfs-record-root", type=str, default="/tmp/areal/name_resolve" - ) - parser.add_argument("--etcd3-addr", type=str, default="localhost:2379") - parser.add_argument( - "--fileroot", - type=str, - default=None, - help="Root directory for log files. If set, forked worker logs are written here.", - ) - args, _ = parser.parse_known_args() - # Set global config in app module for fork endpoint to use - guard_app._server_host = args.host - if guard_app._server_host == "0.0.0.0": - guard_app._server_host = gethostip() - - guard_app._experiment_name = args.experiment_name - guard_app._trial_name = args.trial_name - guard_app._fileroot = args.fileroot - - # Get worker identity - worker_role = args.role - worker_index = args.worker_index - if "SLURM_PROCID" in os.environ: - # Overwriting with slurm task id - worker_index = os.environ["SLURM_PROCID"] - if worker_index == -1: - raise ValueError("Invalid worker index. Not found from SLURM environ or args.") - worker_id = f"{worker_role}/{worker_index}" - - # Make a flask server - server = make_server(args.host, args.port, flask_app, threaded=True) - server_port = server.socket.getsockname()[1] - - # Configure name_resolve - name_resolve.reconfigure( - NameResolveConfig( - type=args.name_resolve_type, - nfs_record_root=args.nfs_record_root, - etcd3_addr=args.etcd3_addr, - ) - ) - key = names.worker_discovery( - args.experiment_name, args.trial_name, args.role, worker_index - ) - name_resolve.add(key, f"{guard_app._server_host}:{server_port}", replace=True) - - logger.info( - f"Starting RPCGuard on {guard_app._server_host}:{server_port} for worker {worker_id}" - ) - - def _sigterm_handler(signum, frame): - """Convert SIGTERM to SystemExit so the finally block runs.""" - raise SystemExit(0) - - signal.signal(signal.SIGTERM, _sigterm_handler) + bind_host = configure_state_from_args(_state, args) - try: - server.serve_forever() - except KeyboardInterrupt: - logger.info("Shutting down RPCGuard (SIGINT)") - except SystemExit: - logger.info("Shutting down RPCGuard (SIGTERM)") - finally: - guard_app.cleanup_forked_children() - server.shutdown() + run_server(_state, app, bind_host, args.port) if __name__ == "__main__": diff --git a/areal/experimental/inference_service/guard/app.py b/areal/experimental/inference_service/guard/app.py index 40c1a26994..6a804957f3 100644 --- a/areal/experimental/inference_service/guard/app.py +++ b/areal/experimental/inference_service/guard/app.py @@ -1,406 +1,27 @@ -import getpass -import os -import subprocess -import sys -import time -import traceback -from pathlib import Path -from threading import Lock - -import requests as http_requests -from flask import Flask, jsonify, request - -from areal.infra.utils.proc import kill_process_tree, run_with_streaming_logs +"""RPCGuard — inference service guard backed by the shared guard. + +All guard functionality is now provided by ``areal.infra.rpc.guard``. +This module creates and exposes the Flask app and shared state instance +for backward compatibility with existing imports. +""" + +from __future__ import annotations + +from areal.infra.rpc.guard.app import ( + GuardState, + create_app, +) +from areal.infra.rpc.guard.app import ( + cleanup_forked_children as _cleanup_impl, +) from areal.utils import logging -from areal.utils.network import find_free_ports logger = logging.getLogger("RPCGuard") -# Port tracking - allocated ports excluded from future allocations -_allocated_ports: set[int] = set() -_allocated_ports_lock = Lock() - -# Forked child processes - tracked for cleanup -_forked_children: list[subprocess.Popen] = [] -_forked_children_lock = Lock() -# Map (role, worker_index) to forked process for selective killing -_forked_children_map: dict[tuple[str, int], subprocess.Popen] = {} - -# Server address (set at startup) -_server_host: str = "0.0.0.0" - -# Server config (needed for /fork endpoint to spawn children with same config) -_experiment_name: str | None = None -_trial_name: str | None = None -_fileroot: str | None = None - -# Create Flask app -app = Flask(__name__) - - -@app.route("/health", methods=["GET"]) -def health_check(): - """Health check endpoint to verify server is alive.""" - return jsonify( - { - "status": "healthy", - "forked_children": len(_forked_children), - } - ) - - -@app.route("/configure", methods=["POST"]) -def configure(): - """No-op configuration endpoint for scheduler compatibility. - - The LocalScheduler calls ``/configure`` on every worker after creation - when ``exp_config`` is set. RPCGuard does not need experiment config - (the GatewayInferenceController handles setup via ``/alloc_ports`` and - ``/fork``), so we simply acknowledge the request. - """ - logger.debug("Received /configure request (no-op for RPCGuard)") - return jsonify({"status": "ok"}) - - -@app.route("/alloc_ports", methods=["POST"]) -def alloc_ports(): - """Allocate multiple free ports. - - Expected JSON payload: - { - "count": 5 # Number of ports to allocate - } - """ - try: - data = request.get_json(silent=True) - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - count = data.get("count") - if count is None: - return jsonify({"error": "Missing 'count' field in request"}), 400 - - if not isinstance(count, int) or count <= 0: - return jsonify({"error": "'count' must be a positive integer"}), 400 - - global _allocated_ports - with _allocated_ports_lock: - ports = find_free_ports(count, exclude_ports=_allocated_ports) - _allocated_ports.update(ports) - - return jsonify({"status": "success", "ports": ports, "host": _server_host}) - - except Exception as e: - logger.error(f"Error in alloc_ports: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -def _wait_for_worker_ready(host: str, port: int, timeout: float = 60) -> bool: - """Wait for a worker to be ready by polling its health endpoint. - - Args: - host: The host address of the worker. - port: The port of the worker. - timeout: Maximum time to wait in seconds (default: 60). - - Returns: - True if the worker is ready, False if timeout is reached. - """ - url = f"http://{host}:{port}/health" - deadline = time.time() + timeout - - while time.time() < deadline: - try: - resp = http_requests.get(url, timeout=2) - if resp.status_code == 200: - return True - except http_requests.exceptions.RequestException: - pass - time.sleep(0.5) - - return False - - -@app.route("/fork", methods=["POST"]) -def fork_worker(): - """Fork a new worker process on the same node. - - Supports two modes: - - **Module-path mode** (``command`` field): - Builds ``python -m {command} --host 0.0.0.0 --port {port} ...`` with - scheduler args injected. Waits for health readiness before returning. - - **Raw-command mode** (``raw_cmd`` field): - Launches the provided command list as-is. A port is allocated but NOT - injected into the command (caller provides port in ``raw_cmd``). - Returns immediately after spawn without readiness polling. - - Expected JSON payload (module-path mode): - { - "role": "ref", - "worker_index": 0, - "command": "areal.infra.rpc.rpc_server" - } - - Expected JSON payload (raw-command mode): - { - "role": "sglang", - "worker_index": 0, - "raw_cmd": ["python", "-m", "sglang.launch_server", "--model", "..."] - } - - Returns: - { - "status": "success", - "host": "192.168.1.10", - "port": 8001, - "pid": 12345 - } - """ - global _forked_children, _forked_children_map, _allocated_ports - - try: - data = request.get_json(silent=True) - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - role = data.get("role") - worker_index = data.get("worker_index") - command = data.get("command") # Module-path mode - raw_cmd = data.get("raw_cmd") # Raw-command mode - - if role is None: - return jsonify({"error": "Missing 'role' field in request"}), 400 - if worker_index is None: - return jsonify({"error": "Missing 'worker_index' field in request"}), 400 - - if command is None and raw_cmd is None: - return ( - jsonify( - { - "error": "Must provide either 'command' (module path) " - "or 'raw_cmd' (raw command list)" - } - ), - 400, - ) - - # Allocate a free port for the child process - with _allocated_ports_lock: - ports = find_free_ports(1, exclude_ports=_allocated_ports) - child_port = ports[0] - _allocated_ports.add(child_port) - - # Optional per-process environment overrides (e.g. TRITON_CACHE_PATH) - env_overrides: dict[str, str] = data.get("env", {}) - - # Determine if this is raw-command mode or module-path mode - is_raw_mode = raw_cmd is not None - - if is_raw_mode: - # Raw-command mode: use command as-is, do NOT inject port or args - cmd = list(raw_cmd) - else: - # Module-path mode: build command with scheduler args - cmd = [ - sys.executable, - "-m", - command, - "--host", - "0.0.0.0", - "--port", - str(child_port), - "--experiment-name", - _experiment_name, - "--trial-name", - _trial_name, - "--role", - role, - "--worker-index", - str(worker_index), - ] - - logger.info( - f"Forking new worker process for role '{role}' index {worker_index} " - f"on port {child_port} (raw_mode={is_raw_mode})" - ) - - # Build log paths - log_dir = ( - Path(_fileroot or "/tmp") - / "logs" - / getpass.getuser() - / (_experiment_name or "default") - / (_trial_name or "default") - ) - log_dir.mkdir(parents=True, exist_ok=True) - log_file = log_dir / f"{role}.log" - merged_log = log_dir / "merged.log" - - logger.info(f"Forked worker logs will be written to: {log_file}") - - child_env = os.environ.copy() - child_env.update(env_overrides) - - child_process = run_with_streaming_logs( - cmd, - log_file, - merged_log, - role, - env=child_env, - ) - - with _forked_children_lock: - _forked_children.append(child_process) - _forked_children_map[(role, worker_index)] = child_process - - child_host = _server_host - - if not is_raw_mode: - # Module-path mode: wait for child to be ready - if not _wait_for_worker_ready(child_host, child_port): - # Cleanup on failure - try: - kill_process_tree(child_process.pid, timeout=3, graceful=True) - except Exception: - pass - with _forked_children_lock: - if child_process in _forked_children: - _forked_children.remove(child_process) - _forked_children_map.pop((role, worker_index), None) - with _allocated_ports_lock: - _allocated_ports.discard(child_port) - return jsonify( - {"error": "Forked worker failed to start within timeout"} - ), 500 - - logger.info( - f"Forked worker for role '{role}' index {worker_index} ready at " - f"{child_host}:{child_port} (pid={child_process.pid})" - ) - else: - # Raw-command mode: return immediately without readiness polling - logger.info( - f"Forked raw-command worker for role '{role}' index {worker_index} " - f"spawned (pid={child_process.pid}), port={child_port}" - ) - - return jsonify( - { - "status": "success", - "host": child_host, - "port": child_port, - "pid": child_process.pid, - } - ) - - except Exception as e: - logger.error(f"Error in fork: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/kill_forked_worker", methods=["POST"]) -def kill_forked_worker(): - """Kill a specific forked worker process. - - This endpoint terminates a previously forked child process identified by - its role and worker_index. - - Expected JSON payload: - { - "role": "ref", - "worker_index": 0 - } - - Returns: - { - "status": "success", - "message": "Killed forked worker ref/0 (pid=12345)" - } - """ - global _forked_children, _forked_children_map - - try: - data = request.get_json(silent=True) - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - role = data.get("role") - worker_index = data.get("worker_index") - - if role is None: - return jsonify({"error": "Missing 'role' field in request"}), 400 - if worker_index is None: - return jsonify({"error": "Missing 'worker_index' field in request"}), 400 - - key = (role, worker_index) - - # Remove from tracking structures first (hold lock only for dict/list ops) - with _forked_children_lock: - child_process = _forked_children_map.pop(key, None) - if child_process: - try: - _forked_children.remove(child_process) - except ValueError: - # Defensive: process was in map but not in list - logger.warning( - f"Process for {role}/{worker_index} was in map but not in list" - ) - - if child_process is None: - return jsonify( - {"error": f"Forked worker {role}/{worker_index} not found"} - ), 404 - - pid = child_process.pid - - # Kill the process tree (outside the lock to avoid blocking other operations) - try: - if child_process.poll() is None: # Still running - kill_process_tree(pid, timeout=3, graceful=True) - logger.info(f"Killed forked worker {role}/{worker_index} (pid={pid})") - except Exception as e: - logger.error( - f"Error killing forked worker {role}/{worker_index} (pid={pid}): {e}" - ) - return jsonify( - { - "error": f"Failed to kill forked worker: {str(e)}", - "pid": pid, - } - ), 500 - - return jsonify( - { - "status": "success", - "message": f"Killed forked worker {role}/{worker_index} (pid={pid})", - } - ) - - except Exception as e: - logger.error(f"Error in kill_forked_worker: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - +_state = GuardState() -def cleanup_forked_children(): - """Clean up all forked child processes.""" - global _forked_children, _forked_children_map +app = create_app(_state) - # Copy the list under lock, then release before blocking kills - # to avoid holding the lock for up to 4s × N children. - with _forked_children_lock: - if not _forked_children: - return - children_to_kill = list(_forked_children) - _forked_children.clear() - _forked_children_map.clear() - logger.info(f"Cleaning up {len(children_to_kill)} forked child processes") - for child in children_to_kill: - try: - if child.poll() is None: # Still running - kill_process_tree(child.pid, timeout=3, graceful=True) - logger.info(f"Killed forked child process {child.pid}") - except Exception as e: - logger.error(f"Error killing forked child {child.pid}: {e}") +def cleanup_forked_children() -> None: + _cleanup_impl(_state) diff --git a/areal/infra/rpc/guard/__init__.py b/areal/infra/rpc/guard/__init__.py new file mode 100644 index 0000000000..3f1c18ad46 --- /dev/null +++ b/areal/infra/rpc/guard/__init__.py @@ -0,0 +1,36 @@ +"""Shared Guard: reusable process management for RPC and inference services. + +The Guard is the base process management layer shared between: + +- ``areal.infra.rpc.rpc_server`` — RPC server (guard + data + engine) +- ``areal.experimental.inference_service.guard`` — inference service guard + +Typical usage:: + + from areal.infra.rpc.guard import GuardState, create_app, run_server + + state = GuardState() + app = create_app(state) + # Optionally register additional blueprints + run_server(state, app, bind_host="0.0.0.0", port=0) +""" + +from .app import ( + GuardState, + cleanup_forked_children, + configure_state_from_args, + create_app, + get_state, + make_base_parser, + run_server, +) + +__all__ = [ + "GuardState", + "cleanup_forked_children", + "configure_state_from_args", + "create_app", + "get_state", + "make_base_parser", + "run_server", +] diff --git a/areal/infra/rpc/guard/__main__.py b/areal/infra/rpc/guard/__main__.py new file mode 100644 index 0000000000..5b57d86acb --- /dev/null +++ b/areal/infra/rpc/guard/__main__.py @@ -0,0 +1,33 @@ +"""CLI entrypoint: ``python -m areal.infra.rpc.guard`` + +Starts a standalone Guard process with only process-management +endpoints (no engine, no data storage). +""" + +from __future__ import annotations + +from areal.infra.rpc.guard.app import ( + GuardState, + configure_state_from_args, + create_app, + make_base_parser, + run_server, +) + + +def main(): + """Main entry point for the standalone Guard service.""" + parser = make_base_parser( + description=("AReaL Guard — HTTP gateway for coordinating forked workers") + ) + args, _ = parser.parse_known_args() + + state = GuardState() + bind_host = configure_state_from_args(state, args) + app = create_app(state) + + run_server(state, app, bind_host, args.port) + + +if __name__ == "__main__": + main() diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py new file mode 100644 index 0000000000..fc2d7369f9 --- /dev/null +++ b/areal/infra/rpc/guard/app.py @@ -0,0 +1,612 @@ +"""Shared Guard process: process management, port allocation, and child forking. + +This module provides the base Guard functionality shared between: + +- ``areal.infra.rpc.rpc_server`` (RPC server = guard + data + engine) +- ``areal.experimental.inference_service.guard`` (inference service guard) + +Key components: + +- :class:`GuardState` — mutable shared state with hook system +- :func:`create_app` — Flask app factory with core guard routes +- :func:`make_base_parser` — CLI argument parser shared by entrypoints +- :func:`configure_state_from_args` — populate state from parsed CLI args +- :func:`run_server` — start werkzeug server with name_resolve registration +""" + +from __future__ import annotations + +import argparse +import getpass +import os +import signal +import subprocess +import traceback +from collections.abc import Callable +from pathlib import Path +from threading import Lock +from typing import Any + +from flask import Flask, current_app, jsonify, request + +from areal.infra.utils.proc import kill_process_tree, run_with_streaming_logs +from areal.utils import logging +from areal.utils.network import find_free_ports, format_hostport + +logger = logging.getLogger("Guard") + + +class GuardState: + """Mutable shared state for the Guard process. + + All guard-level state lives here so that both core routes and + extension blueprints can access it via :func:`get_state`. + + The hook system allows blueprints to extend core endpoints: + + - **health hooks** — contribute extra fields to ``/health`` response + - **configure hooks** — handle ``/configure`` payload + - **cleanup hooks** — run during server shutdown + """ + + def __init__(self) -> None: + # Server identity + self.server_host: str = "0.0.0.0" + self.server_port: int = 0 + + # Experiment / trial config (used for log paths and name_resolve) + self.experiment_name: str | None = None + self.trial_name: str | None = None + self.fileroot: str | None = None + + # Name-resolve config (used by run_server for service registration) + self.name_resolve_type: str | None = None + self.nfs_record_root: str | None = None + self.etcd3_addr: str | None = None + + # Worker identity + self.role: str | None = None + self.worker_index: int = -1 + + # Port tracking (thread-safe) + self.allocated_ports: set[int] = set() + self.allocated_ports_lock = Lock() + + # Forked child processes (thread-safe) + self.forked_children: list[subprocess.Popen] = [] + self.forked_children_map: dict[tuple[str, int], subprocess.Popen] = {} + self.forked_children_lock = Lock() + + # Hook system — blueprints register hooks to extend core endpoints + self._health_hooks: list[HealthHook] = [] + self._configure_hooks: list[ConfigureHook] = [] + self._cleanup_hooks: list[CleanupHook] = [] + + def register_health_hook(self, hook: HealthHook) -> None: + """Register a hook that contributes fields to ``/health`` response. + + The hook is called with no arguments and must return a dict of + extra fields to merge into the health response. + """ + self._health_hooks.append(hook) + + def register_configure_hook(self, hook: ConfigureHook) -> None: + """Register a hook that handles ``/configure`` payload. + + The hook receives the full JSON dict and returns a result dict. + Raise :class:`ValueError` for 400-worthy client errors. + """ + self._configure_hooks.append(hook) + + def register_cleanup_hook(self, hook: CleanupHook) -> None: + """Register a hook called during server shutdown.""" + self._cleanup_hooks.append(hook) + + @property + def node_addr(self) -> str: + """Return ``host:port`` string for this server (IPv6-safe).""" + return format_hostport(self.server_host, self.server_port) + + +HealthHook = Callable[[], dict[str, Any]] +ConfigureHook = Callable[[dict], dict] +CleanupHook = Callable[[], None] + + +def get_state() -> GuardState: + """Get the :class:`GuardState` from the current Flask app context.""" + return current_app.config["guard_state"] + + +# --------------------------------------------------------------------------- +# Helper utilities +# --------------------------------------------------------------------------- + + +def cleanup_forked_children(state: GuardState) -> None: + """Clean up all forked child processes. + + Copies the child list under the lock, then releases before blocking + kills (avoids holding the lock for up to 4s × N children). + """ + with state.forked_children_lock: + if not state.forked_children: + return + children_to_kill = list(state.forked_children) + state.forked_children.clear() + state.forked_children_map.clear() + + logger.info(f"Cleaning up {len(children_to_kill)} forked child processes") + for child in children_to_kill: + try: + if child.poll() is None: # Still running + kill_process_tree(child.pid, timeout=3, graceful=True) + logger.info(f"Killed forked child process {child.pid}") + except Exception as e: + logger.error(f"Error killing forked child {child.pid}: {e}") + + +# --------------------------------------------------------------------------- +# Flask app factory +# --------------------------------------------------------------------------- + + +def create_app(state: GuardState) -> Flask: + """Create a Flask app with core guard routes. + + Routes provided: + + - ``GET /health`` — health check (extensible via health hooks) + - ``POST /alloc_ports`` — allocate free ports + - ``POST /fork`` — fork a child worker from a raw command + - ``POST /kill_forked_worker`` — kill a specific forked child + - ``POST /configure`` — configure worker (extensible via configure hooks) + + Parameters + ---------- + state : GuardState + Shared mutable state for the guard process. + + Returns + ------- + Flask + Configured Flask application. + """ + app = Flask(__name__) + app.config["guard_state"] = state + + @app.route("/health", methods=["GET"]) + def health_check(): + """Health check endpoint.""" + s = get_state() + result: dict[str, Any] = { + "status": "healthy", + "forked_children": len(s.forked_children), + } + # Collect additional fields from health hooks + for hook in s._health_hooks: + result.update(hook()) + return jsonify(result) + + @app.route("/alloc_ports", methods=["POST"]) + def alloc_ports(): + """Allocate multiple free ports. + + Expected JSON payload:: + + {"count": 5} + """ + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + count = data.get("count") + if count is None: + return jsonify({"error": "Missing 'count' field in request"}), 400 + + if not isinstance(count, int) or count <= 0: + return ( + jsonify({"error": "'count' must be a positive integer"}), + 400, + ) + + s = get_state() + with s.allocated_ports_lock: + ports = find_free_ports(count, exclude_ports=s.allocated_ports) + s.allocated_ports.update(ports) + + return jsonify({"status": "success", "ports": ports, "host": s.server_host}) + + except Exception as e: + logger.error(f"Error in alloc_ports: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + @app.route("/fork", methods=["POST"]) + def fork_worker(): + """Fork a new worker process on the same node. + + Launches the provided command list (``raw_cmd``) as-is. The caller + is responsible for allocating ports (via ``/alloc_ports``), building + the full command, and polling for readiness after the response. + + Expected JSON payload:: + + { + "role": "actor", + "worker_index": 0, + "raw_cmd": ["python", "-m", "some.module", "--port", "8001"], + "env": {"KEY": "value"} // optional + } + + Returns:: + + {"status": "success", "host": "10.0.0.1", "pid": 42} + """ + s = get_state() + + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + role = data.get("role") + worker_index = data.get("worker_index") + raw_cmd = data.get("raw_cmd") + + if role is None: + return ( + jsonify({"error": "Missing 'role' field in request"}), + 400, + ) + if worker_index is None: + return ( + jsonify({"error": "Missing 'worker_index' field in request"}), + 400, + ) + if raw_cmd is None: + return ( + jsonify({"error": "Missing 'raw_cmd' field in request"}), + 400, + ) + + cmd = list(raw_cmd) + + # Optional per-process environment overrides + env_overrides: dict[str, str] = data.get("env", {}) + + logger.info( + f"Forking new worker process for role '{role}' index {worker_index}" + ) + + # Build log paths + log_dir = ( + Path(s.fileroot or "/tmp") + / "logs" + / getpass.getuser() + / (s.experiment_name or "default") + / (s.trial_name or "default") + ) + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"{role}.log" + merged_log = log_dir / "merged.log" + + logger.info(f"Forked worker logs will be written to: {log_file}") + + child_env = os.environ.copy() + child_env.update(env_overrides) + + child_process = run_with_streaming_logs( + cmd, + log_file, + merged_log, + role, + env=child_env, + ) + + with s.forked_children_lock: + s.forked_children.append(child_process) + s.forked_children_map[(role, worker_index)] = child_process + + logger.info( + f"Forked worker for role '{role}' index " + f"{worker_index} spawned (pid={child_process.pid})" + ) + + return jsonify( + { + "status": "success", + "host": s.server_host, + "pid": child_process.pid, + } + ) + + except Exception as e: + logger.error(f"Error in fork: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + @app.route("/kill_forked_worker", methods=["POST"]) + def kill_forked_worker(): + """Kill a specific forked worker process. + + Expected JSON payload:: + + {"role": "ref", "worker_index": 0} + """ + s = get_state() + + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + role = data.get("role") + worker_index = data.get("worker_index") + + if role is None: + return ( + jsonify({"error": "Missing 'role' field in request"}), + 400, + ) + if worker_index is None: + return ( + jsonify({"error": "Missing 'worker_index' field in request"}), + 400, + ) + + key = (role, worker_index) + + # Remove from tracking structures (hold lock only for dict/list ops) + with s.forked_children_lock: + child_process = s.forked_children_map.pop(key, None) + if child_process: + try: + s.forked_children.remove(child_process) + except ValueError: + logger.warning( + f"Process for {role}/{worker_index} was in map " + "but not in list" + ) + + if child_process is None: + return ( + jsonify( + {"error": (f"Forked worker {role}/{worker_index} not found")} + ), + 404, + ) + + pid = child_process.pid + + # Kill process tree (outside lock to avoid blocking) + try: + if child_process.poll() is None: # Still running + kill_process_tree(pid, timeout=3, graceful=True) + logger.info( + f"Killed forked worker {role}/{worker_index} (pid={pid})" + ) + except Exception as e: + logger.error( + f"Error killing forked worker " + f"{role}/{worker_index} (pid={pid}): {e}" + ) + return ( + jsonify( + { + "error": f"Failed to kill forked worker: {str(e)}", + "pid": pid, + } + ), + 500, + ) + + return jsonify( + { + "status": "success", + "message": ( + f"Killed forked worker {role}/{worker_index} (pid={pid})" + ), + } + ) + + except Exception as e: + logger.error(f"Error in kill_forked_worker: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + @app.route("/configure", methods=["POST"]) + def configure(): + """Configure the worker process. + + Base implementation is a no-op. Blueprints register configure hooks + to handle the payload (e.g., engine blueprint sets random seeds). + + Hooks may raise :class:`ValueError` for 400-worthy client errors. + """ + s = get_state() + + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + if not s._configure_hooks: + # No hooks registered — no-op (guard-only mode) + logger.debug("Received /configure request (no-op)") + return jsonify({"status": "ok"}) + + # Dispatch to all registered configure hooks + result: dict[str, Any] = {} + for hook in s._configure_hooks: + hook_result = hook(data) + result.update(hook_result) + + result.setdefault("status", "success") + return jsonify(result) + + except ValueError as e: + return jsonify({"error": str(e)}), 400 + except Exception as e: + logger.error( + f"Unexpected error in configure: {e}\n{traceback.format_exc()}" + ) + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + return app + + +# --------------------------------------------------------------------------- +# CLI argument parsing +# --------------------------------------------------------------------------- + + +def make_base_parser( + description: str = "AReaL Guard Service", +) -> argparse.ArgumentParser: + """Create the base argument parser shared across guard-based CLIs. + + Includes: ``--host``, ``--port``, ``--experiment-name``, ``--trial-name``, + ``--role``, ``--worker-index``, ``--name-resolve-type``, + ``--nfs-record-root``, ``--etcd3-addr``, ``--fileroot``. + """ + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "--port", + type=int, + default=0, + help="Port to serve on (default: 0 = auto-assign)", + ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind to (default: 0.0.0.0)", + ) + # Name-resolve / scheduler config + parser.add_argument("--experiment-name", type=str, required=True) + parser.add_argument("--trial-name", type=str, required=True) + parser.add_argument("--role", type=str, required=True) + parser.add_argument("--worker-index", type=int, default=-1) + parser.add_argument("--name-resolve-type", type=str, default="nfs") + parser.add_argument( + "--nfs-record-root", type=str, default="/tmp/areal/name_resolve" + ) + parser.add_argument("--etcd3-addr", type=str, default="localhost:2379") + parser.add_argument( + "--fileroot", + type=str, + default=None, + help="Root directory for log files.", + ) + return parser + + +def configure_state_from_args(state: GuardState, args: argparse.Namespace) -> str: + """Populate :class:`GuardState` from parsed CLI args. + + Returns the ``bind_host`` address for werkzeug (may differ from + ``state.server_host`` when binding to ``0.0.0.0`` / ``::``). + """ + from areal.utils.network import gethostip + + bind_host = args.host + if bind_host == "0.0.0.0": + host_ip = gethostip() + if ":" in host_ip: + bind_host = "::" + state.server_host = host_ip + elif bind_host == "::": + state.server_host = gethostip() + else: + state.server_host = bind_host + + state.experiment_name = args.experiment_name + state.trial_name = args.trial_name + state.role = args.role + state.fileroot = args.fileroot + + # Name-resolve config + state.name_resolve_type = getattr(args, "name_resolve_type", "nfs") + state.nfs_record_root = getattr(args, "nfs_record_root", "/tmp/areal/name_resolve") + state.etcd3_addr = getattr(args, "etcd3_addr", "localhost:2379") + + # Worker index (SLURM override) + worker_index = args.worker_index + if "SLURM_PROCID" in os.environ: + worker_index = int(os.environ["SLURM_PROCID"]) + if worker_index == -1: + raise ValueError("Invalid worker index. Not found from SLURM environ or args.") + state.worker_index = worker_index + + return bind_host + + +# --------------------------------------------------------------------------- +# Server lifecycle +# --------------------------------------------------------------------------- + + +def run_server( + state: GuardState, + app: Flask, + bind_host: str, + port: int, +) -> None: + """Start the werkzeug server and register with name_resolve. + + This is the shared server loop used by both the rpc_server and + standalone guard entrypoints. Handles SIGTERM, cleanup hooks, + and forked-child cleanup on shutdown. + """ + from werkzeug.serving import make_server + + from areal.api.cli_args import NameResolveConfig + from areal.utils import name_resolve, names + + server = make_server(bind_host, port, app, threaded=True) + state.server_port = server.socket.getsockname()[1] + + with state.allocated_ports_lock: + state.allocated_ports.add(state.server_port) + + # Register with name_resolve + if state.name_resolve_type is not None: + name_resolve.reconfigure( + NameResolveConfig( + type=state.name_resolve_type, + nfs_record_root=(state.nfs_record_root or "/tmp/areal/name_resolve"), + etcd3_addr=state.etcd3_addr or "localhost:2379", + ) + ) + + worker_id = f"{state.role}/{state.worker_index}" + key = names.worker_discovery( + state.experiment_name, + state.trial_name, + state.role, + state.worker_index, + ) + name_resolve.add(key, state.node_addr, replace=True) + + logger.info(f"Starting Guard on {state.node_addr} for worker {worker_id}") + + def _sigterm_handler(signum, frame): + """Convert SIGTERM to SystemExit so the finally block runs.""" + raise SystemExit(0) + + signal.signal(signal.SIGTERM, _sigterm_handler) + + try: + server.serve_forever() + except KeyboardInterrupt: + logger.info("Shutting down (SIGINT)") + except SystemExit: + logger.info("Shutting down (SIGTERM)") + finally: + # Run registered cleanup hooks (engine cleanup, perf_tracer, etc.) + for hook in state._cleanup_hooks: + try: + hook() + except Exception as e: + logger.error(f"Error in cleanup hook: {e}") + cleanup_forked_children(state) + server.shutdown() diff --git a/areal/infra/rpc/guard/data_blueprint.py b/areal/infra/rpc/guard/data_blueprint.py new file mode 100644 index 0000000000..b0b2cdee78 --- /dev/null +++ b/areal/infra/rpc/guard/data_blueprint.py @@ -0,0 +1,164 @@ +"""Data Blueprint: RTensor ``/data/*`` storage endpoints. + +Provides a Flask Blueprint that handles tensor shard storage and +retrieval via HTTP. Used by any service that needs local tensor +storage accessible over the network (e.g., the RPC server). + +Routes: + +- ``PUT /data/`` — store a single shard +- ``GET /data/`` — retrieve a single shard +- ``POST /data/batch`` — retrieve multiple shards +- ``DELETE /data/clear`` — clear specified shards +""" + +from __future__ import annotations + +import traceback + +import orjson +from flask import Blueprint, Response, jsonify, request + +from areal.infra.rpc import rtensor +from areal.infra.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging + +logger = logging.getLogger("DataBP") + +data_bp = Blueprint("data", __name__) + + +@data_bp.route("/data/", methods=["PUT"]) +def store_batch_data(shard_id: str): + """Store batch data shard.""" + try: + data_bytes = request.get_data() + + # Deserialize to get tensor (already on CPU) + serialized_data = orjson.loads(data_bytes) + data = deserialize_value(serialized_data) + + rtensor.store(shard_id, data) + + logger.debug(f"Stored batch shard {shard_id} (size={len(data_bytes)} bytes)") + return jsonify({"status": "ok", "shard_id": shard_id}) + + except Exception as e: + logger.error(f"Error storing batch shard {shard_id}: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@data_bp.route("/data/", methods=["GET"]) +def retrieve_batch_data(shard_id: str): + """Retrieve batch data shard.""" + logger.debug(f"Received data get request for shard {shard_id}") + try: + try: + data = rtensor.fetch(shard_id) + except KeyError: + return ( + jsonify( + { + "status": "error", + "message": f"Shard {shard_id} not found", + } + ), + 404, + ) + + serialized_data = serialize_value(data) + data_bytes = orjson.dumps(serialized_data) + + logger.debug(f"Retrieved batch shard {shard_id} (size={len(data_bytes)} bytes)") + return Response(data_bytes, mimetype="application/octet-stream") + + except Exception as e: + logger.error(f"Error retrieving batch shard {shard_id}: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@data_bp.route("/data/batch", methods=["POST"]) +def retrieve_batch_data_many(): + """Retrieve multiple batch data shards in one request.""" + try: + payload = request.get_json(silent=True) or {} + shard_ids = payload.get("shard_ids", []) + if not isinstance(shard_ids, list) or not all( + isinstance(shard_id, str) for shard_id in shard_ids + ): + return ( + jsonify( + { + "status": "error", + "message": ( + "Expected JSON body with string list field 'shard_ids'" + ), + } + ), + 400, + ) + + data = [] + missing_shard_ids = [] + for shard_id in shard_ids: + try: + data.append(rtensor.fetch(shard_id)) + except KeyError: + missing_shard_ids.append(shard_id) + + if missing_shard_ids: + return ( + jsonify( + { + "status": "error", + "message": ("One or more requested shards were not found"), + "missing_shard_ids": missing_shard_ids, + } + ), + 400, + ) + + serialized_data = serialize_value(data) + data_bytes = orjson.dumps(serialized_data) + logger.debug( + "Retrieved %s batch shards (size=%s bytes)", + len(shard_ids), + len(data_bytes), + ) + return Response(data_bytes, mimetype="application/octet-stream") + + except Exception as e: + logger.error(f"Error retrieving batch shards: {e}\n{traceback.format_exc()}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@data_bp.route("/data/clear", methods=["DELETE"]) +def clear_batch_data(): + """Clear specified batch data shards. + + Expected JSON payload:: + + {"shard_ids": ["id1", "id2", ...]} + """ + try: + data = request.get_json(silent=True) or {} + shard_ids = data.get("shard_ids", []) + if not isinstance(shard_ids, list): + return ( + jsonify({"status": "error", "message": "'shard_ids' must be a list"}), + 400, + ) + + cleared_count = sum(rtensor.remove(sid) for sid in shard_ids) + storage = rtensor.storage_stats() + result = { + "status": "ok", + "cleared_count": cleared_count, + **storage, + } + logger.info(f"Cleared {cleared_count} batch shards. Stats: {result}") + return jsonify(result) + + except Exception as e: + logger.error(f"Error clearing batch data: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py new file mode 100644 index 0000000000..4dbd1c765d --- /dev/null +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -0,0 +1,569 @@ +"""Engine Blueprint: engine lifecycle and method invocation. + +Provides a Flask Blueprint that manages engine threads, engine creation, +and engine method calls. Registers hooks on :class:`GuardState` for +``/configure``, ``/health``, and cleanup. + +Routes: + +- ``POST /set_env`` — set environment variables in the engine thread +- ``POST /create_engine`` — instantiate a TrainEngine or InferenceEngine +- ``POST /call`` — invoke a method on a named engine instance + +The engine thread guarantees serial execution of all engine operations, +which is required for NCCL compatibility. +""" + +from __future__ import annotations + +import os +import traceback +from collections.abc import Callable +from concurrent.futures import Future +from queue import Queue +from threading import Lock, Thread +from typing import Any + +from flask import Blueprint, jsonify, request + +from areal.api import InferenceEngine, TrainEngine +from areal.infra.platforms import current_platform +from areal.infra.rpc.guard.app import GuardState, get_state +from areal.infra.rpc.rtensor import RTensor +from areal.infra.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging, perf_tracer, seeding +from areal.utils.data import broadcast_tensor_container, tensor_container_to +from areal.utils.dynamic_import import import_from_string + +logger = logging.getLogger("EngineBP") + +engine_bp = Blueprint("engine", __name__) + +# --------------------------------------------------------------------------- +# Engine-specific module-level state +# --------------------------------------------------------------------------- + +# Global engine instances — keyed by engine_name (e.g., "actor/0", "ref/0") +_engines: dict[str, TrainEngine | InferenceEngine] = {} + +# Engine thread for executing all engine-related operations serially. +# This ensures NCCL compatibility by running engine operations in a single +# thread, while allowing /data/ endpoints to be processed concurrently. +_engine_thread: Thread | None = None +_engine_work_queue: Queue | None = None +_engine_thread_lock = Lock() + + +# --------------------------------------------------------------------------- +# Engine thread management +# --------------------------------------------------------------------------- + + +def _init_engine_thread() -> None: + """Lazily initialize the engine worker thread.""" + global _engine_thread, _engine_work_queue + + with _engine_thread_lock: + if _engine_thread is not None: + if _engine_thread.is_alive(): + return # Already initialized + else: + raise RuntimeError("Engine thread is dead.") + + _engine_work_queue = Queue() + + def engine_worker(): + logger.info("Engine thread started") + while True: + try: + work_item = _engine_work_queue.get() + if work_item is None: # Shutdown signal + logger.info("Engine thread shutting down") + break + + func, args, kwargs, future, func_name = work_item + try: + result = func(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + finally: + _engine_work_queue.task_done() + except Exception as e: + logger.error( + f"Error in engine thread when " + f"running {func_name}: {e}\n{traceback.format_exc()}" + ) + if work_item and len(work_item) > 3: + work_item[3].set_exception(e) + + _engine_thread = Thread(target=engine_worker, daemon=True, name="EngineWorker") + _engine_thread.start() + logger.info("Engine thread initialized") + + +def _submit_to_engine_thread( + func_name: str, func: Callable, *args: Any, **kwargs: Any +) -> Any: + """Submit work to the engine thread and block until result is available.""" + global _engine_work_queue + + _init_engine_thread() + + future: Future = Future() + _engine_work_queue.put((func, args, kwargs, future, func_name)) + return future.result() # Block until result is available + + +# --------------------------------------------------------------------------- +# Hook registration +# --------------------------------------------------------------------------- + + +def register_engine_hooks(state: GuardState) -> None: + """Register engine-specific hooks on the :class:`GuardState`. + + Must be called after creating the Flask app and before starting + the server. Registers: + + - health hook → adds ``engine_count`` and ``engines`` to /health + - configure hook → sets random seeds in the engine thread + - cleanup hooks → destroy engines and shut down engine thread + """ + state.register_health_hook(_engine_health_hook) + state.register_configure_hook(_engine_configure_hook) + state.register_cleanup_hook(cleanup_engine_thread) + state.register_cleanup_hook(cleanup_engines) + + +def _engine_health_hook() -> dict[str, Any]: + """Contribute engine info to the /health response.""" + return {"engine_count": len(_engines), "engines": list(_engines.keys())} + + +def _engine_configure_hook(data: dict) -> dict: + """Handle /configure by setting random seeds in the engine thread. + + Raises + ------ + ValueError + If required fields (``config``, ``rank``) are missing. + """ + config_data = data.get("config") + if config_data is None: + raise ValueError("Missing 'config' field in request") + + rank = data.get("rank") + if rank is None: + raise ValueError("Missing 'rank' field in request") + + config = deserialize_value(config_data) + + # Capture role from GuardState (we're in a request context) + state = get_state() + role = state.role + + def execute_configure(): + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + return { + "status": "success", + "message": "Worker configured successful.", + "result": None, + } + + return _submit_to_engine_thread("configure", execute_configure) + + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + + +def cleanup_engines() -> None: + """Destroy all engine instances.""" + global _engines + if _engines: + for engine_name, engine in list(_engines.items()): + try: + engine.destroy() + logger.info(f"Engine '{engine_name}' destroyed successfully") + except Exception as e: + logger.error(f"Error destroying engine '{engine_name}': {e}") + _engines.clear() + + +def cleanup_engine_thread() -> None: + """Shut down the engine worker thread.""" + global _engine_thread, _engine_work_queue + + with _engine_thread_lock: + if _engine_work_queue is not None: + # Send shutdown signal + _engine_work_queue.put(None) + _engine_work_queue = None + + if _engine_thread is not None: + _engine_thread.join(timeout=5.0) + if _engine_thread.is_alive(): + logger.warning("Engine thread did not shut down gracefully") + _engine_thread = None + logger.info("Engine thread cleaned up") + + +# --------------------------------------------------------------------------- +# Flask routes +# --------------------------------------------------------------------------- + + +@engine_bp.route("/set_env", methods=["POST"]) +def set_env(): + """Set environment variables for the worker process. + + This endpoint is routed to the engine thread for serial execution. + """ + try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + env_payload = data.get("env") + if env_payload is None: + return jsonify({"error": "Missing 'env' field in request"}), 400 + if not isinstance(env_payload, dict): + return jsonify({"error": "'env' must be a dictionary"}), 400 + + for key in env_payload.keys(): + if not isinstance(key, str): + return ( + jsonify( + { + "error": ( + "Environment variable name must be str, " + f"got {type(key)}" + ) + } + ), + 400, + ) + + def execute_set_env(): + for key, value in env_payload.items(): + os.environ[key] = str(value) + logger.info(f"Set {key}={value}") + return {"status": "success"} + + result = _submit_to_engine_thread("set_env", execute_set_env) + return jsonify(result) + + except Exception as e: + logger.error(f"Unexpected error in set_env: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@engine_bp.route("/create_engine", methods=["POST"]) +def create_engine(): + """Create and initialize an engine instance on this worker. + + This endpoint is routed to the engine thread for serial execution. + Supports multiple engines per worker, keyed by ``engine_name``. + + Expected JSON payload:: + + { + "engine": "areal.engine.fsdp_engine.FSDPPPOActor", + "engine_name": "actor/0", + "init_args": [...], + "init_kwargs": {"config": ...} + } + """ + global _engines + + try: + # Parse request in main thread (has Flask request context) + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + engine = data.get("engine") + engine_name = data.get("engine_name") + # Deserialize init_args and init_kwargs (may contain tensors/dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine: + return ( + jsonify({"error": "Missing 'engine' field in request"}), + 400, + ) + + if not engine_name: + return ( + jsonify({"error": "Missing 'engine_name' field in request"}), + 400, + ) + + if engine_name in _engines: + return ( + jsonify( + { + "error": f"Engine '{engine_name}' already exists. " + "Use a different name or delete the existing " + "engine first." + } + ), + 400, + ) + + # Dynamic import (can be done in main thread) + try: + engine_class = import_from_string(engine) + + # Validate that the class is a TrainEngine or InferenceEngine + if not issubclass(engine_class, TrainEngine) and not issubclass( + engine_class, InferenceEngine + ): + raise TypeError( + "Engine class must be a subclass of TrainEngine or " + f"InferenceEngine, got {engine_class}.." + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine}': {e}") + return ( + jsonify({"error": (f"Failed to import engine '{engine}': {str(e)}")}), + 400, + ) + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + return jsonify({"error": str(e)}), 400 + + # Instantiate engine in engine thread (may involve NCCL init) + def create_engine_in_engine_thread(): + """Create engine in engine thread.""" + try: + engine_obj = engine_class(*init_args, **init_kwargs) + logger.info( + f"Engine '{engine_name}' (class: {engine}) " + "instantiated successfully" + ) + return engine_obj + except Exception as e: + logger.error( + f"Failed to instantiate engine: {e}\n{traceback.format_exc()}" + ) + raise + + try: + engine_obj = _submit_to_engine_thread( + "create_engine", create_engine_in_engine_thread + ) + _engines[engine_name] = engine_obj + return jsonify( + { + "status": "success", + "message": (f"Engine '{engine_name}' created and initialized"), + "engine_name": engine_name, + "result": None, + } + ) + except Exception as e: + return ( + jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), + 500, + ) + + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@engine_bp.route("/call", methods=["POST"]) +def call_engine_method(): + """Call a method on an engine instance. + + This endpoint is routed to the engine thread to ensure all engine + operations run serially in the same thread, preventing NCCL conflicts. + + Expected JSON payload:: + + { + "method": "train_batch", + "engine_name": "actor/0", + "args": [...], + "kwargs": {...} + } + """ + global _engines + + try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + method_name = data.get("method") + engine_name = data.get("engine_name") + raw_args = data.get("args", []) + raw_kwargs = data.get("kwargs", {}) + + if not method_name: + return ( + jsonify({"error": "Missing 'method' field in request"}), + 400, + ) + + if not engine_name: + return ( + jsonify({"error": "Missing 'engine_name' field in request"}), + 400, + ) + + if engine_name not in _engines: + return ( + jsonify( + { + "error": f"Engine '{engine_name}' not found. " + f"Available engines: {list(_engines.keys())}" + } + ), + 404, + ) + + # Get the specific engine to call + engine = _engines[engine_name] + + # Deserialize data + raw_args = deserialize_value(raw_args) + raw_kwargs = deserialize_value(raw_kwargs) + # Fetch remote tensors + args = RTensor.localize(raw_args) + kwargs = RTensor.localize(raw_kwargs) + + def execute_in_engine_thread(): + try: + # Broadcast args when engine is a TrainEngine and initialized + if isinstance(engine, TrainEngine) and engine.initialized: + logger.debug( + f"Broadcasting data for TrainEngine method: {method_name}" + ) + + nonlocal raw_args, raw_kwargs + raw_args = broadcast_tensor_container( + tensor_container_to( + raw_args, current_platform.current_device() + ), + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + raw_kwargs = broadcast_tensor_container( + tensor_container_to( + raw_kwargs, current_platform.current_device() + ), + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + + args_bcast = tensor_container_to( + args, current_platform.current_device() + ) + args_bcast = broadcast_tensor_container( + args_bcast, + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + kwargs_bcast = tensor_container_to( + kwargs, current_platform.current_device() + ) + kwargs_bcast = broadcast_tensor_container( + kwargs_bcast, + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + logger.debug("Broadcasting data done.") + else: + args_bcast = args + kwargs_bcast = kwargs + + logger.debug(f"Calling engine '{engine_name}' method: {method_name}") + + # Determine trace category based on method name + category = "misc" # Default category + method_lower = method_name.lower() + if any(keyword in method_lower for keyword in ["submit", "wait"]): + category = "scheduler" + elif any( + keyword in method_lower + for keyword in ["update_weights", "broadcast"] + ): + category = "comm" + elif any(keyword in method_lower for keyword in ["save", "load"]): + category = "io" + elif any( + keyword in method_lower + for keyword in [ + "train", + "eval", + "forward", + "compute", + "step", + "update", + "optimizer", + "zero_grad", + "lr_scheduler", + ] + ): + category = "compute" + + # Wrap engine method call with perf_tracer + with perf_tracer.trace_scope( + f"rpc.{method_name}", + category=category, + args={"method": method_name, "engine": engine_name}, + ): + method = getattr(engine, method_name) + result = method(*args_bcast, **kwargs_bcast) + + # Handle update weights future + if isinstance(result, Future): + logger.debug("Waiting for update weights future") + result = result.result() + logger.debug("Update weights future done") + + return result + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + raise ValueError(f"Engine does not have method '{method_name}'") + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: " + f"{e}\n{traceback.format_exc()}" + ) + raise + + try: + result = _submit_to_engine_thread( + f"call_{method_name}", execute_in_engine_thread + ) + except Exception as e: + error_msg = str(e) + if "Engine does not have method" in error_msg: + return ( + jsonify({"error": error_msg}), + 400, + ) + return ( + jsonify( + {"error": (f"Engine method '{method_name}' failed: {error_msg}")} + ), + 500, + ) + + # Convert all tensors to RTensors and store locally + state = get_state() + result = RTensor.remotize(result, node_addr=state.node_addr) + serialized_result = serialize_value(result) + return jsonify({"status": "success", "result": serialized_result}) + + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 diff --git a/areal/infra/rpc/rpc_server.py b/areal/infra/rpc/rpc_server.py index 41a3d12223..edc230a7ae 100644 --- a/areal/infra/rpc/rpc_server.py +++ b/areal/infra/rpc/rpc_server.py @@ -1,1005 +1,37 @@ -import argparse -import getpass -import logging as stdlib_logging -import os -import subprocess -import sys -import time -import traceback -from collections.abc import Callable -from concurrent.futures import Future -from pathlib import Path -from queue import Queue -from threading import Lock, Thread -from typing import Any - -import orjson -import requests -from flask import Flask, Response, jsonify, request -from werkzeug.serving import make_server - -from areal.api import InferenceEngine, TrainEngine -from areal.api.cli_args import BaseExperimentConfig, NameResolveConfig -from areal.infra.platforms import current_platform -from areal.infra.rpc import rtensor -from areal.infra.rpc.rtensor import RTensor -from areal.infra.rpc.serialization import ( - deserialize_value, - serialize_value, -) -from areal.infra.utils.proc import kill_process_tree, run_with_streaming_logs -from areal.utils import logging, name_resolve, names, perf_tracer, seeding -from areal.utils.data import ( - broadcast_tensor_container, - tensor_container_to, -) -from areal.utils.dynamic_import import import_from_string -from areal.utils.network import ( - find_free_ports, - format_hostport, - gethostip, -) - -logger = logging.getLogger("SyncRPCServer") - -# Global engine instances - keyed by engine_name (e.g., "actor/0", "ref/0") -_engines: dict[str, TrainEngine | InferenceEngine] = {} - -_role: str | None = None - -# Engine thread for executing all engine-related endpoints serially -# This ensures NCCL compatibility by running engine operations in a single thread, -# while allowing /data/ endpoints to be processed concurrently -_engine_thread: Thread | None = None -_engine_work_queue: Queue[tuple[Callable, tuple, dict, Future]] | None = None -_engine_thread_lock = Lock() - -# Server address (set at startup) -_server_host: str = "0.0.0.0" -_server_port: int = 8000 - -_allocated_ports: set[int] = set() - -# Forked child processes - tracked for cleanup -_forked_children: list[subprocess.Popen] = [] -_forked_children_lock = Lock() -# Map (role, worker_index) to forked process for selective killing -_forked_children_map: dict[tuple[str, int], subprocess.Popen] = {} - -# Server config (needed for /fork endpoint to spawn children with same config) -_experiment_name: str | None = None -_trial_name: str | None = None -_name_resolve_type: str = "nfs" -_nfs_record_root: str = "/tmp/areal/name_resolve" -_etcd3_addr: str = "localhost:2379" -_fileroot: str | None = None # Log file directory root - -# Create Flask app -app = Flask(__name__) - - -def _init_engine_thread(): - global _engine_thread, _engine_work_queue - - with _engine_thread_lock: - if _engine_thread is not None: - if _engine_thread.is_alive(): - return # Already initialized - else: - raise RuntimeError("Engine thread is dead.") - - _engine_work_queue = Queue() - - def engine_worker(): - logger.info("Engine thread started") - while True: - try: - work_item = _engine_work_queue.get() - if work_item is None: # Shutdown signal - logger.info("Engine thread shutting down") - break - - func, args, kwargs, future, func_name = work_item - try: - result = func(*args, **kwargs) - future.set_result(result) - except Exception as e: - future.set_exception(e) - finally: - _engine_work_queue.task_done() - except Exception as e: - logger.error( - f"Error in engine thread when " - f"running {func_name}: {e}\n{traceback.format_exc()}" - ) - if work_item and len(work_item) > 3: - work_item[3].set_exception(e) - - _engine_thread = Thread(target=engine_worker, daemon=True, name="EngineWorker") - _engine_thread.start() - logger.info("Engine thread initialized") - - -def _submit_to_engine_thread(func_name: str, func: Callable, *args, **kwargs) -> Any: - global _engine_work_queue - - _init_engine_thread() - - future = Future() - _engine_work_queue.put((func, args, kwargs, future, func_name)) - return future.result() # Block until result is available - - -@app.route("/health", methods=["GET"]) -def health_check(): - """Health check endpoint to verify server is alive.""" - global _engines - return jsonify( - { - "status": "healthy", - "engine_count": len(_engines), - "engines": list(_engines.keys()), - } - ) - - -@app.route("/alloc_ports", methods=["POST"]) -def alloc_ports(): - """Allocate multiple free ports. - - Expected JSON payload: - { - "count": 5 # Number of ports to allocate - } - """ - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - count = data.get("count") - if count is None: - return jsonify({"error": "Missing 'count' field in request"}), 400 - - if not isinstance(count, int) or count <= 0: - return jsonify({"error": "'count' must be a positive integer"}), 400 - - global _allocated_ports - ports = find_free_ports(count, exclude_ports=_allocated_ports) - _allocated_ports.update(ports) - - return jsonify({"status": "success", "ports": ports, "host": _server_host}) - - except Exception as e: - logger.error(f"Error in alloc_ports: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -def _wait_for_worker_ready(host: str, port: int, timeout: float = 60) -> bool: - """Wait for a worker to be ready by polling its health endpoint. - - Args: - host: The host address of the worker. - port: The port of the worker. - timeout: Maximum time to wait in seconds (default: 60). - - Returns: - True if the worker is ready, False if timeout is reached. - """ - url = f"http://{format_hostport(host, port)}/health" - deadline = time.time() + timeout - - while time.time() < deadline: - try: - resp = requests.get(url, timeout=2) - if resp.status_code == 200: - return True - except requests.exceptions.RequestException: - pass - time.sleep(0.5) - - return False - - -@app.route("/fork", methods=["POST"]) -def fork_worker(): - """Fork a new worker process on the same node. - - This endpoint spawns a new RPC server process as a child of this worker. - The child inherits the same environment (including CUDA_VISIBLE_DEVICES) - but runs as an independent process with its own engine registry. - - Expected JSON payload: - { - "role": "ref", # Role name for the forked worker - "worker_index": 0, # Worker index - "command": "areal.infra.rpc.rpc_server" # Optional: custom module to run - } - - Returns: - { - "status": "success", - "host": "192.168.1.10", - "port": 8001, - "pid": 12345 - } - """ - global _forked_children, _forked_children_map, _allocated_ports - - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - role = data.get("role") - worker_index = data.get("worker_index") - command = data.get("command") # Optional custom module path - - if role is None: - return jsonify({"error": "Missing 'role' field in request"}), 400 - if worker_index is None: - return jsonify({"error": "Missing 'worker_index' field in request"}), 400 - - # Allocate a free port for the child process - ports = find_free_ports(1, exclude_ports=_allocated_ports) - child_port = ports[0] - _allocated_ports.add(child_port) - - # Build command for child process - # Use custom module if specified, otherwise default to rpc_server - module = command if command else "areal.infra.rpc.rpc_server" - cmd = [ - sys.executable, - "-m", - module, - "--host", - "0.0.0.0", - "--port", - str(child_port), - "--experiment-name", - _experiment_name, - "--trial-name", - _trial_name, - "--role", - role, - "--worker-index", - str(worker_index), - "--name-resolve-type", - _name_resolve_type, - "--nfs-record-root", - _nfs_record_root, - "--etcd3-addr", - _etcd3_addr, - "--fileroot", - _fileroot, - ] - - logger.info( - f"Forking new worker process for role '{role}' index {worker_index} " - f"on port {child_port}" - ) - - # Build shell command with tee/sed for streaming logs to terminal and files - # This matches LocalScheduler's logging pattern - log_dir = ( - Path(_fileroot) - / "logs" - / getpass.getuser() - / _experiment_name - / _trial_name - ) - log_dir.mkdir(parents=True, exist_ok=True) - log_file = log_dir / f"{role}.log" - merged_log = log_dir / "merged.log" - - logger.info(f"Forked worker logs will be written to: {log_file}") - - # Use streaming log utility for terminal, role log, and merged log output - child_process = run_with_streaming_logs( - cmd, - log_file, - merged_log, - role, - env=os.environ.copy(), - ) - - with _forked_children_lock: - _forked_children.append(child_process) - _forked_children_map[(role, worker_index)] = child_process - - # Wait for child to be ready - child_host = _server_host - if not _wait_for_worker_ready(child_host, child_port): - # Cleanup on failure - try: - kill_process_tree(child_process.pid, timeout=3, graceful=True) - except Exception: - pass - with _forked_children_lock: - if child_process in _forked_children: - _forked_children.remove(child_process) - _forked_children_map.pop((role, worker_index), None) - _allocated_ports.discard(child_port) - return jsonify( - {"error": "Forked worker failed to start within timeout"} - ), 500 - - logger.info( - f"Forked worker for role '{role}' index {worker_index} ready at " - f"{child_host}:{child_port} (pid={child_process.pid})" - ) - - return jsonify( - { - "status": "success", - "host": child_host, - "port": child_port, - "pid": child_process.pid, - } - ) - - except Exception as e: - logger.error(f"Error in fork: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/kill_forked_worker", methods=["POST"]) -def kill_forked_worker(): - """Kill a specific forked worker process. - - This endpoint terminates a previously forked child process identified by - its role and worker_index. - - Expected JSON payload: - { - "role": "ref", # Role name of the forked worker - "worker_index": 0 # Worker index - } - - Returns: - { - "status": "success", - "message": "Killed forked worker ref/0 (pid=12345)" - } - """ - global _forked_children, _forked_children_map - - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - role = data.get("role") - worker_index = data.get("worker_index") - - if role is None: - return jsonify({"error": "Missing 'role' field in request"}), 400 - if worker_index is None: - return jsonify({"error": "Missing 'worker_index' field in request"}), 400 - - key = (role, worker_index) - - # Remove from tracking structures first (hold lock only for dict/list operations) - with _forked_children_lock: - child_process = _forked_children_map.pop(key, None) - if child_process: - try: - _forked_children.remove(child_process) - except ValueError: - # Defensive: process was in map but not in list - logger.warning( - f"Process for {role}/{worker_index} was in map but not in list" - ) - - if child_process is None: - return jsonify( - {"error": f"Forked worker {role}/{worker_index} not found"} - ), 404 - - pid = child_process.pid - - # Kill the process tree (outside the lock to avoid blocking other operations) - try: - if child_process.poll() is None: # Still running - kill_process_tree(pid, timeout=3, graceful=True) - logger.info(f"Killed forked worker {role}/{worker_index} (pid={pid})") - except Exception as e: - logger.error( - f"Error killing forked worker {role}/{worker_index} (pid={pid}): {e}" - ) - return jsonify( - { - "error": f"Failed to kill forked worker: {str(e)}", - "pid": pid, - } - ), 500 - - return jsonify( - { - "status": "success", - "message": f"Killed forked worker {role}/{worker_index} (pid={pid})", - } - ) - - except Exception as e: - logger.error(f"Error in kill_forked_worker: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/configure", methods=["POST"]) -def configure(): - """Configure worker with experiment config. - - This endpoint is routed to the engine thread for serial execution. - """ - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - config = data.get("config") - if config is None: - return jsonify({"error": "Missing 'config' field in request"}), 400 - - rank = data.get("rank") - if rank is None: - return jsonify({"error": "Missing 'rank' field in request"}), 400 - - config = deserialize_value(config) - config: BaseExperimentConfig - - def execute_configure(): - global _role - seeding.set_random_seed(config.seed, key=f"{_role}{rank}") - return { - "status": "success", - "message": "Worker configured successful.", - "result": None, - } - - result = _submit_to_engine_thread("configure", execute_configure) - return jsonify(result) - except Exception as e: - logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/set_env", methods=["POST"]) -def set_env(): - """Set environment variables for the worker process. - - This endpoint is routed to the engine thread for serial execution. - """ - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 +"""AReaL Sync RPC Server — Guard + Data + Engine composition. - env_payload = data.get("env") - if env_payload is None: - return jsonify({"error": "Missing 'env' field in request"}), 400 - if not isinstance(env_payload, dict): - return jsonify({"error": "'env' must be a dictionary"}), 400 +This module composes the shared Guard with data and engine blueprints +to create the full RPC server used by training workers. - for key in env_payload.keys(): - if not isinstance(key, str): - return ( - jsonify( - { - "error": ( - f"Environment variable name must be str, got {type(key)}" - ) - } - ), - 400, - ) +Usage:: - def execute_set_env(): - for key, value in env_payload.items(): - os.environ[key] = str(value) - logger.info(f"Set {key}={value}") - return {"status": "success"} + python -m areal.infra.rpc.rpc_server \\ + --experiment-name exp1 --trial-name trial1 \\ + --role actor --worker-index 0 +""" - result = _submit_to_engine_thread("set_env", execute_set_env) - return jsonify(result) +from __future__ import annotations - except Exception as e: - logger.error(f"Unexpected error in set_env: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/create_engine", methods=["POST"]) -def create_engine(): - """ - Create and initialize a TrainEngine or InferenceEngine instance on this worker. - - This endpoint is routed to the engine thread for serial execution. - Supports multiple engines per worker, keyed by engine_name. - - Expected JSON payload: - { - "engine": "areal.engine.fsdp_engine.FSDPPPOActor", # Import path - "engine_name": "actor/0", # Unique name for this engine (required) - "init_args": [...], # Positional arguments - "init_kwargs": { - "config": ..., # Engine config - } - } - """ - global _engines - - try: - # Parse request in main thread (has Flask request context) - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - engine = data.get("engine") - engine_name = data.get("engine_name") - # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) - init_args = deserialize_value(data.get("init_args", [])) - init_kwargs = deserialize_value(data.get("init_kwargs", {})) - - if not engine: - return jsonify({"error": "Missing 'engine' field in request"}), 400 - - if not engine_name: - return jsonify({"error": "Missing 'engine_name' field in request"}), 400 - - if engine_name in _engines: - return jsonify( - { - "error": f"Engine '{engine_name}' already exists. " - "Use a different name or delete the existing engine first." - } - ), 400 - - # Dynamic import (can be done in main thread) - try: - engine_class = import_from_string(engine) - - # Validate that the class is a TrainEngine or InferenceEngine - if not issubclass(engine_class, TrainEngine) and not issubclass( - engine_class, InferenceEngine - ): - raise TypeError( - f"Engine class must be a subclass of TrainEngine or InferenceEngine, " - f"got {engine_class}.." - ) - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import engine '{engine}': {e}") - return ( - jsonify({"error": f"Failed to import engine '{engine}': {str(e)}"}), - 400, - ) - except TypeError as e: - logger.error(f"Invalid engine type: {e}") - return jsonify({"error": str(e)}), 400 - - # Instantiate engine in engine thread (may involve NCCL initialization) - def create_engine_in_engine_thread(): - """Create engine in engine thread.""" - try: - engine_obj = engine_class(*init_args, **init_kwargs) - logger.info( - f"Engine '{engine_name}' (class: {engine}) instantiated successfully" - ) - return engine_obj - except Exception as e: - logger.error( - f"Failed to instantiate engine: {e}\n{traceback.format_exc()}" - ) - raise - - try: - engine_obj = _submit_to_engine_thread( - "create_engine", create_engine_in_engine_thread - ) - _engines[engine_name] = engine_obj - return jsonify( - { - "status": "success", - "message": f"Engine '{engine_name}' created and initialized", - "engine_name": engine_name, - "result": None, - } - ) - except Exception as e: - return jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), 500 - - except Exception as e: - logger.error( - f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" - ) - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/call", methods=["POST"]) -def call_engine_method(): - """ - Call a method on an engine instance. - - This endpoint is routed to the engine thread to ensure all engine operations - run serially in the same thread, preventing NCCL conflicts. - - Expected JSON payload: - { - "method": "train_batch", - "engine_name": "actor/0", # Required: name of engine to call - "args": [...], - "kwargs": {...} - } - """ - global _engines - - try: - data = request.get_json() - if data is None: - return jsonify({"error": "Invalid JSON in request body"}), 400 - - method_name = data.get("method") - engine_name = data.get("engine_name") - raw_args = data.get("args", []) - raw_kwargs = data.get("kwargs", {}) - - if not method_name: - return jsonify({"error": "Missing 'method' field in request"}), 400 - - if not engine_name: - return jsonify({"error": "Missing 'engine_name' field in request"}), 400 - - if engine_name not in _engines: - return ( - jsonify( - { - "error": f"Engine '{engine_name}' not found. " - f"Available engines: {list(_engines.keys())}" - } - ), - 404, - ) - - # Get the specific engine to call - engine = _engines[engine_name] - - # Deserialize data - raw_args = deserialize_value(raw_args) - raw_kwargs = deserialize_value(raw_kwargs) - # Fetch remote tensors - args = RTensor.localize(raw_args) - kwargs = RTensor.localize(raw_kwargs) - - def execute_in_engine_thread(): - try: - # Broadcast args when engine is a TrainEngine and has been initialized - if isinstance(engine, TrainEngine) and engine.initialized: - logger.debug( - f"Broadcasting data for TrainEngine method: {method_name}" - ) - - nonlocal raw_args, raw_kwargs - raw_args = broadcast_tensor_container( - tensor_container_to( - raw_args, current_platform.current_device() - ), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - raw_kwargs = broadcast_tensor_container( - tensor_container_to( - raw_kwargs, current_platform.current_device() - ), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - - args_bcast = tensor_container_to( - args, current_platform.current_device() - ) - args_bcast = broadcast_tensor_container( - args_bcast, - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - kwargs_bcast = tensor_container_to( - kwargs, current_platform.current_device() - ) - kwargs_bcast = broadcast_tensor_container( - kwargs_bcast, - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - logger.debug("Broadcasting data done.") - else: - args_bcast = args - kwargs_bcast = kwargs - - logger.debug(f"Calling engine '{engine_name}' method: {method_name}") - - # Determine trace category based on method name - category = "misc" # Default category - method_lower = method_name.lower() - if any(keyword in method_lower for keyword in ["submit", "wait"]): - category = "scheduler" - elif any( - keyword in method_lower - for keyword in ["update_weights", "broadcast"] - ): - category = "comm" - elif any(keyword in method_lower for keyword in ["save", "load"]): - category = "io" - elif any( - keyword in method_lower - for keyword in [ - "train", - "eval", - "forward", - "compute", - "step", - "update", - "optimizer", - "zero_grad", - "lr_scheduler", - ] - ): - category = "compute" - - # Wrap engine method call with perf_tracer - with perf_tracer.trace_scope( - f"rpc.{method_name}", - category=category, - args={"method": method_name, "engine": engine_name}, - ): - method = getattr(engine, method_name) - result = method(*args_bcast, **kwargs_bcast) - - # Handle update weights future - if isinstance(result, Future): - logger.debug("Waiting for update weights future") - result = result.result() - logger.debug("Update weights future done") - - return result - except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") - raise ValueError(f"Engine does not have method '{method_name}'") - except Exception as e: - logger.error( - f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - raise - - try: - result = _submit_to_engine_thread( - f"call_{method_name}", execute_in_engine_thread - ) - except Exception as e: - error_msg = str(e) - if "Engine does not have method" in error_msg: - return ( - jsonify({"error": error_msg}), - 400, - ) - return ( - jsonify( - {"error": f"Engine method '{method_name}' failed: {error_msg}"} - ), - 500, - ) - - # Convert all tensors to RTensors and store the tensor locally - result = RTensor.remotize(result, node_addr=f"{_server_host}:{_server_port}") - serialized_result = serialize_value(result) - return jsonify({"status": "success", "result": serialized_result}) - - except Exception as e: - logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -# ==================== Batch Data Storage Endpoints ==================== -@app.route("/data/", methods=["PUT"]) -def store_batch_data(shard_id: str): - """Store batch data shard.""" - - try: - data_bytes = request.get_data() - - # Deserialize to get tensor (already on CPU) - serialized_data = orjson.loads(data_bytes) - data = deserialize_value(serialized_data) - - rtensor.store(shard_id, data) - - logger.debug(f"Stored batch shard {shard_id} (size={len(data_bytes)} bytes)") - return jsonify({"status": "ok", "shard_id": shard_id}) - - except Exception as e: - logger.error(f"Error storing batch shard {shard_id}: {e}") - return jsonify({"status": "error", "message": str(e)}), 500 - - -@app.route("/data/", methods=["GET"]) -def retrieve_batch_data(shard_id: str): - """Retrieve batch data shard.""" - - logger.debug(f"Received data get request for shard {shard_id}") - try: - try: - data = rtensor.fetch(shard_id) - except KeyError: - return ( - jsonify( - { - "status": "error", - "message": f"Shard {shard_id} not found", - } - ), - 404, - ) - - serialized_data = serialize_value(data) - data_bytes = orjson.dumps(serialized_data) - - logger.debug(f"Retrieved batch shard {shard_id} (size={len(data_bytes)} bytes)") - return Response(data_bytes, mimetype="application/octet-stream") - - except Exception as e: - logger.error(f"Error retrieving batch shard {shard_id}: {e}") - return jsonify({"status": "error", "message": str(e)}), 500 - - -@app.route("/data/batch", methods=["POST"]) -def retrieve_batch_data_many(): - """Retrieve multiple batch data shards in one request.""" - - try: - payload = request.get_json(silent=True) or {} - shard_ids = payload.get("shard_ids", []) - if not isinstance(shard_ids, list) or not all( - isinstance(shard_id, str) for shard_id in shard_ids - ): - return ( - jsonify( - { - "status": "error", - "message": "Expected JSON body with string list field 'shard_ids'", - } - ), - 400, - ) - - data = [] - missing_shard_ids = [] - for shard_id in shard_ids: - try: - data.append(rtensor.fetch(shard_id)) - except KeyError: - missing_shard_ids.append(shard_id) - - if missing_shard_ids: - return ( - jsonify( - { - "status": "error", - "message": "One or more requested shards were not found", - "missing_shard_ids": missing_shard_ids, - } - ), - 400, - ) - - serialized_data = serialize_value(data) - data_bytes = orjson.dumps(serialized_data) - logger.debug( - "Retrieved %s batch shards (size=%s bytes)", - len(shard_ids), - len(data_bytes), - ) - return Response(data_bytes, mimetype="application/octet-stream") - - except Exception as e: - logger.error(f"Error retrieving batch shards: {e}\n{traceback.format_exc()}") - return jsonify({"status": "error", "message": str(e)}), 500 - - -@app.route("/data/clear", methods=["DELETE"]) -def clear_batch_data(): - """Clear specified batch data shards. - - Expected JSON payload: - { - "shard_ids": ["id1", "id2", ...] - } - """ - try: - data = request.get_json(silent=True) or {} - shard_ids = data.get("shard_ids", []) - if not isinstance(shard_ids, list): - return ( - jsonify({"status": "error", "message": "'shard_ids' must be a list"}), - 400, - ) - - cleared_count = sum(rtensor.remove(sid) for sid in shard_ids) - stats = dict(cleared_count=cleared_count, **rtensor.storage_stats()) - logger.info(f"Cleared {cleared_count} batch shards. Stats: {stats}") - stats.update({"status": "ok"}) - return jsonify(stats) - - except Exception as e: - logger.error(f"Error clearing batch data: {e}") - return jsonify({"status": "error", "message": str(e)}), 500 - - -# ==================== Cleanup ==================== - - -def cleanup_forked_children(): - """Clean up all forked child processes.""" - global _forked_children, _forked_children_map - - with _forked_children_lock: - if not _forked_children: - return - - logger.info(f"Cleaning up {len(_forked_children)} forked child processes") - for child in _forked_children: - try: - if child.poll() is None: # Still running - kill_process_tree(child.pid, timeout=3, graceful=True) - logger.info(f"Killed forked child process {child.pid}") - except Exception as e: - logger.error(f"Error killing forked child {child.pid}: {e}") - _forked_children.clear() - _forked_children_map.clear() - - -def cleanup_engines(): - """Clean up all engines on shutdown.""" - global _engines - if _engines: - for engine_name, engine in list(_engines.items()): - try: - engine.destroy() - logger.info(f"Engine '{engine_name}' destroyed successfully") - except Exception as e: - logger.error(f"Error destroying engine '{engine_name}': {e}") - _engines.clear() - - -def cleanup_engine_thread(): - """Clean up engine thread on shutdown.""" - global _engine_thread, _engine_work_queue +import logging as stdlib_logging - with _engine_thread_lock: - if _engine_work_queue is not None: - # Send shutdown signal - _engine_work_queue.put(None) - _engine_work_queue = None +from areal.infra.rpc.guard.app import ( + GuardState, + configure_state_from_args, + create_app, + make_base_parser, + run_server, +) +from areal.infra.rpc.guard.data_blueprint import data_bp +from areal.infra.rpc.guard.engine_blueprint import engine_bp, register_engine_hooks +from areal.utils import logging, perf_tracer - if _engine_thread is not None: - _engine_thread.join(timeout=5.0) - if _engine_thread.is_alive(): - logger.warning("Engine thread did not shut down gracefully") - _engine_thread = None - logger.info("Engine thread cleaned up") +logger = logging.getLogger("SyncRPCServer") def main(): - """Main entry point for the sync RPC server.""" - parser = argparse.ArgumentParser( + parser = make_base_parser( description="AReaL Sync RPC Server for TrainEngine/InferenceEngine" ) - parser.add_argument( - "--port", - type=int, - default=0, - help="Port to serve on (default: 0 = auto-assign)", - ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" - ) parser.add_argument( "--werkzeug-log-level", type=str, @@ -1007,102 +39,25 @@ def main(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Log level for Werkzeug (Flask's WSGI server). Default: WARNING", ) - # name_resolve config - parser.add_argument("--experiment-name", type=str, required=True) - parser.add_argument("--trial-name", type=str, required=True) - parser.add_argument("--role", type=str, required=True) - parser.add_argument("--worker-index", type=int, default=-1) - parser.add_argument("--name-resolve-type", type=str, default="nfs") - parser.add_argument( - "--nfs-record-root", type=str, default="/tmp/areal/name_resolve" - ) - parser.add_argument("--etcd3-addr", type=str, default="localhost:2379") - parser.add_argument( - "--fileroot", - type=str, - default=None, - help="Root directory for log files. If set, forked worker logs are written here.", - ) args, _ = parser.parse_known_args() - # Configure Werkzeug logging werkzeug_logger = stdlib_logging.getLogger("werkzeug") werkzeug_logger.setLevel(getattr(stdlib_logging, args.werkzeug_log_level)) - # Set global server address variables - global _server_host, _server_port, _role - global \ - _experiment_name, \ - _trial_name, \ - _name_resolve_type, \ - _nfs_record_root, \ - _etcd3_addr, \ - _fileroot - bind_host = args.host - if bind_host == "0.0.0.0": - host_ip = gethostip() - if ":" in host_ip: - bind_host = "::" - _server_host = host_ip - elif bind_host == "::": - _server_host = gethostip() - else: - _server_host = bind_host - _role = args.role - - # Set global config for fork endpoint - _experiment_name = args.experiment_name - _trial_name = args.trial_name - _name_resolve_type = args.name_resolve_type - _nfs_record_root = args.nfs_record_root - _etcd3_addr = args.etcd3_addr - _fileroot = args.fileroot + state = GuardState() + bind_host = configure_state_from_args(state, args) - # Get worker identity - worker_role = args.role - worker_index = args.worker_index - if "SLURM_PROCID" in os.environ: - # Overwriting with slurm task id - worker_index = os.environ["SLURM_PROCID"] - if worker_index == -1: - raise ValueError("Invalid worker index. Not found from SLURM environ or args.") - worker_id = f"{worker_role}/{worker_index}" + app = create_app(state) + app.register_blueprint(data_bp) + app.register_blueprint(engine_bp) + register_engine_hooks(state) - # Make a flask server - server = make_server(bind_host, args.port, app, threaded=True) - _server_port = server.socket.getsockname()[1] + state.register_cleanup_hook(lambda: perf_tracer.save(force=True)) - name_resolve.reconfigure( - NameResolveConfig( - type=args.name_resolve_type, - nfs_record_root=args.nfs_record_root, - etcd3_addr=args.etcd3_addr, - ) - ) - key = names.worker_discovery( - args.experiment_name, args.trial_name, args.role, worker_index - ) - name_resolve.add(key, format_hostport(_server_host, _server_port), replace=True) - - global _allocated_ports - _allocated_ports.add(_server_port) - - logger.info( - f"Starting sync RPC server on {_server_host}:{_server_port} for worker {worker_id}" - ) logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") - try: - server.serve_forever() - except KeyboardInterrupt: - logger.info("Shutting down sync RPC server") - finally: - perf_tracer.save(force=True) - cleanup_forked_children() - cleanup_engine_thread() - cleanup_engines() - server.shutdown() + run_server(state, app, bind_host, args.port) if __name__ == "__main__": diff --git a/areal/infra/rpc/rtensor.py b/areal/infra/rpc/rtensor.py index 620de400db..899c408e35 100644 --- a/areal/infra/rpc/rtensor.py +++ b/areal/infra/rpc/rtensor.py @@ -301,6 +301,19 @@ def set_backend(backend: RTensorBackend | None) -> None: _backend = backend +# ============================================================================= +# Client-side Fetch Buffer +# ============================================================================= +# Caches fetched tensors by shard_id so that repeated fetch() calls for the +# same shard (e.g. when the same rollout_batch is sent to multiple engine +# calls across RPC boundaries) avoid redundant network transfers. +# Entries are evicted by clear_node() when clear_batches() runs at the end +# of each train step. + +_fetch_buffer: dict[Any, torch.Tensor] = {} +_fetch_buffer_lock = Lock() + + @dataclass class RTensor: shard: TensorShardInfo @@ -309,7 +322,16 @@ class RTensor: def to_local(self) -> torch.Tensor: if not self.data.is_meta: return self.data + # Check client-side fetch buffer before making a network request. + with _fetch_buffer_lock: + cached = _fetch_buffer.get(self.shard.shard_id) + if cached is not None: + self.data = cached + return self.data + # Buffer miss: fetch from backend and populate buffer. self.data = get_backend().fetch([self.shard])[0] + with _fetch_buffer_lock: + _fetch_buffer[self.shard.shard_id] = self.data return self.data @staticmethod @@ -389,12 +411,26 @@ def localize(obj: Any) -> Any: RTensor._collect_all(obj, rtensors) meta_rtensors = [rt for rt in rtensors if rt.data.is_meta] if meta_rtensors: - shards = [rt.shard for rt in meta_rtensors] - results = get_backend().fetch(shards) - for rt, tensor in zip(meta_rtensors, results): - rt.data = tensor - - # Recursively replace RTensors with local tensors (all cache hits now) + # Resolve as many as possible from the client-side fetch buffer. + to_fetch: list[RTensor] = [] + with _fetch_buffer_lock: + for rt in meta_rtensors: + cached = _fetch_buffer.get(rt.shard.shard_id) + if cached is not None: + rt.data = cached + else: + to_fetch.append(rt) + + # Batch-fetch only the misses from the backend. + if to_fetch: + shards = [rt.shard for rt in to_fetch] + results = get_backend().fetch(shards) + with _fetch_buffer_lock: + for rt, tensor in zip(to_fetch, results, strict=True): + rt.data = tensor + _fetch_buffer[rt.shard.shard_id] = tensor + + # Recursively replace RTensors with local tensors (all buffer hits now) return RTensor._localize_recursive(obj) @staticmethod @@ -459,7 +495,7 @@ def _collect(o: Any) -> None: @staticmethod async def clear_node(node_addr: str, shard_ids: list[Any]) -> None: - """Clear shards from a node. + """Clear shards from a node and evict them from the fetch buffer. Parameters ---------- @@ -468,6 +504,9 @@ async def clear_node(node_addr: str, shard_ids: list[Any]) -> None: shard_ids : list[Any] List of shard IDs to delete """ + with _fetch_buffer_lock: + for sid in shard_ids: + _fetch_buffer.pop(sid, None) await get_backend().delete(node_addr, shard_ids) @property diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 53aa73d3a4..8c1b9a7a35 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -254,6 +254,27 @@ def _prepare_worker_specs( f"schedulings length ({len(schedulings)}) must be 1 or equal to replicas ({num_workers})", ) + @staticmethod + async def _wait_for_fork_ready( + session: aiohttp.ClientSession, + host: str, + port: int, + timeout: float = 60, + ) -> bool: + url = f"http://{format_hostport(host, port)}/health" + deadline = time.time() + timeout + while time.time() < deadline: + try: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=2) + ) as resp: + if resp.status == 200: + return True + except (TimeoutError, aiohttp.ClientError): + pass + await asyncio.sleep(0.5) + return False + async def _fork_single_worker( self, session: aiohttp.ClientSession, @@ -269,17 +290,65 @@ async def _fork_single_worker( ---------- command : str, optional Custom module path to run instead of the default rpc_server. - If specified, the forked process runs this module. """ worker_id = f"{role}/{idx}" - target_url = f"http://{format_hostport(target_wi.worker.ip, int(target_wi.worker.worker_ports[0]))}/fork" + guard_url = f"http://{format_hostport(target_wi.worker.ip, int(target_wi.worker.worker_ports[0]))}" try: - payload = {"role": role, "worker_index": idx} - if command is not None: - payload["command"] = command + # 1. Allocate a port on the target guard async with session.post( - target_url, + f"{guard_url}/alloc_ports", + json={"count": 1}, + ) as alloc_resp: + if alloc_resp.status != 200: + error_text = await alloc_resp.text() + raise WorkerCreationError( + role, + f"Port allocation failed for worker {idx}", + f"HTTP {alloc_resp.status}: {error_text}", + ) + alloc_data = await alloc_resp.json() + forked_host = alloc_data["host"] + forked_port = alloc_data["ports"][0] + + # 2. Build the full raw command + module_path = command or "areal.infra.rpc.rpc_server" + raw_cmd = [ + sys.executable, + "-m", + module_path, + "--host", + "0.0.0.0", + "--port", + str(forked_port), + "--experiment-name", + str(self.experiment_name), + "--trial-name", + str(self.trial_name), + "--role", + role, + "--worker-index", + str(idx), + ] + if self.name_resolve_config.type: + raw_cmd.extend(["--name-resolve-type", self.name_resolve_config.type]) + if self.name_resolve_config.nfs_record_root: + raw_cmd.extend( + ["--nfs-record-root", self.name_resolve_config.nfs_record_root] + ) + if self.name_resolve_config.etcd3_addr: + raw_cmd.extend(["--etcd3-addr", self.name_resolve_config.etcd3_addr]) + if self.fileroot: + raw_cmd.extend(["--fileroot", str(self.fileroot)]) + + # 3. Fork via raw_cmd + payload = { + "role": role, + "worker_index": idx, + "raw_cmd": raw_cmd, + } + async with session.post( + f"{guard_url}/fork", json=payload, ) as response: if response.status != 200: @@ -299,15 +368,30 @@ async def _fork_single_worker( result.get("error", "Unknown error"), ) - forked_host = result["host"] - forked_port = result["port"] forked_pid = result.get("pid") - logger.info( - f"Forked worker {worker_id} created at {forked_host}:{forked_port} " - f"(pid={forked_pid}) from {target_role}/{idx}" + # 4. Wait for the forked worker to become ready + if not await self._wait_for_fork_ready(session, forked_host, forked_port): + # Clean up the forked worker on the guard + try: + async with session.post( + f"{guard_url}/kill_forked_worker", + json={"role": role, "worker_index": idx}, + ): + pass + except Exception: + pass + raise WorkerCreationError( + role, + f"Forked worker {idx} failed to become ready", + f"Readiness timeout at {forked_host}:{forked_port}", ) + logger.info( + f"Forked worker {worker_id} created at {forked_host}:{forked_port} " + f"(pid={forked_pid}) from {target_role}/{idx}" + ) + except aiohttp.ClientError as e: raise WorkerCreationError( role, diff --git a/areal/infra/scheduler/ray.py b/areal/infra/scheduler/ray.py index 0307b1cb37..98a8d4b346 100644 --- a/areal/infra/scheduler/ray.py +++ b/areal/infra/scheduler/ray.py @@ -543,24 +543,7 @@ def fork_workers( """Fork new worker processes from existing workers. Creates new Ray actors colocated with existing workers of the target role. - The forked workers share the same placement groups as their target workers. - - Note: The `command` parameter is ignored for RayScheduler since Ray actors - always run the RayRPCServer. For custom module behavior, use LocalScheduler. - - Parameters - ---------- - role : str - Role name for the new forked workers (e.g., "proxy") - target_role : str - Role of existing workers to fork from (e.g., "rollout") - command : str, optional - Custom module path (ignored for Ray - Ray actors always run RayRPCServer) - - Returns - ------- - list[str] - List of worker IDs created (e.g., ["proxy/0", "proxy/1"]) + The ``command`` parameter is ignored — Ray actors always run RayRPCServer. """ if command is not None: logger.warning( diff --git a/areal/infra/scheduler/slurm.py b/areal/infra/scheduler/slurm.py index 46d7141e6b..16ef5402d9 100644 --- a/areal/infra/scheduler/slurm.py +++ b/areal/infra/scheduler/slurm.py @@ -3,6 +3,7 @@ import re import shlex import subprocess +import sys import time from dataclasses import dataclass from pathlib import Path @@ -433,6 +434,27 @@ def _get_colocation_nodes(self, target_role: str, replicas: int) -> tuple[int, s except subprocess.CalledProcessError as e: raise WorkerCreationError(target_role, f"Failed to query target job: {e}") + @staticmethod + async def _wait_for_fork_ready( + session: aiohttp.ClientSession, + host: str, + port: int, + timeout: float = 60, + ) -> bool: + url = f"http://{format_hostport(host, port)}/health" + deadline = time.time() + timeout + while time.time() < deadline: + try: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=2) + ) as resp: + if resp.status == 200: + return True + except (TimeoutError, aiohttp.ClientError): + pass + await asyncio.sleep(0.5) + return False + async def _fork_single_worker( self, session: aiohttp.ClientSession, @@ -448,17 +470,65 @@ async def _fork_single_worker( ---------- command : str, optional Custom module path to run instead of the default rpc_server. - If specified, the forked process runs this module. """ worker_id = f"{role}/{idx}" - target_url = f"http://{format_hostport(target_wi.worker.ip, int(target_wi.worker.worker_ports[0]))}/fork" + guard_url = f"http://{format_hostport(target_wi.worker.ip, int(target_wi.worker.worker_ports[0]))}" try: - payload = {"role": role, "worker_index": idx} - if command is not None: - payload["command"] = command + # 1. Allocate a port on the target guard async with session.post( - target_url, + f"{guard_url}/alloc_ports", + json={"count": 1}, + ) as alloc_resp: + if alloc_resp.status != 200: + error_text = await alloc_resp.text() + raise WorkerCreationError( + role, + f"Port allocation failed for worker {idx}", + f"HTTP {alloc_resp.status}: {error_text}", + ) + alloc_data = await alloc_resp.json() + forked_host = alloc_data["host"] + forked_port = alloc_data["ports"][0] + + # 2. Build the full raw command + module_path = command or "areal.infra.rpc.rpc_server" + raw_cmd = [ + sys.executable, + "-m", + module_path, + "--host", + "0.0.0.0", + "--port", + str(forked_port), + "--experiment-name", + str(self.experiment_name), + "--trial-name", + str(self.trial_name), + "--role", + role, + "--worker-index", + str(idx), + ] + if self.name_resolve_config.type: + raw_cmd.extend(["--name-resolve-type", self.name_resolve_config.type]) + if self.name_resolve_config.nfs_record_root: + raw_cmd.extend( + ["--nfs-record-root", self.name_resolve_config.nfs_record_root] + ) + if self.name_resolve_config.etcd3_addr: + raw_cmd.extend(["--etcd3-addr", self.name_resolve_config.etcd3_addr]) + if self.fileroot: + raw_cmd.extend(["--fileroot", str(self.fileroot)]) + + # 3. Fork via raw_cmd + payload = { + "role": role, + "worker_index": idx, + "raw_cmd": raw_cmd, + } + async with session.post( + f"{guard_url}/fork", json=payload, ) as response: if response.status != 200: @@ -478,15 +548,29 @@ async def _fork_single_worker( result.get("error", "Unknown error"), ) - forked_host = result["host"] - forked_port = result["port"] forked_pid = result.get("pid") - logger.info( - f"Forked worker {worker_id} created at {forked_host}:{forked_port} " - f"(pid={forked_pid}) from {target_role}/{idx}" + # 4. Wait for the forked worker to become ready + if not await self._wait_for_fork_ready(session, forked_host, forked_port): + try: + async with session.post( + f"{guard_url}/kill_forked_worker", + json={"role": role, "worker_index": idx}, + ): + pass + except Exception: + pass + raise WorkerCreationError( + role, + f"Forked worker {idx} failed to become ready", + f"Readiness timeout at {forked_host}:{forked_port}", ) + logger.info( + f"Forked worker {worker_id} created at {forked_host}:{forked_port} " + f"(pid={forked_pid}) from {target_role}/{idx}" + ) + except aiohttp.ClientError as e: raise WorkerCreationError( role, diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 52a9cbf212..35c8bb8e5d 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -414,7 +414,7 @@ def setup_file_logging( def log_swanlab_wandb_tensorboard(data, step=None, summary_writer=None): - # Logs data to SwanLab、 wandb、 TensorBoard. + # Logs data to SwanLab, wandb, TensorBoard, and Trackio. global _LATEST_LOG_STEP if step is None: @@ -435,6 +435,14 @@ def log_swanlab_wandb_tensorboard(data, step=None, summary_writer=None): wandb.log(data, step=step) + # trackio + try: + import trackio + + trackio.log(data, step=step) + except (ModuleNotFoundError, ImportError): + pass + # tensorboard if summary_writer is not None: for key, val in data.items(): diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index d03db37453..54e695b50c 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -5,6 +5,7 @@ import swanlab import torch.distributed as dist +import trackio import wandb from tensorboardX import SummaryWriter @@ -90,6 +91,19 @@ def init(self): logdir=self.get_log_path(self.config), mode=swanlab_config.mode, ) + + # trackio init + self._trackio_enabled = False + trackio_config = self.config.trackio + if trackio_config.mode != "disabled": + trackio.init( + project=trackio_config.project or self.config.experiment_name, + name=trackio_config.name or self.config.trial_name, + config=exp_config_dict, + space_id=trackio_config.space_id, + ) + self._trackio_enabled = True + # tensorboard logging self.summary_writer = None if self.config.tensorboard.path is not None: @@ -111,6 +125,8 @@ def close(self): ) wandb.finish() swanlab.finish() + if getattr(self, "_trackio_enabled", False): + trackio.finish() if self.summary_writer is not None: self.summary_writer.close() @@ -133,6 +149,8 @@ def commit(self, epoch: int, step: int, global_step: int, data: dict | list[dict self.print_stats(item) wandb.log(item, step=log_step + i) swanlab.log(item, step=log_step + i) + if getattr(self, "_trackio_enabled", False): + trackio.log(item, step=log_step + i) if self.summary_writer is not None: for key, val in item.items(): self.summary_writer.add_scalar(f"{key}", val, log_step + i) diff --git a/docs/generate_cli_docs.py b/docs/generate_cli_docs.py index bee4d9fda7..270d1f7b05 100644 --- a/docs/generate_cli_docs.py +++ b/docs/generate_cli_docs.py @@ -85,6 +85,7 @@ def categorize_dataclasses( "WandBConfig", "SwanlabConfig", "TensorBoardConfig", + "TrackioConfig", "SaverConfig", "EvaluatorConfig", "RecoverConfig", diff --git a/pyproject.toml b/pyproject.toml index e5dbd0d512..c9844e288d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ dependencies = [ # Monitoring and logging "wandb", "tensorboardx", + "trackio", "colorama", "colorlog", "swanboard==0.1.9b1", diff --git a/tests/experimental/inference_service/test_guard.py b/tests/experimental/inference_service/test_guard.py index bc7fa892ad..0aac9bf19b 100644 --- a/tests/experimental/inference_service/test_guard.py +++ b/tests/experimental/inference_service/test_guard.py @@ -15,52 +15,39 @@ from areal.experimental.inference_service.guard import app as guard_module from areal.experimental.inference_service.guard.app import app, cleanup_forked_children -# ============================================================================= -# Fixtures -# ============================================================================= +GUARD_APP = "areal.infra.rpc.guard.app" @pytest.fixture(autouse=True) def _reset_guard_globals(): - """Reset module-level global state before each test.""" - guard_module._allocated_ports = set() - guard_module._forked_children = [] - guard_module._forked_children_map = {} - guard_module._server_host = "10.0.0.1" - guard_module._experiment_name = "test-exp" - guard_module._trial_name = "test-trial" - guard_module._fileroot = None + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} + guard_module._state.server_host = "10.0.0.1" + guard_module._state.experiment_name = "test-exp" + guard_module._state.trial_name = "test-trial" + guard_module._state.fileroot = None yield - # Cleanup after test - guard_module._allocated_ports = set() - guard_module._forked_children = [] - guard_module._forked_children_map = {} + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} @pytest.fixture() def client(): - """Flask test client for RPCGuard app.""" app.config["TESTING"] = True with app.test_client() as c: yield c def _make_mock_process(pid: int = 12345, running: bool = True) -> MagicMock: - """Create a mock subprocess.Popen with controllable poll().""" proc = MagicMock(spec=subprocess.Popen) proc.pid = pid proc.poll.return_value = None if running else 0 return proc -# ============================================================================= -# TestHealth -# ============================================================================= - - class TestHealth: - """GET /health returns healthy status with child count.""" - def test_health_returns_200(self, client): resp = client.get("/health") assert resp.status_code == 200 @@ -69,22 +56,18 @@ def test_health_returns_200(self, client): assert data["forked_children"] == 0 def test_health_counts_forked_children(self, client): - # Add mock children to the global list - guard_module._forked_children = [MagicMock(), MagicMock(), MagicMock()] + guard_module._state.forked_children = [ + MagicMock(), + MagicMock(), + MagicMock(), + ] resp = client.get("/health") data = resp.get_json() assert data["forked_children"] == 3 -# ============================================================================= -# TestAllocPorts -# ============================================================================= - - class TestAllocPorts: - """POST /alloc_ports allocates unique ports and tracks exclusions.""" - - @patch("areal.experimental.inference_service.guard.app.find_free_ports") + @patch(f"{GUARD_APP}.find_free_ports") def test_alloc_ports_success(self, mock_find, client): mock_find.return_value = [9001, 9002, 9003] resp = client.post("/alloc_ports", json={"count": 3}) @@ -93,12 +76,10 @@ def test_alloc_ports_success(self, mock_find, client): assert data["status"] == "success" assert data["ports"] == [9001, 9002, 9003] assert data["host"] == "10.0.0.1" - # Ports tracked in exclusion set - assert guard_module._allocated_ports == {9001, 9002, 9003} + assert guard_module._state.allocated_ports == {9001, 9002, 9003} - @patch("areal.experimental.inference_service.guard.app.find_free_ports") + @patch(f"{GUARD_APP}.find_free_ports") def test_alloc_ports_excludes_previous(self, mock_find, client): - """Second allocation excludes ports from the first.""" mock_find.return_value = [9001, 9002, 9003] client.post("/alloc_ports", json={"count": 3}) @@ -108,14 +89,18 @@ def test_alloc_ports_excludes_previous(self, mock_find, client): data = resp.get_json() assert data["ports"] == [9004, 9005] - # find_free_ports called with prior exclusions _, kwargs = mock_find.call_args assert 9001 in kwargs.get("exclude_ports", set()) assert 9002 in kwargs.get("exclude_ports", set()) assert 9003 in kwargs.get("exclude_ports", set()) - # All 5 ports tracked - assert guard_module._allocated_ports == {9001, 9002, 9003, 9004, 9005} + assert guard_module._state.allocated_ports == { + 9001, + 9002, + 9003, + 9004, + 9005, + } def test_alloc_ports_missing_count(self, client): resp = client.post("/alloc_ports", json={}) @@ -139,146 +124,19 @@ def test_alloc_ports_no_json_body(self, client): assert resp.status_code == 400 -# ============================================================================= -# TestForkModulePath -# ============================================================================= - - -class TestForkModulePath: - """POST /fork with command field — module-path mode.""" - - @patch( - "areal.experimental.inference_service.guard.app._wait_for_worker_ready", - return_value=True, - ) - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_module_path_builds_correct_cmd( - self, mock_find, mock_run, mock_wait, client - ): - mock_find.return_value = [8001] - mock_proc = _make_mock_process(pid=42) - mock_run.return_value = mock_proc - - resp = client.post( - "/fork", - json={"role": "test-worker", "worker_index": 0, "command": "some.module"}, - ) - assert resp.status_code == 200 - data = resp.get_json() - assert data["status"] == "success" - assert data["host"] == "10.0.0.1" - assert data["port"] == 8001 - assert data["pid"] == 42 - - # Verify command was built correctly - call_args = mock_run.call_args - cmd = call_args[0][0] - assert "-m" in cmd - assert "some.module" in cmd - assert "--role" in cmd - assert "test-worker" in cmd - assert "--worker-index" in cmd - assert "0" in cmd - assert "--experiment-name" in cmd - assert "test-exp" in cmd - assert "--trial-name" in cmd - assert "test-trial" in cmd - assert "--port" in cmd - assert "8001" in cmd - - @patch( - "areal.experimental.inference_service.guard.app._wait_for_worker_ready", - return_value=True, - ) - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_module_path_tracks_child( - self, mock_find, mock_run, mock_wait, client - ): - mock_find.return_value = [8001] - mock_proc = _make_mock_process(pid=42) - mock_run.return_value = mock_proc - - client.post( - "/fork", - json={"role": "test", "worker_index": 0, "command": "some.module"}, - ) - - assert mock_proc in guard_module._forked_children - assert ("test", 0) in guard_module._forked_children_map - assert guard_module._forked_children_map[("test", 0)] is mock_proc - - @patch( - "areal.experimental.inference_service.guard.app._wait_for_worker_ready", - return_value=True, - ) - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_module_path_waits_for_ready( - self, mock_find, mock_run, mock_wait, client - ): - """Module-path mode polls health before returning.""" - mock_find.return_value = [8001] - mock_run.return_value = _make_mock_process() - - client.post( - "/fork", - json={"role": "test", "worker_index": 0, "command": "some.module"}, - ) - - mock_wait.assert_called_once_with("10.0.0.1", 8001) - - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") - @patch( - "areal.experimental.inference_service.guard.app._wait_for_worker_ready", - return_value=False, - ) - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_module_path_cleanup_on_ready_timeout( - self, mock_find, mock_run, mock_wait, mock_kill, client - ): - """If readiness polling fails, child is killed and cleaned up.""" - mock_find.return_value = [8001] - mock_proc = _make_mock_process(pid=99) - mock_run.return_value = mock_proc - - resp = client.post( - "/fork", - json={"role": "test", "worker_index": 0, "command": "some.module"}, - ) - assert resp.status_code == 500 - assert "failed to start" in resp.get_json()["error"].lower() - - # Child cleaned up - assert mock_proc not in guard_module._forked_children - assert ("test", 0) not in guard_module._forked_children_map - # Port freed - assert 8001 not in guard_module._allocated_ports - # kill_process_tree called - mock_kill.assert_called_once_with(99, timeout=3, graceful=True) - - -# ============================================================================= -# TestForkRawCommand -# ============================================================================= - - -class TestForkRawCommand: - """POST /fork with raw_cmd field — raw-command mode.""" - - @patch("areal.experimental.inference_service.guard.app._wait_for_worker_ready") - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_raw_cmd_passes_command_as_is( - self, mock_find, mock_run, mock_wait, client - ): - mock_find.return_value = [8001] +class TestFork: + @patch(f"{GUARD_APP}.run_with_streaming_logs") + def test_fork_raw_cmd_passes_command_as_is(self, mock_run, client): mock_proc = _make_mock_process(pid=55) mock_run.return_value = mock_proc - raw = ["python", "-m", "sglang.launch_server", "--model", "test-model"] + raw = [ + "python", + "-m", + "sglang.launch_server", + "--model", + "test-model", + ] resp = client.post( "/fork", json={"role": "sglang", "worker_index": 0, "raw_cmd": raw}, @@ -286,141 +144,123 @@ def test_fork_raw_cmd_passes_command_as_is( assert resp.status_code == 200 data = resp.get_json() assert data["status"] == "success" + assert data["host"] == "10.0.0.1" assert data["pid"] == 55 + assert "port" not in data - # Command passed as-is — NO scheduler args injected call_args = mock_run.call_args cmd = call_args[0][0] assert cmd == raw - assert "--experiment-name" not in cmd - assert "--role" not in cmd - - @patch("areal.experimental.inference_service.guard.app._wait_for_worker_ready") - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_raw_cmd_skips_readiness_polling( - self, mock_find, mock_run, mock_wait, client - ): - """Raw-command mode returns immediately without polling health.""" - mock_find.return_value = [8001] - mock_run.return_value = _make_mock_process() + + @patch(f"{GUARD_APP}.run_with_streaming_logs") + def test_fork_tracks_child(self, mock_run, client): + mock_proc = _make_mock_process(pid=42) + mock_run.return_value = mock_proc client.post( "/fork", json={ - "role": "sglang", + "role": "test", "worker_index": 0, - "raw_cmd": ["python", "-m", "sglang.launch_server"], + "raw_cmd": ["echo", "hello"], }, ) - mock_wait.assert_not_called() + assert mock_proc in guard_module._state.forked_children + assert ("test", 0) in guard_module._state.forked_children_map + assert guard_module._state.forked_children_map[("test", 0)] is mock_proc - @patch("areal.experimental.inference_service.guard.app.run_with_streaming_logs") - @patch("areal.experimental.inference_service.guard.app.find_free_ports") - def test_fork_raw_cmd_allocates_port_but_not_injected( - self, mock_find, mock_run, client - ): - """A port is allocated for tracking but NOT injected into raw_cmd.""" - mock_find.return_value = [9999] + @patch(f"{GUARD_APP}.run_with_streaming_logs") + def test_fork_with_env_overrides(self, mock_run, client): mock_run.return_value = _make_mock_process() resp = client.post( "/fork", json={ - "role": "sglang", + "role": "test", "worker_index": 0, "raw_cmd": ["echo", "hello"], + "env": {"MY_VAR": "my_value"}, }, ) - data = resp.get_json() - assert data["port"] == 9999 - assert 9999 in guard_module._allocated_ports - + assert resp.status_code == 200 -# ============================================================================= -# TestForkErrorHandling -# ============================================================================= + call_kwargs = mock_run.call_args + child_env = call_kwargs[1]["env"] + assert child_env["MY_VAR"] == "my_value" class TestForkErrorHandling: - """POST /fork error cases — missing fields and validation.""" - def test_fork_missing_role(self, client): - resp = client.post("/fork", json={"worker_index": 0, "command": "some.module"}) + resp = client.post( + "/fork", + json={"worker_index": 0, "raw_cmd": ["echo"]}, + ) assert resp.status_code == 400 assert "role" in resp.get_json()["error"].lower() def test_fork_missing_worker_index(self, client): - resp = client.post("/fork", json={"role": "test", "command": "some.module"}) + resp = client.post( + "/fork", + json={"role": "test", "raw_cmd": ["echo"]}, + ) assert resp.status_code == 400 assert "worker_index" in resp.get_json()["error"].lower() - def test_fork_missing_command_and_raw_cmd(self, client): + def test_fork_missing_raw_cmd(self, client): resp = client.post("/fork", json={"role": "test", "worker_index": 0}) assert resp.status_code == 400 - assert "command" in resp.get_json()["error"].lower() + assert "raw_cmd" in resp.get_json()["error"].lower() def test_fork_no_json_body(self, client): resp = client.post("/fork", data="not json", content_type="text/plain") assert resp.status_code == 400 - def test_fork_invalid_count(self, client): - """Verify /alloc_ports validates count type.""" + def test_alloc_ports_invalid_count(self, client): resp = client.post("/alloc_ports", json={"count": 1.5}) assert resp.status_code == 400 -# ============================================================================= -# TestKillForkedWorker -# ============================================================================= - - class TestKillForkedWorker: - """POST /kill_forked_worker kills correct child.""" - - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_kill_known_worker(self, mock_kill, client): - """Kill a tracked worker — removes from tracking, calls kill_process_tree.""" mock_proc = _make_mock_process(pid=123) - guard_module._forked_children.append(mock_proc) - guard_module._forked_children_map[("test", 0)] = mock_proc + guard_module._state.forked_children.append(mock_proc) + guard_module._state.forked_children_map[("test", 0)] = mock_proc resp = client.post( - "/kill_forked_worker", json={"role": "test", "worker_index": 0} + "/kill_forked_worker", + json={"role": "test", "worker_index": 0}, ) assert resp.status_code == 200 data = resp.get_json() assert data["status"] == "success" assert "123" in data["message"] - # Removed from tracking - assert mock_proc not in guard_module._forked_children - assert ("test", 0) not in guard_module._forked_children_map + assert mock_proc not in guard_module._state.forked_children + assert ("test", 0) not in guard_module._state.forked_children_map - # kill_process_tree called mock_kill.assert_called_once_with(123, timeout=3, graceful=True) def test_kill_unknown_worker_returns_404(self, client): - """Killing a non-existent worker returns 404.""" resp = client.post( - "/kill_forked_worker", json={"role": "ghost", "worker_index": 99} + "/kill_forked_worker", + json={"role": "ghost", "worker_index": 99}, ) assert resp.status_code == 404 assert "not found" in resp.get_json()["error"].lower() - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_kill_already_exited_worker(self, mock_kill, client): - """Worker that already exited (poll() != None) — no kill needed.""" mock_proc = _make_mock_process(pid=456, running=False) - guard_module._forked_children.append(mock_proc) - guard_module._forked_children_map[("done", 0)] = mock_proc + guard_module._state.forked_children.append(mock_proc) + guard_module._state.forked_children_map[("done", 0)] = mock_proc resp = client.post( - "/kill_forked_worker", json={"role": "done", "worker_index": 0} + "/kill_forked_worker", + json={"role": "done", "worker_index": 0}, ) assert resp.status_code == 200 - # kill_process_tree NOT called because poll() returned non-None mock_kill.assert_not_called() def test_kill_missing_role(self, client): @@ -433,38 +273,35 @@ def test_kill_missing_worker_index(self, client): assert resp.status_code == 400 assert "worker_index" in resp.get_json()["error"].lower() - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_kill_then_kill_again_returns_404(self, mock_kill, client): - """Killing the same worker twice — second attempt gets 404.""" mock_proc = _make_mock_process(pid=789) - guard_module._forked_children.append(mock_proc) - guard_module._forked_children_map[("test", 0)] = mock_proc + guard_module._state.forked_children.append(mock_proc) + guard_module._state.forked_children_map[("test", 0)] = mock_proc resp1 = client.post( - "/kill_forked_worker", json={"role": "test", "worker_index": 0} + "/kill_forked_worker", + json={"role": "test", "worker_index": 0}, ) assert resp1.status_code == 200 resp2 = client.post( - "/kill_forked_worker", json={"role": "test", "worker_index": 0} + "/kill_forked_worker", + json={"role": "test", "worker_index": 0}, ) assert resp2.status_code == 404 -# ============================================================================= -# TestCleanup -# ============================================================================= - - class TestCleanup: - """cleanup_forked_children() kills all tracked children.""" - - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_cleanup_kills_all_running_children(self, mock_kill): proc1 = _make_mock_process(pid=100) proc2 = _make_mock_process(pid=200) - guard_module._forked_children = [proc1, proc2] - guard_module._forked_children_map = {("a", 0): proc1, ("b", 0): proc2} + guard_module._state.forked_children = [proc1, proc2] + guard_module._state.forked_children_map = { + ("a", 0): proc1, + ("b", 0): proc2, + } cleanup_forked_children() @@ -472,47 +309,45 @@ def test_cleanup_kills_all_running_children(self, mock_kill): pids_killed = {call.args[0] for call in mock_kill.call_args_list} assert pids_killed == {100, 200} - # Tracking cleared - assert guard_module._forked_children == [] - assert guard_module._forked_children_map == {} + assert guard_module._state.forked_children == [] + assert guard_module._state.forked_children_map == {} - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_cleanup_skips_already_exited(self, mock_kill): - """Already-exited children (poll() != None) are not killed.""" running = _make_mock_process(pid=100, running=True) exited = _make_mock_process(pid=200, running=False) - guard_module._forked_children = [running, exited] - guard_module._forked_children_map = {("a", 0): running, ("b", 0): exited} + guard_module._state.forked_children = [running, exited] + guard_module._state.forked_children_map = { + ("a", 0): running, + ("b", 0): exited, + } cleanup_forked_children() - # Only the running child gets killed mock_kill.assert_called_once_with(100, timeout=3, graceful=True) - # Tracking still fully cleared - assert guard_module._forked_children == [] - assert guard_module._forked_children_map == {} + assert guard_module._state.forked_children == [] + assert guard_module._state.forked_children_map == {} - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_cleanup_no_children_is_noop(self, mock_kill): - """Cleanup with no children does nothing.""" cleanup_forked_children() mock_kill.assert_not_called() - @patch("areal.experimental.inference_service.guard.app.kill_process_tree") + @patch(f"{GUARD_APP}.kill_process_tree") def test_cleanup_tolerates_kill_exception(self, mock_kill): - """If kill_process_tree raises, other children are still cleaned.""" proc1 = _make_mock_process(pid=100) proc2 = _make_mock_process(pid=200) - guard_module._forked_children = [proc1, proc2] - guard_module._forked_children_map = {("a", 0): proc1, ("b", 0): proc2} + guard_module._state.forked_children = [proc1, proc2] + guard_module._state.forked_children_map = { + ("a", 0): proc1, + ("b", 0): proc2, + } mock_kill.side_effect = [OSError("boom"), None] cleanup_forked_children() - # Both attempted despite first raising assert mock_kill.call_count == 2 - # Tracking cleared even on error - assert guard_module._forked_children == [] - assert guard_module._forked_children_map == {} + assert guard_module._state.forked_children == [] + assert guard_module._state.forked_children_map == {} diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index fc04c929e5..3ec4f66032 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -1,5 +1,6 @@ import asyncio import os +import sys import time from unittest.mock import AsyncMock, Mock, call, patch @@ -2054,10 +2055,35 @@ def test_fork_endpoint_spawns_new_process(self, rpc_server_process): """Should spawn a new RPC server process when /fork is called.""" _, host, port = rpc_server_process - # Call /fork endpoint + alloc_resp = requests.post( + f"http://{host}:{port}/alloc_ports", + json={"count": 1}, + timeout=10, + ) + assert alloc_resp.status_code == 200 + child_port = alloc_resp.json()["ports"][0] + + raw_cmd = [ + sys.executable, + "-m", + "areal.infra.rpc.rpc_server", + "--host", + "0.0.0.0", + "--port", + str(child_port), + "--experiment-name", + "test_fork_exp", + "--trial-name", + "test_fork_trial", + "--role", + "ref", + "--worker-index", + "0", + ] + response = requests.post( f"http://{host}:{port}/fork", - json={"role": "ref", "worker_index": 0}, + json={"role": "ref", "worker_index": 0, "raw_cmd": raw_cmd}, timeout=60, ) @@ -2065,40 +2091,85 @@ def test_fork_endpoint_spawns_new_process(self, rpc_server_process): result = response.json() assert result["status"] == "success" assert "host" in result - assert "port" in result assert "pid" in result forked_pid = result["pid"] - forked_port = result["port"] # Verify new process exists assert psutil.pid_exists(forked_pid) - # Verify forked server is responsive - forked_response = requests.get( - f"http://{result['host']}:{forked_port}/health", timeout=5 - ) - assert forked_response.status_code == 200 + deadline = time.time() + 30 + while time.time() < deadline: + try: + forked_response = requests.get( + f"http://{result['host']}:{child_port}/health", timeout=2 + ) + if forked_response.status_code == 200: + break + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ): + pass + time.sleep(0.5) + else: + pytest.fail("Forked worker did not become ready") def test_forked_worker_inherits_environment(self, rpc_server_process): """Forked worker should inherit environment variables from parent.""" _, host, port = rpc_server_process - # Call /fork endpoint + alloc_resp = requests.post( + f"http://{host}:{port}/alloc_ports", + json={"count": 1}, + timeout=10, + ) + assert alloc_resp.status_code == 200 + child_port = alloc_resp.json()["ports"][0] + + raw_cmd = [ + sys.executable, + "-m", + "areal.infra.rpc.rpc_server", + "--host", + "0.0.0.0", + "--port", + str(child_port), + "--experiment-name", + "test_fork_exp", + "--trial-name", + "test_fork_trial", + "--role", + "ref", + "--worker-index", + "0", + ] + response = requests.post( f"http://{host}:{port}/fork", - json={"role": "ref", "worker_index": 0}, + json={"role": "ref", "worker_index": 0, "raw_cmd": raw_cmd}, timeout=60, ) assert response.status_code == 200 result = response.json() - # Verify forked server is alive and accessible - forked_response = requests.get( - f"http://{result['host']}:{result['port']}/health", timeout=5 - ) - assert forked_response.status_code == 200 + deadline = time.time() + 30 + while time.time() < deadline: + try: + forked_response = requests.get( + f"http://{result['host']}:{child_port}/health", timeout=2 + ) + if forked_response.status_code == 200: + break + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ): + pass + time.sleep(0.5) + else: + pytest.fail("Forked worker did not become ready") def test_create_forked_workers_via_scheduler(self, tmp_path): """LocalScheduler should create forked workers through /fork endpoint.""" diff --git a/tests/test_rtensor.py b/tests/test_rtensor.py index 81f0149e75..423c815e9e 100644 --- a/tests/test_rtensor.py +++ b/tests/test_rtensor.py @@ -819,6 +819,244 @@ def test_remotize_trims_padding_from_attention_mask(self, rpc_server): assert torch.equal(localized["attention_mask"], expected_mask) +class TestFetchBuffer: + """Test client-side fetch buffer for RTensor caching. + + The fetch buffer avoids redundant network fetches when the same + rollout_batch is sent to multiple engine calls across RPC boundaries. + """ + + def setup_method(self): + """Clear fetch buffer before each test.""" + from areal.infra.rpc.rtensor import _fetch_buffer, _fetch_buffer_lock + + with _fetch_buffer_lock: + _fetch_buffer.clear() + + def test_to_local_populates_buffer(self, rpc_server): + """to_local() should populate the fetch buffer on first access.""" + from areal.infra.rpc.rtensor import _fetch_buffer + + tensor = torch.randn(3, 5).cpu() + shard_id = str(uuid.uuid4()) + + serialized = serialize_value(tensor) + requests.put( + f"http://{rpc_server}/data/{shard_id}", + data=orjson.dumps(serialized), + ) + + rtensor = RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(3, 5, device="meta"), + ) + + result = rtensor.to_local() + assert torch.allclose(result, tensor) + assert shard_id in _fetch_buffer + + def test_to_local_serves_from_buffer(self, rpc_server): + """Second to_local() with a fresh RTensor (same shard_id) should + hit the buffer without making a network request.""" + tensor = torch.randn(4, 6).cpu() + shard_id = str(uuid.uuid4()) + + serialized = serialize_value(tensor) + requests.put( + f"http://{rpc_server}/data/{shard_id}", + data=orjson.dumps(serialized), + ) + + # First access: populates buffer + rt1 = RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(4, 6, device="meta"), + ) + result1 = rt1.to_local() + + # Delete shard from server so a real fetch would fail + requests.delete( + f"http://{rpc_server}/data/clear", + json={"shard_ids": [shard_id]}, + ) + + # Second access with a new RTensor object (simulates RPC boundary) + rt2 = RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(4, 6, device="meta"), + ) + result2 = rt2.to_local() + assert torch.allclose(result1, result2) + + def test_localize_populates_buffer(self, rpc_server): + """localize() should populate the fetch buffer for all fetched shards.""" + from areal.infra.rpc.rtensor import _fetch_buffer + + tensor1 = torch.randn(2, 3).cpu() + tensor2 = torch.randn(4, 5).cpu() + shard_id1 = str(uuid.uuid4()) + shard_id2 = str(uuid.uuid4()) + + for sid, t in [(shard_id1, tensor1), (shard_id2, tensor2)]: + serialized = serialize_value(t) + requests.put( + f"http://{rpc_server}/data/{sid}", + data=orjson.dumps(serialized), + ) + + nested = { + "a": RTensor( + shard=TensorShardInfo(shard_id=shard_id1, node_addr=rpc_server), + data=torch.empty(2, 3, device="meta"), + ), + "b": RTensor( + shard=TensorShardInfo(shard_id=shard_id2, node_addr=rpc_server), + data=torch.empty(4, 5, device="meta"), + ), + } + + localized = RTensor.localize(nested) + assert torch.allclose(localized["a"], tensor1) + assert torch.allclose(localized["b"], tensor2) + assert shard_id1 in _fetch_buffer + assert shard_id2 in _fetch_buffer + + def test_localize_serves_from_buffer(self, rpc_server): + """Second localize() with fresh meta RTensors (same shard_ids) should + resolve entirely from the buffer.""" + tensor = torch.randn(3, 4).cpu() + shard_id = str(uuid.uuid4()) + + serialized = serialize_value(tensor) + requests.put( + f"http://{rpc_server}/data/{shard_id}", + data=orjson.dumps(serialized), + ) + + def _make_rtensor(): + return RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(3, 4, device="meta"), + ) + + # First localize: populates buffer + result1 = RTensor.localize({"x": _make_rtensor()}) + + # Remove from server + requests.delete( + f"http://{rpc_server}/data/clear", + json={"shard_ids": [shard_id]}, + ) + + # Second localize with fresh meta RTensor: should hit buffer + result2 = RTensor.localize({"x": _make_rtensor()}) + assert torch.allclose(result1["x"], result2["x"]) + + def test_localize_partial_buffer_hit(self, rpc_server): + """When some shards are in the buffer and others are not, only the + misses should be fetched from the backend.""" + from areal.infra.rpc.rtensor import _fetch_buffer + + tensor_a = torch.randn(2, 3).cpu() + tensor_b = torch.randn(4, 5).cpu() + shard_a = str(uuid.uuid4()) + shard_b = str(uuid.uuid4()) + + for sid, t in [(shard_a, tensor_a), (shard_b, tensor_b)]: + serialized = serialize_value(t) + requests.put( + f"http://{rpc_server}/data/{sid}", + data=orjson.dumps(serialized), + ) + + # Warm buffer with shard_a only + rt_a = RTensor( + shard=TensorShardInfo(shard_id=shard_a, node_addr=rpc_server), + data=torch.empty(2, 3, device="meta"), + ) + RTensor.localize(rt_a) + assert shard_a in _fetch_buffer + assert shard_b not in _fetch_buffer + + # Delete shard_a from server; shard_b remains + requests.delete( + f"http://{rpc_server}/data/clear", + json={"shard_ids": [shard_a]}, + ) + + # Localize both: shard_a from buffer, shard_b from backend + nested = { + "a": RTensor( + shard=TensorShardInfo(shard_id=shard_a, node_addr=rpc_server), + data=torch.empty(2, 3, device="meta"), + ), + "b": RTensor( + shard=TensorShardInfo(shard_id=shard_b, node_addr=rpc_server), + data=torch.empty(4, 5, device="meta"), + ), + } + result = RTensor.localize(nested) + assert torch.allclose(result["a"], tensor_a) + assert torch.allclose(result["b"], tensor_b) + + def test_clear_node_evicts_from_buffer(self, rpc_server): + """clear_node() should remove entries from the fetch buffer.""" + from areal.infra.rpc.rtensor import _fetch_buffer + + tensor = torch.randn(2, 3).cpu() + shard_id = str(uuid.uuid4()) + + serialized = serialize_value(tensor) + requests.put( + f"http://{rpc_server}/data/{shard_id}", + data=orjson.dumps(serialized), + ) + + # Populate buffer + rt = RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(2, 3, device="meta"), + ) + rt.to_local() + assert shard_id in _fetch_buffer + + # clear_node evicts from buffer + asyncio.run(RTensor.clear_node(rpc_server, [shard_id])) + assert shard_id not in _fetch_buffer + + def test_buffer_thread_safety(self, rpc_server): + """Concurrent to_local() calls with the same shard_id should not crash.""" + import threading + + tensor = torch.randn(5, 8).cpu() + shard_id = str(uuid.uuid4()) + + serialized = serialize_value(tensor) + requests.put( + f"http://{rpc_server}/data/{shard_id}", + data=orjson.dumps(serialized), + ) + + results = [None] * 10 + + def fetch_shard(idx): + rt = RTensor( + shard=TensorShardInfo(shard_id=shard_id, node_addr=rpc_server), + data=torch.empty(5, 8, device="meta"), + ) + results[idx] = rt.to_local() + + threads = [threading.Thread(target=fetch_shard, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + for result in results: + assert result is not None + assert torch.allclose(result, tensor) + + class TestTensorShardInfoDocumentation: """Tests verifying TensorShardInfo construction and field semantics.""" diff --git a/tests/test_trackio_backend.py b/tests/test_trackio_backend.py new file mode 100644 index 0000000000..1460ec199e --- /dev/null +++ b/tests/test_trackio_backend.py @@ -0,0 +1,195 @@ +"""Tests for Trackio experiment tracking backend integration.""" + +from dataclasses import fields +from unittest.mock import MagicMock, patch + +from areal.api.cli_args import ( + StatsLoggerConfig, + TrackioConfig, +) + + +class TestTrackioConfig: + """Tests for TrackioConfig dataclass.""" + + def test_default_mode_is_disabled(self): + """TrackioConfig should default to disabled mode.""" + config = TrackioConfig() + assert config.mode == "disabled" + + def test_default_optional_fields_are_none(self): + """Optional fields should default to None.""" + config = TrackioConfig() + assert config.project is None + assert config.name is None + assert config.space_id is None + + def test_custom_values(self): + """TrackioConfig should accept custom values.""" + config = TrackioConfig( + mode="online", + project="my-project", + name="my-run", + space_id="user/my-space", + ) + assert config.mode == "online" + assert config.project == "my-project" + assert config.name == "my-run" + assert config.space_id == "user/my-space" + + def test_invalid_mode_raises_error(self): + """TrackioConfig should reject invalid mode values.""" + import pytest + + with pytest.raises(ValueError, match="Invalid trackio mode"): + TrackioConfig(mode="invalid") + + def test_all_valid_modes_accepted(self): + """TrackioConfig should accept all valid mode values.""" + for mode in ("disabled", "online", "local"): + config = TrackioConfig(mode=mode) + assert config.mode == mode + + +class TestStatsLoggerConfigTrackio: + """Tests for Trackio field in StatsLoggerConfig.""" + + def test_trackio_field_exists(self): + """StatsLoggerConfig should have a trackio field.""" + field_names = [f.name for f in fields(StatsLoggerConfig)] + assert "trackio" in field_names + + def test_trackio_field_default_is_disabled(self): + """StatsLoggerConfig.trackio should default to disabled TrackioConfig.""" + config = StatsLoggerConfig( + experiment_name="test_exp", + trial_name="trial_0", + fileroot="/tmp/test", + ) + assert isinstance(config.trackio, TrackioConfig) + assert config.trackio.mode == "disabled" + + +def _make_test_config(trackio_config=None): + """Create a minimal BaseExperimentConfig for testing StatsLogger.""" + from areal.api.cli_args import BaseExperimentConfig + + config = BaseExperimentConfig( + experiment_name="test_exp", + trial_name="trial_0", + total_train_epochs=1, + ) + config.stats_logger.experiment_name = "test_exp" + config.stats_logger.trial_name = "trial_0" + config.stats_logger.fileroot = "/tmp/test" + if trackio_config is not None: + config.stats_logger.trackio = trackio_config + return config + + +def _make_ft_spec(): + """Create a mock FinetuneSpec for testing.""" + from areal.api import FinetuneSpec + + ft_spec = MagicMock(spec=FinetuneSpec) + ft_spec.total_train_epochs = 1 + ft_spec.steps_per_epoch = 10 + ft_spec.total_train_steps = 10 + return ft_spec + + +class TestStatsLoggerTrackioIntegration: + """Tests for Trackio integration in StatsLogger (mocked).""" + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_init_called_when_enabled( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """trackio.init() should be called when mode is not disabled.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.init.assert_called_once() + assert logger._trackio_enabled is True + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_not_init_when_disabled( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """trackio.init() should NOT be called when mode is disabled.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config() # trackio defaults to disabled + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.init.assert_not_called() + assert logger._trackio_enabled is False + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_log_called_on_commit( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """trackio.log() should be called during commit when enabled.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + data = {"loss/avg": 0.5, "reward/avg": 1.0} + logger.commit(epoch=0, step=0, global_step=0, data=data) + mock_trackio.log.assert_called_once_with(data, step=0) + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_finish_called_on_close( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """trackio.finish() should be called during close when enabled.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.finish.reset_mock() + + logger.close() + mock_trackio.finish.assert_called_once() + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_not_logged_when_disabled( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """trackio.log() should NOT be called during commit when disabled.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config() # trackio defaults to disabled + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + data = {"loss/avg": 0.5} + logger.commit(epoch=0, step=0, global_step=0, data=data) + mock_trackio.log.assert_not_called()