diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..f987027ff7 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,29 @@ +# Default: catch-all +* @garrett4wade + +# Core package +/areal/api/ @garrett4wade +/areal/engine/ @rchardx +/areal/experimental/inference_service @nuzant +/areal/infra/ @garrett4wade +/areal/models/ @rchardx +/areal/trainer/ @garrett4wade + +# Tests & Examples +/tests/ @garrett4wade @rchardx @nuzant +/examples/ @garrett4wade + +# Documentation +/docs/ @garrett4wade + +# CI/CD & infrastructure +/.github/ @garrett4wade @nuzant +/Dockerfile @garrett4wade @fishcrap +pyproject.toml @garrett4wade @fishcrap +pyproject.vllm.toml @garrett4wade @fishcrap +uv.lock @garrett4wade @fishcrap +uv.vllm.lock @garrett4wade @fishcrap + +# Governance & community +GOVERNANCE.md @garrett4wade +CONTRIBUTING.md @garrett4wade diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..9277946f55 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,155 @@ +# Contributor Covenant 3.0 Code of Conduct + +## Our Pledge + +We pledge to make our community welcoming, safe, and equitable for all. + +We are committed to fostering an environment that respects and promotes the dignity, +rights, and contributions of all individuals, regardless of characteristics including +race, ethnicity, caste, color, age, physical characteristics, neurodiversity, +disability, sex or gender, gender identity or expression, sexual orientation, language, +philosophy or religion, national or social origin, socio-economic position, level of +education, or other status. The same privileges of participation are extended to +everyone who participates in good faith and in accordance with this Covenant. + +## Encouraged Behaviors + +While acknowledging differences in social norms, we all strive to meet our community's +expectations for positive behavior. We also understand that our words and actions may be +interpreted differently than we intend based on culture, background, or native language. + +With these considerations in mind, we agree to behave mindfully toward each other and +act in ways that center our shared values, including: + +1. Respecting the **purpose of our community**, our activities, and our ways of + gathering. +1. Engaging **kindly and honestly** with others. +1. Respecting **different viewpoints** and experiences. +1. **Taking responsibility** for our actions and contributions. +1. Gracefully giving and accepting **constructive feedback**. +1. Committing to **repairing harm** when it occurs. +1. Behaving in other ways that promote and sustain the **well-being of our community**. + +## Restricted Behaviors + +We agree to restrict the following behaviors in our community. Instances, threats, and +promotion of these behaviors are violations of this Code of Conduct. + +1. **Harassment.** Violating explicitly expressed boundaries or engaging in unnecessary + personal attention after any clear request to stop. +1. **Character attacks.** Making insulting, demeaning, or pejorative comments directed + at a community member or group of people. +1. **Stereotyping or discrimination.** Characterizing anyone’s personality or behavior + on the basis of immutable identities or traits. +1. **Sexualization.** Behaving in a way that would generally be considered + inappropriately intimate in the context or purpose of the community. +1. **Violating confidentiality**. Sharing or acting on someone's personal or private + information without their permission. +1. **Endangerment.** Causing, encouraging, or threatening violence or other harm toward + any person or group. +1. Behaving in other ways that **threaten the well-being** of our community. + +### Other Restrictions + +1. **Misleading identity.** Impersonating someone else for any reason, or pretending to + be someone else to evade enforcement actions. +1. **Failing to credit sources.** Not properly crediting the sources of content you + contribute. +1. **Promotional materials**. Sharing marketing or other commercial content in a way + that is outside the norms of the community. +1. **Irresponsible communication.** Failing to responsibly present content which + includes, links or describes any other restricted behaviors. + +## Reporting an Issue + +Tensions can occur between community members even when they are trying their best to +collaborate. Not every conflict represents a code of conduct violation, and this Code of +Conduct reinforces encouraged behaviors and norms that can help avoid conflicts and +minimize harm. + +When an incident does occur, it is important to report it promptly. To report a possible +violation, **send an email to the project maintainer fuwth17@gmail.com**. + +Community Moderators take reports of violations seriously and will make every effort to +respond in a timely manner. They will investigate all reports of code of conduct +violations, reviewing messages, logs, and recordings, or interviewing witnesses and +other participants. Community Moderators will keep investigation and enforcement actions +as transparent as possible while prioritizing safety and confidentiality. In order to +honor these values, enforcement actions are carried out in private with the involved +parties, but communicating to the whole community may be part of a mutually agreed upon +resolution. + +## Addressing and Repairing Harm + +If an investigation by the Community Moderators finds that this Code of Conduct has been +violated, the following enforcement ladder may be used to determine how best to repair +harm, based on the incident's impact on the individuals involved and the community as a +whole. Depending on the severity of a violation, lower rungs on the ladder may be +skipped. + +1. Warning + 1. Event: A violation involving a single incident or series of incidents. + 1. Consequence: A private, written warning from the Community Moderators. + 1. Repair: Examples of repair include a private written apology, acknowledgement of + responsibility, and seeking clarification on expectations. +1. Temporarily Limited Activities + 1. Event: A repeated incidence of a violation that previously resulted in a warning, + or the first incidence of a more serious violation. + 1. Consequence: A private, written warning with a time-limited cooldown period + designed to underscore the seriousness of the situation and give the community + members involved time to process the incident. The cooldown period may be limited + to particular communication channels or interactions with particular community + members. + 1. Repair: Examples of repair may include making an apology, using the cooldown + period to reflect on actions and impact, and being thoughtful about re-entering + community spaces after the period is over. +1. Temporary Suspension + 1. Event: A pattern of repeated violation which the Community Moderators have tried + to address with warnings, or a single serious violation. + 1. Consequence: A private written warning with conditions for return from suspension. + In general, temporary suspensions give the person being suspended time to reflect + upon their behavior and possible corrective actions. + 1. Repair: Examples of repair include respecting the spirit of the suspension, + meeting the specified conditions for return, and being thoughtful about how to + reintegrate with the community when the suspension is lifted. +1. Permanent Ban + 1. Event: A pattern of repeated code of conduct violations that other steps on the + ladder have failed to resolve, or a violation so serious that the Community + Moderators determine there is no way to keep the community safe with this person + as a member. + 1. Consequence: Access to all community spaces, tools, and communication channels is + removed. In general, permanent bans should be rarely used, should have strong + reasoning behind them, and should only be resorted to if working through other + remedies has failed to change the behavior. + 1. Repair: There is no possible repair in cases of this severity. + +This enforcement ladder is intended as a guideline. It does not limit the ability of +Community Managers to use their discretion and judgment, in keeping with the best +interests of our community. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public or other spaces. Examples +of representing our community include using an official email address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 3.0, permanently +available at +[https://www.contributor-covenant.org/version/3/0/](https://www.contributor-covenant.org/version/3/0/). + +Contributor Covenant is stewarded by the Organization for Ethical Source and licensed +under CC BY-SA 4.0. To view a copy of this license, visit +[https://creativecommons.org/licenses/by-sa/4.0/](https://creativecommons.org/licenses/by-sa/4.0/) + +For answers to common questions about Contributor Covenant, see the FAQ at +[https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). +Translations are provided at +[https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations). +Additional enforcement and community guideline resources can be found at +[https://www.contributor-covenant.org/resources](https://www.contributor-covenant.org/resources). +The enforcement ladder was inspired by the work of +[Mozilla’s code of conduct team](https://github.com/mozilla/inclusion). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1e0108e84a..9a892786f7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,13 +1,15 @@ # Contributing to AReaL Thank you for your interest in contributing to AReaL! We welcome contributions from -everyone, whether you're fixing bugs, improving documentation, adding new features, or +everyone, whether you're fixing bugs, improving documentations, adding new features, or helping with code reviews. This guide will help you get started. +Please review our [Code of Conduct](CODE_OF_CONDUCT.md) before participating and our +[Governance](GOVERNANCE.md) document to understand how the project is managed. + ## Table of Contents - [Quick Start](#quick-start) -- [Ways to Contribute](#ways-to-contribute) - [Tips for Using AI-Assisted Coding](#tips-for-using-ai-assisted-coding) - [CI/CD](#cicd) @@ -33,7 +35,7 @@ helping with code reviews. This guide will help you get started. # Install hooks (includes formatting, linting, and commit message checks) pre-commit install --install-hooks # Subsequent commits will automatically check your files and commit messages: - git commit -a -m 'feat: my change' + git commit -a -m 'feat(engine): my change' ``` 1. **Find an Issue:** @@ -85,50 +87,8 @@ helping with code reviews. This guide will help you get started. 1. **Submit a Pull Request** -We suggest applying our provided claude command `/create-pr` whenever possible. - -## Ways to Contribute - -### 🐛 Bug Reports - -Found a bug? Please create a -[bug report](https://github.com/inclusionAI/AReaL/issues/new?template=bug.md) with: - -- A clear description of the issue -- Steps to reproduce -- Expected vs. actual behavior -- Environment details (commit ID, hardware, software) -- Full logs when possible - -### ✨ Feature Requests - -Have an idea? Submit a -[feature request](https://github.com/inclusionAI/AReaL/issues/new?template=feature.md) -with: - -- Background and use case -- Proposed solution or implementation approach -- Expected benefits to the community - -### 📚 Documentation - -Documentation improvements are always welcome: - -- Fix typos or clarify existing docs -- Add examples or tutorials -- Improve API documentation -- Write blog posts or guides - -### 💻 Code Contributions - -We accept various types of code contributions: - -- Bug fixes -- New features -- Performance improvements -- Algorithm implementations -- Test coverage improvements -- Code refactoring +We suggest applying our provided agent harness command `/create-pr` whenever possible. +Use that in `claude`, `opencode`, or any other coding agent CLI. **IMPORTANT**: For new features and code refactoring, please submit a corresponding issue or open a draft PR to discuss with the core developers before making any code @@ -205,8 +165,6 @@ def test_some_multi_gpu_functionality(): ### Image Building -> **NOTE:** The image building CI workflow is experimental and subject to change. - The image building workflow can be triggered manually from any branch by users with write permissions to the repository. diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 0000000000..191e6f87e9 --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,65 @@ +# AReaL Project Governance + +This document describes how the AReaL project is governed. + +## Roles + +### Contributors + +Anyone who files issues, submits pull requests, or participates in discussions is +considered a contributor. All contributors are expected to follow the +[Code of Conduct](CODE_OF_CONDUCT.md). + +### Maintainers + +Maintainers have write access to the repository and are responsible for reviewing and +merging pull requests, triaging issues, and guiding the technical direction of the +project. + +| Name | Organization | GitHub | +| ------------ | ------------------------- | ------------- | +| Wei Fu | IIIS, Tsinghua University | @garrett4wade | +| Wentai Zhang | AReaL Team, Ant Group | @rchardx | +| Zhiyu Mei | AReaL Team, Ant Group | @nuzant | +| Xujie Shen | AReaL Team, Ant Group | @fishcrap | +| Tongkai Yang | AReaL Team, Ant Group | @fredy12 | + +### Lead Maintainer (BDFL) + +Wei Fu ([@garrett4wade](https://github.com/garrett4wade)) serves as the lead maintainer. +The lead maintainer has final authority on technical decisions when maintainers cannot +reach consensus. + +### Community Moderators + +The [Code of Conduct](CODE_OF_CONDUCT.md) refers to "Community Moderators" as the +individuals responsible for enforcement. In this project, community moderators are the +current maintainers listed above. + +## Decision-Making + +Decisions are made by consensus among maintainers whenever possible. When consensus +cannot be reached, the lead maintainer makes the final decision. + +Pull request approval policy: + +- Bug fixes or minor improvements: approved by at least one maintainer. +- New features, architectural changes, or API modifications: approved by at least two + maintainers or the lead maintainer. + +## Becoming a Maintainer + +New maintainers are added through nomination by an existing maintainer, followed by +consensus approval from the current maintainers. There are no strict criteria, but +candidates are generally expected to have a track record of quality contributions and +constructive participation in the project. + +## Code of Conduct + +All participants are expected to follow the [Code of Conduct](CODE_OF_CONDUCT.md). +Violations can be reported to fuwth17@gmail.com. + +## Amendments + +Changes to this governance document require consensus among maintainers. If consensus +cannot be reached, the lead maintainer decides. diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 03dd7a1b7e..6716091d6f 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1135,6 +1135,33 @@ class TrainEngineConfig: "e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required." }, ) + + # v2 controller options + _version: str = field( + default="v1", + metadata={ + "help": "Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController.", + "choices": ["v1", "v2"], + }, + ) + admin_api_key: str = field( + default="areal-admin-key", + metadata={ + "help": "Admin API key used by gateway/router/data-proxy in controller v2." + }, + ) + log_level: str = field( + default="warning", + metadata={"help": "Gateway stack log level for controller v2."}, + ) + request_timeout: float = field( + default=3600.0, + metadata={"help": "Gateway request timeout in seconds for controller v2."}, + ) + setup_timeout: float = field( + default=3600.0, + metadata={"help": "Gateway setup timeout in seconds for controller v2."}, + ) scheduling_strategy: SchedulingStrategy = field( default_factory=SchedulingStrategy, metadata={ @@ -1158,6 +1185,10 @@ def __post_init__(self): "memory_efficient_load is for loading pretrained weights on CPU, " "but init_from_scratch creates a model without loading any weights." ) + if self._version not in ("v1", "v2"): + raise ValueError( + f"_version must be either 'v1' or 'v2', got '{self._version}'" + ) @dataclass diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 91c737af92..845874ba39 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1839,6 +1839,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.actor import PPOActorControllerV2 + + return PPOActorControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.actor import PPOActorController return PPOActorController(train_engine=cls, config=config, scheduler=scheduler) @@ -1862,6 +1871,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.critic import PPOCriticControllerV2 + + return PPOCriticControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.critic import PPOCriticController return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler) @@ -1884,6 +1902,15 @@ def evaluate_lm(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.sft.lm_engine import LMControllerV2 + + return LMControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.sft.lm_engine import LMController return LMController(train_engine=cls, config=config, scheduler=scheduler) @@ -1913,6 +1940,11 @@ def evaluate_rw(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.rw.rw_engine import RWControllerV2 + + return RWControllerV2(train_engine=cls, config=config, scheduler=scheduler) + from areal.trainer.rw.rw_engine import RWController return RWController(train_engine=cls, config=config, scheduler=scheduler) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5c4d22b3c8..56f641525c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1820,6 +1820,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.actor import PPOActorControllerV2 + + return PPOActorControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.actor import PPOActorController return PPOActorController(train_engine=cls, config=config, scheduler=scheduler) @@ -1843,6 +1852,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.critic import PPOCriticControllerV2 + + return PPOCriticControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.critic import PPOCriticController return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler) @@ -1865,6 +1883,15 @@ def evaluate_lm(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.sft.lm_engine import LMControllerV2 + + return LMControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.sft.lm_engine import LMController return LMController(train_engine=cls, config=config, scheduler=scheduler) @@ -1894,6 +1921,11 @@ def evaluate_rw(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.rw.rw_engine import RWControllerV2 + + return RWControllerV2(train_engine=cls, config=config, scheduler=scheduler) + from areal.trainer.rw.rw_engine import RWController return RWController(train_engine=cls, config=config, scheduler=scheduler) diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 82adaed6c4..830fec6eb6 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -1359,6 +1359,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.actor import PPOActorControllerV2 + + return PPOActorControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.actor import PPOActorController return PPOActorController(train_engine=cls, config=config, scheduler=scheduler) @@ -1382,6 +1391,15 @@ def ppo_update(self, *args, **kwargs) -> None: @classmethod def as_controller(cls, config, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.ppo.critic import PPOCriticControllerV2 + + return PPOCriticControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.ppo.critic import PPOCriticController return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler) @@ -1404,6 +1422,15 @@ def evaluate_lm(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.sft.lm_engine import LMControllerV2 + + return LMControllerV2( + train_engine=cls, + config=config, + scheduler=scheduler, + ) + from areal.trainer.sft.lm_engine import LMController return LMController(train_engine=cls, config=config, scheduler=scheduler) @@ -1433,6 +1460,11 @@ def evaluate_rw(self, data): @classmethod def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.trainer.rw.rw_engine import RWControllerV2 + + return RWControllerV2(train_engine=cls, config=config, scheduler=scheduler) + from areal.trainer.rw.rw_engine import RWController return RWController(train_engine=cls, config=config, scheduler=scheduler) diff --git a/areal/experimental/training_service/__init__.py b/areal/experimental/training_service/__init__.py new file mode 100644 index 0000000000..783c4c897f --- /dev/null +++ b/areal/experimental/training_service/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Serverless training service — microservice-based training gateway. + +Workers are individual SPMD processes (one per 5D-parallel rank), each +wrapped in a synchronous HTTP server. A data proxy orchestrates a full +5D-parallel group and provides partitioned dispatch. A router maintains +API key → data proxy mappings. A gateway provides the public entry +point with authentication and forwarding. +""" diff --git a/areal/experimental/training_service/controller/__init__.py b/areal/experimental/training_service/controller/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/training_service/controller/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/training_service/controller/controller.py b/areal/experimental/training_service/controller/controller.py new file mode 100644 index 0000000000..5c27e42b69 --- /dev/null +++ b/areal/experimental/training_service/controller/controller.py @@ -0,0 +1,875 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import sys +import time +import traceback +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from areal.utils import logging +from areal.utils.network import format_hostport + +if TYPE_CHECKING: + from areal.api import ParallelStrategy, TrainEngine + from areal.api.cli_args import TrainEngineConfig + from areal.api.io_struct import FinetuneSpec + from areal.api.scheduler_api import Scheduler, Worker + +logger = logging.getLogger("GatewayTrainController") + + +class GatewayTrainController: + _GUARD_SUFFIX = "-guard" + + # TODO(agent): Controller v2 is not yet a drop-in replacement for + # TrainController on PPO/GRPO paths. Add parity for connect_engine, + # prepare_batch/rollout_batch, and update_weights (plus the matching + # gateway/data-proxy/worker endpoints), or keep RL controllers on v1. + + def __init__( + self, + train_engine: type[TrainEngine] | str, + config: TrainEngineConfig, + scheduler: Scheduler, + ) -> None: + from areal.api.alloc_mode import ModelAllocation + + self.train_engine = train_engine + self.scheduler = scheduler + self.config = config + self.train_alloc = ModelAllocation.from_str(config.backend) + self.api_key: str | None = None + self._gateway_addr: str = "" + self._router_addr: str = "" + self._model_addr: str = "" + self._worker_addrs: list[str] = [] + self._forked_services: list[tuple[str, str, int]] = [] + self._service_roles: list[str] = [] + self._role: str = "" + self._parallel_strategy = self.train_alloc.parallel + self._own_process_group = False + + # -- Initialize -------------------------------------------------------- + + def initialize( + self, role: str, ft_spec: FinetuneSpec | None = None, **kwargs: Any + ) -> None: + from areal.infra.utils.concurrent import run_async_task + + self._role = role + run_async_task(self._async_initialize, role, ft_spec, **kwargs) + logger.info( + "GatewayTrainController initialized (role=%s, api_key=%s, gateway=%s)", + role, + self.api_key, + self._gateway_addr, + ) + + async def _async_initialize( + self, + role: str, + ft_spec: FinetuneSpec | None = None, + **kwargs: Any, + ) -> None: + from dataclasses import asdict + + import httpx + + from areal.api.cli_args import SchedulingSpec + from areal.api.scheduler_api import Job + + cfg = self.config + + world_size = self.train_alloc.parallel.world_size + + try: + # ============================================================== + # Step 0: Create world_size guards via scheduler (one per GPU rank) + # ============================================================== + # Each guard is allocated a GPU by the scheduler (like TrainController + # workers). Forked workers inherit the guard's GPU environment. + if len(cfg.scheduling_spec) != 1: + raise ValueError( + "GatewayTrainController (controller v2) requires exactly " + "one scheduling_spec. Legacy 2-spec worker/engine layouts " + "are only supported by TrainController (controller v1)." + ) + + guard_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0])) + guard_spec.cmd = "python -m areal.experimental.training_service.guard" + + guard_role = f"{role}{self._GUARD_SUFFIX}" + guard_job = Job( + replicas=world_size, + tasks=[guard_spec], + scheduling_strategy=cfg.scheduling_strategy, + role=guard_role, + ) + await asyncio.to_thread(self.scheduler.create_workers, job=guard_job) + self._service_roles.append(guard_role) + guard_workers = await asyncio.to_thread( + self.scheduler.get_workers, + role=guard_role, + timeout=int(self.config.setup_timeout), + ) + logger.info("Guards ready: %s", [w.id for w in guard_workers]) + + # ============================================================== + # Step 1: Allocate master addr/port for NCCL rendezvous + # ============================================================== + guard_addr_0 = f"http://{format_hostport(guard_workers[0].ip, int(guard_workers[0].worker_ports[0]))}" + master_addr = guard_workers[0].ip + + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{guard_addr_0}/alloc_ports", json={"count": 1} + ) + resp.raise_for_status() + master_port = resp.json()["ports"][0] + + # ============================================================== + # Step 1.5: Set NCCL env on each guard so forked workers inherit it + # ============================================================== + def _guard_addr(worker: Worker) -> str: + return ( + f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" + ) + + await self._async_set_guards_env( + guard_workers, + _guard_addr, + world_size=world_size, + master_addr=master_addr, + master_port=master_port, + ) + + # ============================================================== + # Step 2: Fork one train worker per guard + # ============================================================== + async def _fork_worker(rank: int) -> str: + guard = _guard_addr(guard_workers[rank]) + worker_cmd = [ + sys.executable, + "-m", + "areal.experimental.training_service.worker", + "--log-level", + cfg.log_level, + ] + + host, port = await self._async_fork_on_guard( + guard_addr=guard, + role="train-worker", + worker_index=rank, + raw_cmd=worker_cmd, + ) + return f"http://{format_hostport(host, port)}" + + self._worker_addrs = list( + await asyncio.gather( + *[_fork_worker(rank) for rank in range(world_size)] + ) + ) + logger.info("Workers: %s", self._worker_addrs) + + # ============================================================== + # Step 3: Create engines on all workers (coordinated NCCL init) + # ============================================================== + if isinstance(self.train_engine, str): + engine_class = self.train_engine + else: + engine_class = ( + f"{self.train_engine.__module__}.{self.train_engine.__name__}" + ) + await asyncio.gather( + *[ + self._create_engine_on_worker( + worker_addr=addr, + engine_class=engine_class, + init_args=[], + init_kwargs={"config": self.config}, + ) + for addr in self._worker_addrs + ] + ) + logger.info("Engines created on all workers") + + pg_kwargs = {"parallel_strategy": self._parallel_strategy} + await asyncio.gather( + *[ + self._call_worker_engine_endpoint( + addr, + "/create_process_group", + args=[], + kwargs=pg_kwargs, + timeout=self.config.setup_timeout, + ) + for addr in self._worker_addrs + ] + ) + + await asyncio.gather( + *[ + self._call_worker_engine_endpoint( + addr, + "/initialize", + args=[], + kwargs={ + "addr": kwargs.get("addr"), + "ft_spec": ft_spec, + }, + timeout=self.config.setup_timeout, + ) + for addr in self._worker_addrs + ] + ) + logger.info("Engines initialized on all workers") + + # ============================================================== + # Step 4: Fork Router on guard 0 + # ============================================================== + router_cmd = [ + sys.executable, + "-m", + "areal.experimental.training_service.router", + "--admin-api-key", + cfg.admin_api_key, + "--log-level", + cfg.log_level, + ] + router_host, router_port = await self._async_fork_on_guard( + guard_addr=guard_addr_0, + role="router", + worker_index=0, + raw_cmd=router_cmd, + ) + self._router_addr = f"http://{format_hostport(router_host, router_port)}" + logger.info("Router: %s", self._router_addr) + + # ============================================================== + # Step 5: Fork Data Proxy on a guard + # ============================================================== + data_proxy_cmd = [ + sys.executable, + "-m", + "areal.experimental.training_service.data_proxy", + "--worker-addrs", + ",".join(self._worker_addrs), + "--admin-api-key", + cfg.admin_api_key, + "--log-level", + cfg.log_level, + ] + dp_host, dp_port = await self._async_fork_on_guard( + guard_addr=guard_addr_0, + role="data-proxy", + worker_index=0, + raw_cmd=data_proxy_cmd, + ) + self._model_addr = f"http://{format_hostport(dp_host, dp_port)}" + logger.info("Model endpoint: %s", self._model_addr) + + # ============================================================== + # Step 6: Fork Gateway on guard 0 + # ============================================================== + gw_cmd = [ + sys.executable, + "-m", + "areal.experimental.training_service.gateway", + "--admin-api-key", + cfg.admin_api_key, + "--router-addr", + self._router_addr, + "--forward-timeout", + str(cfg.request_timeout), + "--log-level", + cfg.log_level, + ] + gw_host, gw_port = await self._async_fork_on_guard( + guard_addr=guard_addr_0, + role="gateway", + worker_index=0, + raw_cmd=gw_cmd, + ) + self._gateway_addr = f"http://{format_hostport(gw_host, gw_port)}" + logger.info("Gateway: %s", self._gateway_addr) + + # ============================================================== + # Step 7: Register data proxy with API key in router + # ============================================================== + self.api_key = f"ak-{role}-{uuid4().hex[:12]}" + await self._register_in_router( + self._router_addr, self._model_addr, self.api_key + ) + logger.info("Model registered with api_key=%s", self.api_key) + except Exception: + logger.error( + "GatewayTrainController initialization failed, rolling back", + exc_info=True, + ) + self._cleanup_runtime_state() + raise + + # -- Engine creation --------------------------------------------------- + + async def _async_set_guards_env( + self, + guard_workers: list[Worker], + guard_addr_fn: Any, + *, + world_size: int, + master_addr: str, + master_port: int, + ) -> None: + import httpx + + async def _set_env(rank: int) -> None: + addr = guard_addr_fn(guard_workers[rank]) + env = { + "RANK": str(rank), + "LOCAL_RANK": "0", + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": str(master_port), + } + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post(f"{addr}/set_env", json={"env": env}) + resp.raise_for_status() + + await asyncio.gather(*[_set_env(rank) for rank in range(len(guard_workers))]) + logger.info("NCCL env set on %d guards", len(guard_workers)) + + async def _create_engine_on_worker( + self, + worker_addr: str, + engine_class: str, + init_args: list[Any], + init_kwargs: dict[str, Any], + ) -> None: + import httpx + + from areal.infra.rpc.serialization import serialize_value + + payload = { + "engine_class": engine_class, + "init_args": serialize_value(init_args), + "init_kwargs": serialize_value(init_kwargs), + } + async with httpx.AsyncClient(timeout=self.config.setup_timeout) as client: + resp = await client.post(f"{worker_addr}/create_engine", json=payload) + if resp.status_code >= 400: + raise RuntimeError( + f"Engine creation failed on {worker_addr}: {resp.text}" + ) + + async def _call_worker_engine_endpoint( + self, + worker_addr: str, + path: str, + *, + args: list[Any] | None = None, + kwargs: dict[str, Any] | None = None, + timeout: float, + ) -> Any: + import httpx + + from areal.infra.rpc.serialization import deserialize_value, serialize_value + + payload = { + "args": serialize_value(args or []), + "kwargs": serialize_value(kwargs or {}), + } + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(f"{worker_addr}{path}", json=payload) + if resp.status_code >= 400: + raise RuntimeError( + f"Worker endpoint call failed on {worker_addr}{path}: {resp.text}" + ) + data = resp.json() + return deserialize_value(data.get("result")) + + # -- Router registration ----------------------------------------------- + + async def _register_in_router( + self, router_addr: str, model_addr: str, api_key: str + ) -> None: + import httpx + + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{router_addr}/register", + json={ + "model_addr": model_addr, + "api_key": api_key, + "name": self._role, + }, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + ) + resp.raise_for_status() + + # -- Guard fork helpers ------------------------------------------------ + + def _fork_on_guard( + self, + guard_addr: str, + role: str, + worker_index: int, + raw_cmd: list[str], + env: dict[str, str] | None = None, + health_path: str = "/health", + ) -> tuple[str, int]: + import requests + + resp = requests.post(f"{guard_addr}/alloc_ports", json={"count": 1}, timeout=30) + resp.raise_for_status() + port_data = resp.json() + host = port_data["host"] + port = port_data["ports"][0] + + cmd = list(raw_cmd) + ["--host", host, "--port", str(port)] + + fork_payload: dict[str, Any] = { + "role": role, + "worker_index": worker_index, + "raw_cmd": cmd, + } + if env: + fork_payload["env"] = env + + resp = requests.post(f"{guard_addr}/fork", json=fork_payload, timeout=30) + resp.raise_for_status() + + self._forked_services.append((guard_addr, role, worker_index)) + + addr = f"http://{format_hostport(host, port)}" + self._wait_for_service(f"{addr}{health_path}", role) + + return host, port + + async def _async_fork_on_guard( + self, + guard_addr: str, + role: str, + worker_index: int, + raw_cmd: list[str], + env: dict[str, str] | None = None, + health_path: str = "/health", + ) -> tuple[str, int]: + import httpx + + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post(f"{guard_addr}/alloc_ports", json={"count": 1}) + resp.raise_for_status() + port_data = resp.json() + host = port_data["host"] + port = port_data["ports"][0] + + cmd = list(raw_cmd) + ["--host", host, "--port", str(port)] + fork_payload: dict[str, Any] = { + "role": role, + "worker_index": worker_index, + "raw_cmd": cmd, + } + if env: + fork_payload["env"] = env + + resp = await client.post(f"{guard_addr}/fork", json=fork_payload) + resp.raise_for_status() + + self._forked_services.append((guard_addr, role, worker_index)) + + addr = f"http://{format_hostport(host, port)}" + await self._async_wait_for_service(f"{addr}{health_path}", role) + + return host, port + + def _kill_forked_service( + self, guard_addr: str, role: str, worker_index: int + ) -> None: + import requests + + try: + resp = requests.post( + f"{guard_addr}/kill_forked_worker", + json={"role": role, "worker_index": worker_index}, + timeout=10, + ) + if resp.status_code == 200: + logger.info("Killed forked service %s/%d", role, worker_index) + else: + logger.warning( + "Failed to kill %s/%d: %s", role, worker_index, resp.text + ) + except Exception as exc: + logger.error("Error killing %s/%d: %s", role, worker_index, exc) + + # -- Health checks ----------------------------------------------------- + + def _wait_for_service( + self, url: str, name: str, timeout: float | None = None + ) -> None: + import requests as _requests + + timeout = timeout or self.config.setup_timeout + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = _requests.get(url, timeout=2) + if resp.status_code == 200: + logger.info("%s is ready at %s", name, url) + return + except _requests.RequestException: + pass + time.sleep(0.1) + raise TimeoutError(f"{name} not healthy at {url} within {timeout}s") + + async def _async_wait_for_service( + self, url: str, name: str, timeout: float | None = None + ) -> None: + import httpx + + timeout = timeout or self.config.setup_timeout + deadline = time.monotonic() + timeout + async with httpx.AsyncClient(timeout=2.0) as client: + while time.monotonic() < deadline: + try: + resp = await client.get(url) + if resp.status_code == 200: + logger.info("%s is ready at %s", name, url) + return + except Exception: + pass + await asyncio.sleep(0.1) + raise TimeoutError(f"{name} not healthy at {url} within {timeout}s") + + # -- Gateway HTTP helpers (duck-type TrainController interface) --------- + + def _gateway_post(self, path: str, payload: Any = None) -> Any: + import requests + + url = f"{self._gateway_addr}{path}" + resp = requests.post( + url, + json=payload, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=self.config.request_timeout, + ) + if resp.status_code >= 400: + raise RuntimeError( + f"Gateway {path} returned {resp.status_code}: {resp.text}" + ) + return resp.json() + + def _gateway_get(self, path: str) -> Any: + import requests + + url = f"{self._gateway_addr}{path}" + resp = requests.get( + url, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=self.config.request_timeout, + ) + if resp.status_code >= 400: + raise RuntimeError( + f"Gateway {path} returned {resp.status_code}: {resp.text}" + ) + return resp.json() + + def _gateway_post_result(self, path: str, payload: Any = None) -> Any: + from areal.infra.rpc.serialization import deserialize_value + + data = self._gateway_post(path, payload) + if not isinstance(data, dict) or "result" not in data: + raise RuntimeError(f"Gateway {path} response missing 'result': {data!r}") + return deserialize_value(data["result"]) + + def _gateway_get_result(self, path: str) -> Any: + from areal.infra.rpc.serialization import deserialize_value + + data = self._gateway_get(path) + if not isinstance(data, dict) or "result" not in data: + raise RuntimeError(f"Gateway {path} response missing 'result': {data!r}") + return deserialize_value(data["result"]) + + # -- TrainController duck-type interface -------------------------------- + + @staticmethod + def _require_list_batch(input_: Any, method_name: str) -> list[dict[str, Any]]: + if not isinstance(input_, list): + raise TypeError( + f"{method_name} expects `input_` as list[dict[str, Any]] for training-service dispatch; " + f"got {type(input_).__name__}." + ) + return input_ + + def train_batch( + self, + input_: list[dict[str, Any]] | None = None, + loss_fn: Any = None, + loss_weight_fn: Any = None, + ) -> Any: + from areal.infra.rpc.serialization import serialize_value + + if input_ is None: + raise TypeError("train_batch expects non-None list[dict[str, Any]] input.") + batch = self._require_list_batch(input_, "train_batch") + + payload = { + "args": serialize_value([batch]), + "kwargs": serialize_value( + {"loss_fn": loss_fn, "loss_weight_fn": loss_weight_fn} + ), + } + return self._gateway_post_result("/train_batch", payload) + + def forward_batch( + self, input_: list[dict[str, Any]] | None = None, **kwargs: Any + ) -> Any: + from areal.infra.rpc.serialization import serialize_value + + if input_ is None: + raise TypeError( + "forward_batch expects non-None list[dict[str, Any]] input." + ) + batch = self._require_list_batch(input_, "forward_batch") + + payload = { + "args": serialize_value([batch]), + "kwargs": serialize_value(kwargs), + } + return self._gateway_post_result("/forward_batch", payload) + + def eval_batch( + self, + input_: list[dict[str, Any]] | None = None, + loss_fn: Any = None, + loss_weight_fn: Any = None, + ) -> Any: + from areal.infra.rpc.serialization import serialize_value + + if input_ is None: + raise TypeError("eval_batch expects non-None list[dict[str, Any]] input.") + batch = self._require_list_batch(input_, "eval_batch") + + payload = { + "args": serialize_value([batch]), + "kwargs": serialize_value( + {"loss_fn": loss_fn, "loss_weight_fn": loss_weight_fn} + ), + } + return self._gateway_post_result("/eval_batch", payload) + + def train(self, mode: bool = True) -> GatewayTrainController: + from areal.infra.rpc.serialization import serialize_value + + self._gateway_post( + "/train", + { + "args": serialize_value([mode]), + "kwargs": serialize_value({}), + }, + ) + return self + + def eval(self) -> GatewayTrainController: + self._gateway_post("/eval") + return self + + def set_version(self, version: int) -> None: + from areal.infra.rpc.serialization import serialize_value + + self._gateway_post( + "/set_version", + { + "args": serialize_value([version]), + "kwargs": serialize_value({}), + }, + ) + + def get_version(self) -> int: + return int(self._gateway_get_result("/get_version")) + + def save(self, meta: Any) -> None: + from areal.infra.rpc.serialization import serialize_value + + self._gateway_post( + "/save", + { + "args": serialize_value([meta]), + "kwargs": serialize_value({}), + }, + ) + + def load(self, meta: Any) -> None: + from areal.infra.rpc.serialization import serialize_value + + self._gateway_post( + "/load", + { + "args": serialize_value([meta]), + "kwargs": serialize_value({}), + }, + ) + + def offload(self) -> None: + self._gateway_post("/offload") + + def onload(self) -> None: + self._gateway_post("/onload") + + def step_lr_scheduler(self) -> None: + self._gateway_post("/step_lr_scheduler") + + def optimizer_zero_grad(self) -> None: + self._gateway_post("/optimizer_zero_grad") + + def optimizer_step(self) -> Any: + return self._gateway_post_result("/optimizer_step") + + def export_stats(self) -> dict[str, Any]: + from areal.utils import stats_tracker + + stats = stats_tracker.export_all() + stats.update(self._gateway_get_result("/export_stats")) + return stats + + def get_device_stats(self) -> Any: + from areal.infra.rpc.serialization import serialize_value + + payload = { + "args": serialize_value([]), + "kwargs": serialize_value({}), + } + return self._gateway_post_result("/get_device_stats", payload) + + def config_perf_tracer(self, config: Any, role: str) -> None: + from areal.infra.rpc.serialization import serialize_value + + payload = { + "args": serialize_value([]), + "kwargs": serialize_value({"config": config, "role": role}), + } + self._gateway_post("/config_perf_tracer", payload) + + def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None: + from areal.infra.rpc.serialization import serialize_value + + payload = { + "args": serialize_value([]), + "kwargs": serialize_value({"step": step, "force": force}), + } + self._gateway_post("/save_perf_tracer", payload) + + def clear_batches(self, *targets: Any) -> None: + from areal.infra.rpc.serialization import serialize_value + + payload = { + "args": serialize_value(list(targets)), + "kwargs": serialize_value({}), + } + self._gateway_post("/clear_batches", payload) + + def current_data_parallel_head(self) -> int: + return 0 + + @property + def context_and_model_parallel_group(self): + return self.cpu_group + + @property + def parallel_strategy(self): + return self._parallel_strategy + + @property + def data_parallel_world_size(self) -> int: + return 1 + + @property + def data_parallel_rank(self) -> int: + return 0 + + # -- Properties (duck-type compat) ------------------------------------- + + @property + def cpu_group(self): + return None + + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + self._parallel_strategy = parallel_strategy + import torch.distributed as dist + + from areal.utils.network import find_free_ports + + if not dist.is_initialized(): + port = find_free_ports(1)[0] + dist.init_process_group( + backend="gloo", + init_method=f"tcp://localhost:{port}", + rank=0, + world_size=1, + ) + self._own_process_group = True + + def is_data_parallel_head(self) -> bool: + return True + + # -- Destroy ----------------------------------------------------------- + + def _cleanup_runtime_state(self) -> None: + if self._router_addr and self._model_addr: + try: + import requests + + requests.post( + f"{self._router_addr}/unregister", + json={"model_addr": self._model_addr}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=10, + ) + except Exception: + logger.error("Failed to unregister model: %s", traceback.format_exc()) + + for guard_addr, role, worker_index in reversed(self._forked_services): + try: + self._kill_forked_service(guard_addr, role, worker_index) + except Exception: + logger.error( + "Error killing %s/%d: %s", + role, + worker_index, + traceback.format_exc(), + ) + self._forked_services.clear() + + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + logger.info("Workers deleted for role: %s", role) + except Exception: + logger.error( + "Error deleting workers for %s: %s", role, traceback.format_exc() + ) + self._service_roles.clear() + self._worker_addrs.clear() + self._router_addr = "" + self._gateway_addr = "" + self._model_addr = "" + self.api_key = None + + import torch.distributed as dist + + if self._own_process_group: + try: + if dist.is_initialized(): + dist.destroy_process_group() + except Exception: + logger.error( + "Failed to destroy process group: %s", traceback.format_exc() + ) + finally: + self._own_process_group = False + + def destroy(self) -> None: + self._cleanup_runtime_state() diff --git a/areal/experimental/training_service/data_proxy/__init__.py b/areal/experimental/training_service/data_proxy/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/training_service/data_proxy/__main__.py b/areal/experimental/training_service/data_proxy/__main__.py new file mode 100644 index 0000000000..d7d869f0f1 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/__main__.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse + +import uvicorn + +from areal.experimental.training_service.data_proxy.app import create_app +from areal.experimental.training_service.data_proxy.config import TrainDataProxyConfig + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Train Data Proxy") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=9082) + parser.add_argument("--worker-addrs", required=True) + parser.add_argument("--admin-api-key", default="areal-admin-key") + parser.add_argument("--idle-timeout", type=float, default=60.0) + parser.add_argument("--warmup-timeout", type=float, default=120.0) + parser.add_argument("--request-timeout", type=float, default=600.0) + parser.add_argument( + "--log-level", + default="info", + choices=["debug", "info", "warning", "error"], + ) + args, _ = parser.parse_known_args() + + worker_addrs = [ + addr.strip() for addr in args.worker_addrs.split(",") if addr.strip() + ] + + config = TrainDataProxyConfig( + host=args.host, + port=args.port, + worker_addrs=worker_addrs, + admin_api_key=args.admin_api_key, + log_level=args.log_level, + request_timeout=args.request_timeout, + warmup_timeout=args.warmup_timeout, + ) + + app = create_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/training_service/data_proxy/app.py b/areal/experimental/training_service/data_proxy/app.py new file mode 100644 index 0000000000..ac888918fb --- /dev/null +++ b/areal/experimental/training_service/data_proxy/app.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.responses import Response as RawResponse + +from areal.experimental.training_service.data_proxy.config import TrainDataProxyConfig +from areal.experimental.training_service.data_proxy.dispatcher import Dispatcher +from areal.experimental.training_service.data_proxy.engine import register_engine_routes +from areal.experimental.training_service.data_proxy.topology import discover_topology +from areal.utils import logging + +logger = logging.getLogger("TrainDataProxy") + + +def _raw_json_response(content: bytes) -> RawResponse: + return RawResponse(content=content, media_type="application/json") + + +def create_app(config: TrainDataProxyConfig) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + logger.info( + "Train data proxy starting with %d workers", len(config.worker_addrs) + ) + topology = await discover_topology( + config.worker_addrs, + timeout=min(config.request_timeout, 30.0), + ) + dispatcher = Dispatcher( + topology=topology, request_timeout=config.request_timeout + ) + + app.state.config = config + app.state.topology = topology + app.state.dispatcher = dispatcher + yield + await dispatcher.close() + logger.info("Train data proxy shutting down") + + app = FastAPI(title="AReaL Train Data Proxy", lifespan=lifespan) + register_engine_routes(app, _raw_json_response=_raw_json_response) + + return app diff --git a/areal/experimental/training_service/data_proxy/config.py b/areal/experimental/training_service/data_proxy/config.py new file mode 100644 index 0000000000..64b2f32ea4 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/config.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class TrainDataProxyConfig: + host: str = "0.0.0.0" + port: int = 9082 + worker_addrs: list[str] = field(default_factory=list) + admin_api_key: str = "areal-admin-key" + log_level: str = "info" + request_timeout: float = 600.0 + warmup_timeout: float = 120.0 diff --git a/areal/experimental/training_service/data_proxy/dispatcher.py b/areal/experimental/training_service/data_proxy/dispatcher.py new file mode 100644 index 0000000000..87df614c42 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/dispatcher.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Partitioned HTTP dispatcher for one 5D-parallel worker group. + +Replicates the dispatch semantics of +:class:`~areal.infra.controller.train_controller.TrainController`: + +- Detect tensor vs scalar inputs. +- Partition tensor inputs across DP groups via + ``balanced_greedy_partition``. +- Fan out to all workers (DP heads receive their data slice; non-DP-head + workers receive an empty signal so they can participate in NCCL + collectives via intra-group broadcast). +- Collect results from DP heads and merge them back into the original + trajectory order. +- Pad the batch to a multiple of ``dp_size * group_size`` when not + evenly divisible (eval-padding behaviour from PR 1109). + +Usage:: + + result = await dispatcher.dispatch("/train_batch").post(body) + version = await dispatcher.dispatch("/get_version").get() + all_stats = await dispatcher.broadcast("/export_stats").get() + responses = await dispatcher.broadcast("/set_version").post(body) +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass +from typing import Any + +import aiohttp +import orjson + +from areal.experimental.training_service.data_proxy.topology import WorkerTopology +from areal.infra.controller.train_controller import ( + _dispatch_tensors, + _is_tensor_like, + _merge_tensors, + _pad_eval_batch, +) +from areal.infra.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging + +logger = logging.getLogger("TrainDataProxy") + + +@dataclass +class _WorkerResponse: + """Internal container for a validated worker HTTP response.""" + + addr: str + status: int + content: bytes + + +class Dispatcher: + """Partitioned HTTP dispatcher for one 5D-parallel worker group.""" + + def __init__( + self, + topology: WorkerTopology, + request_timeout: float = 600.0, + *, + _session: Any | None = None, + ): + self._topology = topology + self._request_timeout = request_timeout + self._session = _session or aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=request_timeout), + ) + + async def close(self) -> None: + """Close the underlying HTTP session.""" + await self._session.close() + + # ------------------------------------------------------------------ + # Fluent API + # ------------------------------------------------------------------ + + def dispatch(self, path: str, *, pad_eval_batch: bool = False) -> DispatchRequest: + """Return a dispatch builder for *path*. + + Dispatch operations route tensors via DP-aware partitioning + and return a single merged result. Non-partitionable payloads + are sent to all workers (DP heads receive the original body; + non-heads receive an empty envelope to participate in + intra-group collectives) and the first DP head's response is + returned. + """ + return DispatchRequest(self, path, pad_eval_batch=pad_eval_batch) + + def broadcast(self, path: str) -> BroadcastRequest: + """Return a broadcast builder for *path*. + + Broadcast operations send the same request to every worker + and return all responses. + """ + return BroadcastRequest(self, path) + + # ------------------------------------------------------------------ + # Internal HTTP helpers + # ------------------------------------------------------------------ + + async def _do_get(self, addr: str, path: str) -> _WorkerResponse: + async with self._session.get(f"{addr}{path}") as resp: + content = await resp.read() + return _WorkerResponse(addr=addr, status=resp.status, content=content) + + async def _do_post(self, addr: str, path: str, body: bytes) -> _WorkerResponse: + async with self._session.post( + f"{addr}{path}", + data=body, + headers={"Content-Type": "application/json"}, + ) as resp: + content = await resp.read() + return _WorkerResponse(addr=addr, status=resp.status, content=content) + + async def _gather_validated( + self, + tasks: Sequence[Awaitable[_WorkerResponse]], + addrs: list[str], + ) -> list[_WorkerResponse]: + """Run *tasks* concurrently and validate every response.""" + raw = await asyncio.gather(*tasks, return_exceptions=True) + validated: list[_WorkerResponse] = [] + for i, result in enumerate(raw): + if isinstance(result, BaseException): + raise RuntimeError(f"Worker {addrs[i]} failed: {result}") + _raise_for_worker(result) + validated.append(result) + return validated + + +class DispatchRequest: + """Builder for DP-aware dispatch operations. + + Tensor payloads are partitioned across DP groups. Non-partitionable + payloads are sent to all workers (DP heads receive the original + body; non-heads receive an empty envelope to participate in + intra-group collectives) and the first DP head's response is + returned. + + Obtain via :meth:`Dispatcher.dispatch`. + """ + + __slots__ = ("_dispatcher", "_path", "_pad_eval_batch") + + def __init__( + self, dispatcher: Dispatcher, path: str, *, pad_eval_batch: bool = False + ) -> None: + self._dispatcher = dispatcher + self._path = path + self._pad_eval_batch = pad_eval_batch + + async def get(self) -> bytes: + """GET from all DP heads, return the first response.""" + d = self._dispatcher + dp_head_addrs = [d._topology.workers[i].addr for i in d._topology.dp_heads] + tasks = [d._do_get(addr, self._path) for addr in dp_head_addrs] + responses = await d._gather_validated(tasks, dp_head_addrs) + return responses[0].content + + async def post(self, body: bytes) -> bytes: + """POST with tensor-aware dispatch. + + If the payload contains partitionable tensor batches, it is split + across DP groups. Eval endpoints can opt into padding before + dispatch so the per-shard results still merge in original + trajectory order. + + Otherwise the body is forwarded to all workers (DP heads receive + the original body; non-heads receive an empty envelope) and the + first DP head's response is returned. + """ + data = orjson.loads(body) + raw_args = deserialize_value(data.get("args", [])) + raw_kwargs = deserialize_value(data.get("kwargs", {})) + + group_size: int = 1 + if isinstance(raw_kwargs, dict): + group_size = raw_kwargs.pop("group_size", 1) + + if ( + _is_tensor_like(raw_args) or _is_tensor_like(raw_kwargs) + ) and _contains_partitionable_tensor_batch(raw_args, raw_kwargs): + return await self._tensor_dispatch( + raw_args, + raw_kwargs, + group_size, + pad_eval_batch=self._pad_eval_batch, + ) + return await self._scalar_fan_out(body) + + async def _scalar_fan_out(self, body: bytes) -> bytes: + """POST to all workers; DP heads get *body*, non-heads get empty. + + Compute routes on the worker side use ``require_broadcast=True``, + so every rank must receive an HTTP request to participate in + intra-group NCCL collectives. + """ + d = self._dispatcher + if not d._topology.dp_heads: + raise RuntimeError("No DP head available for scalar compute dispatch") + + dp_head_set = set(d._topology.dp_heads) + empty = _empty_payload() + + addrs = [w.addr for w in d._topology.workers] + tasks = [ + d._do_post(addr, self._path, body if i in dp_head_set else empty) + for i, addr in enumerate(addrs) + ] + responses = await d._gather_validated(tasks, addrs) + + first_dp_head_idx = d._topology.dp_heads[0] + return responses[first_dp_head_idx].content + + # ------------------------------------------------------------------ + # Tensor dispatch (partitioned fan-out + merge) + # ------------------------------------------------------------------ + + async def _tensor_dispatch( + self, + raw_args: list[Any], + raw_kwargs: dict[str, Any], + group_size: int, + *, + pad_eval_batch: bool, + ) -> bytes: + d = self._dispatcher + dp_size = d._topology.dp_size + + if pad_eval_batch: + args_tuple = _pad_eval_batch(tuple(raw_args), dp_size, group_size) + raw_args = list(args_tuple) + + dp_args, dp_kwargs, group_indices = self._partition_inputs( + raw_args, raw_kwargs, group_size + ) + + dp_head_results = await self._fan_out(dp_args, dp_kwargs) + + merged = _merge_tensors(dp_head_results, group_indices) + + return orjson.dumps({"status": "success", "result": serialize_value(merged)}) + + # ------------------------------------------------------------------ + # Partitioning (mirrors TrainController._partition_inputs) + # ------------------------------------------------------------------ + + def _partition_inputs( + self, + args: list[Any], + kwargs: dict[str, Any], + group_size: int, + ) -> tuple[list[list[Any]], dict[str, list[Any]], list[list[int]]]: + dp_size = self._dispatcher._topology.dp_size + group_indices: list[list[int]] | None = None + + def _split(item: Any) -> list[Any]: + nonlocal group_indices + if _is_tensor_like(item): + if group_indices is None: + splits, group_indices = _dispatch_tensors( + item, dp_size, group_size=group_size + ) + return splits + return [[item[i] for i in idxs] for idxs in group_indices] + return [item] * dp_size + + dp_args = [_split(a) for a in args] + dp_kwargs = {k: _split(v) for k, v in kwargs.items()} + + if group_indices is None: + raise RuntimeError( + "dispatch_compute called with tensor detection but no " + "tensor-like arg was found during partitioning" + ) + + return dp_args, dp_kwargs, group_indices + + # ------------------------------------------------------------------ + # Fan-out to workers (DP heads get partition, others get empty) + # ------------------------------------------------------------------ + + async def _fan_out( + self, + dp_args: list[list[Any]], + dp_kwargs: dict[str, list[Any]], + ) -> list[Any]: + d = self._dispatcher + dp_head_set = set(d._topology.dp_heads) + + payloads: list[bytes] = [] + dp_idx = 0 + for i in range(len(d._topology.workers)): + if i in dp_head_set: + worker_args = [splits[dp_idx] for splits in dp_args] + worker_kwargs = {k: splits[dp_idx] for k, splits in dp_kwargs.items()} + dp_idx += 1 + else: + payloads.append(_empty_payload()) + continue + + payloads.append( + orjson.dumps( + { + "args": serialize_value(worker_args), + "kwargs": serialize_value(worker_kwargs), + } + ) + ) + + addrs = [w.addr for w in d._topology.workers] + tasks = [ + d._do_post(addrs[i], self._path, payloads[i]) + for i in range(len(d._topology.workers)) + ] + responses = await d._gather_validated(tasks, addrs) + + dp_head_results: list[Any] = [] + for i in d._topology.dp_heads: + result_data = orjson.loads(responses[i].content) + result = deserialize_value(result_data.get("result")) + dp_head_results.append(result) + + return dp_head_results + + +class BroadcastRequest: + """Builder for broadcast operations across all workers. + + Every worker receives the same request and all responses are + returned. + + Obtain via :meth:`Dispatcher.broadcast`. + """ + + __slots__ = ("_dispatcher", "_path") + + def __init__(self, dispatcher: Dispatcher, path: str) -> None: + self._dispatcher = dispatcher + self._path = path + + async def get(self) -> list[bytes]: + """GET from every worker, return all responses.""" + d = self._dispatcher + addrs = [w.addr for w in d._topology.workers] + tasks = [d._do_get(addr, self._path) for addr in addrs] + responses = await d._gather_validated(tasks, addrs) + return [r.content for r in responses] + + async def post(self, body: bytes) -> list[bytes]: + """POST the same body to every worker, return all responses.""" + d = self._dispatcher + addrs = [w.addr for w in d._topology.workers] + tasks = [d._do_post(addr, self._path, body) for addr in addrs] + responses = await d._gather_validated(tasks, addrs) + return [r.content for r in responses] + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _empty_payload() -> bytes: + """Empty args/kwargs envelope for non-DP-head workers. + + Non-heads only need to enter the endpoint to participate in + intra-group NCCL collectives via ``broadcast_tensor_container``. + This must stay in sync with the envelope format expected by + :func:`~areal.infra.rpc.serialization.deserialize_value`. + """ + return orjson.dumps({"args": serialize_value([]), "kwargs": serialize_value({})}) + + +def _raise_for_worker(resp: _WorkerResponse) -> None: + if resp.status >= 400: + text = resp.content.decode("utf-8", errors="replace") + raise RuntimeError(f"Worker {resp.addr} returned {resp.status}: {text}") + + +def _contains_partitionable_tensor_batch( + args: list[Any], kwargs: dict[str, Any] +) -> bool: + """Return True when payload matches list-of-items partition contract. + + The current partitioner (``_dispatch_tensors``) operates on list-like batches of + per-item dict payloads. Some endpoints (e.g. ``forward_batch`` with packed + tensor dicts) send tensor-containing dicts directly; those should use scalar + fan-out instead of list partitioning. + """ + + def _is_partitionable(v: Any) -> bool: + return isinstance(v, list) and len(v) > 0 and _is_tensor_like(v) + + return any(_is_partitionable(v) for v in args) or any( + _is_partitionable(v) for v in kwargs.values() + ) diff --git a/areal/experimental/training_service/data_proxy/engine.py b/areal/experimental/training_service/data_proxy/engine.py new file mode 100644 index 0000000000..e3825a43d1 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/engine.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import asdict +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + + +def register_engine_routes( + app: FastAPI, + *, + _raw_json_response: Callable[[bytes], Any], +) -> None: + # -- core routes ------------------------------------------------------- + + @app.get("/health") + async def health(): + topology = app.state.topology + return { + "status": "ok", + "worker_count": len(topology.workers), + "dp_size": topology.dp_size, + "dp_heads": topology.dp_heads, + } + + @app.get("/topology") + async def topology(): + t = app.state.topology + return { + "workers": [asdict(w) for w in t.workers], + "dp_heads": t.dp_heads, + "dp_size": t.dp_size, + "dp_groups": t.dp_groups, + "pp_size": t.pp_size, + "tp_size": t.tp_size, + "cp_size": t.cp_size, + "ep_size": t.ep_size, + } + + # -- dispatch helpers -------------------------------------------------- + + def _dispatch_compute_route(path: str, *, pad_eval_batch: bool = False): + async def handler(request: Request): + dispatcher = app.state.dispatcher + try: + body = await request.body() + return _raw_json_response( + await dispatcher.dispatch(path, pad_eval_batch=pad_eval_batch).post( + body + ) + ) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + + return handler + + def _broadcast_post_route(path: str, *, require_non_empty: bool = False): + async def handler(request: Request): + dispatcher = app.state.dispatcher + try: + body = await request.body() + responses = await dispatcher.broadcast(path).post(body) + if require_non_empty and not responses: + raise RuntimeError(f"No worker responses for {path}") + return _raw_json_response(responses[0]) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + + return handler + + def _dispatch_get_route(path: str): + async def handler(): + dispatcher = app.state.dispatcher + try: + return _raw_json_response(await dispatcher.dispatch(path).get()) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + + return handler + + def _broadcast_get_route(path: str, *, require_non_empty: bool = False): + async def handler(): + dispatcher = app.state.dispatcher + try: + responses = await dispatcher.broadcast(path).get() + if require_non_empty and not responses: + raise RuntimeError(f"No worker responses for {path}") + return _raw_json_response(responses[0]) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + + return handler + + # -- engine routes ----------------------------------------------------- + + app.post("/train_batch")(_dispatch_compute_route("/train_batch")) + app.post("/forward_batch")(_dispatch_compute_route("/forward_batch")) + app.post("/eval_batch")(_dispatch_compute_route("/eval_batch", pad_eval_batch=True)) + + app.post("/train")(_broadcast_post_route("/train")) + app.post("/eval")(_broadcast_post_route("/eval")) + app.post("/offload")(_broadcast_post_route("/offload")) + app.post("/onload")(_broadcast_post_route("/onload")) + app.post("/set_version")(_broadcast_post_route("/set_version")) + app.get("/get_version")(_dispatch_get_route("/get_version")) + app.post("/save")(_broadcast_post_route("/save", require_non_empty=True)) + app.post("/load")(_broadcast_post_route("/load", require_non_empty=True)) + app.post("/step_lr_scheduler")(_broadcast_post_route("/step_lr_scheduler")) + app.post("/optimizer_zero_grad")( + _broadcast_post_route("/optimizer_zero_grad", require_non_empty=True) + ) + app.post("/optimizer_step")( + _broadcast_post_route("/optimizer_step", require_non_empty=True) + ) + app.post("/get_device_stats")( + _broadcast_post_route("/get_device_stats", require_non_empty=True) + ) + app.post("/config_perf_tracer")( + _broadcast_post_route("/config_perf_tracer", require_non_empty=True) + ) + app.post("/save_perf_tracer")( + _broadcast_post_route("/save_perf_tracer", require_non_empty=True) + ) + app.post("/clear_batches")( + _broadcast_post_route("/clear_batches", require_non_empty=True) + ) + app.get("/export_stats")( + _broadcast_get_route("/export_stats", require_non_empty=True) + ) + + # -- SFT routes -------------------------------------------------------- + + app.post("/sft/train")(_dispatch_compute_route("/sft/train")) + app.post("/sft/evaluate")( + _dispatch_compute_route("/sft/evaluate", pad_eval_batch=True) + ) + + # -- PPO actor routes -------------------------------------------------- + + app.post("/ppo/actor/compute_logp")( + _dispatch_compute_route("/ppo/actor/compute_logp") + ) + app.post("/ppo/actor/compute_advantages")( + _dispatch_compute_route("/ppo/actor/compute_advantages") + ) + app.post("/ppo/actor/update")(_dispatch_compute_route("/ppo/actor/update")) + + # -- PPO critic routes ------------------------------------------------- + + app.post("/ppo/critic/compute_values")( + _dispatch_compute_route("/ppo/critic/compute_values") + ) + app.post("/ppo/critic/update")(_dispatch_compute_route("/ppo/critic/update")) + + # -- RW routes --------------------------------------------------------- + + app.post("/rw/train")(_dispatch_compute_route("/rw/train")) + app.post("/rw/evaluate")( + _dispatch_compute_route("/rw/evaluate", pad_eval_batch=True) + ) diff --git a/areal/experimental/training_service/data_proxy/topology.py b/areal/experimental/training_service/data_proxy/topology.py new file mode 100644 index 0000000000..a13a71d501 --- /dev/null +++ b/areal/experimental/training_service/data_proxy/topology.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + +import aiohttp + +from areal.utils import logging + +logger = logging.getLogger("TrainDataProxy") + + +@dataclass +class WorkerInfo: + addr: str + rank: int = 0 + world_size: int = 1 + dp_rank: int = 0 + dp_size: int = 1 + is_dp_head: bool = True + local_rank: int = 0 + + +@dataclass +class WorkerTopology: + workers: list[WorkerInfo] = field(default_factory=list) + dp_heads: list[int] = field(default_factory=list) + dp_size: int = 1 + dp_groups: list[list[int]] = field(default_factory=list) + pp_size: int = 1 + tp_size: int = 1 + cp_size: int = 1 + ep_size: int = 1 + + +async def discover_topology( + worker_addrs: list[str], + timeout: float = 10.0, +) -> WorkerTopology: + workers: list[WorkerInfo] = [] + meta: dict[str, int] = {} + + async def _fetch(session: aiohttp.ClientSession, addr: str) -> dict: + async with session.get(f"{addr}/topology") as resp: + if resp.status >= 400: + text = await resp.text() + raise RuntimeError( + f"Failed to discover topology from {addr}: " + f"HTTP {resp.status}: {text}" + ) + return await resp.json(content_type=None) + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + tasks = [_fetch(session, addr) for addr in worker_addrs] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for i, resp in enumerate(responses): + addr = worker_addrs[i] + if isinstance(resp, BaseException): + raise RuntimeError(f"Failed to discover topology from {addr}: {resp}") + data = resp + if not meta: + meta = { + "pp_size": int(data.get("pp_size", 1)), + "tp_size": int(data.get("tp_size", 1)), + "cp_size": int(data.get("cp_size", 1)), + "ep_size": int(data.get("ep_size", 1)), + } + workers.append( + WorkerInfo( + addr=addr, + rank=data.get("rank", 0), + world_size=data.get("world_size", 1), + dp_rank=data.get("dp_rank", 0), + dp_size=data.get("dp_size", 1), + is_dp_head=data.get("is_dp_head", True), + local_rank=data.get("local_rank", 0), + ) + ) + + dp_heads = [i for i, w in enumerate(workers) if w.is_dp_head] + dp_size = workers[0].dp_size if workers else 1 + dp_groups: list[list[int]] = [[] for _ in range(max(dp_size, 1))] + for i, w in enumerate(workers): + if w.dp_rank >= len(dp_groups): + dp_groups.extend([] for _ in range(w.dp_rank - len(dp_groups) + 1)) + dp_groups[w.dp_rank].append(i) + + return WorkerTopology( + workers=workers, + dp_heads=dp_heads, + dp_size=dp_size, + dp_groups=dp_groups, + pp_size=meta.get("pp_size", 1), + tp_size=meta.get("tp_size", 1), + cp_size=meta.get("cp_size", 1), + ep_size=meta.get("ep_size", 1), + ) diff --git a/areal/experimental/training_service/gateway/__init__.py b/areal/experimental/training_service/gateway/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/training_service/gateway/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/training_service/gateway/__main__.py b/areal/experimental/training_service/gateway/__main__.py new file mode 100644 index 0000000000..23ca2cc069 --- /dev/null +++ b/areal/experimental/training_service/gateway/__main__.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Training Gateway") + parser.add_argument("--host", default="0.0.0.0", help="Bind address") + parser.add_argument("--port", type=int, default=9080, help="Bind port") + parser.add_argument( + "--admin-api-key", + default="areal-admin-key", + help="Admin API key for privileged operations", + ) + parser.add_argument( + "--router-addr", + default="http://localhost:8081", + help="Router service address", + ) + parser.add_argument( + "--router-timeout", + type=float, + default=2.0, + help="Timeout (seconds) for router /route calls", + ) + parser.add_argument( + "--forward-timeout", + type=float, + default=600.0, + help="Timeout (seconds) for forwarding requests to data proxies", + ) + parser.add_argument( + "--log-level", + default="info", + choices=["debug", "info", "warning", "error"], + help="Log level", + ) + args, _ = parser.parse_known_args() + + from areal.experimental.training_service.gateway.app import create_app + from areal.experimental.training_service.gateway.config import GatewayConfig + + config = GatewayConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + router_addr=args.router_addr, + router_timeout=args.router_timeout, + forward_timeout=args.forward_timeout, + log_level=args.log_level, + ) + + import uvicorn + + app = create_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/training_service/gateway/app.py b/areal/experimental/training_service/gateway/app.py new file mode 100644 index 0000000000..f539a28b10 --- /dev/null +++ b/areal/experimental/training_service/gateway/app.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response + +from areal.experimental.training_service.gateway import streaming +from areal.experimental.training_service.gateway.auth import extract_bearer_token +from areal.experimental.training_service.gateway.config import GatewayConfig +from areal.experimental.training_service.gateway.engine import register_engine_routes +from areal.utils import logging + +logger = logging.getLogger("TrainGateway") + + +def _router_error_response(exc: Exception) -> JSONResponse: + if isinstance(exc, streaming.RouterUnreachableError): + return JSONResponse({"error": str(exc)}, status_code=502) + if isinstance(exc, streaming.RouterKeyRejectedError): + status = 401 if exc.status_code == 404 else exc.status_code + return JSONResponse({"error": exc.detail}, status_code=status) + return JSONResponse({"error": str(exc)}, status_code=500) + + +async def _forward_post( + request: Request, + path: str, + config: GatewayConfig, + *, + use_admin_auth_for_upstream: bool = False, +) -> Response: + token = extract_bearer_token(request) + try: + model_addr = await streaming.query_router( + config.router_addr, + token, + config.router_timeout, + admin_api_key=config.admin_api_key, + client=request.app.state.router_client, + ) + except (streaming.RouterUnreachableError, streaming.RouterKeyRejectedError) as exc: + return _router_error_response(exc) + + body = await request.body() + headers = streaming._forwarding_headers(dict(request.headers)) + if use_admin_auth_for_upstream: + for key in list(headers.keys()): + if key.lower() == "authorization": + headers.pop(key) + headers["Authorization"] = f"Bearer {config.admin_api_key}" + try: + resp = await streaming.forward_request( + f"{model_addr}{path}", + body, + headers, + config.forward_timeout, + client=request.app.state.upstream_client, + ) + except Exception as exc: + logger.error("Forwarding POST failed for %s: %s", path, exc) + return JSONResponse({"error": str(exc)}, status_code=502) + return Response( + content=resp.content, + status_code=resp.status_code, + media_type=resp.headers.get("content-type"), + ) + + +async def _forward_get(request: Request, path: str, config: GatewayConfig) -> Response: + token = extract_bearer_token(request) + try: + model_addr = await streaming.query_router( + config.router_addr, + token, + config.router_timeout, + admin_api_key=config.admin_api_key, + client=request.app.state.router_client, + ) + except (streaming.RouterUnreachableError, streaming.RouterKeyRejectedError) as exc: + return _router_error_response(exc) + + try: + resp = await request.app.state.upstream_client.get( + f"{model_addr}{path}", + headers=streaming._forwarding_headers(dict(request.headers)), + timeout=config.forward_timeout, + ) + except Exception as exc: + logger.error("Forwarding GET failed for %s: %s", path, exc) + return JSONResponse({"error": str(exc)}, status_code=502) + return Response( + content=resp.content, + status_code=resp.status_code, + media_type=resp.headers.get("content-type"), + ) + + +def create_app(config: GatewayConfig) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + router_client = httpx.AsyncClient(timeout=config.router_timeout) + upstream_client = httpx.AsyncClient(timeout=config.forward_timeout) + app.state.router_client = router_client + app.state.upstream_client = upstream_client + try: + yield + finally: + await upstream_client.aclose() + await router_client.aclose() + + app = FastAPI(title="AReaL Training Gateway", lifespan=lifespan) + + register_engine_routes( + app, + config, + _forward_post=_forward_post, + _forward_get=_forward_get, + ) + + return app diff --git a/areal/experimental/training_service/gateway/auth.py b/areal/experimental/training_service/gateway/auth.py new file mode 100644 index 0000000000..41f2f8e600 --- /dev/null +++ b/areal/experimental/training_service/gateway/auth.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Authentication helpers for the inference gateway.""" + +from __future__ import annotations + +import hmac +from dataclasses import dataclass + +from fastapi import HTTPException, Request + + +@dataclass +class AuthResult: + """Result of API key authentication.""" + + key_type: str # "admin" | "session" + api_key: str + + +class AuthError(Exception): + """Raised when auth fails.""" + + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + + +def extract_bearer_token(request: Request) -> str: + """Extract API token from Authorization header. + + Raises HTTPException(401) if missing or malformed. + """ + auth_header = request.headers.get("authorization", "") + if auth_header.lower().startswith("bearer "): + return auth_header[7:].strip() + raise HTTPException( + status_code=401, + detail="Missing or malformed Authorization header. Expected 'Bearer '.", + ) + + +def require_admin_key(request: Request, admin_api_key: str) -> str: + """Validate that the request carries the admin API key. + + Returns the bearer token on success. Raises HTTPException(403) on failure. + """ + token = extract_bearer_token(request) + if not hmac.compare_digest(token, admin_api_key): + raise HTTPException(status_code=403, detail="Admin API key required.") + return token diff --git a/areal/experimental/training_service/gateway/config.py b/areal/experimental/training_service/gateway/config.py new file mode 100644 index 0000000000..182274624c --- /dev/null +++ b/areal/experimental/training_service/gateway/config.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class GatewayConfig: + host: str = "0.0.0.0" + port: int = 9080 + router_addr: str = "" + admin_api_key: str = "areal-admin-key" + log_level: str = "info" + router_timeout: float = 2.0 + forward_timeout: float = 600.0 diff --git a/areal/experimental/training_service/gateway/engine.py b/areal/experimental/training_service/gateway/engine.py new file mode 100644 index 0000000000..736b5e4968 --- /dev/null +++ b/areal/experimental/training_service/gateway/engine.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from fastapi import FastAPI, Request + +from areal.experimental.training_service.gateway.config import GatewayConfig + + +def register_engine_routes( + app: FastAPI, + config: GatewayConfig, + *, + _forward_post: Callable[..., Any], + _forward_get: Callable[..., Any], +) -> None: + # -- core routes ------------------------------------------------------- + + @app.get("/health") + async def health(): + return {"status": "ok", "router_addr": config.router_addr} + + @app.post("/train_batch") + async def train_batch(request: Request): + return await _forward_post(request, "/train_batch", config) + + @app.post("/forward_batch") + async def forward_batch(request: Request): + return await _forward_post(request, "/forward_batch", config) + + @app.post("/eval_batch") + async def eval_batch(request: Request): + return await _forward_post(request, "/eval_batch", config) + + @app.post("/train") + async def train(request: Request): + return await _forward_post(request, "/train", config) + + @app.post("/eval") + async def eval_(request: Request): + return await _forward_post(request, "/eval", config) + + @app.post("/set_version") + async def set_version(request: Request): + return await _forward_post(request, "/set_version", config) + + @app.get("/get_version") + async def get_version(request: Request): + return await _forward_get(request, "/get_version", config) + + @app.post("/save") + async def save(request: Request): + return await _forward_post(request, "/save", config) + + @app.post("/load") + async def load(request: Request): + return await _forward_post(request, "/load", config) + + @app.post("/offload") + async def offload(request: Request): + return await _forward_post( + request, + "/offload", + config, + use_admin_auth_for_upstream=True, + ) + + @app.post("/onload") + async def onload(request: Request): + return await _forward_post( + request, + "/onload", + config, + use_admin_auth_for_upstream=True, + ) + + @app.post("/step_lr_scheduler") + async def step_lr_scheduler(request: Request): + return await _forward_post(request, "/step_lr_scheduler", config) + + @app.post("/optimizer_zero_grad") + async def optimizer_zero_grad(request: Request): + return await _forward_post(request, "/optimizer_zero_grad", config) + + @app.post("/optimizer_step") + async def optimizer_step(request: Request): + return await _forward_post(request, "/optimizer_step", config) + + @app.post("/get_device_stats") + async def get_device_stats(request: Request): + return await _forward_post(request, "/get_device_stats", config) + + @app.post("/config_perf_tracer") + async def config_perf_tracer(request: Request): + return await _forward_post(request, "/config_perf_tracer", config) + + @app.post("/save_perf_tracer") + async def save_perf_tracer(request: Request): + return await _forward_post(request, "/save_perf_tracer", config) + + @app.post("/clear_batches") + async def clear_batches(request: Request): + return await _forward_post(request, "/clear_batches", config) + + @app.get("/export_stats") + async def export_stats(request: Request): + return await _forward_get(request, "/export_stats", config) + + # -- SFT routes -------------------------------------------------------- + + @app.post("/sft/train") + async def train_sft(request: Request): + return await _forward_post(request, "/sft/train", config) + + @app.post("/sft/evaluate") + async def evaluate_sft(request: Request): + return await _forward_post(request, "/sft/evaluate", config) + + # -- PPO actor routes -------------------------------------------------- + + @app.post("/ppo/actor/compute_logp") + async def actor_compute_logp(request: Request): + return await _forward_post(request, "/ppo/actor/compute_logp", config) + + @app.post("/ppo/actor/compute_advantages") + async def actor_compute_advantages(request: Request): + return await _forward_post(request, "/ppo/actor/compute_advantages", config) + + @app.post("/ppo/actor/update") + async def actor_update(request: Request): + return await _forward_post(request, "/ppo/actor/update", config) + + # -- PPO critic routes ------------------------------------------------- + + @app.post("/ppo/critic/compute_values") + async def critic_compute_values(request: Request): + return await _forward_post(request, "/ppo/critic/compute_values", config) + + @app.post("/ppo/critic/update") + async def critic_update(request: Request): + return await _forward_post(request, "/ppo/critic/update", config) + + # -- RW routes --------------------------------------------------------- + + @app.post("/rw/train") + async def rw_train(request: Request): + return await _forward_post(request, "/rw/train", config) + + @app.post("/rw/evaluate") + async def rw_evaluate(request: Request): + return await _forward_post(request, "/rw/evaluate", config) diff --git a/areal/experimental/training_service/gateway/streaming.py b/areal/experimental/training_service/gateway/streaming.py new file mode 100644 index 0000000000..a3de135109 --- /dev/null +++ b/areal/experimental/training_service/gateway/streaming.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import httpx + +from areal.utils import logging + +logger = logging.getLogger("TrainGateway") + + +class RouterUnreachableError(Exception): + pass + + +class RouterKeyRejectedError(Exception): + def __init__(self, detail: str, status_code: int = 404): + super().__init__(detail) + self.detail = detail + self.status_code = status_code + + +async def query_router( + router_addr: str, + api_key: str, + timeout: float = 2.0, + *, + admin_api_key: str | None = None, + client: httpx.AsyncClient, +) -> str: + payload = {"api_key": api_key} + try: + headers = {} + if admin_api_key is not None: + headers["Authorization"] = f"Bearer {admin_api_key}" + resp = await client.post( + f"{router_addr}/route", + json=payload, + headers=headers, + timeout=timeout, + ) + if resp.status_code in {404, 503}: + try: + data = resp.json() + detail = data.get("detail", data.get("error", resp.text)) + except Exception: + detail = resp.text + raise RouterKeyRejectedError(detail, resp.status_code) + resp.raise_for_status() + return resp.json()["model_addr"] + except (httpx.ConnectError, httpx.ConnectTimeout) as exc: + raise RouterUnreachableError(f"Router unreachable: {exc}") from exc + except httpx.TimeoutException as exc: + raise RouterUnreachableError(f"Router timed out: {exc}") from exc + except httpx.HTTPStatusError as exc: + raise RouterUnreachableError( + f"Router returned HTTP {exc.response.status_code}: {exc}" + ) from exc + + +def _forwarding_headers(raw_headers: dict[str, str]) -> dict[str, str]: + skip = {"host", "content-length", "transfer-encoding"} + return {k: v for k, v in raw_headers.items() if k.lower() not in skip} + + +async def forward_request( + upstream_url: str, + body: bytes, + headers: dict[str, str], + timeout: float = 600.0, + *, + client: httpx.AsyncClient, +) -> httpx.Response: + fwd_headers = _forwarding_headers(headers) + resp = await client.post( + upstream_url, + content=body, + headers=fwd_headers, + timeout=timeout, + ) + return resp diff --git a/areal/experimental/training_service/guard/__init__.py b/areal/experimental/training_service/guard/__init__.py new file mode 100644 index 0000000000..4a7ff04f59 --- /dev/null +++ b/areal/experimental/training_service/guard/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Training service local guard package.""" diff --git a/areal/experimental/training_service/guard/__main__.py b/areal/experimental/training_service/guard/__main__.py new file mode 100644 index 0000000000..3028624721 --- /dev/null +++ b/areal/experimental/training_service/guard/__main__.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entrypoint: ``python -m areal.experimental.training_service.guard``""" + +from __future__ import annotations + +from areal.experimental.training_service.guard.app import _state, app +from areal.infra.rpc.guard.app import ( + configure_state_from_args, + make_base_parser, + run_server, +) + + +def main() -> None: + parser = make_base_parser( + description="AReaL Train RPCGuard — HTTP gateway for coordinating forked workers" + ) + args, _ = parser.parse_known_args() + bind_host = configure_state_from_args(_state, args) + run_server(_state, app, bind_host, args.port) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/training_service/guard/app.py b/areal/experimental/training_service/guard/app.py new file mode 100644 index 0000000000..b268ef1a1d --- /dev/null +++ b/areal/experimental/training_service/guard/app.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Training service guard backed by the shared RPC guard.""" + +from __future__ import annotations + +from areal.infra.rpc.guard.app import GuardState, create_app +from areal.infra.rpc.guard.app import cleanup_forked_children as _cleanup_impl +from areal.utils import logging + +logger = logging.getLogger("TrainRPCGuard") + +_state = GuardState() + +app = create_app(_state) + + +def cleanup_forked_children() -> None: + _cleanup_impl(_state) diff --git a/areal/experimental/training_service/router/__init__.py b/areal/experimental/training_service/router/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/training_service/router/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/training_service/router/__main__.py b/areal/experimental/training_service/router/__main__.py new file mode 100644 index 0000000000..eaa07fda57 --- /dev/null +++ b/areal/experimental/training_service/router/__main__.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import importlib + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Train Router") + parser.add_argument("--host", default="0.0.0.0", help="Bind address") + parser.add_argument("--port", type=int, default=9081, help="Bind port") + parser.add_argument( + "--admin-api-key", + default="areal-admin-key", + help="Admin API key for privileged operations", + ) + parser.add_argument( + "--poll-interval", + type=float, + default=5.0, + help="Seconds between model health polls", + ) + parser.add_argument( + "--worker-health-timeout", + type=float, + default=2.0, + help="Timeout (seconds) per model health check", + ) + parser.add_argument( + "--log-level", + default="info", + choices=["debug", "info", "warning", "error"], + help="Log level", + ) + args, _ = parser.parse_known_args() + + from areal.experimental.training_service.router.app import create_app + from areal.experimental.training_service.router.config import RouterConfig + + config = RouterConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + poll_interval=args.poll_interval, + worker_health_timeout=args.worker_health_timeout, + log_level=args.log_level, + ) + + app = create_app(config) + uvicorn = importlib.import_module("uvicorn") + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/training_service/router/app.py b/areal/experimental/training_service/router/app.py new file mode 100644 index 0000000000..ad8b1aa257 --- /dev/null +++ b/areal/experimental/training_service/router/app.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import hmac +import importlib +from contextlib import asynccontextmanager +from typing import Any + +from areal.experimental.training_service.router.config import RouterConfig +from areal.experimental.training_service.router.state import ModelRegistry +from areal.utils import logging + +httpx = importlib.import_module("httpx") +fastapi = importlib.import_module("fastapi") +pydantic = importlib.import_module("pydantic") +FastAPI = fastapi.FastAPI +HTTPException = fastapi.HTTPException +Request = fastapi.Request +BaseModel = pydantic.BaseModel + +logger = logging.getLogger("TrainRouter") + + +async def _probe_model_health( + model_registry: ModelRegistry, + model_addr: str, + client: Any, +) -> None: + try: + resp = await client.get(f"{model_addr}/health") + healthy = resp.status_code == 200 + except Exception: + healthy = False + await model_registry.update_health(model_addr, healthy) + + +def _extract_bearer_token(request: Request) -> str: + auth_header = request.headers.get("authorization", "") + if auth_header.lower().startswith("bearer "): + return auth_header[7:].strip() + raise HTTPException( + status_code=401, + detail="Missing or malformed Authorization header.", + ) + + +def _require_admin_key(request: Request, admin_key: str) -> str: + token = _extract_bearer_token(request) + if not hmac.compare_digest(token, admin_key): + raise HTTPException(status_code=403, detail="Invalid admin API key.") + return token + + +class RouteRequest(BaseModel): + api_key: str | None = None + + +class RegisterRequest(BaseModel): + model_addr: str + api_key: str + name: str = "" + + +class UnregisterRequest(BaseModel): + model_addr: str + + +def create_app(config: RouterConfig) -> FastAPI: + model_registry = ModelRegistry() + + async def _poll_models(client: Any) -> None: + while True: + models = await model_registry.get_all() + for model in models: + await _probe_model_health(model_registry, model.model_addr, client) + await asyncio.sleep(config.poll_interval) + + @asynccontextmanager + async def lifespan(app: FastAPI): + logger.info( + "Train router starting — poll_interval=%.1fs", + config.poll_interval, + ) + health_client = httpx.AsyncClient(timeout=config.worker_health_timeout) + app.state.model_registry = model_registry + app.state.health_client = health_client + poll_task = asyncio.create_task(_poll_models(health_client)) + try: + yield + finally: + poll_task.cancel() + try: + await poll_task + except asyncio.CancelledError: + pass + await health_client.aclose() + logger.info("Train router shutting down") + + app = FastAPI(title="AReaL Train Router", lifespan=lifespan) + app.state.model_registry = model_registry + + @app.get("/health") + async def health(): + model_count = await model_registry.count() + return { + "status": "ok", + "models": model_count, + } + + @app.post("/route") + async def route(body: RouteRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + if body.api_key is None: + raise HTTPException(status_code=422, detail="api_key required") + if hmac.compare_digest(body.api_key, config.admin_api_key): + raise HTTPException( + status_code=400, + detail="Admin key cannot be used for data-plane routing", + ) + model_info = await model_registry.lookup_by_key(body.api_key) + if model_info is None: + raise HTTPException(status_code=404, detail="Unknown API key") + if not model_info.is_healthy: + raise HTTPException(status_code=503, detail="Pinned model is unhealthy") + return {"model_addr": model_info.model_addr} + + @app.post("/register") + async def register(body: RegisterRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + await model_registry.register(body.model_addr, body.api_key, body.name) + logger.info( + "Model registered: %s", + body.model_addr, + ) + return {"status": "ok"} + + @app.post("/unregister") + async def unregister(body: UnregisterRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + await model_registry.deregister(body.model_addr) + logger.info( + "Model unregistered: %s", + body.model_addr, + ) + return {"status": "ok"} + + @app.get("/models") + async def list_models(request: Request): + _require_admin_key(request, config.admin_api_key) + models = await model_registry.get_all() + return { + "models": [ + { + "model_addr": m.model_addr, + "api_key": m.api_key, + "name": m.name, + "is_healthy": m.is_healthy, + "registered_at": m.registered_at, + } + for m in models + ] + } + + return app diff --git a/areal/experimental/training_service/router/config.py b/areal/experimental/training_service/router/config.py new file mode 100644 index 0000000000..f027b1dbf6 --- /dev/null +++ b/areal/experimental/training_service/router/config.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class RouterConfig: + host: str = "0.0.0.0" + port: int = 9081 + admin_api_key: str = "areal-admin-key" + log_level: str = "info" + poll_interval: float = 5.0 + worker_health_timeout: float = 2.0 diff --git a/areal/experimental/training_service/router/state.py b/areal/experimental/training_service/router/state.py new file mode 100644 index 0000000000..4d2407cd70 --- /dev/null +++ b/areal/experimental/training_service/router/state.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass, field + +_MAX_CONSECUTIVE_HEALTH_FAILURES = 2 + + +@dataclass +class ModelInfo: + model_addr: str + api_key: str + name: str = "" + is_healthy: bool = True + consecutive_health_failures: int = 0 + registered_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self) -> None: + self._models: dict[str, ModelInfo] = {} + self._key_to_addr: dict[str, str] = {} + self._lock = asyncio.Lock() + + async def register(self, model_addr: str, api_key: str, name: str = "") -> None: + async with self._lock: + existing_model = self._models.get(model_addr) + if existing_model is not None and existing_model.api_key != api_key: + self._key_to_addr.pop(existing_model.api_key, None) + + existing_addr_for_key = self._key_to_addr.get(api_key) + if ( + existing_addr_for_key is not None + and existing_addr_for_key != model_addr + ): + self._models.pop(existing_addr_for_key, None) + + if existing_model is None: + info = ModelInfo(model_addr=model_addr, api_key=api_key, name=name) + else: + existing_model.api_key = api_key + if name: + existing_model.name = name + info = existing_model + + self._models[model_addr] = info + self._key_to_addr[api_key] = model_addr + + async def deregister(self, model_addr: str) -> None: + async with self._lock: + model = self._models.pop(model_addr, None) + if model is not None: + self._key_to_addr.pop(model.api_key, None) + + async def lookup_by_key(self, api_key: str) -> ModelInfo | None: + async with self._lock: + model_addr = self._key_to_addr.get(api_key) + if model_addr is None: + return None + return self._models.get(model_addr) + + async def update_health(self, model_addr: str, healthy: bool) -> None: + async with self._lock: + model = self._models.get(model_addr) + if model is not None: + if healthy: + model.is_healthy = True + model.consecutive_health_failures = 0 + return + + model.consecutive_health_failures += 1 + if ( + model.consecutive_health_failures + >= _MAX_CONSECUTIVE_HEALTH_FAILURES + ): + model.is_healthy = False + + async def get_all(self) -> list[ModelInfo]: + async with self._lock: + return list(self._models.values()) + + async def count(self) -> int: + async with self._lock: + return len(self._models) diff --git a/areal/experimental/training_service/worker/__init__.py b/areal/experimental/training_service/worker/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/training_service/worker/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/training_service/worker/__main__.py b/areal/experimental/training_service/worker/__main__.py new file mode 100644 index 0000000000..3aee4abc61 --- /dev/null +++ b/areal/experimental/training_service/worker/__main__.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entrypoint for the train worker.""" + +from __future__ import annotations + +import argparse + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Train Worker") + parser.add_argument("--host", default="0.0.0.0", help="Bind address") + parser.add_argument("--port", type=int, default=30000, help="Bind port") + parser.add_argument( + "--admin-api-key", + default="areal-admin-key", + help="Admin API key for privileged operations", + ) + parser.add_argument( + "--log-level", + default="info", + choices=["debug", "info", "warning", "error"], + help="Log level", + ) + args, _ = parser.parse_known_args() + + from areal.experimental.training_service.worker.app import create_app + from areal.experimental.training_service.worker.config import TrainWorkerConfig + + config = TrainWorkerConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + log_level=args.log_level, + ) + + import logging as _logging + + _logging.getLogger("werkzeug").setLevel( + getattr(_logging, config.log_level.upper(), _logging.WARNING) + ) + + app = create_app(config) + app.run(host=config.host, port=config.port, threaded=True) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/training_service/worker/app.py b/areal/experimental/training_service/worker/app.py new file mode 100644 index 0000000000..0ddb10b324 --- /dev/null +++ b/areal/experimental/training_service/worker/app.py @@ -0,0 +1,204 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib +import traceback +from collections.abc import Callable +from concurrent.futures import Future +from queue import Queue +from threading import Lock, Thread +from typing import Any + +from areal.api import TrainEngine +from areal.experimental.training_service.worker.config import TrainWorkerConfig +from areal.experimental.training_service.worker.engine import create_engine_module +from areal.infra.platforms import current_platform +from areal.infra.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging +from areal.utils.data import broadcast_tensor_container, tensor_container_to + +logger = logging.getLogger("TrainWorker") + +_engine: TrainEngine | None = None +_node_addr: str = "" + +_engine_thread: Thread | None = None +_engine_work_queue: Queue | None = None +_engine_thread_lock = Lock() + + +def _init_engine_thread() -> None: + global _engine_thread, _engine_work_queue + + with _engine_thread_lock: + if _engine_thread is not None: + if _engine_thread.is_alive(): + return + else: + raise RuntimeError("Engine thread is dead.") + + _engine_work_queue = Queue() + + def engine_worker(): + logger.info("Engine thread started") + work_queue = _engine_work_queue + if work_queue is None: + raise RuntimeError("Engine work queue not initialized") + while True: + work_item = None + func_name = "" + try: + work_item = work_queue.get() + if work_item is None: + logger.info("Engine thread shutting down") + break + + func, args, kwargs, future, func_name = work_item + try: + result = func(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + logger.error( + f"Error in engine thread when " + f"running {func_name}: {e}\n{traceback.format_exc()}" + ) + finally: + work_queue.task_done() + except Exception as e: + logger.error( + f"Error in engine thread when " + f"running {func_name}: {e}\n{traceback.format_exc()}" + ) + if work_item and len(work_item) > 3: + work_item[3].set_exception(e) + + _engine_thread = Thread(target=engine_worker, daemon=True, name="EngineWorker") + _engine_thread.start() + logger.info("Engine thread initialized") + + +def _submit_to_engine_thread( + func_name: str, func: Callable, *args: Any, **kwargs: Any +) -> Any: + global _engine_work_queue + + _init_engine_thread() + if _engine_work_queue is None: + raise RuntimeError("Engine work queue not initialized") + + future: Future = Future() + _engine_work_queue.put((func, args, kwargs, future, func_name)) + return future.result() + + +def _require_engine() -> TrainEngine: + if _engine is None: + raise RuntimeError("Engine not created. Call /create_engine first.") + return _engine + + +def _execute_compute( + method_name: str, + args: Any, + kwargs: Any, + *, + require_broadcast: bool = False, +) -> Any: + engine = _require_engine() + method = getattr(engine, method_name, None) + if not callable(method): + raise RuntimeError(f"Engine does not implement method '{method_name}'") + + def execute(): + nonlocal args, kwargs + if require_broadcast: + group = engine.context_and_model_parallel_group + if group is None: + if engine.data_parallel_world_size > 1: + raise RuntimeError( + "Broadcast required for endpoint, but " + "engine.context_and_model_parallel_group is None" + ) + else: + args = broadcast_tensor_container( + tensor_container_to(args, current_platform.current_device()), + src_rank=engine.current_data_parallel_head(), + group=group, + ) + kwargs = broadcast_tensor_container( + tensor_container_to(kwargs, current_platform.current_device()), + src_rank=engine.current_data_parallel_head(), + group=group, + ) + return method(*args, **kwargs) + + return _submit_to_engine_thread(method_name, execute) + + +def _parse_args_kwargs(data: dict[str, Any] | None) -> tuple[Any, Any]: + if data is None: + raise ValueError("Invalid JSON in request body") + raw_args = deserialize_value(data.get("args", [])) + raw_kwargs = deserialize_value(data.get("kwargs", {})) + return raw_args, raw_kwargs + + +def create_app(config: TrainWorkerConfig): + global _node_addr + _node_addr = f"{config.host}:{config.port}" + + flask = importlib.import_module("flask") + jsonify = flask.jsonify + + app = flask.Flask(__name__) + + def _run_endpoint( + endpoint_name: str, + action: Callable[[], Any], + return_result: bool = True, + ): + try: + result = action() + if return_result: + return jsonify({"status": "success", "result": serialize_value(result)}) + return jsonify({"status": "success", "result": None}) + except RuntimeError as e: + return jsonify({"error": str(e)}), 400 + except ValueError as e: + return jsonify({"error": str(e)}), 400 + except Exception as e: + logger.error(f"Error in {endpoint_name}: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + def _get_engine() -> TrainEngine | None: + return _engine + + def _set_engine(engine: TrainEngine) -> None: + global _engine + _engine = engine + + def _get_node_addr() -> str: + return _node_addr + + app.register_blueprint( + create_engine_module( + flask_module=flask, + config=config, + get_engine=_get_engine, + set_engine=_set_engine, + submit_to_engine_thread=_submit_to_engine_thread, + parse_args_kwargs=_parse_args_kwargs, + require_engine=_require_engine, + run_endpoint=_run_endpoint, + execute_compute=_execute_compute, + get_node_addr=_get_node_addr, + ) + ) + + from areal.infra.rpc.guard.data_blueprint import data_bp + + app.register_blueprint(data_bp) + + return app diff --git a/areal/experimental/training_service/worker/config.py b/areal/experimental/training_service/worker/config.py new file mode 100644 index 0000000000..fda015564c --- /dev/null +++ b/areal/experimental/training_service/worker/config.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class TrainWorkerConfig: + host: str = "0.0.0.0" + port: int = 0 + admin_api_key: str = "areal-admin-key" + log_level: str = "info" diff --git a/areal/experimental/training_service/worker/engine.py b/areal/experimental/training_service/worker/engine.py new file mode 100644 index 0000000000..086c63248f --- /dev/null +++ b/areal/experimental/training_service/worker/engine.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import traceback +from collections.abc import Callable +from typing import Any + +from areal.api import TrainEngine +from areal.infra.rpc.rtensor import RTensor +from areal.infra.rpc.serialization import deserialize_value +from areal.utils import logging +from areal.utils.dynamic_import import import_from_string + +logger = logging.getLogger("TrainWorker") + + +def create_engine_module( + *, + flask_module: Any, + config: Any, + get_engine: Callable[[], TrainEngine | None], + set_engine: Callable[[TrainEngine], None], + submit_to_engine_thread: Callable[..., Any], + parse_args_kwargs: Callable[[dict[str, Any] | None], tuple[Any, Any]], + require_engine: Callable[[], TrainEngine], + run_endpoint: Callable[[str, Callable[[], Any]], Any], + execute_compute: Callable[..., Any], + get_node_addr: Callable[[], str], +) -> Any: + Blueprint = flask_module.Blueprint + jsonify = flask_module.jsonify + request = flask_module.request + + bp = Blueprint("worker_engine", __name__) + + # -- core routes ------------------------------------------------------- + + @bp.route("/health", methods=["GET"]) + def health_check(): + rank = int(os.environ.get("RANK", 0)) + role = os.environ.get("ROLE", "train_worker") + ready = get_engine() is not None + return jsonify( + { + "status": "healthy", + "rank": rank, + "role": role, + "ready": ready, + } + ) + + @bp.route("/create_engine", methods=["POST"]) + def create_engine(): + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + engine_class_path = data.get("engine_class") + if engine_class_path is None: + engine_class_path = data.get("engine") + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine_class_path: + return jsonify( + {"error": "Missing 'engine_class' field in request"} + ), 400 + if get_engine() is not None: + return jsonify({"error": "Engine already exists on this worker"}), 400 + + try: + engine_class = import_from_string(engine_class_path) + if not issubclass(engine_class, TrainEngine): + raise TypeError( + "Engine class must be a subclass of TrainEngine, " + f"got {engine_class}." + ) + except (ValueError, ImportError, AttributeError) as e: + return ( + jsonify( + { + "error": ( + f"Failed to import engine class '{engine_class_path}': {str(e)}" + ) + } + ), + 400, + ) + except TypeError as e: + return jsonify({"error": str(e)}), 400 + + def create_in_engine_thread(): + return engine_class(*init_args, **init_kwargs) + + engine = submit_to_engine_thread("create_engine", create_in_engine_thread) + set_engine(engine) + return jsonify( + { + "status": "success", + "message": "Engine created and initialized", + "result": None, + } + ) + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + @bp.route("/configure", methods=["POST"]) + def configure(): + data = request.get_json(silent=True) + raw_args, raw_kwargs = parse_args_kwargs(data) + + def action(): + engine = require_engine() + configure_fn = getattr(engine, "configure", None) + if callable(configure_fn): + return configure_fn(*raw_args, **raw_kwargs) + return None + + return run_endpoint( + "configure", + lambda: submit_to_engine_thread("configure", action), + ) + + @bp.route("/topology", methods=["GET"]) + def topology(): + try: + engine = require_engine() + return jsonify( + { + "rank": int(os.environ.get("RANK", 0)), + "world_size": int(os.environ.get("WORLD_SIZE", 1)), + "dp_rank": engine.data_parallel_rank, + "dp_size": engine.data_parallel_world_size, + "is_dp_head": engine.is_data_parallel_head(), + "local_rank": int(os.environ.get("LOCAL_RANK", 0)), + } + ) + except RuntimeError as e: + return jsonify({"error": str(e)}), 400 + except Exception as e: + logger.error(f"Error in topology: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + @bp.route("/get_param_info", methods=["GET"]) + def get_param_info(): + def action(): + engine = require_engine() + get_param_info_fn = getattr(engine, "get_param_info", None) + if callable(get_param_info_fn): + return get_param_info_fn() + get_parameter_info_fn = getattr(engine, "get_parameter_info", None) + if callable(get_parameter_info_fn): + return get_parameter_info_fn() + return None + + return run_endpoint( + "get_param_info", + lambda: submit_to_engine_thread("get_param_info", action), + ) + + # -- dispatch helpers -------------------------------------------------- + + def _register_compute_route( + path: str, method_name: str, *, endpoint_prefix: str = "" + ) -> None: + def handler(): + data = request.get_json(silent=True) + raw_args, raw_kwargs = parse_args_kwargs(data) + args = RTensor.localize(raw_args) + kwargs = RTensor.localize(raw_kwargs) + result = execute_compute( + method_name, + args, + kwargs, + require_broadcast=True, + ) + return RTensor.remotize(result, node_addr=get_node_addr()) + + ep_name = ( + f"{endpoint_prefix}{method_name}_endpoint" + if endpoint_prefix + else f"{method_name}_compute_endpoint" + ) + bp.add_url_rule( + path, + ep_name, + lambda: run_endpoint(method_name, handler), + methods=["POST"], + ) + + def _register_engine_route( + path: str, + method_name: str, + *, + methods: list[str] | None = None, + return_result: bool = True, + ) -> None: + http_methods = methods or ["POST"] + + def handler(): + if request.method == "GET": + args, kwargs = [], {} + else: + data = request.get_json(silent=True) or {} + args, kwargs = parse_args_kwargs(data) + + return run_endpoint( + method_name, + lambda: submit_to_engine_thread( + method_name, + lambda: getattr(require_engine(), method_name)(*args, **kwargs), + ), + return_result=return_result, + ) + + bp.add_url_rule(path, f"{method_name}_endpoint", handler, methods=http_methods) + + # -- engine routes ----------------------------------------------------- + + _register_engine_route("/train", "train", return_result=False) + _register_engine_route("/eval", "eval", return_result=False) + _register_compute_route("/train_batch", "train_batch") + _register_compute_route("/forward_batch", "forward_batch") + _register_compute_route("/eval_batch", "eval_batch") + _register_engine_route("/create_process_group", "create_process_group") + _register_engine_route("/initialize", "initialize") + _register_engine_route("/set_version", "set_version") + _register_engine_route("/get_version", "get_version", methods=["GET"]) + _register_engine_route("/save", "save") + _register_engine_route("/load", "load") + _register_engine_route("/offload", "offload") + _register_engine_route("/onload", "onload") + _register_engine_route("/optimizer_zero_grad", "optimizer_zero_grad") + _register_engine_route("/optimizer_step", "optimizer_step") + _register_engine_route("/step_lr_scheduler", "step_lr_scheduler") + _register_engine_route("/get_device_stats", "get_device_stats") + _register_engine_route("/config_perf_tracer", "config_perf_tracer") + _register_engine_route("/save_perf_tracer", "save_perf_tracer") + _register_engine_route("/clear_batches", "clear_batches") + _register_engine_route("/export_stats", "export_stats", methods=["GET"]) + + # -- SFT routes -------------------------------------------------------- + + _register_compute_route("/sft/train", "train_lm") + _register_compute_route("/sft/evaluate", "evaluate_lm") + + # -- PPO actor routes -------------------------------------------------- + + _register_compute_route( + "/ppo/actor/compute_logp", "compute_logp", endpoint_prefix="ppo_actor_" + ) + _register_compute_route( + "/ppo/actor/compute_advantages", + "compute_advantages", + endpoint_prefix="ppo_actor_", + ) + _register_compute_route( + "/ppo/actor/update", "ppo_update", endpoint_prefix="ppo_actor_" + ) + + # -- PPO critic routes ------------------------------------------------- + + _register_compute_route( + "/ppo/critic/compute_values", + "compute_values", + endpoint_prefix="ppo_critic_", + ) + _register_compute_route( + "/ppo/critic/update", "ppo_update", endpoint_prefix="ppo_critic_" + ) + + # -- RW routes --------------------------------------------------------- + + _register_compute_route("/rw/train", "train_rw", endpoint_prefix="rw_") + _register_compute_route("/rw/evaluate", "evaluate_rw", endpoint_prefix="rw_") + + return bp diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py index 8138c0fed4..69a0a80f37 100644 --- a/areal/infra/rpc/guard/app.py +++ b/areal/infra/rpc/guard/app.py @@ -415,6 +415,37 @@ def kill_forked_worker(): logger.error(f"Error in kill_forked_worker: {e}\n{traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + @app.route("/set_env", methods=["POST"]) + def set_env(): + """Set environment variables on the guard process. + + Forked child processes will inherit these via ``os.environ``. + + Expected JSON payload:: + + {"env": {"KEY": "value", "KEY2": "value2"}} + """ + try: + data = request.get_json(silent=True) + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + env_payload = data.get("env") + if env_payload is None: + return jsonify({"error": "Missing 'env' field in request"}), 400 + if not isinstance(env_payload, dict): + return jsonify({"error": "'env' must be a dictionary"}), 400 + + for key, value in env_payload.items(): + os.environ[key] = str(value) + + logger.info("Updated %d environment variables", len(env_payload)) + return jsonify({"status": "success"}) + + except Exception as e: + logger.error(f"Error in set_env: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + @app.route("/configure", methods=["POST"]) def configure(): """Configure the worker process. diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index ebcc0fdc45..7b5688a9ab 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -56,6 +56,8 @@ logger = logging.getLogger("LocalScheduler") +_MAX_STARTUP_PORT_CONFLICT_RETRIES = 3 + @dataclass class WorkerInfo: @@ -234,6 +236,20 @@ def _allocate_ports(self, count: int) -> list[int]: except ValueError as e: raise PortAllocationError(str(e)) from e + def _release_ports(self, ports: list[int]) -> None: + for port in ports: + self._allocated_ports.discard(port) + + @staticmethod + def _is_port_conflict_error(details: str) -> bool: + lowered = details.lower() + return ( + "address already in use" in lowered + or "errno 98" in lowered + or "errno 48" in lowered + or ("port " in lowered and "is in use by another program" in lowered) + ) + def _prepare_worker_specs( self, role: str, num_workers: int, schedulings: list[SchedulingSpec] | None ) -> list[SchedulingSpec]: @@ -704,10 +720,8 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: scheduling = schedulings[idx] try: - # Allocate GPUs and ports for this worker gpu_devices = self._allocate_gpus(scheduling.gpu) logger.debug(f"Worker {worker_id} allocated GPUs {gpu_devices}") - ports = self._allocate_ports(scheduling.port_count) except ( GPUAllocationError, PortAllocationError, @@ -756,49 +770,103 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: "Custom command should not include --port argument", "The scheduler automatically allocates and provides the port.", ) - cmd = shlex.split(scheduling.cmd) - cmd.extend(["--port", str(ports[0])]) - # Add name_resolve and worker identity args - cmd.extend(["--experiment-name", str(self.experiment_name)]) - cmd.extend(["--trial-name", str(self.trial_name)]) - cmd.extend(["--role", role]) - cmd.extend(["--worker-index", str(idx)]) - cmd.extend(["--name-resolve-type", self.name_resolve_config.type]) - cmd.extend( - ["--nfs-record-root", self.name_resolve_config.nfs_record_root] - ) - cmd.extend(["--etcd3-addr", self.name_resolve_config.etcd3_addr]) - cmd.extend(["--fileroot", str(self.fileroot)]) + cmd_prefix = shlex.split(scheduling.cmd) + cmd_suffix = [ + "--experiment-name", + str(self.experiment_name), + "--trial-name", + str(self.trial_name), + "--role", + role, + "--worker-index", + str(idx), + "--name-resolve-type", + self.name_resolve_config.type, + "--nfs-record-root", + self.name_resolve_config.nfs_record_root, + "--etcd3-addr", + self.name_resolve_config.etcd3_addr, + "--fileroot", + str(self.fileroot), + ] - logger.info(f"Starting worker {worker_id}: {' '.join(cmd)}") - if cmd[0].startswith("python"): - cmd[0] = sys.executable + process = None + ports = [] + for attempt in range(1, _MAX_STARTUP_PORT_CONFLICT_RETRIES + 1): + try: + ports = self._allocate_ports(scheduling.port_count) + except ( + PortAllocationError, + WorkerNotFoundError, + ValueError, + ) as e: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Resource allocation failed for worker {idx}", + str(e), + ) from e + + cmd = [*cmd_prefix, "--port", str(ports[0]), *cmd_suffix] - try: - process = run_with_streaming_logs( - cmd, - log_file, - merged_log, - role, - env_vars_in_cmd=env, + logger.info( + "Starting worker %s (attempt %s/%s): %s", + worker_id, + attempt, + _MAX_STARTUP_PORT_CONFLICT_RETRIES, + " ".join(cmd), ) - except Exception as e: - self._cleanup_workers(workers) - raise WorkerCreationError( - role, - f"Failed to spawn subprocess for worker {idx}", - str(e), - ) from e + if cmd[0].startswith("python"): + cmd[0] = sys.executable + + try: + process = run_with_streaming_logs( + cmd, + log_file, + merged_log, + role, + env_vars_in_cmd=env, + ) + except Exception as e: + self._release_ports(ports) + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Failed to spawn subprocess for worker {idx}", + str(e), + ) from e + + time.sleep(0.1) + if process.poll() is None: + break + + stderr = self._read_log_tail(str(log_file)) + self._release_ports(ports) + + if self._is_port_conflict_error(stderr): + logger.warning( + "Worker %s hit port conflict on startup attempt %s/%s; retrying with new ports.", + worker_id, + attempt, + _MAX_STARTUP_PORT_CONFLICT_RETRIES, + ) + if attempt < _MAX_STARTUP_PORT_CONFLICT_RETRIES: + time.sleep(0.1 * attempt) + continue - time.sleep(0.1) - if process.poll() is not None: - stderr = self._read_log_tail(log_file) self._cleanup_workers(workers) raise WorkerCreationError( role, f"Worker {worker_id} exited immediately with code {process.returncode}", stderr, ) + else: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Worker {worker_id} failed to start after {_MAX_STARTUP_PORT_CONFLICT_RETRIES} attempts", + self._read_log_tail(str(log_file)), + ) worker = Worker( id=worker_id, diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 68d7855376..c9e61dd290 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -7,7 +7,11 @@ from areal.api import TrainEngine from areal.api.cli_args import MicroBatchSpec, PPOActorConfig +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) from areal.infra import TrainController +from areal.infra.rpc.serialization import serialize_value from areal.trainer.ppo.stats import infer_token_denominator from areal.utils import logging, stats_tracker from areal.utils.constants import ( @@ -373,6 +377,29 @@ def ppo_update(self, *args, **kwargs) -> None: ) +class PPOActorControllerV2(GatewayTrainController): + def compute_logp(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + return self._gateway_post_result("/ppo/actor/compute_logp", payload) + + def compute_advantages(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + return self._gateway_post_result("/ppo/actor/compute_advantages", payload) + + def ppo_update(self, *args, **kwargs) -> None: + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post("/ppo/actor/update", payload) + + def grpo_loss_fn( logprobs: torch.Tensor, entropy: torch.Tensor, diff --git a/areal/trainer/ppo/critic.py b/areal/trainer/ppo/critic.py index 45bb7f24f2..a554181bc2 100644 --- a/areal/trainer/ppo/critic.py +++ b/areal/trainer/ppo/critic.py @@ -7,7 +7,11 @@ from areal.api import TrainEngine from areal.api.cli_args import MicroBatchSpec, PPOCriticConfig +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) from areal.infra import TrainController +from areal.infra.rpc.serialization import serialize_value from areal.trainer.ppo.stats import infer_token_denominator from areal.utils import stats_tracker from areal.utils.data import ( @@ -82,6 +86,22 @@ def ppo_update(self, *args, **kwargs): ) +class PPOCriticControllerV2(GatewayTrainController): + def compute_values(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + return self._gateway_post_result("/ppo/critic/compute_values", payload) + + def ppo_update(self, *args, **kwargs) -> None: + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post("/ppo/critic/update", payload) + + def ppo_loss_fn( value: torch.Tensor, input_data: dict, diff --git a/areal/trainer/rw/rw_engine.py b/areal/trainer/rw/rw_engine.py index b899e5c2a7..d8c22e832c 100644 --- a/areal/trainer/rw/rw_engine.py +++ b/areal/trainer/rw/rw_engine.py @@ -5,7 +5,11 @@ import torch from areal.api import TrainEngine +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) from areal.infra import TrainController +from areal.infra.rpc.serialization import serialize_value from areal.utils import logging, stats_tracker from areal.utils.data import batched_call from areal.utils.perf_tracer import trace_perf @@ -90,6 +94,24 @@ def evaluate_rw(self, *args, **kwargs): ) +class RWControllerV2(GatewayTrainController): + def train_rw(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post_result("/rw/train", payload) + + def evaluate_rw(self, *args, **kwargs): + kwargs = dict(kwargs) + kwargs.setdefault("group_size", 2) + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post_result("/rw/evaluate", payload) + + def compute_rw_loss(scores: torch.Tensor, input_: dict[str, Any]) -> torch.Tensor: device = scores.device cu_seqlens = input_["cu_seqlens"] diff --git a/areal/trainer/sft/lm_engine.py b/areal/trainer/sft/lm_engine.py index 253f9a65de..4551a4ae7b 100644 --- a/areal/trainer/sft/lm_engine.py +++ b/areal/trainer/sft/lm_engine.py @@ -5,7 +5,11 @@ import torch from areal.api import TrainEngine +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) from areal.infra import TrainController +from areal.infra.rpc.serialization import serialize_value from areal.utils import stats_tracker from areal.utils.data import batched_call from areal.utils.perf_tracer import trace_perf @@ -56,6 +60,22 @@ def evaluate_lm(self, *args, **kwargs): ) +class LMControllerV2(GatewayTrainController): + def train_lm(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post_result("/sft/train", payload) + + def evaluate_lm(self, *args, **kwargs): + payload = { + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + self._gateway_post_result("/sft/evaluate", payload) + + def compute_packed_sft_loss( logprobs: torch.Tensor, entropy: torch.Tensor, diff --git a/areal/utils/recover.py b/areal/utils/recover.py index fed069fa47..f332d5f6d6 100644 --- a/areal/utils/recover.py +++ b/areal/utils/recover.py @@ -171,9 +171,53 @@ def recover_info_path( "recover_info", ) + @staticmethod + def _is_gateway_train_controller( + engine: TrainEngine + | TrainController + | dict[str, TrainEngine | TrainController], + ) -> bool: + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + + if isinstance(engine, GatewayTrainController): + return True + if isinstance(engine, dict): + return any( + isinstance(controller, GatewayTrainController) + for controller in engine.values() + ) + return False + + def _ensure_recover_supported( + self, + engine: TrainEngine + | TrainController + | dict[str, TrainEngine | TrainController], + ) -> None: + if self._is_gateway_train_controller(engine): + raise NotImplementedError( + "Recovery is not supported with GatewayTrainController " + '(`_version="v2"`) yet. Disable `recover.mode` or use ' + '`_version="v1"`.' + ) + + @staticmethod + def _normalize_recover_engines( + engine: TrainEngine + | TrainController + | dict[str, TrainEngine | TrainController], + ) -> dict[str, TrainEngine | TrainController]: + if isinstance(engine, dict): + return engine + return {"default": engine} + def dump( self, - engine: TrainEngine | dict[str, TrainEngine], + engine: TrainEngine + | TrainController + | dict[str, TrainEngine | TrainController], step_info: StepInfo, saver: Saver, evaluator: Evaluator, @@ -185,15 +229,17 @@ def dump( ): if self.config.mode in ("disabled", "off"): return + self._ensure_recover_supported(engine) # currently only support recover on one engine if not self.freq_ctl.check( epochs=int(step_info.epoch_step == self.ft_spec.steps_per_epoch - 1), steps=1, ): return - if isinstance(engine, TrainEngine): - engine = {"default": engine} - for name, engine_ in engine.items(): + normalized_engine: dict[str, TrainEngine | TrainController] = ( + self._normalize_recover_engines(engine) + ) + for name, engine_ in normalized_engine.items(): self._save_checkpoint( engine_, name=name, @@ -232,11 +278,17 @@ def load( ) -> RecoverInfo | None: if self.config.mode in ("disabled", "off"): return + self._ensure_recover_supported(engine) if inference_engine is not None and weight_update_meta is None: raise ValueError("Weight update meta is required for recovery.") - if isinstance(engine, (TrainEngine, TrainController)): - engine = {"default": engine} + # TODO(agent): GatewayTrainController is currently duck-typed and does + # not satisfy this TrainController type check. Extend recovery to accept + # controller-v2 instances (or make v2 inherit TrainController) before + # relying on resumed runs with `_version="v2"`. + normalized_engine: dict[str, TrainEngine | TrainController] = ( + self._normalize_recover_engines(engine) + ) recover_info_path = self.recover_info_path( self.config.experiment_name, @@ -253,13 +305,13 @@ def load( stats_logger.load_state_dict(recover_info.stats_logger_info) dataloader.load_state_dict(recover_info.dataloader_info) - for name, engine_ in engine.items(): + for name, engine_ in normalized_engine.items(): self._load_checkpoint(engine_, name=name) global_step = recover_info.last_step_info.global_step if inference_engine is not None: assert weight_update_meta is not None - update_engine = engine[inference_engine_update_from] + update_engine = normalized_engine[inference_engine_update_from] recovery_version = global_step + 1 versioned_meta = weight_update_meta.with_version(recovery_version) update_engine.connect_engine(inference_engine, versioned_meta) diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index f519635cf4..9ccf70e141 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -362,6 +362,11 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.2` | Clipping factor for policy ratio | @@ -430,6 +435,11 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.5` | Clipping factor for value loss | @@ -471,6 +481,11 @@ Core configuration for model training, including optimization and backend settin | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | (section-generation-hyperparameters)= @@ -1087,6 +1102,11 @@ Configuration class: TeacherConfig | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.2` | Clipping factor for policy ratio | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 252830b7f8..b382cd26ba 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -360,6 +360,11 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.2` | Clipping factor for policy ratio | @@ -428,6 +433,11 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.5` | Clipping factor for value loss | @@ -469,6 +479,11 @@ Core configuration for model training, including optimization and backend settin | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | (section-generation-hyperparameters)= @@ -1085,6 +1100,11 @@ Configuration class: TeacherConfig | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.2` | Clipping factor for policy ratio | diff --git a/tests/experimental/training_service/__init__.py b/tests/experimental/training_service/__init__.py new file mode 100644 index 0000000000..c41c581fc9 --- /dev/null +++ b/tests/experimental/training_service/__init__.py @@ -0,0 +1 @@ +"""Tests for experimental training service.""" diff --git a/tests/experimental/training_service/fake_train_engine.py b/tests/experimental/training_service/fake_train_engine.py new file mode 100644 index 0000000000..e886a0ccad --- /dev/null +++ b/tests/experimental/training_service/fake_train_engine.py @@ -0,0 +1,201 @@ +"""CPU-only fake TrainEngine for training-service integration tests.""" + +from __future__ import annotations + +from typing import Any + +from areal.api import TrainEngine + + +def _sum_numbers(value: Any) -> float: + if isinstance(value, dict): + return sum(_sum_numbers(v) for v in value.values()) + if isinstance(value, list): + return sum(_sum_numbers(v) for v in value) + if isinstance(value, tuple): + return sum(_sum_numbers(v) for v in value) + if isinstance(value, bool): + return float(value) + if isinstance(value, (int, float)): + return float(value) + return 0.0 + + +class FakeTrainEngine(TrainEngine): + """Minimal concrete TrainEngine used by integration tests.""" + + def __init__(self, *args: Any, **kwargs: Any): + self._initialized = False + self._version = 0 + self._train_mode = True + self._offloaded = False + self._zero_grad_calls = 0 + self._optimizer_step_calls = 0 + self._lr_step_calls = 0 + self._last_saved_meta: Any = None + self._last_loaded_meta: Any = None + self._init_kwargs = dict(kwargs) + + def create_process_group(self, parallel_strategy=None): + return None + + def initialize(self, *args, **kwargs): + self._initialized = True + return None + + @property + def data_parallel_group(self): + return None + + @property + def data_parallel_rank(self) -> int: + return 0 + + @property + def data_parallel_world_size(self) -> int: + return 1 + + def current_data_parallel_head(self) -> int: + return 0 + + def is_data_parallel_head(self) -> bool: + return True + + @property + def context_and_model_parallel_group(self): + return None + + @property + def cpu_group(self): + return None + + @property + def initialized(self) -> bool: + return self._initialized + + def train(self, mode: bool = True): + self._train_mode = mode + return None + + def update_weights(self, meta): + return None + + def connect_engine(self, engine, meta): + return None + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow, + workflow_kwargs: dict[str, Any] | None = None, + group_size: int = 1, + ) -> list[dict[str, Any]]: + return data + + def prepare_batch( + self, + dataloader, + workflow, + workflow_kwargs: dict[str, Any] | None = None, + should_accept_fn=None, + group_size: int = 1, + dynamic_bs: bool = False, + ) -> list[dict[str, Any]]: + return [] + + def set_version(self, version: int): + self._version = int(version) + + def get_version(self) -> int: + return self._version + + def save(self, meta): + self._last_saved_meta = meta + + def load(self, meta): + self._last_loaded_meta = meta + + def optimizer_zero_grad(self): + self._zero_grad_calls += 1 + + def optimizer_step(self): + self._optimizer_step_calls += 1 + return { + "update_successful": 1.0, + "grad_norm": float(self._optimizer_step_calls), + "lr": 1e-3, + } + + def lr_scheduler_step(self): + self._lr_step_calls += 1 + + def forward_backward_batch( + self, + mb_list, + process_output_fn, + forward_only: bool = False, + ) -> None: + return None + + def train_batch( + self, + input_: dict[str, Any], + loss_fn=None, + loss_weight_fn=None, + ) -> dict[str, float]: + return { + "total": _sum_numbers(input_), + "version": float(self._version), + "train_mode": float(self._train_mode), + } + + def eval_batch( + self, + input_: dict[str, Any], + loss_fn=None, + loss_weight_fn=None, + ) -> float: + return _sum_numbers(input_) + self._version + + def forward_batch( + self, + input_: dict[str, Any], + output_seqlens: list[int] | None = None, + aggregate_fn=None, + ) -> dict[str, Any]: + return { + "total": _sum_numbers(input_), + "version": self._version, + "train_mode": self._train_mode, + "output_seqlens": output_seqlens, + } + + def train_lm(self, input_, **kwargs): + _ = kwargs + return self.train_batch(input_) + + def evaluate_lm(self, input_, **kwargs): + _ = kwargs + return self.eval_batch(input_) + + def export_stats(self) -> dict[str, float]: + return { + "version": float(self._version), + "train_mode": float(self._train_mode), + "offloaded": float(self._offloaded), + "zero_grad_calls": float(self._zero_grad_calls), + "optimizer_step_calls": float(self._optimizer_step_calls), + "lr_step_calls": float(self._lr_step_calls), + "saved_meta_size": float(len(str(self._last_saved_meta))), + "loaded_meta_size": float(len(str(self._last_loaded_meta))), + "world_size": float(self._init_kwargs.get("world_size", -1)), + } + + def onload(self) -> None: + self._offloaded = False + + def offload(self) -> None: + self._offloaded = True + + def get_device_stats(self): + return {"device": "cpu"} diff --git a/tests/experimental/training_service/test_controller_integration.py b/tests/experimental/training_service/test_controller_integration.py new file mode 100644 index 0000000000..d304ea9b64 --- /dev/null +++ b/tests/experimental/training_service/test_controller_integration.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import os +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass + +import pytest +import requests +import torch + +from areal.api.cli_args import ( + MicroBatchSpec, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, +) +from areal.api.io_struct import FinetuneSpec, SaveLoadMeta +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) +from areal.infra.platforms import current_platform +from areal.infra.scheduler.local import LocalScheduler + +LOCAL_MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" +LOCAL_MOE_MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/" + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def _resolve_model_path_or_skip() -> str: + if os.path.exists(LOCAL_MODEL_PATH): + return LOCAL_MODEL_PATH + pytest.skip( + "Local model path not found for CUDA integration test: " + f"{LOCAL_MODEL_PATH} (HF model: Qwen/Qwen3-0.6B)" + ) + raise RuntimeError("unreachable after pytest.skip") + + +def _resolve_moe_model_path_or_skip() -> str: + if os.path.exists(LOCAL_MOE_MODEL_PATH): + return LOCAL_MOE_MODEL_PATH + pytest.skip( + "Local MoE model path not found for CUDA integration test: " + f"{LOCAL_MOE_MODEL_PATH} (HF model: Qwen/Qwen3-30B-A3B)" + ) + raise RuntimeError("unreachable after pytest.skip") + + +@dataclass(frozen=True) +class _StrategyCase: + name: str + train_engine: str + backend: str + model_resolver: Callable[[], str] + expected_dp_size: int + + +def _strategy_cases_2gpu() -> list[_StrategyCase]: + return [ + _StrategyCase( + name="fsdp_dp2", + train_engine="areal.engine.FSDPEngine", + backend="fsdp:d2", + model_resolver=_resolve_model_path_or_skip, + expected_dp_size=2, + ), + _StrategyCase( + name="megatron_dp2", + train_engine="areal.engine.MegatronEngine", + backend="megatron:d2", + model_resolver=_resolve_model_path_or_skip, + expected_dp_size=2, + ), + _StrategyCase( + name="megatron_tp2", + train_engine="areal.engine.MegatronEngine", + backend="megatron:t2", + model_resolver=_resolve_model_path_or_skip, + expected_dp_size=1, + ), + _StrategyCase( + name="megatron_cp2", + train_engine="areal.engine.MegatronEngine", + backend="megatron:c2", + model_resolver=_resolve_model_path_or_skip, + expected_dp_size=1, + ), + _StrategyCase( + name="megatron_pp2", + train_engine="areal.engine.MegatronEngine", + backend="megatron:p2", + model_resolver=_resolve_model_path_or_skip, + expected_dp_size=1, + ), + _StrategyCase( + name="megatron_ep2", + train_engine="areal.engine.MegatronEngine", + backend="megatron:d2e2", + model_resolver=_resolve_moe_model_path_or_skip, + expected_dp_size=2, + ), + ] + + +@contextmanager +def _build_cuda_gateway_controller( + tmp_path_factory: pytest.TempPathFactory, + *, + case: _StrategyCase, +): + if current_platform.device_count() < 2: + pytest.skip("This test requires 2 GPUs") + + model_path = case.model_resolver() + tmp_path = tmp_path_factory.mktemp(f"training_gateway_cuda_{case.name}") + fileroot = tmp_path / "fileroot" + fileroot.mkdir() + name_resolve_root = tmp_path / "name_resolve" + name_resolve_root.mkdir() + + scheduler = LocalScheduler( + gpu_devices=[0, 1], + log_dir=str(tmp_path), + enable_tms_offload=True, + experiment_name=f"test_training_gateway_controller_cuda_{case.name}", + trial_name="trial_0", + fileroot=str(fileroot), + nfs_record_root=str(name_resolve_root), + ) + + config = TrainEngineConfig( + experiment_name=f"test_training_gateway_controller_cuda_{case.name}", + trial_name="trial_0", + backend=case.backend, + scheduling_spec=( + SchedulingSpec( + cpu=1, + gpu=1, + mem=2048, + port_count=1, + cmd="python -m areal.infra.rpc.rpc_server", + ), + ), + path=model_path, + admin_api_key=f"test-admin-key-cuda-{case.name}", + request_timeout=180.0, + setup_timeout=300.0, + offload=False, + mb_spec=MicroBatchSpec(max_tokens_per_mb=128), + optimizer=OptimizerConfig(), + ) + + controller = GatewayTrainController( + train_engine=case.train_engine, + scheduler=scheduler, + config=config, + ) + try: + controller.initialize( + role=f"train-gateway-cuda-{case.name}", + ft_spec=FinetuneSpec( + total_train_epochs=1, + dataset_size=8, + train_batch_size=2, + ), + ) + except Exception as exc: + try: + controller.destroy() + finally: + scheduler.delete_workers(role=None) + msg = str(exc).lower() + if "out of memory" in msg: + pytest.skip( + f"Skipping {case.name} due to transient CUDA OOM during bootstrap: {exc}" + ) + raise + + try: + yield controller, tmp_path_factory + finally: + controller.destroy() + scheduler.delete_workers(role=None) + + +def _make_batch(n: int = 4, seq_len: int = 16) -> list[dict[str, torch.Tensor]]: + return [ + { + "input_ids": torch.randint(0, 100, (1, seq_len), dtype=torch.long), + "attention_mask": torch.ones((1, seq_len), dtype=torch.bool), + } + for _ in range(n) + ] + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.parametrize("case", _strategy_cases_2gpu(), ids=lambda c: c.name) +def test_gateway_controller_integration( + tmp_path_factory: pytest.TempPathFactory, + case: _StrategyCase, +): + with _build_cuda_gateway_controller(tmp_path_factory, case=case) as ( + controller, + tmp_factory, + ): + # -- health -------------------------------------------------------- + + gateway_resp = requests.get(f"{controller._gateway_addr}/health", timeout=15) + assert gateway_resp.status_code == 200 + assert gateway_resp.json()["status"] == "ok" + + router_resp = requests.get(f"{controller._router_addr}/health", timeout=15) + assert router_resp.status_code == 200 + assert router_resp.json()["status"] == "ok" + + # -- topology ------------------------------------------------------ + + topology_resp = requests.get(f"{controller._model_addr}/topology", timeout=15) + assert topology_resp.status_code == 200 + topology = topology_resp.json() + + assert topology["dp_size"] == case.expected_dp_size + assert len(topology["workers"]) == 2 + worker_dp_ranks = sorted(w["dp_rank"] for w in topology["workers"]) + worker_dp_heads = [w["is_dp_head"] for w in topology["workers"]] + + if case.expected_dp_size == 2: + assert len(topology["dp_heads"]) == 2 + assert len(topology["dp_groups"]) == 2 + assert all(len(g) == 1 for g in topology["dp_groups"]) + assert worker_dp_ranks == [0, 1] + assert worker_dp_heads.count(True) == 2 + else: + assert len(topology["dp_heads"]) == 1 + assert len(topology["dp_groups"]) == 1 + assert len(topology["dp_groups"][0]) == 2 + assert worker_dp_ranks == [0, 0] + assert worker_dp_heads.count(True) == 1 + + # -- train / eval mode toggle -------------------------------------- + + controller.train(mode=False) + controller.train(mode=True) + controller.eval() + + # -- version ------------------------------------------------------- + + controller.set_version(11) + assert controller.get_version() == 11 + + controller.set_version(23) + assert controller.get_version() == 23 + + # -- forward_batch ------------------------------------------------- + + forward_result = controller.forward_batch(_make_batch(4)) + assert forward_result is not None + + # -- export_stats -------------------------------------------------- + + stats = controller.export_stats() + assert isinstance(stats, dict) + + # -- offload / onload cycle ---------------------------------------- + + controller.offload() + controller.onload() + + # -- save / load --------------------------------------------------- + + model_path = case.model_resolver() + save_load_path = str(tmp_factory.mktemp("hf_saveload")) + save_meta = SaveLoadMeta( + path=save_load_path, + weight_format="hf", + with_optim=False, + base_model_path=model_path, + ) + controller.save(save_meta) + + load_meta = SaveLoadMeta( + path=save_load_path, + weight_format="hf", + with_optim=False, + base_model_path=model_path, + ) + controller.load(load_meta) + + # -- step_lr_scheduler --------------------------------------------- + + controller.step_lr_scheduler() + + # -- optimizer_zero_grad / optimizer_step -------------------------- + + controller.optimizer_zero_grad() + controller.optimizer_step() + + # -- clear_batches ------------------------------------------------- + + controller.clear_batches() + + # -- second offload / onload cycle --------------------------------- + + controller.offload() + controller.onload() + + # -- final stats --------------------------------------------------- + + stats = controller.export_stats() + assert isinstance(stats, dict) diff --git a/tests/experimental/training_service/test_controller_unit.py b/tests/experimental/training_service/test_controller_unit.py new file mode 100644 index 0000000000..bd3b1d3651 --- /dev/null +++ b/tests/experimental/training_service/test_controller_unit.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from areal.api.cli_args import SchedulingSpec, TrainEngineConfig +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) + +MODULE = "areal.experimental.training_service.controller.controller" + + +def _make_response(method: str, url: str, *, json=None) -> httpx.Response: + return httpx.Response( + 200, + json=json, + request=httpx.Request(method, url), + ) + + +def _make_controller(scheduler: MagicMock | None = None) -> GatewayTrainController: + return GatewayTrainController( + train_engine="areal.engine.FSDPEngine", + scheduler=scheduler or MagicMock(), + config=TrainEngineConfig( + experiment_name="test-exp", + trial_name="trial-0", + backend="fsdp:d2", + scheduling_spec=( + SchedulingSpec( + cpu=1, + gpu=1, + mem=1024, + port_count=1, + cmd="python -m areal.infra.rpc.rpc_server", + ), + ), + admin_api_key="test-admin-key", + request_timeout=5.0, + setup_timeout=5.0, + ), + ) + + +class _FakeAsyncClient: + def __init__(self, responses_or_errors): + self._responses_or_errors = list(responses_or_errors) + self.get = AsyncMock(side_effect=self._get) + self.post = AsyncMock(side_effect=self._post) + + async def __aenter__(self): + return self + + async def __aexit__(self, *_args): + return None + + async def _get(self, _url: str): + next_item = self._responses_or_errors.pop(0) + if isinstance(next_item, Exception): + raise next_item + return next_item + + async def _post(self, _url: str, json=None): + _ = json + next_item = self._responses_or_errors.pop(0) + if isinstance(next_item, Exception): + raise next_item + return next_item + + +class TestGatewayTrainControllerAsyncHelpers: + @pytest.mark.asyncio + async def test_async_wait_for_service_reuses_single_client_across_retries(self): + controller = _make_controller() + fake_client = _FakeAsyncClient( + [ + httpx.ConnectError("not ready"), + _make_response("GET", "http://service/health"), + ] + ) + + with ( + patch("httpx.AsyncClient", return_value=fake_client) as mock_client_cls, + patch(f"{MODULE}.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + await controller._async_wait_for_service( + "http://service/health", "service", timeout=1.0 + ) + + mock_client_cls.assert_called_once_with(timeout=2.0) + assert fake_client.get.await_count == 2 + mock_sleep.assert_awaited_once_with(0.1) + + +class TestGatewayTrainControllerInitialization: + @pytest.mark.asyncio + async def test_async_initialize_offloads_scheduler_and_uses_async_helpers(self): + worker0 = MagicMock(ip="127.0.0.1", worker_ports=[18000], id="guard-0") + worker1 = MagicMock(ip="127.0.0.1", worker_ports=[18001], id="guard-1") + + scheduler = MagicMock() + scheduler.create_workers.return_value = ["guard-0", "guard-1"] + scheduler.get_workers.return_value = [worker0, worker1] + + controller = _make_controller(scheduler) + controller._role = "train-role" + + port_client = _FakeAsyncClient( + [ + _make_response( + "POST", + "http://127.0.0.1:18000/alloc_ports", + json={"ports": [29500]}, + ) + ] + ) + + async def _run_in_thread(func, *args, **kwargs): + return func(*args, **kwargs) + + with ( + patch("httpx.AsyncClient", return_value=port_client), + patch( + f"{MODULE}.asyncio.to_thread", side_effect=_run_in_thread + ) as mock_to_thread, + patch.object( + controller, "_async_set_guards_env", new_callable=AsyncMock + ) as mock_set_env, + patch.object( + controller, + "_async_fork_on_guard", + new_callable=AsyncMock, + side_effect=[ + ("127.0.0.1", 19001), + ("127.0.0.1", 19002), + ("127.0.0.1", 18081), + ("127.0.0.1", 18082), + ("127.0.0.1", 18080), + ], + ) as mock_async_fork, + patch.object(controller, "_fork_on_guard", autospec=True) as mock_sync_fork, + patch.object( + controller, "_create_engine_on_worker", new_callable=AsyncMock + ) as mock_create_engine, + patch.object( + controller, + "_call_worker_engine_endpoint", + new_callable=AsyncMock, + ) as mock_call_engine, + patch.object( + controller, "_register_in_router", new_callable=AsyncMock + ) as mock_register, + ): + await controller._async_initialize(role="train-role") + + assert mock_to_thread.await_count == 2 + create_call = mock_to_thread.await_args_list[0] + get_call = mock_to_thread.await_args_list[1] + assert create_call.args[0] is scheduler.create_workers + assert get_call.args[0] is scheduler.get_workers + assert get_call.kwargs == { + "role": "train-role-guard", + "timeout": 5, + } + + mock_set_env.assert_awaited_once() + assert mock_async_fork.await_count == 5 + mock_sync_fork.assert_not_called() + assert mock_create_engine.await_count == 2 + assert mock_call_engine.await_count == 4 + mock_register.assert_awaited_once_with( + "http://127.0.0.1:18081", + "http://127.0.0.1:18082", + controller.api_key, + ) + + assert controller._worker_addrs == [ + "http://127.0.0.1:19001", + "http://127.0.0.1:19002", + ] + assert controller._router_addr == "http://127.0.0.1:18081" + assert controller._model_addr == "http://127.0.0.1:18082" + assert controller._gateway_addr == "http://127.0.0.1:18080" + assert controller.api_key is not None + assert controller.api_key.startswith("ak-train-role-") diff --git a/tests/experimental/training_service/test_data_proxy_unit.py b/tests/experimental/training_service/test_data_proxy_unit.py new file mode 100644 index 0000000000..5a209dbdd9 --- /dev/null +++ b/tests/experimental/training_service/test_data_proxy_unit.py @@ -0,0 +1,650 @@ +"""Unit tests for training-service data proxy.""" + +from __future__ import annotations + +from typing import Any + +import httpx +import orjson +import pytest +import pytest_asyncio +import torch + +from areal.experimental.training_service.data_proxy.app import create_app +from areal.experimental.training_service.data_proxy.config import TrainDataProxyConfig +from areal.experimental.training_service.data_proxy.dispatcher import Dispatcher +from areal.experimental.training_service.data_proxy.topology import ( + WorkerInfo, + WorkerTopology, +) +from areal.infra.controller.train_controller import _dispatch_tensors +from areal.infra.rpc.serialization import deserialize_value, serialize_value + +ADMIN_KEY = "dp-admin-key" + + +# ------------------------------------------------------------------ +# aiohttp mock helpers +# ------------------------------------------------------------------ + + +class _FakeAiohttpResponse: + def __init__(self, content: bytes, status: int = 200): + self._content = content + self.status = status + + async def read(self) -> bytes: + return self._content + + +class _AsyncCM: + def __init__(self, value): + self._value = value + + async def __aenter__(self): + return self._value + + async def __aexit__(self, *args): + pass + + +class _NoOpSession: + async def close(self): + pass + + +class _CapturingSession: + def __init__(self, *, post_handler=None): + self.captured_payloads: list[dict[str, Any]] = [] + self._post_handler = post_handler + + def post(self, url, *, data=b"", headers=None): + _ = headers + payload = orjson.loads(data) + self.captured_payloads.append( + { + "url": url, + "payload": payload, + "args": deserialize_value(payload["args"]), + "kwargs": deserialize_value(payload["kwargs"]), + } + ) + if self._post_handler: + content = self._post_handler(url, data, headers) + else: + content = orjson.dumps({"status": "success", "result": payload["args"]}) + return _AsyncCM(_FakeAiohttpResponse(content)) + + async def close(self): + pass + + +# ------------------------------------------------------------------ +# Fake Dispatcher for app-level route tests +# ------------------------------------------------------------------ + + +class _FakeDispatchRequest: + def __init__( + self, + parent: _FakeDispatcher, + path: str, + *, + pad_eval_batch: bool = False, + ): + self._parent = parent + self._path = path + self._pad_eval_batch = pad_eval_batch + + async def get(self) -> bytes: + return b'{"status":"success","result":{"path":"get"}}' + + async def post(self, body: bytes) -> bytes: + _ = body + self._parent.dispatch_calls.append( + {"path": self._path, "pad_eval_batch": self._pad_eval_batch} + ) + return b'{"status":"success","result":{"path":"compute"}}' + + +class _FakeBroadcastRequest: + def __init__(self, parent: _FakeDispatcher, path: str): + self._parent = parent + self._path = path + + async def get(self) -> list[bytes]: + return self._parent.broadcast_get_return + + async def post(self, body: bytes) -> list[bytes]: + _ = body + return self._parent.broadcast_return + + +class _FakeDispatcher: + def __init__(self): + self.broadcast_return: list[bytes] = [ + b'{"status":"success","result":{"ok":true}}' + ] + self.broadcast_get_return: list[bytes] = [ + b'{"status":"success","result":{"path":"broadcast_get"}}' + ] + self.dispatch_calls: list[dict[str, Any]] = [] + + def dispatch( + self, path: str, *, pad_eval_batch: bool = False + ) -> _FakeDispatchRequest: + return _FakeDispatchRequest(self, path, pad_eval_batch=pad_eval_batch) + + def broadcast(self, path: str) -> _FakeBroadcastRequest: + return _FakeBroadcastRequest(self, path) + + +@pytest.fixture +def config() -> TrainDataProxyConfig: + return TrainDataProxyConfig( + host="127.0.0.1", + port=18082, + worker_addrs=["http://worker-0:19001", "http://worker-1:19001"], + admin_api_key=ADMIN_KEY, + request_timeout=10.0, + ) + + +@pytest_asyncio.fixture +async def app_client(config): + app = create_app(config) + app.state.topology = WorkerTopology( + workers=[ + WorkerInfo(addr="http://worker-0:19001", rank=0, dp_rank=0, dp_size=2), + WorkerInfo(addr="http://worker-1:19001", rank=1, dp_rank=1, dp_size=2), + ], + dp_heads=[0, 1], + dp_size=2, + dp_groups=[[0], [1]], + ) + app.state.dispatcher = _FakeDispatcher() + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield app, c + + +def _admin_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +class TestDataProxyBasics: + @pytest.mark.asyncio + async def test_health_and_topology(self, app_client): + _app, client = app_client + health = await client.get("/health") + assert health.status_code == 200 + assert health.json()["dp_size"] == 2 + + topo = await client.get("/topology") + assert topo.status_code == 200 + assert len(topo.json()["workers"]) == 2 + + @pytest.mark.asyncio + async def test_train_batch_uses_dispatcher(self, app_client): + _app, client = app_client + resp = await client.post("/train_batch", content=b"{}") + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + async def test_ppo_actor_compute_logp_uses_dispatcher(self, app_client): + _app, client = app_client + resp = await client.post("/ppo/actor/compute_logp", content=b"{}") + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + async def test_sft_train_uses_dispatcher(self, app_client): + _app, client = app_client + resp = await client.post("/sft/train", content=b"{}") + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + async def test_rw_train_uses_dispatcher(self, app_client): + _app, client = app_client + resp = await client.post("/rw/train", content=b"{}") + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + async def test_eval_routes_opt_in_to_padding(self, app_client): + app, client = app_client + + await client.post("/eval_batch", content=b"{}") + await client.post("/sft/evaluate", content=b"{}") + await client.post("/rw/evaluate", content=b"{}") + + assert app.state.dispatcher.dispatch_calls == [ + {"path": "/eval_batch", "pad_eval_batch": True}, + {"path": "/sft/evaluate", "pad_eval_batch": True}, + {"path": "/rw/evaluate", "pad_eval_batch": True}, + ] + + @pytest.mark.asyncio + async def test_training_routes_do_not_enable_padding(self, app_client): + app, client = app_client + + await client.post("/train_batch", content=b"{}") + await client.post("/forward_batch", content=b"{}") + await client.post("/sft/train", content=b"{}") + await client.post("/ppo/actor/update", content=b"{}") + await client.post("/ppo/critic/update", content=b"{}") + await client.post("/rw/train", content=b"{}") + + assert app.state.dispatcher.dispatch_calls == [ + {"path": "/train_batch", "pad_eval_batch": False}, + {"path": "/forward_batch", "pad_eval_batch": False}, + {"path": "/sft/train", "pad_eval_batch": False}, + {"path": "/ppo/actor/update", "pad_eval_batch": False}, + {"path": "/ppo/critic/update", "pad_eval_batch": False}, + {"path": "/rw/train", "pad_eval_batch": False}, + ] + + @pytest.mark.asyncio + async def test_optimizer_step_empty_broadcast_returns_502(self, app_client): + app, client = app_client + app.state.dispatcher.broadcast_return = [] + resp = await client.post("/optimizer_step", content=b"{}") + assert resp.status_code == 502 + assert "No worker responses" in resp.json()["error"] + + @pytest.mark.asyncio + async def test_export_stats_uses_broadcast_get(self, app_client): + app, client = app_client + app.state.dispatcher.broadcast_get_return = [ + b'{"status":"success","result":{"stats":1}}' + ] + resp = await client.get("/export_stats") + assert resp.status_code == 200 + assert resp.json()["result"]["stats"] == 1 + + +def _make_dispatcher( + *, + dp_size: int, + dp_heads: list[int], + dp_ranks: list[int], + session=None, +) -> Dispatcher: + workers = [ + WorkerInfo( + addr=f"http://worker-{i}:19001", + rank=i, + dp_rank=dp_ranks[i], + dp_size=dp_size, + is_dp_head=(i in dp_heads), + ) + for i in range(len(dp_ranks)) + ] + max_dp_rank = max(dp_ranks) if dp_ranks else 0 + dp_groups = [[] for _ in range(max_dp_rank + 1)] + for i, dp_rank in enumerate(dp_ranks): + dp_groups[dp_rank].append(i) + topology = WorkerTopology( + workers=workers, + dp_heads=dp_heads, + dp_size=dp_size, + dp_groups=dp_groups, + ) + return Dispatcher( + topology=topology, + request_timeout=10.0, + _session=session or _NoOpSession(), + ) + + +def _make_tensor_item(seq_len: int) -> dict[str, torch.Tensor]: + return { + "input_ids": torch.randint(0, 100, (1, seq_len), dtype=torch.long), + "attention_mask": torch.ones((1, seq_len), dtype=torch.bool), + } + + +class TestDispatcherParityWithTrainController: + def test_partition_inputs_matches_train_controller_dispatch(self): + dispatcher = _make_dispatcher(dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1]) + req = dispatcher.dispatch("/any") + + batch = [ + _make_tensor_item(16), + _make_tensor_item(8), + _make_tensor_item(12), + _make_tensor_item(10), + ] + args = [batch] + kwargs: dict[str, object] = {"tag": "x"} + + dp_args, dp_kwargs, group_indices = req._partition_inputs( + args=args, + kwargs=kwargs, + group_size=1, + ) + + expected_splits, expected_indices = _dispatch_tensors(batch, 2, group_size=1) + assert group_indices == expected_indices + assert dp_args[0] == expected_splits + assert dp_kwargs["tag"] == ["x", "x"] + + def test_partition_inputs_respects_group_size_atomicity(self): + dispatcher = _make_dispatcher(dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1]) + req = dispatcher.dispatch("/any") + + batch = [ + _make_tensor_item(16), + _make_tensor_item(8), + _make_tensor_item(12), + _make_tensor_item(10), + ] + dp_args, _dp_kwargs, group_indices = req._partition_inputs( + args=[batch], + kwargs={}, + group_size=2, + ) + + assert len(dp_args[0]) == 2 + for idxs in group_indices: + assert len(idxs) % 2 == 0 + for i in range(0, len(idxs), 2): + left, right = idxs[i], idxs[i + 1] + assert right == left + 1 + assert left % 2 == 0 + + @pytest.mark.asyncio + async def test_fan_out_dispatches_only_to_dp_heads_for_model_parallel( + self, + ): + session = _CapturingSession() + dispatcher = _make_dispatcher( + dp_size=1, dp_heads=[0], dp_ranks=[0, 0], session=session + ) + req = dispatcher.dispatch("/forward_batch") + + dp_args = [[[_make_tensor_item(16), _make_tensor_item(8)]]] + dp_kwargs = {"output_seqlens": [[2, 3]]} + + await req._fan_out(dp_args=dp_args, dp_kwargs=dp_kwargs) + + captured = session.captured_payloads + assert len(captured) == 2 + assert captured[0]["args"] != [] + assert captured[0]["kwargs"] == {"output_seqlens": [2, 3]} + assert "rpc_meta" not in captured[0]["payload"] + assert captured[1]["args"] == [] + assert captured[1]["kwargs"] == {} + assert "rpc_meta" not in captured[1]["payload"] + + @pytest.mark.asyncio + async def test_fan_out_omits_rpc_meta_for_algorithm_paths( + self, + ): + session = _CapturingSession() + dispatcher = _make_dispatcher( + dp_size=1, dp_heads=[0], dp_ranks=[0, 0], session=session + ) + req = dispatcher.dispatch("/sft/train") + + dp_args = [[[_make_tensor_item(16), _make_tensor_item(8)]]] + dp_kwargs = {"output_seqlens": [[2, 3]]} + + await req._fan_out(dp_args=dp_args, dp_kwargs=dp_kwargs) + + captured = session.captured_payloads + assert len(captured) == 2 + assert "rpc_meta" not in captured[0]["payload"] + assert "rpc_meta" not in captured[1]["payload"] + + @pytest.mark.asyncio + async def test_fan_out_dispatches_per_dp_shard_when_dp_size_two( + self, + ): + session = _CapturingSession() + dispatcher = _make_dispatcher( + dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1], session=session + ) + req = dispatcher.dispatch("/forward_batch") + + shard0 = [_make_tensor_item(4)] + shard1 = [_make_tensor_item(12), _make_tensor_item(6)] + dp_args = [[shard0, shard1]] + dp_kwargs = {"output_seqlens": [[1], [2, 1]]} + + await req._fan_out(dp_args=dp_args, dp_kwargs=dp_kwargs) + + captured = session.captured_payloads + assert len(captured) == 2 + assert len(captured[0]["args"][0]) == len(shard0) + assert len(captured[1]["args"][0]) == len(shard1) + assert ( + captured[0]["args"][0][0]["input_ids"].shape == shard0[0]["input_ids"].shape + ) + assert ( + captured[1]["args"][0][0]["input_ids"].shape == shard1[0]["input_ids"].shape + ) + assert captured[0]["kwargs"] == {"output_seqlens": [1]} + assert captured[1]["kwargs"] == {"output_seqlens": [2, 1]} + + @pytest.mark.asyncio + async def test_dispatch_post_merges_results_in_original_order( + self, + ): + def _shard_handler(url, data, headers): + _ = url, headers + payload = orjson.loads(data) + args = deserialize_value(payload["args"]) + shard = args[0] if args else [] + shard_result = [int(item["attention_mask"].sum().item()) for item in shard] + return orjson.dumps( + { + "status": "success", + "result": serialize_value(shard_result), + } + ) + + session = _CapturingSession(post_handler=_shard_handler) + dispatcher = _make_dispatcher( + dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1], session=session + ) + + batch = [ + _make_tensor_item(5), + _make_tensor_item(11), + _make_tensor_item(7), + _make_tensor_item(13), + ] + body = orjson.dumps( + { + "args": serialize_value([batch]), + "kwargs": serialize_value({}), + } + ) + + result_bytes = await dispatcher.dispatch("/forward_batch").post(body) + result_payload = orjson.loads(result_bytes) + merged = deserialize_value(result_payload["result"]) + + assert merged == [5, 11, 7, 13] + + @pytest.mark.asyncio + async def test_dispatch_post_does_not_pad_training_routes(self): + def _shard_handler(url, data, headers): + _ = url, headers + payload = orjson.loads(data) + args = deserialize_value(payload["args"]) + shard = args[0] if args else [] + return orjson.dumps( + { + "status": "success", + "result": serialize_value( + [int(item["attention_mask"].sum().item()) for item in shard] + ), + } + ) + + session = _CapturingSession(post_handler=_shard_handler) + dispatcher = _make_dispatcher( + dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1], session=session + ) + + batch = [_make_tensor_item(5), _make_tensor_item(11), _make_tensor_item(7)] + body = orjson.dumps( + { + "args": serialize_value([batch]), + "kwargs": serialize_value({}), + } + ) + + with pytest.raises(ValueError, match="divisible by K"): + await dispatcher.dispatch("/train_batch").post(body) + + @pytest.mark.asyncio + async def test_dispatch_post_pads_eval_routes_only(self): + def _shard_handler(url, data, headers): + _ = url, headers + payload = orjson.loads(data) + args = deserialize_value(payload["args"]) + shard = args[0] if args else [] + return orjson.dumps( + { + "status": "success", + "result": serialize_value( + [int(item["attention_mask"].sum().item()) for item in shard] + ), + } + ) + + session = _CapturingSession(post_handler=_shard_handler) + dispatcher = _make_dispatcher( + dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1], session=session + ) + + batch = [_make_tensor_item(5), _make_tensor_item(11), _make_tensor_item(7)] + body = orjson.dumps( + { + "args": serialize_value([batch]), + "kwargs": serialize_value({}), + } + ) + + result_bytes = await dispatcher.dispatch( + "/eval_batch", pad_eval_batch=True + ).post(body) + result_payload = orjson.loads(result_bytes) + merged = deserialize_value(result_payload["result"]) + + assert merged == [5, 11, 7, 0] + + +class TestScalarFanOut: + """Tests for _scalar_fan_out: non-partitionable payloads must reach all workers.""" + + @pytest.mark.asyncio + async def test_non_partitionable_tensor_fans_out_to_all_workers(self): + session = _CapturingSession() + dispatcher = _make_dispatcher( + dp_size=1, dp_heads=[0], dp_ranks=[0, 0], session=session + ) + + packed = {"tensor_a": [1, 2, 3]} + body = orjson.dumps( + { + "args": serialize_value([packed]), + "kwargs": serialize_value({}), + } + ) + + await dispatcher.dispatch("/forward_batch").post(body) + + captured = session.captured_payloads + assert len(captured) == 2 + + assert captured[0]["args"] == [packed] + assert captured[0]["kwargs"] == {} + + assert captured[1]["args"] == [] + assert captured[1]["kwargs"] == {} + + @pytest.mark.asyncio + async def test_returns_first_dp_head_with_non_contiguous_heads(self): + def _echo_handler(url, data, headers): + _ = headers + addr = url.rsplit("/", 1)[0] + return orjson.dumps({"status": "success", "result": serialize_value(addr)}) + + session = _CapturingSession(post_handler=_echo_handler) + dispatcher = _make_dispatcher( + dp_size=2, + dp_heads=[1, 3], + dp_ranks=[0, 0, 1, 1], + session=session, + ) + + body = orjson.dumps( + { + "args": serialize_value(["scalar_value"]), + "kwargs": serialize_value({}), + } + ) + + result_bytes = await dispatcher.dispatch("/train_batch").post(body) + result_payload = orjson.loads(result_bytes) + + assert result_payload == { + "status": "success", + "result": serialize_value("http://worker-1:19001"), + } + + @pytest.mark.asyncio + async def test_partitionable_list_still_uses_tensor_dispatch(self): + def _shard_handler(url, data, headers): + _ = url, headers + payload = orjson.loads(data) + args = deserialize_value(payload["args"]) + shard = args[0] if args else [] + return orjson.dumps( + {"status": "success", "result": serialize_value(len(shard))} + ) + + session = _CapturingSession(post_handler=_shard_handler) + dispatcher = _make_dispatcher( + dp_size=2, dp_heads=[0, 1], dp_ranks=[0, 1], session=session + ) + + batch = [_make_tensor_item(4), _make_tensor_item(8)] + body = orjson.dumps( + { + "args": serialize_value([batch]), + "kwargs": serialize_value({}), + } + ) + + result_bytes = await dispatcher.dispatch("/forward_batch").post(body) + result_payload = orjson.loads(result_bytes) + merged = deserialize_value(result_payload["result"]) + + assert merged == [1, 1] + + @pytest.mark.asyncio + async def test_single_worker_single_dp_head(self): + session = _CapturingSession() + dispatcher = _make_dispatcher( + dp_size=1, dp_heads=[0], dp_ranks=[0], session=session + ) + + body = orjson.dumps( + { + "args": serialize_value(["value"]), + "kwargs": serialize_value({}), + } + ) + + await dispatcher.dispatch("/eval_batch").post(body) + + captured = session.captured_payloads + assert len(captured) == 1 + assert captured[0]["args"] == ["value"] diff --git a/tests/experimental/training_service/test_gateway_unit.py b/tests/experimental/training_service/test_gateway_unit.py new file mode 100644 index 0000000000..84c55dbcc7 --- /dev/null +++ b/tests/experimental/training_service/test_gateway_unit.py @@ -0,0 +1,293 @@ +"""Unit tests for training-service gateway.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +import pytest_asyncio + +from areal.experimental.training_service.gateway.app import create_app +from areal.experimental.training_service.gateway.config import GatewayConfig +from areal.experimental.training_service.gateway.streaming import ( + RouterKeyRejectedError, + RouterUnreachableError, + forward_request, + query_router, +) + +MODULE = "areal.experimental.training_service.gateway.app" +ADMIN_KEY = "test-admin-key" +SESSION_KEY = "session-key" +WORKER_ADDR = "http://mock-worker:18082" + + +@pytest.fixture +def config() -> GatewayConfig: + return GatewayConfig( + host="127.0.0.1", + port=18080, + router_addr="http://mock-router:18081", + admin_api_key=ADMIN_KEY, + router_timeout=2.0, + forward_timeout=20.0, + ) + + +@pytest_asyncio.fixture +async def client(config, router_client, upstream_client): + app = create_app(config) + app.state.router_client = router_client + app.state.upstream_client = upstream_client + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +@pytest.fixture +def router_client(): + client = MagicMock() + client.get = AsyncMock() + client.post = AsyncMock() + return client + + +@pytest.fixture +def upstream_client(): + client = MagicMock() + client.get = AsyncMock() + client.post = AsyncMock() + return client + + +class TestGatewayHealth: + @pytest.mark.asyncio + async def test_health_reports_router(self, client, config): + resp = await client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["router_addr"] == config.router_addr + + +class TestGatewayRoutingAndForwarding: + @pytest.mark.asyncio + async def test_query_router_uses_provided_client(self, router_client, config): + router_client.post.return_value = httpx.Response( + 200, + json={"model_addr": WORKER_ADDR}, + request=httpx.Request("POST", f"{config.router_addr}/route"), + ) + + model_addr = await query_router( + config.router_addr, + SESSION_KEY, + config.router_timeout, + admin_api_key=ADMIN_KEY, + client=router_client, + ) + + assert model_addr == WORKER_ADDR + router_client.post.assert_awaited_once_with( + f"{config.router_addr}/route", + json={"api_key": SESSION_KEY}, + headers={"Authorization": f"Bearer {ADMIN_KEY}"}, + timeout=config.router_timeout, + ) + + @pytest.mark.asyncio + async def test_forward_request_uses_provided_client(self, upstream_client): + upstream_client.post.return_value = httpx.Response( + 200, + json={"status": "success"}, + request=httpx.Request("POST", f"{WORKER_ADDR}/train_batch"), + ) + + resp = await forward_request( + f"{WORKER_ADDR}/train_batch", + b"{}", + {"Authorization": f"Bearer {SESSION_KEY}", "Host": "ignored"}, + client=upstream_client, + ) + + assert resp.status_code == 200 + upstream_client.post.assert_awaited_once_with( + f"{WORKER_ADDR}/train_batch", + content=b"{}", + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + timeout=600.0, + ) + + @pytest.mark.asyncio + async def test_missing_bearer_token_returns_401(self, client): + resp = await client.post("/train_batch", json={"args": [], "kwargs": {}}) + assert resp.status_code == 401 + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_router_unreachable_maps_to_502(self, mock_query_router, client): + mock_query_router.side_effect = RouterUnreachableError("router unavailable") + + resp = await client.post( + "/train_batch", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 502 + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_router_404_key_rejected_maps_to_401(self, mock_query_router, client): + mock_query_router.side_effect = RouterKeyRejectedError("unknown key", 404) + + resp = await client.post( + "/eval_batch", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 401 + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.forward_request", new_callable=AsyncMock) + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_forward_batch_forwards_response( + self, + mock_query_router, + mock_forward_request, + client, + router_client, + upstream_client, + config, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward_request.return_value = httpx.Response( + 200, + json={"status": "success", "result": {"ok": True}}, + ) + + resp = await client.post( + "/forward_batch", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + mock_query_router.assert_awaited_once_with( + config.router_addr, + SESSION_KEY, + config.router_timeout, + admin_api_key=ADMIN_KEY, + client=router_client, + ) + assert mock_forward_request.await_args.kwargs["client"] is upstream_client + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.forward_request", new_callable=AsyncMock) + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_ppo_actor_compute_logp_forwards_response( + self, + mock_query_router, + mock_forward_request, + client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward_request.return_value = httpx.Response( + 200, + json={"status": "success", "result": {"ok": True}}, + ) + + resp = await client.post( + "/ppo/actor/compute_logp", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.forward_request", new_callable=AsyncMock) + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_sft_train_forwards_response( + self, + mock_query_router, + mock_forward_request, + client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward_request.return_value = httpx.Response( + 200, + json={"status": "success", "result": {"ok": True}}, + ) + + resp = await client.post( + "/sft/train", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.forward_request", new_callable=AsyncMock) + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_offload_uses_admin_auth_upstream( + self, + mock_query_router, + mock_forward_request, + client, + ): + mock_query_router.return_value = WORKER_ADDR + + async def _check_forward(_url, _body, headers, _timeout, *, client): + assert client is not None + assert headers["Authorization"] == f"Bearer {ADMIN_KEY}" + return httpx.Response(200, json={"status": "success", "result": None}) + + mock_forward_request.side_effect = _check_forward + + resp = await client.post( + "/offload", + json={"args": [], "kwargs": {}}, + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + @pytest.mark.asyncio + @patch(f"{MODULE}.streaming.query_router", new_callable=AsyncMock) + async def test_get_version_uses_shared_clients( + self, + mock_query_router, + client, + router_client, + upstream_client, + config, + ): + mock_query_router.return_value = WORKER_ADDR + upstream_client.get.return_value = httpx.Response( + 200, + json={"status": "success", "result": 11}, + request=httpx.Request("GET", f"{WORKER_ADDR}/get_version"), + ) + + resp = await client.get( + "/get_version", + headers={"Authorization": f"Bearer {SESSION_KEY}"}, + ) + + assert resp.status_code == 200 + mock_query_router.assert_awaited_once_with( + config.router_addr, + SESSION_KEY, + config.router_timeout, + admin_api_key=ADMIN_KEY, + client=router_client, + ) + upstream_client.get.assert_awaited_once() + upstream_call = upstream_client.get.await_args + assert upstream_call.args == (f"{WORKER_ADDR}/get_version",) + assert upstream_call.kwargs["timeout"] == config.forward_timeout + assert ( + upstream_call.kwargs["headers"]["authorization"] == f"Bearer {SESSION_KEY}" + ) diff --git a/tests/experimental/training_service/test_router_unit.py b/tests/experimental/training_service/test_router_unit.py new file mode 100644 index 0000000000..2f9bae4295 --- /dev/null +++ b/tests/experimental/training_service/test_router_unit.py @@ -0,0 +1,197 @@ +"""Unit tests for training-service router and registry behavior.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +import pytest_asyncio + +from areal.experimental.training_service.router.app import ( + _probe_model_health, + create_app, +) +from areal.experimental.training_service.router.config import RouterConfig + +ADMIN_KEY = "router-admin-key" + + +def _admin_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +@pytest.fixture +def config() -> RouterConfig: + return RouterConfig( + host="127.0.0.1", + port=18081, + admin_api_key=ADMIN_KEY, + poll_interval=3600.0, + worker_health_timeout=0.5, + ) + + +@pytest_asyncio.fixture +async def app_client(config): + app = create_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield app, c + + +class TestRouterHealthAndRegistry: + @pytest.mark.asyncio + async def test_probe_model_health_uses_provided_client(self, app_client): + app, _client = app_client + model_addr = "http://worker-health:19001" + await app.state.model_registry.register(model_addr, "health-key") + + health_client = MagicMock() + health_client.get = AsyncMock(return_value=httpx.Response(200)) + + await _probe_model_health(app.state.model_registry, model_addr, health_client) + + health_client.get.assert_awaited_once_with(f"{model_addr}/health") + models = await app.state.model_registry.get_all() + assert models[0].is_healthy is True + assert models[0].consecutive_health_failures == 0 + + @pytest.mark.asyncio + async def test_probe_model_health_requires_two_consecutive_failures( + self, app_client + ): + app, _client = app_client + model_addr = "http://worker-flaky:19001" + await app.state.model_registry.register(model_addr, "health-key") + + health_client = MagicMock() + health_client.get = AsyncMock(side_effect=httpx.ConnectTimeout("boom")) + + await _probe_model_health(app.state.model_registry, model_addr, health_client) + + models = await app.state.model_registry.get_all() + assert models[0].is_healthy is True + assert models[0].consecutive_health_failures == 1 + + await _probe_model_health(app.state.model_registry, model_addr, health_client) + + models = await app.state.model_registry.get_all() + assert models[0].is_healthy is False + assert models[0].consecutive_health_failures == 2 + + @pytest.mark.asyncio + async def test_probe_model_health_success_resets_failure_streak(self, app_client): + app, _client = app_client + model_addr = "http://worker-recover:19001" + await app.state.model_registry.register(model_addr, "health-key") + + failing_client = MagicMock() + failing_client.get = AsyncMock(side_effect=httpx.ConnectTimeout("boom")) + + await _probe_model_health(app.state.model_registry, model_addr, failing_client) + + healthy_client = MagicMock() + healthy_client.get = AsyncMock(return_value=httpx.Response(200)) + + await _probe_model_health(app.state.model_registry, model_addr, healthy_client) + + models = await app.state.model_registry.get_all() + assert models[0].is_healthy is True + assert models[0].consecutive_health_failures == 0 + + @pytest.mark.asyncio + async def test_health_reports_model_count(self, app_client): + _app, client = app_client + resp = await client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["models"] == 0 + + @pytest.mark.asyncio + async def test_register_then_route_success(self, app_client): + _app, client = app_client + model_addr = "http://worker-a:19001" + model_api_key = "model-key-a" + + resp = await client.post( + "/register", + json={"model_addr": model_addr, "api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 200 + + resp = await client.post( + "/route", + json={"api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["model_addr"] == model_addr + + @pytest.mark.asyncio + async def test_route_unknown_key_returns_404(self, app_client): + _app, client = app_client + resp = await client.post( + "/route", + json={"api_key": "unknown-key"}, + headers=_admin_headers(), + ) + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_route_rejects_admin_key_as_data_key(self, app_client): + _app, client = app_client + resp = await client.post( + "/route", + json={"api_key": ADMIN_KEY}, + headers=_admin_headers(), + ) + assert resp.status_code == 400 + + @pytest.mark.asyncio + async def test_route_unhealthy_model_returns_503(self, app_client): + app, client = app_client + model_addr = "http://worker-b:19001" + model_api_key = "model-key-b" + + resp = await client.post( + "/register", + json={"model_addr": model_addr, "api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 200 + + await app.state.model_registry.update_health(model_addr, False) + await app.state.model_registry.update_health(model_addr, False) + + resp = await client.post( + "/route", + json={"api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 503 + + @pytest.mark.asyncio + async def test_route_tolerates_single_transient_health_failure(self, app_client): + app, client = app_client + model_addr = "http://worker-c:19001" + model_api_key = "model-key-c" + + resp = await client.post( + "/register", + json={"model_addr": model_addr, "api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 200 + + await app.state.model_registry.update_health(model_addr, False) + + resp = await client.post( + "/route", + json={"api_key": model_api_key}, + headers=_admin_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["model_addr"] == model_addr diff --git a/tests/experimental/training_service/test_worker_unit.py b/tests/experimental/training_service/test_worker_unit.py new file mode 100644 index 0000000000..8d3c71e7aa --- /dev/null +++ b/tests/experimental/training_service/test_worker_unit.py @@ -0,0 +1,242 @@ +"""Unit tests for training-service worker Flask app.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from areal.experimental.training_service.worker.config import TrainWorkerConfig +from areal.infra.rpc.serialization import deserialize_value, serialize_value + +MODULE = "areal.experimental.training_service.worker.app" + + +@pytest.fixture(autouse=True) +def reset_worker_state(): + import areal.experimental.training_service.worker.app as worker_app + + if worker_app._engine_work_queue is not None: + worker_app._engine_work_queue.put(None) + if worker_app._engine_thread is not None: + worker_app._engine_thread.join(timeout=1.0) + + worker_app._engine = None + worker_app._node_addr = "" + worker_app._engine_thread = None + worker_app._engine_work_queue = None + + yield + + if worker_app._engine_work_queue is not None: + worker_app._engine_work_queue.put(None) + if worker_app._engine_thread is not None: + worker_app._engine_thread.join(timeout=1.0) + + worker_app._engine = None + worker_app._node_addr = "" + worker_app._engine_thread = None + worker_app._engine_work_queue = None + + +@pytest.fixture +def client(): + from areal.experimental.training_service.worker.app import create_app + + app = create_app( + TrainWorkerConfig( + host="127.0.0.1", + port=19001, + admin_api_key="worker-admin", + ) + ) + return app.test_client() + + +class TestWorkerEngineCreation: + def test_create_engine_requires_engine_class(self, client): + resp = client.post( + "/create_engine", + json={"init_args": [], "init_kwargs": {}}, + ) + assert resp.status_code == 400 + assert "engine_class" in resp.get_json()["error"] + + def test_create_engine_success(self, client): + resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert resp.status_code == 200 + payload = resp.get_json() + assert payload["status"] == "success" + + +class TestWorkerEndpoints: + def test_topology_before_create_engine_returns_400(self, client): + resp = client.get("/topology") + assert resp.status_code == 400 + assert "Engine not created" in resp.get_json()["error"] + + def test_train_batch_after_create_engine(self, client): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + train_resp = client.post( + "/train_batch", + json={ + "args": serialize_value( + [{"token_ids": [1, 2, 3], "metadata": {"weight": 2.0}}] + ), + "kwargs": serialize_value({}), + }, + ) + assert train_resp.status_code == 200 + result = deserialize_value(train_resp.get_json()["result"]) + assert isinstance(result, dict) + assert "total" in result + + def test_topology_after_create_engine(self, client): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + with patch.dict( + "os.environ", + {"RANK": "0", "WORLD_SIZE": "1", "LOCAL_RANK": "0"}, + clear=False, + ): + topo_resp = client.get("/topology") + assert topo_resp.status_code == 200 + topo = topo_resp.get_json() + assert topo["rank"] == 0 + assert topo["world_size"] == 1 + assert topo["dp_size"] == 1 + + def test_ppo_endpoints_return_400_when_engine_method_missing(self, client): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + payload = { + "args": serialize_value([[{"token_ids": [1, 2, 3]}]]), + "kwargs": serialize_value({}), + } + for path in [ + "/ppo/actor/compute_logp", + "/ppo/actor/compute_advantages", + "/ppo/actor/update", + "/ppo/critic/compute_values", + "/ppo/critic/update", + ]: + resp = client.post(path, json=payload) + assert resp.status_code == 400 + assert "does not implement method" in resp.get_json()["error"] + + def test_forward_batch_after_initialize_succeeds_without_distributed_group( + self, client + ): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + init_resp = client.post( + "/initialize", + json={ + "args": serialize_value([]), + "kwargs": serialize_value({"addr": None, "ft_spec": None}), + }, + ) + assert init_resp.status_code == 200 + + forward_resp = client.post( + "/forward_batch", + json={ + "args": serialize_value( + [[{"token_ids": [1, 2, 3], "metadata": {"weight": 2.0}}]] + ), + "kwargs": serialize_value({"output_seqlens": [3]}), + }, + ) + assert forward_resp.status_code == 200 + result = deserialize_value(forward_resp.get_json()["result"]) + assert isinstance(result, dict) + assert result["output_seqlens"] == [3] + + def test_sft_route_succeeds_without_distributed_group_for_single_worker( + self, client + ): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + resp = client.post( + "/sft/train", + json={ + "args": serialize_value([[{"token_ids": [1, 2, 3]}]]), + "kwargs": serialize_value({}), + }, + ) + assert resp.status_code == 200 + result = deserialize_value(resp.get_json()["result"]) + assert isinstance(result, dict) + assert "total" in result + + def test_sft_route_ignores_rpc_meta_override_for_single_worker(self, client): + create_resp = client.post( + "/create_engine", + json={ + "engine_class": "tests.experimental.training_service.fake_train_engine.FakeTrainEngine", + "init_args": serialize_value([]), + "init_kwargs": serialize_value({"world_size": 1}), + }, + ) + assert create_resp.status_code == 200 + + resp = client.post( + "/sft/train", + json={ + "args": serialize_value([[{"token_ids": [1, 2, 3]}]]), + "kwargs": serialize_value({}), + "rpc_meta": {"broadcast": False}, + }, + ) + assert resp.status_code == 200 + result = deserialize_value(resp.get_json()["result"]) + assert isinstance(result, dict) + assert "total" in result diff --git a/tests/sft/test_sft.py b/tests/sft/test_sft.py index a9cdbe3055..1800fbcd31 100644 --- a/tests/sft/test_sft.py +++ b/tests/sft/test_sft.py @@ -12,8 +12,22 @@ from areal.api.cli_args import SFTConfig, load_expr_config -@pytest.mark.parametrize("backend", ["fsdp", "megatron", "archon"]) -def test_sft(tmp_path: str, backend: str) -> None: +@pytest.mark.parametrize( + ("backend", "v2"), + [ + ("fsdp", False), + ("megatron", False), + ("archon", False), + ("fsdp", True), + ], + ids=[ + "fsdp-v1", + "megatron-v1", + "archon-v1", + "fsdp-v2", + ], +) +def test_sft(tmp_path: str, backend: str, v2: bool) -> None: base_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(base_dir, f"config_{backend}.yaml") ref_losses_path = os.path.join(base_dir, f"ref_losses_{backend}.json") @@ -55,6 +69,8 @@ def test_sft(tmp_path: str, backend: str) -> None: os.path.join(tmp_path, "config", "config.yaml"), f"cluster.fileroot={tmp_path}", ] + if v2: + cmd.append("actor._version=v2") result = subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr, env=os.environ) assert result.returncode == 0, ( diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index 3ec4f66032..c910750940 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -786,6 +786,42 @@ def test_create_workers_subprocess_fails_immediately( scheduler.create_workers(job) assert "exited immediately with code 1" in str(exc_info.value) + assert mock_popen.call_count == 1 + + @patch("areal.infra.scheduler.local.gethostip") + @patch("areal.infra.scheduler.local.subprocess.Popen") + @patch("areal.infra.scheduler.local.find_free_ports") + def test_create_workers_retries_immediate_port_conflict( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000, 8001], [8002, 8003]] + + conflict_proc = Mock() + conflict_proc.pid = 1234 + conflict_proc.poll.return_value = 1 + conflict_proc.returncode = 1 + + success_proc = Mock() + success_proc.pid = 1235 + success_proc.poll.return_value = None + + mock_popen.side_effect = [conflict_proc, success_proc] + + scheduler = create_scheduler(tmp_path) + job = Job(replicas=1, role="test") + + with patch.object( + scheduler, + "_read_log_tail", + return_value="Address already in use\nPort 8000 is in use by another program", + ): + worker_ids = scheduler.create_workers(job) + + assert worker_ids == ["test/0"] + assert mock_popen.call_count == 2 + assert scheduler._workers["test"][0].worker.worker_ports == ["8002", "8003"] + assert scheduler._allocated_ports == {8002, 8003} @patch("areal.infra.scheduler.local.gethostip") @patch("areal.infra.scheduler.local.subprocess.Popen") diff --git a/tests/test_recover.py b/tests/test_recover.py index 5155fd4033..57e5634cd5 100644 --- a/tests/test_recover.py +++ b/tests/test_recover.py @@ -1,11 +1,16 @@ """Tests for the recovery configuration and functionality.""" import tempfile +from unittest.mock import Mock import pytest from areal.api.cli_args import RecoverConfig -from areal.utils.recover import check_if_auto_recover, check_if_recover +from areal.api.io_struct import FinetuneSpec, StepInfo +from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, +) +from areal.utils.recover import RecoverHandler, check_if_auto_recover, check_if_recover class TestRecoverConfig: @@ -167,3 +172,65 @@ def test_off_equals_disabled(self): # Both should return False assert check_if_recover(config_off, 0) is False assert check_if_recover(config_disabled, 0) is False + + +class TestRecoverHandler: + @staticmethod + def _make_handler(tmpdir: str, mode: str) -> RecoverHandler: + config = RecoverConfig( + experiment_name="test_exp", + trial_name="test_trial", + fileroot=tmpdir, + mode=mode, + ) + ft_spec = FinetuneSpec( + total_train_epochs=1, + dataset_size=8, + train_batch_size=2, + ) + return RecoverHandler(config, ft_spec) + + @staticmethod + def _make_gateway_controller() -> GatewayTrainController: + return GatewayTrainController.__new__(GatewayTrainController) + + @pytest.mark.parametrize("mode", ["on", "auto"]) + def test_load_rejects_gateway_train_controller(self, mode): + with tempfile.TemporaryDirectory() as tmpdir: + handler = self._make_handler(tmpdir, mode) + + with pytest.raises(NotImplementedError) as exc_info: + handler.load( + self._make_gateway_controller(), + Mock(), + Mock(), + Mock(), + Mock(), + ) + + assert "GatewayTrainController" in str(exc_info.value) + assert '`_version="v2"`' in str(exc_info.value) + + @pytest.mark.parametrize("mode", ["on", "auto"]) + def test_dump_rejects_gateway_train_controller(self, mode): + with tempfile.TemporaryDirectory() as tmpdir: + handler = self._make_handler(tmpdir, mode) + step_info = StepInfo( + epoch=0, + epoch_step=0, + global_step=0, + steps_per_epoch=handler.ft_spec.steps_per_epoch, + ) + + with pytest.raises(NotImplementedError) as exc_info: + handler.dump( + self._make_gateway_controller(), + step_info, + Mock(), + Mock(), + Mock(), + Mock(), + ) + + assert "GatewayTrainController" in str(exc_info.value) + assert "recover.mode" in str(exc_info.value)