From 4034edad3fad1f471aea0e0a98b1cbbf5c6f0c0d Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:53:50 +0800 Subject: [PATCH] feat(infra): add distributed data loading service (#1120) * feat(infra): add distributed data loading service Introduce a controller/router/gateway/worker data service so single-controller training can offload dataset access to remote workers while preserving existing trainer workflows. Key changes: - add RDataset-based remote dataset registration and fetch path - integrate DataController lifecycle into RL/SFT/RW trainers - add guard reuse, data-service APIs, and full infra test coverage * chore: revert stats_logger to make trackio test happy --- .github/workflows/test-areal.yml | 2 +- areal/__init__.py | 16 +- areal/api/cli_args.py | 30 + areal/api/io_struct.py | 2 +- areal/dataset/__init__.py | 51 +- areal/infra/data_service/__init__.py | 9 + .../infra/data_service/controller/__init__.py | 0 areal/infra/data_service/controller/config.py | 39 ++ .../data_service/controller/controller.py | 529 +++++++++++++++ areal/infra/data_service/gateway/__init__.py | 0 areal/infra/data_service/gateway/__main__.py | 37 ++ areal/infra/data_service/gateway/app.py | 342 ++++++++++ areal/infra/data_service/gateway/auth.py | 62 ++ areal/infra/data_service/gateway/config.py | 13 + areal/infra/data_service/guard/__init__.py | 0 areal/infra/data_service/guard/__main__.py | 6 + areal/infra/data_service/guard/app.py | 10 + areal/infra/data_service/rdataset.py | 330 ++++++++++ areal/infra/data_service/router/__init__.py | 0 areal/infra/data_service/router/__main__.py | 38 ++ areal/infra/data_service/router/app.py | 139 ++++ areal/infra/data_service/router/config.py | 12 + areal/infra/data_service/types.py | 43 ++ areal/infra/data_service/worker/__init__.py | 0 areal/infra/data_service/worker/__main__.py | 37 ++ areal/infra/data_service/worker/app.py | 203 ++++++ areal/infra/data_service/worker/config.py | 12 + areal/trainer/rl_trainer.py | 127 +++- areal/trainer/rw_trainer.py | 68 +- areal/trainer/sft_trainer.py | 91 ++- areal/utils/data.py | 5 +- areal/utils/dataloader.py | 28 +- areal/utils/recover.py | 7 +- areal/utils/stats_logger.py | 7 +- docs/en/cli_reference.md | 48 +- docs/zh/cli_reference.md | 48 +- tests/infra/__init__.py | 0 tests/infra/data_service/__init__.py | 0 tests/infra/data_service/test_auth.py | 128 ++++ tests/infra/data_service/test_controller.py | 241 +++++++ .../data_service/test_data_service_e2e.py | 537 +++++++++++++++ .../infra/data_service/test_epoch_crossing.py | 612 ++++++++++++++++++ tests/infra/data_service/test_gateway.py | 461 +++++++++++++ tests/infra/data_service/test_guard.py | 104 +++ tests/infra/data_service/test_performance.py | 117 ++++ tests/infra/data_service/test_router.py | 221 +++++++ .../infra/data_service/test_trainer_compat.py | 249 +++++++ tests/infra/data_service/test_worker.py | 230 +++++++ tests/sft/entrypoint.py | 7 + 49 files changed, 5157 insertions(+), 141 deletions(-) create mode 100644 areal/infra/data_service/__init__.py create mode 100644 areal/infra/data_service/controller/__init__.py create mode 100644 areal/infra/data_service/controller/config.py create mode 100644 areal/infra/data_service/controller/controller.py create mode 100644 areal/infra/data_service/gateway/__init__.py create mode 100644 areal/infra/data_service/gateway/__main__.py create mode 100644 areal/infra/data_service/gateway/app.py create mode 100644 areal/infra/data_service/gateway/auth.py create mode 100644 areal/infra/data_service/gateway/config.py create mode 100644 areal/infra/data_service/guard/__init__.py create mode 100644 areal/infra/data_service/guard/__main__.py create mode 100644 areal/infra/data_service/guard/app.py create mode 100644 areal/infra/data_service/rdataset.py create mode 100644 areal/infra/data_service/router/__init__.py create mode 100644 areal/infra/data_service/router/__main__.py create mode 100644 areal/infra/data_service/router/app.py create mode 100644 areal/infra/data_service/router/config.py create mode 100644 areal/infra/data_service/types.py create mode 100644 areal/infra/data_service/worker/__init__.py create mode 100644 areal/infra/data_service/worker/__main__.py create mode 100644 areal/infra/data_service/worker/app.py create mode 100644 areal/infra/data_service/worker/config.py create mode 100644 tests/infra/__init__.py create mode 100644 tests/infra/data_service/__init__.py create mode 100644 tests/infra/data_service/test_auth.py create mode 100644 tests/infra/data_service/test_controller.py create mode 100644 tests/infra/data_service/test_data_service_e2e.py create mode 100644 tests/infra/data_service/test_epoch_crossing.py create mode 100644 tests/infra/data_service/test_gateway.py create mode 100644 tests/infra/data_service/test_guard.py create mode 100644 tests/infra/data_service/test_performance.py create mode 100644 tests/infra/data_service/test_router.py create mode 100644 tests/infra/data_service/test_trainer_compat.py create mode 100644 tests/infra/data_service/test_worker.py diff --git a/.github/workflows/test-areal.yml b/.github/workflows/test-areal.yml index cc6f17cc9b..56334fcb2e 100644 --- a/.github/workflows/test-areal.yml +++ b/.github/workflows/test-areal.yml @@ -322,7 +322,7 @@ jobs: VIRTUAL_ENV: /AReaL/.venv run: | export PATH="/AReaL/.venv/bin:$PATH" - pytest -m "(not slow or ci) and not ${EXCLUDE_BACKEND}" --durations=20 -s -vv tests/test_*.py tests/experimental/ + pytest -m "(not slow or ci) and not ${EXCLUDE_BACKEND}" --durations=20 -s -vv tests/test_*.py tests/experimental/ tests/infra/ - name: Run SFT integration tests env: diff --git a/areal/__init__.py b/areal/__init__.py index b675ce133e..4878751985 100644 --- a/areal/__init__.py +++ b/areal/__init__.py @@ -10,7 +10,21 @@ current_platform, workflow_context, ) -from .trainer import PPOTrainer, RWTrainer, SFTTrainer + + +def __getattr__(name: str): + if name in ("PPOTrainer", "RWTrainer", "SFTTrainer"): + from .trainer import PPOTrainer, RWTrainer, SFTTrainer + + _map = { + "PPOTrainer": PPOTrainer, + "RWTrainer": RWTrainer, + "SFTTrainer": SFTTrainer, + } + globals().update(_map) + return _map[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "PPOTrainer", diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index e08c852ec4..5430f32aed 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -2223,6 +2223,10 @@ class SchedulerConfig: class _DatasetConfig: """Configuration for dataset loading and preprocessing.""" + split: str = field( + default="train", + metadata={"help": "Dataset split to use, e.g., 'train', 'test'."}, + ) path: str = field( default=MISSING, metadata={ @@ -2248,6 +2252,12 @@ class _DatasetConfig: num_workers: int = field( default=0, metadata={"help": "Number of worker processes for data loading"} ) + num_dataset_workers: int = field( + default=1, + metadata={ + "help": "Number of remote data-service worker processes to launch when using scheduling_spec." + }, + ) drop_last: bool = field( default=True, metadata={"help": "Drop the last incomplete batch"} ) @@ -2257,6 +2267,22 @@ class _DatasetConfig: "help": "Maximum token length of sequences in dataset. Longer sequences are filtered out." }, ) + dataset_kwargs: dict[str, Any] = field( + default_factory=dict, + metadata={ + "help": "Additional keyword arguments for dataset loading. " + "These are passed to the dataset loading function `get_custom_dataset`." + }, + ) + scheduling_spec: SchedulingSpec | None = field( + default_factory=lambda: SchedulingSpec( + cpu=1, gpu=0, mem=10, cmd="python3 -m areal.infra.rpc.guard" + ), + metadata={ + "help": "Scheduling spec for remote data loading workers. " + "If set, dataset loading will be offloaded to a data service with remote workers." + }, + ) @dataclass @@ -2272,6 +2298,10 @@ class ValidDatasetConfig(_DatasetConfig): `shuffle` and `drop_last` default to False. """ + split: str = field( + default="test", + metadata={"help": "Dataset split to use, e.g., 'train', 'test'."}, + ) shuffle: bool = field( default=False, metadata={"help": "Whether to shuffle the dataset"} ) diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index e63f849230..f1d7cdf8f8 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -337,7 +337,7 @@ class LocalInfServerInfo: host: str port: int - process: subprocess.Popen + process: subprocess.Popen | None @dataclass diff --git a/areal/dataset/__init__.py b/areal/dataset/__init__.py index cc98db403f..8d86a61db3 100644 --- a/areal/dataset/__init__.py +++ b/areal/dataset/__init__.py @@ -8,6 +8,8 @@ from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + from areal.infra.data_service.rdataset import RDataset + VALID_DATASETS = [ "gsm8k", "clevr_count_70k", @@ -120,10 +122,32 @@ def _get_custom_dataset( **kwargs, ) else: - raise ValueError( - f"Dataset {path} with split {split} and training type {type} is not supported. " - f"Supported datasets are: {VALID_DATASETS}. " - ) + # Fallback: try loading as a generic HuggingFace dataset from disk. + # This supports arbitrary datasets saved via dataset.save_to_disk(). + try: + from datasets import DatasetDict, load_from_disk + + dataset = load_from_disk(path) + if isinstance(dataset, DatasetDict): + if split is not None: + if split in dataset: + return dataset[split] + available = list(dataset.keys()) + raise ValueError( + f"Requested split '{split}' not found in DatasetDict at {path}. " + f"Available splits: {available}" + ) + available = list(dataset.keys()) + if available: + return dataset[available[0]] + raise ValueError(f"Empty DatasetDict at {path}") + return dataset + except Exception as load_err: + raise ValueError( + f"Dataset {path} with split {split} and training type {type} is not supported. " + f"Supported datasets are: {VALID_DATASETS}. " + f"Also failed to load from disk: {load_err}" + ) def get_custom_dataset( @@ -132,7 +156,24 @@ def get_custom_dataset( tokenizer: Optional["PreTrainedTokenizerFast"] = None, processor: Optional["ProcessorMixin"] = None, **kwargs, -) -> "Dataset": +) -> "Dataset | RDataset": + from areal.utils.environ import is_single_controller + + if ( + is_single_controller() + and dataset_config is not None + and dataset_config.scheduling_spec is not None + ): + from areal.infra.data_service.rdataset import RDataset + + return RDataset( + path=dataset_config.path, + type=dataset_config.type, + split=split, + max_length=dataset_config.max_length, + dataset_kwargs=getattr(dataset_config, "dataset_kwargs", None), + ) + if dataset_config is not None: return _get_custom_dataset( path=dataset_config.path, diff --git a/areal/infra/data_service/__init__.py b/areal/infra/data_service/__init__.py new file mode 100644 index 0000000000..2f412b953c --- /dev/null +++ b/areal/infra/data_service/__init__.py @@ -0,0 +1,9 @@ +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.infra.data_service.controller.controller import DataController +from areal.infra.data_service.rdataset import RDataset + +__all__ = [ + "DataController", + "DataServiceConfig", + "RDataset", +] diff --git a/areal/infra/data_service/controller/__init__.py b/areal/infra/data_service/controller/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/areal/infra/data_service/controller/config.py b/areal/infra/data_service/controller/config.py new file mode 100644 index 0000000000..be906dd746 --- /dev/null +++ b/areal/infra/data_service/controller/config.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from areal.api.cli_args import SchedulingSpec, SchedulingStrategy + + +@dataclass +class DataServiceConfig: + """Internal config for the data service controller. + + Constructed from ``_DatasetConfig`` fields by the trainer. + Not exposed to end users directly. + """ + + num_workers: int = 1 + scheduling_spec: SchedulingSpec = field( + default_factory=lambda: SchedulingSpec( + cpu=1, gpu=0, mem=10, cmd="python3 -m areal.infra.rpc.guard" + ), + ) + # Always separation — data controller starts before engines. + scheduling_strategy: SchedulingStrategy = field( + default_factory=lambda: SchedulingStrategy(type="separation"), + ) + setup_timeout: float = 120.0 + dataloader_num_workers: int = 4 + + @staticmethod + def from_dataset_config(dataset_config) -> DataServiceConfig: + """Build from a ``_DatasetConfig`` instance.""" + return DataServiceConfig( + num_workers=max(dataset_config.num_dataset_workers, 1), + scheduling_spec=dataset_config.scheduling_spec, + dataloader_num_workers=max(dataset_config.num_workers, 0), + ) + + +__all__ = ["DataServiceConfig"] diff --git a/areal/infra/data_service/controller/controller.py b/areal/infra/data_service/controller/controller.py new file mode 100644 index 0000000000..9afa823a7c --- /dev/null +++ b/areal/infra/data_service/controller/controller.py @@ -0,0 +1,529 @@ +"""DataController — orchestrator for the distributed data loading service. + +Manages the full lifecycle: create RPCGuard workers → fork DataWorkers, +Router, Gateway → register datasets → serve batches → shutdown. + +Follows the same patterns as ``GatewayInferenceController``. +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import time +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +import aiohttp + +from areal.api.scheduler_api import Job +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.utils import logging +from areal.utils.network import format_hostport + +if TYPE_CHECKING: + from areal.api.scheduler_api import Scheduler, Worker + +logger = logging.getLogger("DataController") + + +class DataController: + """Controller for the distributed data loading service. + + API follows ``TrainController`` / ``GatewayInferenceController`` patterns: + ``__init__(config, scheduler)`` then ``initialize(role, ...)``. + """ + + _GUARD_SUFFIX = "-data" + _ADMIN_API_KEY = os.environ.get("AREAL_ADMIN_KEY", "areal-data-admin") + + def __init__( + self, + config: DataServiceConfig, + scheduler: Scheduler, + ) -> None: + self.config = config + self.scheduler = scheduler + + self.workers: list[Worker] = [] + self._worker_role: str = "" + + self._gateway_addr: str = "" + self._router_addr: str = "" + self._worker_addrs: list[str] = [] + + self._service_roles: list[str] = [] + self._forked_services: list[tuple[str, str, int]] = [] + + self._admin_api_key: str = self._ADMIN_API_KEY + + self._datasets: dict[str, dict[str, Any]] = {} + + # -- Initialize -------------------------------------------------------- + + def initialize( + self, + role: str, + num_dataset_workers: int = 1, + **kwargs: Any, + ) -> None: + from areal.infra.utils.concurrent import run_async_task + + self._worker_role = role + run_async_task(self._async_initialize, num_dataset_workers, **kwargs) + + async def _async_initialize( + self, + num_dataset_workers: int, + **kwargs: Any, + ) -> None: + cfg = self.config + spec = cfg.scheduling_spec + if spec is None: + raise ValueError( + "DataServiceConfig.scheduling_spec must be set to launch data service workers" + ) + + # Use sys.executable as the interpreter; don't mutate cfg.scheduling_spec + cmd = spec.cmd + if not cmd: + raise ValueError( + "DataServiceConfig.scheduling_spec.cmd must be set to launch RPC guards" + ) + parts = cmd.split("-m", 1) + if len(parts) == 2: + module = parts[1].strip() + guard_cmd = f"{sys.executable} -m {module}" + else: + guard_cmd = f"{sys.executable} {cmd}" + guard_spec = replace(spec, cmd=guard_cmd) + + guard_role = f"{self._worker_role}{self._GUARD_SUFFIX}" + + guard_job = Job( + replicas=num_dataset_workers, + tasks=[guard_spec for _ in range(num_dataset_workers)], + scheduling_strategy=cfg.scheduling_strategy, + role=guard_role, + ) + + self.scheduler.create_workers(job=guard_job) + self._service_roles.append(guard_role) + guard_workers = self.scheduler.get_workers(role=guard_role) + self.workers = guard_workers + logger.info("RPCGuard workers ready: %s", [w.id for w in guard_workers]) + + guard_addrs = [ + f"http://{format_hostport(w.ip, int(w.worker_ports[0]))}" + for w in guard_workers + ] + guard_addr_0 = guard_addrs[0] + + try: + async with aiohttp.ClientSession() as session: + # Wave 1: Fork all DataWorkers + Router in parallel + worker_tasks = [ + self._async_fork_on_guard( + session, + guard_addrs[rank], + "data-worker", + rank, + [ + sys.executable, + "-m", + "areal.infra.data_service.worker", + "--rank", + str(rank), + "--world-size", + str(num_dataset_workers), + "--dataloader-num-workers", + str(cfg.dataloader_num_workers), + ], + ) + for rank in range(num_dataset_workers) + ] + router_task = self._async_fork_on_guard( + session, + guard_addr_0, + "data-router", + 0, + [ + sys.executable, + "-m", + "areal.infra.data_service.router", + "--admin-api-key", + self._admin_api_key, + ], + ) + + results = await asyncio.gather(*worker_tasks, router_task) + + for host, port in results[:-1]: + self._worker_addrs.append(f"http://{format_hostport(host, port)}") + router_host, router_port = results[-1] + self._router_addr = ( + f"http://{format_hostport(router_host, router_port)}" + ) + logger.info("DataWorkers: %s", self._worker_addrs) + logger.info("Router: %s", self._router_addr) + + # Wave 2: Fork Gateway + Register workers with Router + async def _register_workers() -> None: + for worker_addr in self._worker_addrs: + async with session.post( + f"{self._router_addr}/register", + json={"worker_addr": worker_addr}, + headers={"Authorization": f"Bearer {self._admin_api_key}"}, + timeout=aiohttp.ClientTimeout(total=5), + ) as resp: + resp.raise_for_status() + logger.info("Registered DataWorker %s in router", worker_addr) + + gw_result, _ = await asyncio.gather( + self._async_fork_on_guard( + session, + guard_addr_0, + "data-gateway", + 0, + [ + sys.executable, + "-m", + "areal.infra.data_service.gateway", + "--admin-api-key", + self._admin_api_key, + "--router-addr", + self._router_addr, + "--forward-timeout", + str(60.0), + ], + ), + _register_workers(), + ) + gw_host, gw_port = gw_result + self._gateway_addr = f"http://{format_hostport(gw_host, gw_port)}" + logger.info("Gateway: %s", self._gateway_addr) + except Exception: + # Rollback: kill forked services and delete scheduler workers + logger.error( + "DataController initialization failed, rolling back", + exc_info=True, + ) + if self._forked_services: + await self._async_kill_forked_services( + list(reversed(self._forked_services)) + ) + self._forked_services.clear() + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + except Exception: + pass + self._service_roles.clear() + self.workers.clear() + self._worker_addrs.clear() + self._router_addr = "" + self._gateway_addr = "" + raise + + logger.info("DataController initialized with %d workers", num_dataset_workers) + + # -- Register / Unregister Datasets ------------------------------------ + + def register_dataset( + self, + dataset_id: str, + dataset_path: str, + dataset_type: str, + dataset_kwargs: dict[str, Any] | None = None, + tokenizer_or_processor_path: str = "", + split: str = "train", + seed: int = 42, + shuffle: bool = True, + drop_last: bool = True, + max_length: int | None = None, + ) -> dict[str, Any]: + """Register a dataset with the service. + + POST /v1/datasets/register on Gateway. + """ + + payload = { + "dataset_id": dataset_id, + "dataset_path": dataset_path, + "dataset_type": dataset_type, + "split": split, + "tokenizer_or_processor_path": tokenizer_or_processor_path, + "seed": seed, + "max_length": max_length, + "shuffle": shuffle, + "drop_last": drop_last, + "dataset_kwargs": dataset_kwargs or {}, + } + + from areal.infra.utils.concurrent import run_async_task + + data = run_async_task( + self._async_gateway_post, + "/v1/datasets/register", + self._admin_api_key, + payload, + self.config.setup_timeout, + ) + + total_samples = data["dataset_size"] + + self._datasets[data["api_key"]] = { + "dataset_id": data["dataset_id"], + "total_samples": total_samples, + "drop_last": drop_last, + } + logger.info( + "Registered dataset %s: total_samples=%d, workers=%d", + dataset_id, + total_samples, + data["num_workers"], + ) + return { + "api_key": data["api_key"], + "dataset_id": data["dataset_id"], + "dataset_size": total_samples, + "total_samples": total_samples, + "num_workers": data["num_workers"], + } + + def unregister_dataset(self, dataset_id: str) -> None: + """Unregister a dataset from the service.""" + from areal.infra.utils.concurrent import run_async_task + + run_async_task( + self._async_gateway_post, + "/v1/datasets/unregister", + self._admin_api_key, + {"dataset_id": dataset_id}, + 30, + ) + + to_remove = [ + k for k, v in self._datasets.items() if v["dataset_id"] == dataset_id + ] + for k in to_remove: + del self._datasets[k] + + logger.info("Unregistered dataset %s", dataset_id) + + # -- Batch cleanup ----------------------------------------------------- + + def clear_batches(self) -> None: + """Clear batch caches and tensor stores on all data workers. + + Called by trainers after each training step, alongside + ``actor.clear_batches()``, to free memory held by the data + service instead of relying on TTL-based eviction. + """ + if not self._worker_addrs: + return + from areal.infra.utils.concurrent import run_async_task + + run_async_task(self._async_clear_batches) + + async def _async_clear_batches(self) -> None: + async def _clear_one(session: aiohttp.ClientSession, addr: str) -> None: + try: + async with session.delete( + f"{addr}/data/clear", + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + resp.raise_for_status() + except Exception: + logger.debug("Failed to clear batches on %s", addr) + + async with aiohttp.ClientSession() as session: + await asyncio.gather( + *(_clear_one(session, addr) for addr in self._worker_addrs), + return_exceptions=True, + ) + + # -- Destroy ----------------------------------------------------------- + + def destroy(self) -> None: + """Shutdown service: unload all datasets, kill services, delete workers.""" + from areal.infra.utils.concurrent import run_async_task + + if self._gateway_addr: + try: + run_async_task( + self._async_gateway_post, + "/v1/shutdown", + self._admin_api_key, + {}, + 5, + ) + except Exception as exc: + logger.debug( + "Gateway shutdown request failed (expected during teardown): %s", + exc, + ) + + # Kill forked services concurrently + if self._forked_services: + run_async_task( + self._async_kill_forked_services, + list(reversed(self._forked_services)), + ) + self._forked_services.clear() + + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + logger.info("Workers deleted for role: %s", role) + except Exception as exc: + logger.debug("Could not delete workers for role %s: %s", role, exc) + + self._service_roles.clear() + self.workers.clear() + self._worker_addrs.clear() + self._router_addr = "" + self._gateway_addr = "" + self._datasets.clear() + + # -- Internal HTTP helpers (async) ------------------------------------- + + async def _async_fork_on_guard( + self, + session: Any, + guard_addr: str, + role: str, + worker_index: int, + raw_cmd: list[str], + health_path: str = "/health", + ) -> tuple[str, int]: + async with session.post( + f"{guard_addr}/alloc_ports", + json={"count": 1}, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + resp.raise_for_status() + alloc_data = await resp.json() + host = alloc_data["host"] + port = alloc_data["ports"][0] + + cmd = list(raw_cmd) + ["--host", host, "--port", str(port)] + + async with session.post( + f"{guard_addr}/fork", + json={ + "role": role, + "worker_index": worker_index, + "raw_cmd": cmd, + }, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + resp.raise_for_status() + + self._forked_services.append((guard_addr, role, worker_index)) + + addr = f"http://{format_hostport(host, port)}" + await self._async_wait_for_service(session, f"{addr}{health_path}", role) + + return host, port + + async def _async_wait_for_service( + self, + session: Any, + url: str, + name: str, + timeout: float | None = None, + ) -> None: + timeout_val = timeout or self.config.setup_timeout + deadline = time.monotonic() + timeout_val + while time.monotonic() < deadline: + try: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=2) + ) as resp: + if resp.status == 200: + logger.info("%s is ready at %s", name, url) + return + except Exception: + pass + await asyncio.sleep(0.1) + raise TimeoutError( + f"{name} did not become healthy at {url} within {timeout_val}s" + ) + + async def _async_kill_forked_services( + self, services: list[tuple[str, str, int]] + ) -> None: + async def _kill_one( + session: aiohttp.ClientSession, + guard_addr: str, + role: str, + worker_index: int, + ) -> None: + try: + async with session.post( + f"{guard_addr}/kill_forked_worker", + json={"role": role, "worker_index": worker_index}, + timeout=aiohttp.ClientTimeout(total=5), + ) as resp: + if resp.status == 200: + logger.info("Killed forked service %s/%d", role, worker_index) + else: + text = await resp.text() + logger.warning( + "Failed to kill %s/%d: HTTP %d: %s", + role, + worker_index, + resp.status, + text, + ) + except Exception as exc: + logger.error( + "Error killing forked service %s/%d: %s", + role, + worker_index, + exc, + ) + + async with aiohttp.ClientSession() as session: + await asyncio.gather( + *(_kill_one(session, *svc) for svc in services), + return_exceptions=True, + ) + + def _gateway_post( + self, + endpoint: str, + api_key: str, + payload: dict[str, Any], + ) -> dict[str, Any]: + from areal.infra.utils.concurrent import run_async_task + + return run_async_task(self._async_gateway_post, endpoint, api_key, payload) + + async def _async_gateway_post( + self, + endpoint: str, + api_key: str, + payload: dict[str, Any], + timeout: float = 60, + ) -> dict[str, Any]: + url = f"{self._gateway_addr}{endpoint}" + try: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + async with session.post( + url, + json=payload, + headers={"Authorization": f"Bearer {api_key}"}, + ) as resp: + if resp.status >= 400: + text = await resp.text() + raise RuntimeError( + f"Gateway {endpoint} returned {resp.status}: {text}" + ) + return await resp.json() + except aiohttp.ClientError as exc: + raise RuntimeError(f"Failed to POST {endpoint}: {exc}") from exc diff --git a/areal/infra/data_service/gateway/__init__.py b/areal/infra/data_service/gateway/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/areal/infra/data_service/gateway/__main__.py b/areal/infra/data_service/gateway/__main__.py new file mode 100644 index 0000000000..90daae5463 --- /dev/null +++ b/areal/infra/data_service/gateway/__main__.py @@ -0,0 +1,37 @@ +"""CLI entrypoint: python -m areal.infra.data_service.gateway""" + +from __future__ import annotations + +import argparse + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Data Service Gateway") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8090) + parser.add_argument("--admin-api-key", default="areal-data-admin") + parser.add_argument("--router-addr", default="http://localhost:8091") + parser.add_argument("--router-timeout", type=float, default=2.0) + parser.add_argument("--forward-timeout", type=float, default=60.0) + args, _ = parser.parse_known_args() + + from areal.infra.data_service.gateway.app import create_gateway_app + from areal.infra.data_service.gateway.config import GatewayConfig + + config = GatewayConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + router_addr=args.router_addr, + router_timeout=args.router_timeout, + forward_timeout=args.forward_timeout, + ) + + import uvicorn + + app = create_gateway_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/areal/infra/data_service/gateway/app.py b/areal/infra/data_service/gateway/app.py new file mode 100644 index 0000000000..b2bfd5f0f6 --- /dev/null +++ b/areal/infra/data_service/gateway/app.py @@ -0,0 +1,342 @@ +"""Data Service Gateway — thin HTTP proxy with auth, routing, and forwarding.""" + +from __future__ import annotations + +import httpx +from fastapi import FastAPI, HTTPException, Request + +from areal.infra.data_service.gateway.auth import ( + DatasetKeyRegistry, + extract_bearer_token, + require_admin_key, +) +from areal.infra.data_service.gateway.config import GatewayConfig +from areal.utils import logging + +logger = logging.getLogger("DataGateway") + + +async def _query_router(router_addr: str, admin_key: str, timeout: float) -> str: + """Get a worker address from the router via round-robin.""" + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + f"{router_addr}/route", + headers={"Authorization": f"Bearer {admin_key}"}, + ) + if resp.status_code != 200: + raise HTTPException(status_code=502, detail=f"Router error: {resp.text}") + return resp.json()["worker_addr"] + + +async def _get_all_worker_addrs( + router_addr: str, admin_key: str, timeout: float +) -> list[str]: + """Get all worker addresses from the router.""" + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.get( + f"{router_addr}/workers", + headers={"Authorization": f"Bearer {admin_key}"}, + ) + if resp.status_code != 200: + raise HTTPException(status_code=502, detail=f"Router error: {resp.text}") + return [w["addr"] for w in resp.json()["workers"]] + + +async def _broadcast_to_workers( + worker_addrs: list[str], endpoint: str, payload: dict, timeout: float +) -> list[dict]: + """Broadcast a POST request to all workers and collect responses.""" + results: list[dict] = [] + async with httpx.AsyncClient(timeout=timeout) as client: + for addr in worker_addrs: + try: + resp = await client.post(f"{addr}{endpoint}", json=payload) + try: + data = resp.json() + except Exception: + data = {"raw": resp.text} + results.append({"addr": addr, "status": resp.status_code, "data": data}) + except Exception as exc: + results.append({"addr": addr, "status": 500, "error": str(exc)}) + return results + + +def create_gateway_app(config: GatewayConfig) -> FastAPI: + app = FastAPI(title="AReaL Data Gateway") + registry = DatasetKeyRegistry(config.admin_api_key) + + # Helper: resolve dataset key to dataset_id, raise if invalid + def _resolve_dataset_key(token: str) -> str: + dataset_id = registry.resolve(token) + if dataset_id is None: + raise HTTPException(status_code=401, detail="Invalid dataset API key") + return dataset_id + + def _check_broadcast_results(results: list[dict], operation: str) -> None: + """Raise HTTPException if any worker failed during a broadcast.""" + failed = [r for r in results if r["status"] != 200] + if failed: + details = ", ".join( + f"{r['addr']}: {r.get('error', r.get('data'))}" for r in failed + ) + raise HTTPException( + status_code=502, + detail=( + f"{operation} failed on {len(failed)}/{len(results)} " + f"workers: {details}" + ), + ) + + # ===== Health ===== + @app.get("/health") + async def health(): + return {"status": "ok", "router_addr": config.router_addr} + + # ===== Admin: Register Dataset ===== + @app.post("/v1/datasets/register") + async def register_dataset(request: Request): + require_admin_key(request, config.admin_api_key) + body = await request.json() + + dataset_id = body.get( + "dataset_id", + f"{body.get('split', 'train')}-{body.get('dataset_path', 'unknown').split('/')[-1]}", + ) + # Broadcast /datasets/load to all workers + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + if not worker_addrs: + raise HTTPException(status_code=503, detail="No workers available") + load_payload = {**body, "dataset_id": dataset_id} + results = await _broadcast_to_workers( + worker_addrs, + "/datasets/load", + load_payload, + config.forward_timeout, + ) + + successful_addrs = [r["addr"] for r in results if r["status"] == 200] + failed = [r for r in results if r["status"] != 200] + if failed: + rollback_error_detail = "" + if successful_addrs: + rollback_results = await _broadcast_to_workers( + successful_addrs, + "/datasets/unload", + {"dataset_id": dataset_id}, + config.forward_timeout, + ) + rollback_failed = [r for r in rollback_results if r["status"] != 200] + if rollback_failed: + rollback_error_detail = ( + " Rollback failed on " + f"{len(rollback_failed)}/{len(rollback_results)} workers." + ) + + details = ", ".join( + f"{r['addr']}: {r.get('error', r.get('data'))}" for r in failed + ) + raise HTTPException( + status_code=502, + detail=( + f"register_dataset failed on {len(failed)}/{len(results)} workers: " + f"{details}.{rollback_error_detail}" + ), + ) + + registry.revoke(dataset_id) + api_key = registry.generate_key(dataset_id) + + total_size = 0 + for result in results: + if result["status"] == 200: + d = result.get("data", {}) + total_size += d.get("dataset_size", 0) + + return { + "api_key": api_key, + "dataset_id": dataset_id, + "dataset_size": total_size, + "num_workers": len(worker_addrs), + } + + # ===== Admin: Unregister Dataset ===== + @app.post("/v1/datasets/unregister") + async def unregister_dataset(request: Request): + require_admin_key(request, config.admin_api_key) + body = await request.json() + dataset_id = body.get("dataset_id") + if not dataset_id: + raise HTTPException(status_code=400, detail="dataset_id is required") + + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + results = await _broadcast_to_workers( + worker_addrs, + "/datasets/unload", + {"dataset_id": dataset_id}, + config.forward_timeout, + ) + _check_broadcast_results(results, "unregister_dataset") + registry.revoke(dataset_id) + return {"status": "ok"} + + # ===== Admin: Shutdown ===== + @app.post("/v1/shutdown") + async def shutdown(request: Request): + require_admin_key(request, config.admin_api_key) + try: + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + dataset_ids = list(registry._dataset_to_key.keys()) + for dataset_id in dataset_ids: + await _broadcast_to_workers( + worker_addrs, + "/datasets/unload", + {"dataset_id": dataset_id}, + config.forward_timeout, + ) + registry.revoke(dataset_id) + except Exception as exc: + logger.warning("Error during shutdown broadcast: %s", exc) + return {"status": "ok"} + + # ===== Admin: Workers ===== + @app.get("/v1/workers") + async def list_workers(request: Request): + require_admin_key(request, config.admin_api_key) + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + return {"workers": [{"addr": addr} for addr in worker_addrs]} + + # ===== Consumer: Fetch Samples by Index ===== + @app.post("/v1/samples/fetch") + async def fetch_samples(request: Request): + token = extract_bearer_token(request) + dataset_id = _resolve_dataset_key(token) + body = await request.json() + indices = body.get("indices", []) + + worker_addr = await _query_router( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + async with httpx.AsyncClient(timeout=config.forward_timeout) as client: + resp = await client.post( + f"{worker_addr}/v1/samples/fetch", + json={"dataset_id": dataset_id, "indices": indices}, + ) + if resp.status_code != 200: + raise HTTPException( + status_code=502, + detail=f"Worker fetch_samples error: {resp.text}", + ) + return resp.json() + + # ===== Consumer: Epoch Advance ===== + @app.post("/v1/epochs/advance") + async def epoch_advance(request: Request): + token = extract_bearer_token(request) + dataset_id = _resolve_dataset_key(token) + body = await request.json() + epoch = body.get("epoch", 0) + + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + if not worker_addrs: + raise HTTPException(status_code=503, detail="No workers available") + results = await _broadcast_to_workers( + worker_addrs, + "/epoch/reset", + {"dataset_id": dataset_id, "epoch": epoch}, + config.forward_timeout, + ) + _check_broadcast_results(results, "epoch_advance") + return { + "status": "ok", + "workers_reset": sum(1 for result in results if result["status"] == 200), + } + + # ===== Consumer: State Save ===== + @app.post("/v1/state/save") + async def state_save(request: Request): + token = extract_bearer_token(request) + dataset_id = _resolve_dataset_key(token) + body = await request.json() + path = body.get("path", "") + + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + results = await _broadcast_to_workers( + worker_addrs, + "/state/save", + {"dataset_id": dataset_id, "path": path}, + config.forward_timeout, + ) + _check_broadcast_results(results, "state_save") + return {"status": "ok", "path": path} + + # ===== Consumer: State Load ===== + @app.post("/v1/state/load") + async def state_load(request: Request): + token = extract_bearer_token(request) + dataset_id = _resolve_dataset_key(token) + body = await request.json() + path = body.get("path", "") + + worker_addrs = await _get_all_worker_addrs( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + results = await _broadcast_to_workers( + worker_addrs, + "/state/load", + {"dataset_id": dataset_id, "path": path}, + config.forward_timeout, + ) + _check_broadcast_results(results, "state_load") + return {"status": "ok"} + + # ===== Consumer: Status ===== + @app.get("/v1/status") + async def status(request: Request): + token = extract_bearer_token(request) + dataset_id = _resolve_dataset_key(token) + + try: + worker_addr = await _query_router( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + async with httpx.AsyncClient(timeout=config.forward_timeout) as client: + resp = await client.get(f"{worker_addr}/health") + if resp.status_code == 200: + payload = resp.json() + payload["dataset_id"] = dataset_id + return payload + except Exception: + pass + return {"status": "ok", "dataset_id": dataset_id} + + return app diff --git a/areal/infra/data_service/gateway/auth.py b/areal/infra/data_service/gateway/auth.py new file mode 100644 index 0000000000..aec8680f06 --- /dev/null +++ b/areal/infra/data_service/gateway/auth.py @@ -0,0 +1,62 @@ +"""Authentication and API key registry for the data service gateway.""" + +from __future__ import annotations + +import hmac +import uuid + +from fastapi import HTTPException, Request + +from areal.utils import logging + +logger = logging.getLogger("DataGatewayAuth") + + +class DatasetKeyRegistry: + """Maps API keys to dataset IDs. Manages admin + dataset keys.""" + + def __init__(self, admin_api_key: str): + self._admin_key = admin_api_key + self._key_to_dataset: dict[str, str] = {} # api_key → dataset_id + self._dataset_to_key: dict[str, str] = {} # dataset_id → api_key + + def generate_key(self, dataset_id: str) -> str: + """Generate a new API key for a dataset.""" + api_key = f"ds-{uuid.uuid4().hex[:16]}" + self._key_to_dataset[api_key] = dataset_id + self._dataset_to_key[dataset_id] = api_key + logger.info("Generated API key for dataset %s", dataset_id) + return api_key + + def resolve(self, api_key: str) -> str | None: + """Resolve API key to dataset_id. Returns None if not found.""" + return self._key_to_dataset.get(api_key) + + def revoke(self, dataset_id: str) -> str | None: + """Revoke API key for a dataset. Returns the revoked key.""" + api_key = self._dataset_to_key.pop(dataset_id, None) + if api_key: + self._key_to_dataset.pop(api_key, None) + return api_key + + def is_admin(self, api_key: str) -> bool: + return hmac.compare_digest(api_key, self._admin_key) + + def is_valid_dataset_key(self, api_key: str) -> bool: + return api_key in self._key_to_dataset + + +def extract_bearer_token(request: Request) -> str: + auth_header = request.headers.get("authorization", "") + if auth_header.lower().startswith("bearer "): + return auth_header[7:].strip() + raise HTTPException( + status_code=401, detail="Missing or malformed Authorization header." + ) + + +def require_admin_key(request: Request, admin_api_key: str) -> str: + token = extract_bearer_token(request) + if not hmac.compare_digest(token, admin_api_key): + raise HTTPException(status_code=403, detail="Admin API key required.") + return token diff --git a/areal/infra/data_service/gateway/config.py b/areal/infra/data_service/gateway/config.py new file mode 100644 index 0000000000..e85e9b9faa --- /dev/null +++ b/areal/infra/data_service/gateway/config.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class GatewayConfig: + host: str = "0.0.0.0" + port: int = 8090 + router_addr: str = "" + admin_api_key: str = "areal-data-admin" + forward_timeout: float = 60.0 + router_timeout: float = 2.0 diff --git a/areal/infra/data_service/guard/__init__.py b/areal/infra/data_service/guard/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/areal/infra/data_service/guard/__main__.py b/areal/infra/data_service/guard/__main__.py new file mode 100644 index 0000000000..e436f177b4 --- /dev/null +++ b/areal/infra/data_service/guard/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from areal.infra.rpc.guard.__main__ import main + +if __name__ == "__main__": + main() diff --git a/areal/infra/data_service/guard/app.py b/areal/infra/data_service/guard/app.py new file mode 100644 index 0000000000..090013ef08 --- /dev/null +++ b/areal/infra/data_service/guard/app.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from areal.infra.rpc.guard.app import ( # noqa: F401 + GuardState, + cleanup_forked_children, + configure_state_from_args, + create_app, + make_base_parser, + run_server, +) diff --git a/areal/infra/data_service/rdataset.py b/areal/infra/data_service/rdataset.py new file mode 100644 index 0000000000..fe463a123a --- /dev/null +++ b/areal/infra/data_service/rdataset.py @@ -0,0 +1,330 @@ +"""RDataset — remote dataset proxy, analogous to RTensor for tensors. + +Provides a map-style dataset interface backed by remote data workers. +Samples are fetched over HTTP and cached locally via a sampler-aware +prefetch buffer for near-zero latency ``__getitem__`` calls. + +Lifecycle +--------- +1. Created by ``get_custom_dataset()`` — stores metadata only (unconnected). +2. Trainer calls ``connect(controller, ...)`` — registers with the data + service and enables fetching. +3. ``create_dataloader()`` wraps the dataset with ``_PrefetchAwareSampler`` + which triggers prefetch on each ``set_epoch`` call. +4. ``__getitem__(idx)`` pops from the prefetch buffer (cache hit) or falls + back to a blocking HTTP fetch (cache miss). +""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING, Any + +from torch.utils.data import DistributedSampler + +from areal.utils import logging + +if TYPE_CHECKING: + from areal.infra.data_service.controller.controller import DataController + +logger = logging.getLogger("RDataset") + + +class _PrefetchBuffer: + """Background thread that fetches samples from remote workers. + + Fetches proceed in the exact index order that the + ``DistributedSampler`` will request, ensuring near-100 % cache hit + rate. The buffer is bounded by *max_cached*; when full the + prefetch thread pauses until space is freed by ``get()`` calls. + + Parameters + ---------- + fetch_fn : callable + ``fn(indices: list[int]) -> list[Any]`` — batch-fetch samples + by index from remote workers. + chunk_size : int + Number of indices to fetch in a single HTTP round-trip. + max_cached : int + Maximum number of samples to hold in the local cache before + the prefetch thread pauses. + """ + + def __init__( + self, + fetch_fn: Any, + chunk_size: int = 64, + max_cached: int = 512, + ) -> None: + self._fetch_fn = fetch_fn + self._chunk_size = chunk_size + self._max_cached = max_cached + + self._cache: dict[int, Any] = {} + self._lock = threading.Lock() + self._indices: list[int] = [] + self._pos: int = 0 + + self._thread: threading.Thread | None = None + self._stop = threading.Event() + self._space_available = threading.Event() + self._space_available.set() + + # -- Public API -------------------------------------------------------- + + def set_index_order(self, indices: list[int]) -> None: + """Reset the cache and start prefetching in *indices* order. + + Called by ``_PrefetchAwareSampler.set_epoch`` at the beginning + of each epoch. + """ + self._stop.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=10) + + with self._lock: + self._cache.clear() + self._indices = list(indices) + self._pos = 0 + + self._stop.clear() + self._space_available.set() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def get(self, idx: int) -> Any: + """Return the sample for *idx*. + + Pops from the prefetch cache on hit. On miss, performs a + blocking single-index fetch. + """ + with self._lock: + if idx in self._cache: + sample = self._cache.pop(idx) + self._space_available.set() + return sample + + logger.debug("Prefetch cache miss for index %d, fetching directly", idx) + return self._fetch_fn([idx])[0] + + def stop(self) -> None: + """Signal the prefetch thread to stop.""" + self._stop.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=10) + + # -- Background thread ------------------------------------------------- + + def _run(self) -> None: + while not self._stop.is_set(): + with self._lock: + if len(self._cache) >= self._max_cached: + self._space_available.clear() + + if self._pos >= len(self._indices): + break + + chunk = self._indices[self._pos : self._pos + self._chunk_size] + self._pos += len(chunk) + + while not self._space_available.wait(timeout=0.1): + if self._stop.is_set(): + return + + if self._stop.is_set(): + return + + try: + samples = self._fetch_fn(chunk) + except Exception: + logger.exception( + "Prefetch failed for chunk starting at pos %d", + self._pos - len(chunk), + ) + time.sleep(0.5) + with self._lock: + self._pos -= len(chunk) + continue + + with self._lock: + for idx, sample in zip(chunk, samples): + self._cache[idx] = sample + + +class RDataset: + """Remote dataset proxy — map-style dataset backed by remote workers. + + Analogous to :class:`RTensor` for tensors. Locally stores only + dataset metadata; actual samples are fetched lazily from the + distributed data loading service via the :class:`_PrefetchBuffer`. + + Parameters + ---------- + path : str + Dataset path (HuggingFace dataset name or local path). + type : str + Dataset type (``"sft"``, ``"rl"``, ``"rw"``). + split : str | None + Dataset split to load on workers. + max_length : int | None + Maximum sequence length for tokenisation on workers. + dataset_kwargs : dict | None + Extra keyword arguments forwarded to the dataset loader. + """ + + def __init__( + self, + path: str, + type: str = "rl", + split: str | None = None, + max_length: int | None = None, + dataset_kwargs: dict[str, Any] | None = None, + ) -> None: + self._path = path + self._type = type + self._split = split + self._max_length = max_length + self._dataset_kwargs = dataset_kwargs or {} + + self._controller: DataController | None = None + self._api_key: str = "" + self._dataset_id: str = "" + self._total_samples: int = 0 + self._connected: bool = False + self._prefetch_buffer: _PrefetchBuffer | None = None + + # -- Connection -------------------------------------------------------- + + def connect( + self, + controller: DataController, + dataset_id: str, + tokenizer_or_processor_path: str = "", + seed: int = 42, + shuffle: bool = True, + drop_last: bool = True, + prefetch_chunk_size: int = 64, + prefetch_max_cached: int = 512, + ) -> None: + """Register with *controller* and enable data fetching. + + Called by the trainer after the ``DataController`` is + initialised. The controller broadcasts dataset loading to all + workers and returns the total sample count. + """ + if self._connected: + raise RuntimeError("RDataset is already connected") + + handle = controller.register_dataset( + dataset_id=dataset_id, + dataset_path=self._path, + dataset_type=self._type, + split=self._split or "train", + max_length=self._max_length, + dataset_kwargs=self._dataset_kwargs, + tokenizer_or_processor_path=tokenizer_or_processor_path, + seed=seed, + shuffle=shuffle, + drop_last=drop_last, + ) + + self._controller = controller + self._api_key = handle["api_key"] + self._dataset_id = handle["dataset_id"] + self._total_samples = handle["total_samples"] + self._connected = True + self._prefetch_buffer = _PrefetchBuffer( + fetch_fn=self._fetch_samples, + chunk_size=prefetch_chunk_size, + max_cached=prefetch_max_cached, + ) + logger.info( + "RDataset connected: id=%s, total_samples=%d", + self._dataset_id, + self._total_samples, + ) + + # -- Map-style dataset interface --------------------------------------- + + def __len__(self) -> int: + if not self._connected: + raise RuntimeError( + "RDataset is not connected to a DataController. " + "Call connect() before using the dataset." + ) + return self._total_samples + + def __getitem__(self, idx: int) -> Any: + if not self._connected or self._prefetch_buffer is None: + raise RuntimeError( + "RDataset is not connected to a DataController. " + "Call connect() before using the dataset." + ) + return self._prefetch_buffer.get(idx) + + # -- Prefetch control (called by _PrefetchAwareSampler) ---------------- + + def _start_prefetch(self, indices: list[int]) -> None: + """Kick off the prefetch buffer in *indices* order.""" + if self._prefetch_buffer is not None: + self._prefetch_buffer.set_index_order(indices) + + # -- Remote fetch ------------------------------------------------------ + + def _fetch_samples(self, indices: list[int]) -> list[Any]: + """Batch-fetch samples by index from remote workers via the gateway.""" + assert self._controller is not None + from areal.infra.rpc.serialization import deserialize_value + + resp = self._controller._gateway_post( + "/v1/samples/fetch", + self._api_key, + {"indices": indices}, + ) + return [deserialize_value(s) for s in resp["samples"]] + + # -- Lifecycle --------------------------------------------------------- + + def close(self) -> None: + """Stop prefetching and unregister from the controller.""" + if self._prefetch_buffer is not None: + self._prefetch_buffer.stop() + self._prefetch_buffer = None + if self._connected and self._controller is not None: + try: + self._controller.unregister_dataset(self._dataset_id) + except Exception: + logger.debug( + "Failed to unregister dataset %s (expected during teardown)", + self._dataset_id, + ) + self._connected = False + + @property + def connected(self) -> bool: + return self._connected + + +class _PrefetchAwareSampler(DistributedSampler): + """``DistributedSampler`` that triggers ``RDataset`` prefetch on epoch change. + + When ``cycle_dataloader`` calls ``sampler.set_epoch(epoch)``, this + sampler generates the deterministic index order for the new epoch + and passes it to the ``RDataset``'s prefetch buffer so that + samples are fetched in the order the ``DataLoader`` will request + them. + """ + + def __init__(self, dataset: RDataset, *args: Any, **kwargs: Any) -> None: + super().__init__(dataset, *args, **kwargs) + self._rdataset = dataset + self._trigger_prefetch() + + def set_epoch(self, epoch: int) -> None: + super().set_epoch(epoch) + self._trigger_prefetch() + + def _trigger_prefetch(self) -> None: + indices = list(super().__iter__()) + self._rdataset._start_prefetch(indices) diff --git a/areal/infra/data_service/router/__init__.py b/areal/infra/data_service/router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/areal/infra/data_service/router/__main__.py b/areal/infra/data_service/router/__main__.py new file mode 100644 index 0000000000..d46feb966e --- /dev/null +++ b/areal/infra/data_service/router/__main__.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import argparse +import importlib + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Data Service Router") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8091) + parser.add_argument("--admin-api-key", default="areal-data-admin") + parser.add_argument("--poll-interval", type=float, default=5.0) + parser.add_argument("--worker-health-timeout", type=float, default=3.0) + args, _ = parser.parse_known_args() + + router_app_module = importlib.import_module("areal.infra.data_service.router.app") + router_config_module = importlib.import_module( + "areal.infra.data_service.router.config" + ) + create_router_app = router_app_module.create_router_app + RouterConfig = router_config_module.RouterConfig + + config = RouterConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + poll_interval=args.poll_interval, + worker_health_timeout=args.worker_health_timeout, + ) + + uvicorn = importlib.import_module("uvicorn") + + app = create_router_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/areal/infra/data_service/router/app.py b/areal/infra/data_service/router/app.py new file mode 100644 index 0000000000..8c85e72977 --- /dev/null +++ b/areal/infra/data_service/router/app.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import asyncio +import hmac +import importlib +from contextlib import asynccontextmanager + +from areal.infra.data_service.router.config import RouterConfig +from areal.utils import logging + +httpx = importlib.import_module("httpx") +_fastapi = importlib.import_module("fastapi") +FastAPI = _fastapi.FastAPI +HTTPException = _fastapi.HTTPException +Request = _fastapi.Request +BaseModel = importlib.import_module("pydantic").BaseModel + +logger = logging.getLogger("DataRouter") + + +def _extract_bearer_token(request: Request) -> str: + auth_header = request.headers.get("authorization", "") + if auth_header.lower().startswith("bearer "): + return auth_header[7:].strip() + raise HTTPException( + status_code=401, + detail="Missing or malformed Authorization header.", + ) + + +def _require_admin_key(request: Request, admin_key: str) -> str: + token = _extract_bearer_token(request) + if not hmac.compare_digest(token, admin_key): + raise HTTPException(status_code=403, detail="Invalid admin API key.") + return token + + +class RegisterWorkerRequest(BaseModel): + worker_addr: str + + +class UnregisterWorkerRequest(BaseModel): + worker_addr: str + + +def create_router_app(config: RouterConfig) -> FastAPI: + registered_workers: list[str] = [] + worker_healthy: dict[str, bool] = {} + rr_idx: int = 0 + lock = asyncio.Lock() + + async def _poll_workers() -> None: + while True: + for addr in list(registered_workers): + try: + async with httpx.AsyncClient( + timeout=config.worker_health_timeout + ) as client: + resp = await client.get(f"{addr}/health") + worker_healthy[addr] = resp.status_code == 200 + except Exception: + worker_healthy[addr] = False + await asyncio.sleep(config.poll_interval) + + @asynccontextmanager + async def lifespan(app: FastAPI): + logger.info("DataRouter starting") + poll_task = asyncio.create_task(_poll_workers()) + yield + poll_task.cancel() + try: + await poll_task + except asyncio.CancelledError: + pass + logger.info("DataRouter shutting down") + + app = FastAPI(title="AReaL Data Router", lifespan=lifespan) + app.state.worker_healthy = worker_healthy + + @app.get("/health") + async def health(): + return { + "status": "ok", + "workers": len(registered_workers), + "healthy": sum(1 for h in worker_healthy.values() if h), + } + + @app.post("/register") + async def register(body: RegisterWorkerRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + async with lock: + if body.worker_addr not in registered_workers: + registered_workers.append(body.worker_addr) + worker_healthy[body.worker_addr] = True + logger.info( + "Worker registered: %s (total=%d)", + body.worker_addr, + len(registered_workers), + ) + return {"status": "ok"} + + @app.post("/unregister") + async def unregister(body: UnregisterWorkerRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + async with lock: + if body.worker_addr in registered_workers: + registered_workers.remove(body.worker_addr) + worker_healthy.pop(body.worker_addr, None) + logger.info("Worker unregistered: %s", body.worker_addr) + return {"status": "ok"} + + @app.post("/route") + async def route(request: Request): + nonlocal rr_idx + _require_admin_key(request, config.admin_api_key) + async with lock: + healthy = [ + addr for addr in registered_workers if worker_healthy.get(addr, False) + ] + if not healthy: + raise HTTPException( + status_code=503, + detail="No healthy workers available", + ) + addr = healthy[rr_idx % len(healthy)] + rr_idx += 1 + return {"worker_addr": addr} + + @app.get("/workers") + async def list_workers(request: Request): + _require_admin_key(request, config.admin_api_key) + return { + "workers": [ + {"addr": addr, "healthy": worker_healthy.get(addr, False)} + for addr in registered_workers + ] + } + + return app diff --git a/areal/infra/data_service/router/config.py b/areal/infra/data_service/router/config.py new file mode 100644 index 0000000000..f15028a2eb --- /dev/null +++ b/areal/infra/data_service/router/config.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class RouterConfig: + host: str = "0.0.0.0" + port: int = 8091 + admin_api_key: str = "areal-data-admin" + poll_interval: float = 5.0 + worker_health_timeout: float = 3.0 diff --git a/areal/infra/data_service/types.py b/areal/infra/data_service/types.py new file mode 100644 index 0000000000..30f6c5fef7 --- /dev/null +++ b/areal/infra/data_service/types.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +# pyright: reportMissingImports=false +from typing import Any + +from pydantic import BaseModel, Field + + +class WorkerLoadDatasetRequest(BaseModel): + dataset_id: str + dataset_path: str + dataset_type: str + split: str = "train" + tokenizer_or_processor_path: str = "" + seed: int = 42 + max_length: int | None = None + shuffle: bool = True + drop_last: bool = True + dataset_kwargs: dict[str, Any] = Field(default_factory=dict) + + +class WorkerUnloadDatasetRequest(BaseModel): + dataset_id: str + + +class WorkerEpochResetRequest(BaseModel): + dataset_id: str + epoch: int + + +class WorkerStateSaveRequest(BaseModel): + dataset_id: str + path: str + + +class WorkerStateLoadRequest(BaseModel): + dataset_id: str + path: str + + +class FetchSamplesRequest(BaseModel): + dataset_id: str + indices: list[int] diff --git a/areal/infra/data_service/worker/__init__.py b/areal/infra/data_service/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/areal/infra/data_service/worker/__main__.py b/areal/infra/data_service/worker/__main__.py new file mode 100644 index 0000000000..13fe7f0f39 --- /dev/null +++ b/areal/infra/data_service/worker/__main__.py @@ -0,0 +1,37 @@ +"""CLI entrypoint: python -m areal.infra.data_service.worker""" + +from __future__ import annotations + +import argparse +import importlib + + +def main(): + parser = argparse.ArgumentParser(description="AReaL Data Service Worker") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--world-size", type=int, default=1) + parser.add_argument("--dataloader-num-workers", type=int, default=4) + args, _ = parser.parse_known_args() + + app_module = importlib.import_module("areal.infra.data_service.worker.app") + config_module = importlib.import_module("areal.infra.data_service.worker.config") + create_worker_app = getattr(app_module, "create_worker_app") + DataWorkerConfig = getattr(config_module, "DataWorkerConfig") + + config = DataWorkerConfig( + host=args.host, + port=args.port, + rank=args.rank, + world_size=args.world_size, + dataloader_num_workers=args.dataloader_num_workers, + ) + uvicorn = importlib.import_module("uvicorn") + + app = create_worker_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/areal/infra/data_service/worker/app.py b/areal/infra/data_service/worker/app.py new file mode 100644 index 0000000000..1b238edba7 --- /dev/null +++ b/areal/infra/data_service/worker/app.py @@ -0,0 +1,203 @@ +"""DataWorker FastAPI app — serves dataset samples over HTTP.""" + +from __future__ import annotations + +import asyncio +import pickle +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from fastapi import FastAPI, HTTPException +from torch.utils.data import DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.dataset import _get_custom_dataset +from areal.infra.data_service.types import ( + FetchSamplesRequest, + WorkerEpochResetRequest, + WorkerLoadDatasetRequest, + WorkerStateLoadRequest, + WorkerStateSaveRequest, + WorkerUnloadDatasetRequest, +) +from areal.infra.data_service.worker.config import DataWorkerConfig +from areal.infra.rpc.serialization import serialize_value +from areal.utils import logging, seeding +from areal.utils.dataloader import EvalDistributedSampler +from areal.utils.hf_utils import load_hf_processor_and_tokenizer + +logger = logging.getLogger("DataWorker") + + +def _identity_collate(samples: list[Any]) -> list[Any]: + return samples + + +@dataclass +class _DatasetState: + dataset_id: str + raw_dataset: Any + dataloader: Any + sampler: DistributedSampler | None + epoch: int + exhausted: bool + seed: int + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + +def create_worker_app(config: DataWorkerConfig) -> FastAPI: + datasets: dict[str, _DatasetState] = {} + + @asynccontextmanager + async def lifespan(app: Any): + app.state.config = config + app.state.datasets = datasets + try: + yield + finally: + datasets.clear() + + app = FastAPI(title="AReaL Data Worker", lifespan=lifespan) + + def _require_dataset(dataset_id: str) -> _DatasetState: + state = datasets.get(dataset_id) + if state is None: + raise HTTPException( + status_code=404, detail=f"Unknown dataset_id: {dataset_id}" + ) + return state + + @app.get("/health") + async def health(): + return { + "status": "ok", + "rank": config.rank, + "datasets": len(datasets), + } + + @app.post("/datasets/load") + async def load_dataset(body: WorkerLoadDatasetRequest): + if body.dataset_id in datasets: + raise HTTPException( + status_code=409, + detail=f"Dataset {body.dataset_id} is already loaded", + ) + + tokenizer = None + processor = None + if body.tokenizer_or_processor_path: + processor, tokenizer = load_hf_processor_and_tokenizer( + body.tokenizer_or_processor_path + ) + + seeding.set_random_seed(body.seed, key=f"data_worker_{config.rank}") + + # Workers must load real datasets, not RDataset proxies. + # Call _get_custom_dataset directly to bypass the is_single_controller() + # gate in get_custom_dataset() that would create an RDataset. + dataset = _get_custom_dataset( + path=body.dataset_path, + type=body.dataset_type, + split=body.split, + max_length=body.max_length, + tokenizer=tokenizer, + processor=processor, + **body.dataset_kwargs, + ) + + sampler_cls = DistributedSampler if body.drop_last else EvalDistributedSampler + sampler = sampler_cls( + dataset, + num_replicas=config.world_size, + rank=config.rank, + shuffle=body.shuffle, + drop_last=body.drop_last, + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=1, + num_workers=config.dataloader_num_workers, + sampler=sampler, + drop_last=False, + collate_fn=_identity_collate, + ) + + datasets[body.dataset_id] = _DatasetState( + dataset_id=body.dataset_id, + raw_dataset=dataset, + dataloader=dataloader, + sampler=sampler, + epoch=0, + exhausted=False, + seed=body.seed, + ) + + return { + "status": "ok", + "dataset_size": sampler.num_samples, + "steps_per_epoch": len(dataloader), + } + + @app.post("/v1/samples/fetch") + async def fetch_samples(body: FetchSamplesRequest): + state = _require_dataset(body.dataset_id) + samples = [serialize_value(state.raw_dataset[idx]) for idx in body.indices] + return {"samples": samples} + + @app.post("/datasets/unload") + async def unload_dataset(body: WorkerUnloadDatasetRequest): + state = _require_dataset(body.dataset_id) + async with state.lock: + del datasets[body.dataset_id] + return {"status": "ok"} + + @app.post("/epoch/reset") + async def reset_epoch(body: WorkerEpochResetRequest): + state = _require_dataset(body.dataset_id) + async with state.lock: + seeding.set_random_seed(state.seed, key=f"data_worker_{config.rank}") + state.epoch = body.epoch + state.exhausted = False + if state.sampler is not None: + state.sampler.set_epoch(body.epoch) + return {"status": "ok", "epoch": state.epoch} + + @app.post("/state/save") + async def save_state(body: WorkerStateSaveRequest): + state = _require_dataset(body.dataset_id) + async with state.lock: + save_dir = Path(body.path) + save_dir.mkdir(parents=True, exist_ok=True) + save_path = save_dir / f"worker_{config.rank}.pkl" + + with save_path.open("wb") as f: + pickle.dump(state.dataloader.state_dict(), f) + + return {"status": "ok", "path": str(save_path)} + + @app.post("/state/load") + async def load_state(body: WorkerStateLoadRequest): + state = _require_dataset(body.dataset_id) + async with state.lock: + load_path = Path(body.path) / f"worker_{config.rank}.pkl" + if not load_path.exists(): + raise HTTPException( + status_code=404, + detail=f"State file not found: {load_path}", + ) + + with load_path.open("rb") as f: + state_dict = pickle.load(f) + state.dataloader.load_state_dict(state_dict) + state.exhausted = False + + return {"status": "ok", "path": str(load_path)} + + @app.delete("/data/clear") + async def clear_data(): + return {"status": "ok", "tensor_shards": 0} + + return app diff --git a/areal/infra/data_service/worker/config.py b/areal/infra/data_service/worker/config.py new file mode 100644 index 0000000000..8a1842c037 --- /dev/null +++ b/areal/infra/data_service/worker/config.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class DataWorkerConfig: + host: str = "0.0.0.0" + port: int = 0 + rank: int = 0 + world_size: int = 1 + dataloader_num_workers: int = 4 diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 98c7c50ce6..41b4524f6c 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any import torch.distributed as dist -from datasets import Dataset from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import ( @@ -41,6 +40,9 @@ SlurmScheduler, current_platform, ) +from areal.infra.data_service import DataController +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.infra.data_service.rdataset import RDataset from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller @@ -52,6 +54,8 @@ from areal.utils.stats_logger import StatsLogger if TYPE_CHECKING: + from datasets import Dataset + from areal.engine import ( FSDPPPOActor, FSDPPPOCritic, @@ -111,6 +115,9 @@ def __init__( self.scheduler = None if is_single_controller(): self.scheduler = self._init_scheduler() + self.data_controller: DataController | None = None + self._train_rdataset: RDataset | None = None + self._valid_rdataset: RDataset | None = None # Set seed. seeding.set_random_seed(config.seed, key=f"trainer{rank}") @@ -126,6 +133,23 @@ def __init__( self._amend_xccl_weight_update_envvar() + openai_cfg = config.rollout.openai + self._online_mode = openai_cfg is not None and openai_cfg.mode == "online" + + if self._online_mode and config.valid_dataset is not None: + raise ValueError( + "valid_dataset must not be set when using online RL mode " + "(openai.mode='online'). Online mode does not support " + "validation datasets." + ) + + # -- Dataset loading -------------------------------------------------- + if not self._online_mode and train_dataset is None: + raise ValueError( + "train_dataset must be provided unless using online RL mode " + "(openai.mode='online')." + ) + # Create models: actor, critic, ref — each with its own allocation. self.actor = self._create_train_engine(config.actor, self.actor_alloc) self.critic = None @@ -139,19 +163,21 @@ def __init__( ref_alloc = ModelAllocation.from_str(config.ref.backend, name="ref") self.ref = self._create_train_engine(config.ref, ref_alloc) - # Create dataloaders - self.train_dataset = train_dataset - self.valid_dataset = valid_dataset - if train_dataset is None: - # Online mode: require total_train_steps to compute steps_per_epoch. - # Without this, __len__()=1 causes every step to be treated as an - # epoch boundary, making Saver/RecoverHandler fire every step and - # corrupting the LR schedule. + self.teacher = None + if config.teacher is not None: + teacher_alloc = ModelAllocation.from_str( + config.teacher.backend, name="teacher" + ) + self.teacher = self._create_train_engine(config.teacher, teacher_alloc) + + steps_per_epoch: int | None = None + self.train_dataloader: StatefulDataLoader | _EmptyDataLoader + if self._online_mode: if config.total_train_steps is None: raise ValueError( - "total_train_steps must be set for online mode " - "(train_dataset is None). Both total_train_epochs and " - "total_train_steps are needed to compute steps_per_epoch." + "total_train_steps must be set for online mode. " + "Both total_train_epochs and total_train_steps are needed " + "to compute steps_per_epoch." ) steps_per_epoch = config.total_train_steps // config.total_train_epochs if steps_per_epoch < 1: @@ -165,14 +191,47 @@ def __init__( steps_per_epoch=steps_per_epoch, ) else: + assert train_dataset is not None + if is_single_controller() and isinstance(train_dataset, RDataset): + ds_cfg = DataServiceConfig.from_dataset_config(config.train_dataset) + assert self.scheduler is not None + controller = DataController(ds_cfg, self.scheduler) + controller.initialize( + role="data", num_dataset_workers=ds_cfg.num_workers + ) + self.data_controller = controller + train_dataset.connect( + controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_train", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=config.train_dataset.shuffle, + drop_last=config.train_dataset.drop_last, + ) + self._train_rdataset = train_dataset + self.train_dataloader = self._create_dataloader( train_dataset, dataset_config=self.config.train_dataset, rank=self.actor.data_parallel_rank, world_size=self.actor.data_parallel_world_size, ) - self.valid_dataloader = None + + self.valid_dataloader: StatefulDataLoader | None = None if self.config.valid_dataset is not None and valid_dataset is not None: + assert self.config.valid_dataset is not None + if is_single_controller() and isinstance(valid_dataset, RDataset): + assert self.data_controller is not None + valid_dataset.connect( + self.data_controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_valid", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=self.config.valid_dataset.shuffle, + drop_last=self.config.valid_dataset.drop_last, + ) + self._valid_rdataset = valid_dataset + self.valid_dataloader = self._create_dataloader( valid_dataset, dataset_config=self.config.valid_dataset, @@ -180,12 +239,24 @@ def __init__( world_size=self.actor.data_parallel_world_size, ) - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) + # -- FinetuneSpec ----------------------------------------------------- + if self._online_mode: + assert steps_per_epoch is not None + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=steps_per_epoch * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + else: + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(self.train_dataloader) + * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + # Initialize engines first — the scheduler must know about roles + # before the data controller can colocate with them. engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} self.actor.initialize(**engine_init_kwargs, role="actor") if self.critic is not None: @@ -193,12 +264,7 @@ def __init__( if self.ref is not None: self.ref.initialize(**engine_init_kwargs, role="ref") - self.teacher = None - if config.teacher is not None: - teacher_alloc = ModelAllocation.from_str( - config.teacher.backend, name="teacher" - ) - self.teacher = self._create_train_engine(config.teacher, teacher_alloc) + if self.teacher is not None: self.teacher.initialize(**engine_init_kwargs, role="teacher") # Save initial LoRA weights if enabled (for inference server pre-loading) @@ -208,11 +274,6 @@ def __init__( self.rollout = self._init_rollout( config.rollout, is_eval=False, lora_path=initial_lora_path ) - # Online mode detection: skip eval rollout for efficiency. - openai_cfg = config.rollout.openai - self._online_mode = train_dataset is None or ( - openai_cfg is not None and openai_cfg.mode == "online" - ) self.eval_rollout = None if not self._online_mode: @@ -529,6 +590,8 @@ def train( # Since all RTensor objects are affiliated IPs, # calling `clear_batches` once should be sufficient. self.actor.clear_batches(rollout_batch, adv_batch) + if self.data_controller is not None: + self.data_controller.clear_batches() with perf_tracer.trace_scope( "train.log_stats", @@ -546,6 +609,12 @@ def train( def close(self): self.saver.finalize() + if hasattr(self, "_train_rdataset") and self._train_rdataset is not None: + self._train_rdataset.close() + if hasattr(self, "_valid_rdataset") and self._valid_rdataset is not None: + self._valid_rdataset.close() + if hasattr(self, "data_controller") and self.data_controller is not None: + self.data_controller.destroy() self.stats_logger.close() if self.eval_rollout is not None: self.eval_rollout.destroy() diff --git a/areal/trainer/rw_trainer.py b/areal/trainer/rw_trainer.py index d456587bc1..99cec97597 100644 --- a/areal/trainer/rw_trainer.py +++ b/areal/trainer/rw_trainer.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist -from datasets import Dataset from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import FinetuneSpec, Scheduler, StepInfo @@ -22,6 +21,9 @@ SlurmScheduler, current_platform, ) +from areal.infra.data_service import DataController +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.infra.data_service.rdataset import RDataset from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.data import ( broadcast_tensor_container, @@ -38,6 +40,8 @@ from areal.utils.stats_logger import StatsLogger if TYPE_CHECKING: + from datasets import Dataset + from areal.engine import FSDPRWEngine, MegatronRWEngine from areal.experimental.engine.archon_engine import ArchonRWEngine from areal.trainer.rw.rw_engine import RWController @@ -86,6 +90,9 @@ def __init__( self.scheduler = None if is_single_controller(): self.scheduler = self._init_scheduler() + self.data_controller: DataController | None = None + self._train_rdataset: RDataset | None = None + self._valid_rdataset: RDataset | None = None # Set seed. seeding.set_random_seed(config.seed, key=f"trainer{rank}") @@ -93,26 +100,30 @@ def __init__( # Parse per-engine allocation. self.actor_alloc = ModelAllocation.from_str(config.actor.backend, name="actor") - # Create models. self.actor = self._create_actor(config.actor) - # Create dataloaders - self.train_dataset = train_dataset - self.valid_dataset = valid_dataset + if is_single_controller() and isinstance(train_dataset, RDataset): + ds_cfg = DataServiceConfig.from_dataset_config(config.train_dataset) + controller = DataController(ds_cfg, self.scheduler) + controller.initialize(role="data", num_dataset_workers=ds_cfg.num_workers) + self.data_controller = controller + + train_dataset.connect( + controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_train", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=config.train_dataset.shuffle, + drop_last=config.train_dataset.drop_last, + ) + self._train_rdataset = train_dataset + self.train_dataloader = self._create_dataloader( train_dataset, dataset_config=self.config.train_dataset, rank=self.actor.data_parallel_rank, world_size=self.actor.data_parallel_world_size, ) - self.valid_dataloader = None - if self.config.valid_dataset is not None and valid_dataset is not None: - self.valid_dataloader = self._create_dataloader( - valid_dataset, - dataset_config=self.config.valid_dataset, - rank=self.actor.data_parallel_rank, - world_size=self.actor.data_parallel_world_size, - ) ft_spec = FinetuneSpec( total_train_epochs=config.total_train_epochs, @@ -120,9 +131,30 @@ def __init__( train_batch_size=config.train_dataset.batch_size, ) - # Initialize models self.actor.initialize(addr=None, ft_spec=ft_spec, role="actor") + self.valid_dataloader: StatefulDataLoader | None = None + if config.valid_dataset is not None and valid_dataset is not None: + assert config.valid_dataset is not None + if is_single_controller() and isinstance(valid_dataset, RDataset): + assert self.data_controller is not None + valid_dataset.connect( + self.data_controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_valid", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=config.valid_dataset.shuffle, + drop_last=config.valid_dataset.drop_last, + ) + self._valid_rdataset = valid_dataset + + self.valid_dataloader = self._create_dataloader( + valid_dataset, + dataset_config=self.config.valid_dataset, + rank=self.actor.data_parallel_rank, + world_size=self.actor.data_parallel_world_size, + ) + # Set up evaluation self.evaluator = Evaluator(config.evaluator, ft_spec) @@ -239,6 +271,8 @@ def train(self): ), ): self.actor.clear_batches(batch) + if self.data_controller is not None: + self.data_controller.clear_batches() with perf_tracer.trace_scope( "train.log_stats", @@ -253,6 +287,12 @@ def train(self): def close(self): self.saver.finalize() + if hasattr(self, "_train_rdataset") and self._train_rdataset is not None: + self._train_rdataset.close() + if hasattr(self, "_valid_rdataset") and self._valid_rdataset is not None: + self._valid_rdataset.close() + if hasattr(self, "data_controller") and self.data_controller is not None: + self.data_controller.destroy() self.stats_logger.close() self.actor.destroy() perf_tracer.save(force=True) diff --git a/areal/trainer/sft_trainer.py b/areal/trainer/sft_trainer.py index 41376b5e2e..39ad96a837 100644 --- a/areal/trainer/sft_trainer.py +++ b/areal/trainer/sft_trainer.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING import torch.distributed as dist -from datasets import Dataset from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import FinetuneSpec, Scheduler, StepInfo @@ -21,6 +20,9 @@ SlurmScheduler, current_platform, ) +from areal.infra.data_service import DataController +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.infra.data_service.rdataset import RDataset from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.data import ( broadcast_tensor_container, @@ -38,6 +40,8 @@ from areal.utils.stats_logger import StatsLogger if TYPE_CHECKING: + from datasets import Dataset + from areal.engine import FSDPLMEngine, MegatronLMEngine from areal.experimental.engine.archon_engine import ArchonLMEngine from areal.trainer.sft.lm_engine import LMController @@ -54,7 +58,6 @@ def __init__( ): rank = int(os.getenv("RANK", "0")) if is_single_controller(): - # Set up file logging for controller process logging.setup_file_logging(StatsLogger.get_log_path(config.stats_logger)) self.config = config @@ -64,33 +67,38 @@ def __init__( self.scheduler = None if is_single_controller(): self.scheduler = self._init_scheduler() + self.data_controller: DataController | None = None + self._train_rdataset: RDataset | None = None + self._valid_rdataset: RDataset | None = None - # Set seed. seeding.set_random_seed(config.seed, key=f"trainer{rank}") - # Parse per-engine allocation. self.actor_alloc = ModelAllocation.from_str(config.actor.backend, name="actor") - # Create models. self.actor = self._create_actor(config.actor) - # Create dataloaders - self.train_dataset = train_dataset - self.valid_dataset = valid_dataset - self.train_dataloader = self._create_dataloader( + if is_single_controller() and isinstance(train_dataset, RDataset): + ds_cfg = DataServiceConfig.from_dataset_config(config.train_dataset) + controller = DataController(ds_cfg, self.scheduler) + controller.initialize(role="data", num_dataset_workers=ds_cfg.num_workers) + self.data_controller = controller + + train_dataset.connect( + controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_train", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=config.train_dataset.shuffle, + drop_last=config.train_dataset.drop_last, + ) + self._train_rdataset = train_dataset + + self.train_dataloader: StatefulDataLoader = self._create_dataloader( train_dataset, - dataset_config=self.config.train_dataset, + dataset_config=config.train_dataset, rank=self.actor.data_parallel_rank, world_size=self.actor.data_parallel_world_size, ) - self.valid_dataloader = None - if self.config.valid_dataset is not None and valid_dataset is not None: - self.valid_dataloader = self._create_dataloader( - valid_dataset, - dataset_config=self.config.valid_dataset, - rank=self.actor.data_parallel_rank, - world_size=self.actor.data_parallel_world_size, - ) ft_spec = FinetuneSpec( total_train_epochs=config.total_train_epochs, @@ -98,20 +106,34 @@ def __init__( train_batch_size=config.train_dataset.batch_size, ) - # Initialize models self.actor.initialize(addr=None, ft_spec=ft_spec, role="actor") - # Set up evaluation - self.evaluator = Evaluator(config.evaluator, ft_spec) + self.valid_dataloader: StatefulDataLoader | None = None + if config.valid_dataset is not None and valid_dataset is not None: + assert config.valid_dataset is not None + if is_single_controller() and isinstance(valid_dataset, RDataset): + assert self.data_controller is not None + valid_dataset.connect( + self.data_controller, + dataset_id=f"{config.experiment_name}_{config.trial_name}_valid", + tokenizer_or_processor_path=config.tokenizer_path, + seed=config.seed, + shuffle=config.valid_dataset.shuffle, + drop_last=config.valid_dataset.drop_last, + ) + self._valid_rdataset = valid_dataset - # Set up save as HF model + self.valid_dataloader = self._create_dataloader( + valid_dataset, + dataset_config=config.valid_dataset, + rank=self.actor.data_parallel_rank, + world_size=self.actor.data_parallel_world_size, + ) + + self.evaluator = Evaluator(config.evaluator, ft_spec) self.saver = Saver(config.saver, ft_spec) self.recover_handler = RecoverHandler(config.recover, ft_spec) - - # Set up statistics logging (wandb, tensoboard, etc.) self.stats_logger = StatsLogger(config, ft_spec) - - # Set up checkpointing for recover self.recover_info = self.recover_handler.load( self.actor, self.saver, @@ -119,7 +141,6 @@ def __init__( self.stats_logger, self.train_dataloader, ) - self._config_perf_tracer() def train(self): @@ -155,7 +176,6 @@ def train(self): ): batch = self._load_bcast_from(data_generator) - # Wait for async checkpoint staging to complete before modifying parameters self.saver.maybe_wait_for_staging() with ( @@ -217,6 +237,8 @@ def train(self): ), ): self.actor.clear_batches(batch) + if self.data_controller is not None: + self.data_controller.clear_batches() with perf_tracer.trace_scope( "train.log_stats", @@ -231,6 +253,12 @@ def train(self): def close(self): self.saver.finalize() + if self._train_rdataset is not None: + self._train_rdataset.close() + if self._valid_rdataset is not None: + self._valid_rdataset.close() + if hasattr(self, "data_controller") and self.data_controller is not None: + self.data_controller.destroy() self.stats_logger.close() self.actor.destroy() perf_tracer.save(force=True) @@ -262,7 +290,7 @@ def _init_scheduler(self) -> Scheduler: def _create_dataloader( self, - dataset: Dataset, + dataset, dataset_config: TrainDatasetConfig | ValidDatasetConfig, rank: int, world_size: int, @@ -308,7 +336,6 @@ def _load_bcast_from(self, data_generator): if is_single_controller(): return batch - # NOTE: data are identical across model+context parallel group batch = tensor_container_to(batch, current_platform.current_device()) batch = broadcast_tensor_container( batch, @@ -318,7 +345,6 @@ def _load_bcast_from(self, data_generator): return batch def _save_hf(self, epoch: int, epoch_step: int, global_step: int): - # Save as HF models for evaluation self.saver.save( self.actor, epoch, @@ -328,13 +354,11 @@ def _save_hf(self, epoch: int, epoch_step: int, global_step: int): processor=self.processor, ) - # Async mode: synchronization handled by AsyncCheckpointManager if not self.saver.is_async: dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int): - # Save recoverable checkpoints to_save: dict = dict(default=self.actor) step_info = StepInfo( global_step=global_step, @@ -383,7 +407,6 @@ def _evaluate( current_platform.synchronize() def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int): - # Upload statistics to the logger (e.g., wandb) stats = self.actor.export_stats() self.stats_logger.commit(epoch, epoch_step, global_step, stats) diff --git a/areal/utils/data.py b/areal/utils/data.py index 2979f0bd7c..d1cb4f8c60 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -11,7 +11,6 @@ import torch.distributed as dist import torch.nn.functional as F from einops import rearrange -from torch.utils.data import DistributedSampler from torchdata.stateful_dataloader import StatefulDataLoader from areal.api.cli_args import MicroBatchSpec, NormConfig @@ -1350,9 +1349,7 @@ def cycle_dataloader(dataloader: StatefulDataLoader, num_cycles: int = -1): """Cycle through a dataloader indefinitely.""" epoch = 0 while True: - if hasattr(dataloader, "sampler") and isinstance( - dataloader.sampler, DistributedSampler - ): + if hasattr(dataloader, "sampler") and hasattr(dataloader.sampler, "set_epoch"): dataloader.sampler.set_epoch(epoch) yield from dataloader epoch += 1 diff --git a/areal/utils/dataloader.py b/areal/utils/dataloader.py index 9b6bba5d72..49be3e7c92 100644 --- a/areal/utils/dataloader.py +++ b/areal/utils/dataloader.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any from datasets import Dataset from torch.utils.data import DistributedSampler @@ -28,12 +29,21 @@ def create_dataloader( f"batch size({dataset_config.batch_size}) must be divisible by world_size({world_size})!" ) - sampler_cls = DistributedSampler + from areal.infra.data_service.rdataset import RDataset, _PrefetchAwareSampler + drop_sampler_last = True if isinstance(dataset_config, ValidDatasetConfig): - sampler_cls = EvalDistributedSampler drop_sampler_last = False + if isinstance(dataset, RDataset) and isinstance(dataset_config, ValidDatasetConfig): + sampler_cls = _PrefetchAwareEvalSampler + elif isinstance(dataset, RDataset): + sampler_cls = _PrefetchAwareSampler + elif isinstance(dataset_config, ValidDatasetConfig): + sampler_cls = EvalDistributedSampler + else: + sampler_cls = DistributedSampler + return StatefulDataLoader( dataset, batch_size=dataset_config.batch_size // world_size, @@ -85,3 +95,17 @@ def __init__( if self.rank + (self.num_samples - 1) * self.num_replicas >= self.total_size: self.num_samples -= 1 + + +class _PrefetchAwareEvalSampler(EvalDistributedSampler): + def __init__(self, dataset: Any, *args: Any, **kwargs: Any) -> None: + super().__init__(dataset, *args, **kwargs) + self._rdataset = dataset + self._trigger_prefetch() + + def set_epoch(self, epoch: int) -> None: + super().set_epoch(epoch) + self._trigger_prefetch() + + def _trigger_prefetch(self) -> None: + self._rdataset._start_prefetch(list(super().__iter__())) diff --git a/areal/utils/recover.py b/areal/utils/recover.py index 4a5a48ba3d..ef6080aed1 100644 --- a/areal/utils/recover.py +++ b/areal/utils/recover.py @@ -2,10 +2,9 @@ import json import os import pickle -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch.distributed as dist -from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoProcessor, PreTrainedTokenizerFast from areal.api import ( @@ -177,7 +176,7 @@ def dump( saver: Saver, evaluator: Evaluator, stats_logger: "StatsLogger", - dataloader: StatefulDataLoader, + dataloader: Any, tokenizer: PreTrainedTokenizerFast | None = None, processor: AutoProcessor | None = None, base_model_path: str | None = None, @@ -224,7 +223,7 @@ def load( saver: Saver, evaluator: Evaluator, stats_logger: "StatsLogger", - dataloader: StatefulDataLoader, + dataloader: Any, inference_engine: InferenceEngine | None = None, weight_update_meta: WeightUpdateMeta | None = None, inference_engine_update_from: str = "default", diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 228cc772c6..54e695b50c 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -5,6 +5,7 @@ import swanlab import torch.distributed as dist +import trackio import wandb from tensorboardX import SummaryWriter @@ -95,8 +96,6 @@ def init(self): self._trackio_enabled = False trackio_config = self.config.trackio if trackio_config.mode != "disabled": - import trackio - trackio.init( project=trackio_config.project or self.config.experiment_name, name=trackio_config.name or self.config.trial_name, @@ -127,8 +126,6 @@ def close(self): wandb.finish() swanlab.finish() if getattr(self, "_trackio_enabled", False): - import trackio - trackio.finish() if self.summary_writer is not None: self.summary_writer.close() @@ -153,8 +150,6 @@ def commit(self, epoch: int, step: int, global_step: int, data: dict | list[dict wandb.log(item, step=log_step + i) swanlab.log(item, step=log_step + i) if getattr(self, "_trackio_enabled", False): - import trackio - trackio.log(item, step=log_step + i) if self.summary_writer is not None: for key, val in item.items(): diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index d51c58866c..0cd05c3452 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -629,16 +629,20 @@ https://docs.vllm.ai/en/stable/api/index.html for detailed documentation. Configuration for training dataset loading and preprocessing. -| Parameter | Type | Default | Description | -| ------------- | --------------- | ------------ | -------------------------------------------------------------------------------- | -| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | -| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | -| `batch_size` | integer | `1` | Batch size for the dataloader | -| `shuffle` | boolean | `True` | Whether to shuffle the dataset | -| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | -| `num_workers` | integer | `0` | Number of worker processes for data loading | -| `drop_last` | boolean | `True` | Drop the last incomplete batch | -| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| Parameter | Type | Default | Description | +| --------------------- | ---------------------------------------------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- | +| `split` | string | `"train"` | Dataset split to use, e.g., 'train', 'test'. | +| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | +| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | +| `batch_size` | integer | `1` | Batch size for the dataloader | +| `shuffle` | boolean | `True` | Whether to shuffle the dataset | +| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | +| `num_workers` | integer | `0` | Number of worker processes for data loading | +| `num_dataset_workers` | integer | `1` | Number of remote data-service worker processes to launch when using scheduling_spec. | +| `drop_last` | boolean | `True` | Drop the last incomplete batch | +| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| `dataset_kwargs` | `dict` | **Required** | Additional keyword arguments for dataset loading. These are passed to the dataset loading function `get_custom_dataset`. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) \| None | **Required** | Scheduling spec for remote data loading workers. If set, dataset loading will be offloaded to a data service with remote workers. | (section-valid-dataset)= @@ -649,16 +653,20 @@ Configuration for validation dataset loading and preprocessing. It has different default values with `TrainDatasetConfig`. `shuffle` and `drop_last` default to False. -| Parameter | Type | Default | Description | -| ------------- | --------------- | ------------ | -------------------------------------------------------------------------------- | -| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | -| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | -| `batch_size` | integer | `1` | Batch size for the dataloader | -| `shuffle` | boolean | `False` | Whether to shuffle the dataset | -| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | -| `num_workers` | integer | `0` | Number of worker processes for data loading | -| `drop_last` | boolean | `False` | Drop the last incomplete batch | -| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| Parameter | Type | Default | Description | +| --------------------- | ---------------------------------------------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- | +| `split` | string | `"test"` | Dataset split to use, e.g., 'train', 'test'. | +| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | +| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | +| `batch_size` | integer | `1` | Batch size for the dataloader | +| `shuffle` | boolean | `False` | Whether to shuffle the dataset | +| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | +| `num_workers` | integer | `0` | Number of worker processes for data loading | +| `num_dataset_workers` | integer | `1` | Number of remote data-service worker processes to launch when using scheduling_spec. | +| `drop_last` | boolean | `False` | Drop the last incomplete batch | +| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| `dataset_kwargs` | `dict` | **Required** | Additional keyword arguments for dataset loading. These are passed to the dataset loading function `get_custom_dataset`. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) \| None | **Required** | Scheduling spec for remote data loading workers. If set, dataset loading will be offloaded to a data service with remote workers. | (section-cluster)= diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index afd64db4af..d2b10d77c3 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -627,16 +627,20 @@ https://docs.vllm.ai/en/stable/api/index.html for detailed documentation. Configuration for training dataset loading and preprocessing. -| Parameter | Type | Default | Description | -| ------------- | --------------- | ------------ | -------------------------------------------------------------------------------- | -| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | -| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | -| `batch_size` | integer | `1` | Batch size for the dataloader | -| `shuffle` | boolean | `True` | Whether to shuffle the dataset | -| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | -| `num_workers` | integer | `0` | Number of worker processes for data loading | -| `drop_last` | boolean | `True` | Drop the last incomplete batch | -| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| Parameter | Type | Default | Description | +| --------------------- | ---------------------------------------------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- | +| `split` | string | `"train"` | Dataset split to use, e.g., 'train', 'test'. | +| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | +| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | +| `batch_size` | integer | `1` | Batch size for the dataloader | +| `shuffle` | boolean | `True` | Whether to shuffle the dataset | +| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | +| `num_workers` | integer | `0` | Number of worker processes for data loading | +| `num_dataset_workers` | integer | `1` | Number of remote data-service worker processes to launch when using scheduling_spec. | +| `drop_last` | boolean | `True` | Drop the last incomplete batch | +| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| `dataset_kwargs` | `dict` | **Required** | Additional keyword arguments for dataset loading. These are passed to the dataset loading function `get_custom_dataset`. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) \| None | **Required** | Scheduling spec for remote data loading workers. If set, dataset loading will be offloaded to a data service with remote workers. | (section-valid-dataset)= @@ -647,16 +651,20 @@ Configuration for validation dataset loading and preprocessing. It has different default values with `TrainDatasetConfig`. `shuffle` and `drop_last` default to False. -| Parameter | Type | Default | Description | -| ------------- | --------------- | ------------ | -------------------------------------------------------------------------------- | -| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | -| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | -| `batch_size` | integer | `1` | Batch size for the dataloader | -| `shuffle` | boolean | `False` | Whether to shuffle the dataset | -| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | -| `num_workers` | integer | `0` | Number of worker processes for data loading | -| `drop_last` | boolean | `False` | Drop the last incomplete batch | -| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| Parameter | Type | Default | Description | +| --------------------- | ---------------------------------------------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- | +| `split` | string | `"test"` | Dataset split to use, e.g., 'train', 'test'. | +| `path` | string | **Required** | Path to the dataset. Can be a local path or a HuggingFace dataset name. | +| `type` | string | **Required** | Type of training method, e.g., 'sft', 'rl', etc. | +| `batch_size` | integer | `1` | Batch size for the dataloader | +| `shuffle` | boolean | `False` | Whether to shuffle the dataset | +| `pin_memory` | boolean | `False` | Pin memory for faster data loading (set True for GPU training) | +| `num_workers` | integer | `0` | Number of worker processes for data loading | +| `num_dataset_workers` | integer | `1` | Number of remote data-service worker processes to launch when using scheduling_spec. | +| `drop_last` | boolean | `False` | Drop the last incomplete batch | +| `max_length` | integer \| None | `None` | Maximum token length of sequences in dataset. Longer sequences are filtered out. | +| `dataset_kwargs` | `dict` | **Required** | Additional keyword arguments for dataset loading. These are passed to the dataset loading function `get_custom_dataset`. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) \| None | **Required** | Scheduling spec for remote data loading workers. If set, dataset loading will be offloaded to a data service with remote workers. | (section-cluster)= diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/infra/data_service/__init__.py b/tests/infra/data_service/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/infra/data_service/test_auth.py b/tests/infra/data_service/test_auth.py new file mode 100644 index 0000000000..0e18bbf04e --- /dev/null +++ b/tests/infra/data_service/test_auth.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from fastapi import HTTPException + +from areal.infra.data_service.gateway.auth import ( + DatasetKeyRegistry, + extract_bearer_token, + require_admin_key, +) + + +class TestDatasetKeyRegistry: + def test_generate_key_returns_string(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key = registry.generate_key("dataset-a") + assert isinstance(key, str) + assert key.startswith("ds-") + + def test_generate_key_unique(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key1 = registry.generate_key("dataset-a") + key2 = registry.generate_key("dataset-b") + assert key1 != key2 + + def test_resolve_returns_dataset_id(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key = registry.generate_key("dataset-a") + assert registry.resolve(key) == "dataset-a" + + def test_resolve_unknown_returns_none(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.resolve("ds-does-not-exist") is None + + def test_revoke_removes_key(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key = registry.generate_key("dataset-a") + assert registry.revoke("dataset-a") == key + assert registry.resolve(key) is None + + def test_revoke_unknown_returns_none(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.revoke("dataset-missing") is None + + def test_is_admin_correct_key(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.is_admin("admin-key") is True + + def test_is_admin_wrong_key(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.is_admin("not-admin") is False + + def test_is_admin_timing_safe(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.is_admin("admin-key") is True + assert registry.is_admin("admin-keyx") is False + assert registry.is_admin("admin-keY") is False + + def test_is_valid_dataset_key_after_generate(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key = registry.generate_key("dataset-a") + assert registry.is_valid_dataset_key(key) is True + + def test_is_valid_dataset_key_unknown(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + assert registry.is_valid_dataset_key("ds-unknown") is False + + def test_is_valid_dataset_key_after_revoke(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key = registry.generate_key("dataset-a") + registry.revoke("dataset-a") + assert registry.is_valid_dataset_key(key) is False + + def test_generate_revoke_generate_new_key(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + first_key = registry.generate_key("dataset-a") + registry.revoke("dataset-a") + second_key = registry.generate_key("dataset-a") + assert second_key != first_key + assert registry.resolve(second_key) == "dataset-a" + + def test_multiple_datasets_independent(self): + registry = DatasetKeyRegistry(admin_api_key="admin-key") + key_a = registry.generate_key("dataset-a") + key_b = registry.generate_key("dataset-b") + assert registry.resolve(key_a) == "dataset-a" + assert registry.resolve(key_b) == "dataset-b" + registry.revoke("dataset-a") + assert registry.resolve(key_a) is None + assert registry.resolve(key_b) == "dataset-b" + + +class TestExtractBearerToken: + def test_extract_bearer_token_with_valid_header(self): + request = SimpleNamespace(headers={"authorization": "Bearer token-123"}) + assert extract_bearer_token(request) == "token-123" + + def test_extract_bearer_token_missing_header_raises_401(self): + request = SimpleNamespace(headers={}) + with pytest.raises(HTTPException) as exc_info: + extract_bearer_token(request) + assert exc_info.value.status_code == 401 + + def test_extract_bearer_token_basic_auth_raises_401(self): + request = SimpleNamespace(headers={"authorization": "Basic token-123"}) + with pytest.raises(HTTPException) as exc_info: + extract_bearer_token(request) + assert exc_info.value.status_code == 401 + + def test_extract_bearer_token_empty_header_raises_401(self): + request = SimpleNamespace(headers={"authorization": ""}) + with pytest.raises(HTTPException) as exc_info: + extract_bearer_token(request) + assert exc_info.value.status_code == 401 + + +class TestRequireAdminKey: + def test_require_admin_key_accepts_valid_admin_token(self): + request = SimpleNamespace(headers={"authorization": "Bearer admin-key"}) + assert require_admin_key(request, "admin-key") == "admin-key" + + def test_require_admin_key_rejects_non_admin_token(self): + request = SimpleNamespace(headers={"authorization": "Bearer user-key"}) + with pytest.raises(HTTPException) as exc_info: + require_admin_key(request, "admin-key") + assert exc_info.value.status_code == 403 diff --git a/tests/infra/data_service/test_controller.py b/tests/infra/data_service/test_controller.py new file mode 100644 index 0000000000..b295cfe982 --- /dev/null +++ b/tests/infra/data_service/test_controller.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +from areal.api.cli_args import ( + PPOConfig, + SchedulingStrategy, + SchedulingStrategyType, + SFTConfig, +) +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.infra.data_service.controller.controller import DataController + + +def _make_mock_aiohttp(status=200, json_data=None, text_data=""): + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.json = AsyncMock(return_value=json_data if json_data is not None else {}) + mock_resp.text = AsyncMock(return_value=text_data) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_resp) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_cls = MagicMock(return_value=mock_session) + return mock_cls, mock_session + + +class TestDataServiceConfig: + def test_default_values(self): + cfg = DataServiceConfig() + + assert cfg.num_workers == 1 + assert cfg.setup_timeout == 120.0 + assert isinstance(cfg.scheduling_strategy, SchedulingStrategy) + assert cfg.scheduling_strategy.type == SchedulingStrategyType.separation + + def test_custom_values(self): + cfg = DataServiceConfig( + num_workers=8, + setup_timeout=300.0, + ) + + assert cfg.num_workers == 8 + assert cfg.setup_timeout == 300.0 + + def test_scheduling_strategy_colocation(self): + cfg = DataServiceConfig( + scheduling_strategy=SchedulingStrategy( + type=SchedulingStrategyType.colocation, target="rollout" + ), + ) + + assert cfg.scheduling_strategy.type == SchedulingStrategyType.colocation + assert cfg.scheduling_strategy.target == "rollout" + assert cfg.scheduling_strategy.fork is True + + def test_scheduling_strategy_separation(self): + cfg = DataServiceConfig( + scheduling_strategy=SchedulingStrategy(type="separation"), + ) + + assert cfg.scheduling_strategy.type == "separation" + assert cfg.scheduling_strategy.target is None + + def test_config_in_base_experiment_config(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="rl") + cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert isinstance(cfg, DataServiceConfig) + + def test_from_dataset_config_uses_num_dataset_workers(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig( + path="dummy", + type="rl", + num_workers=5, + num_dataset_workers=7, + ) + cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert cfg.num_workers == 7 + assert cfg.dataloader_num_workers == 5 + + def test_config_in_ppo_config(self): + cfg = PPOConfig(experiment_name="exp", trial_name="trial") + ds_cfg = DataServiceConfig.from_dataset_config(cfg.train_dataset) + assert isinstance(ds_cfg, DataServiceConfig) + + def test_config_in_sft_config(self): + cfg = SFTConfig(experiment_name="exp", trial_name="trial") + ds_cfg = DataServiceConfig.from_dataset_config(cfg.train_dataset) + assert isinstance(ds_cfg, DataServiceConfig) + + def test_config_in_rw_config(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="rl") + cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert isinstance(cfg, DataServiceConfig) + + +class TestDataControllerInit: + def test_init_stores_config(self): + cfg = DataServiceConfig() + scheduler = MagicMock() + + controller = DataController(cfg, scheduler) + + assert controller.config is cfg + + def test_init_stores_scheduler(self): + cfg = DataServiceConfig() + scheduler = MagicMock() + + controller = DataController(cfg, scheduler) + + assert controller.scheduler is scheduler + + def test_init_empty_state(self): + controller = DataController(DataServiceConfig(), MagicMock()) + + assert controller.workers == [] + assert controller._gateway_addr == "" + assert controller._datasets == {} + + +class TestDataControllerGatewayPost: + def test_gateway_post_sends_bearer_auth(self): + controller = DataController(DataServiceConfig(), MagicMock()) + controller._gateway_addr = "http://gateway" + + mock_cls, mock_session = _make_mock_aiohttp(status=200, json_data={"ok": True}) + + with patch( + "areal.infra.data_service.controller.controller.aiohttp.ClientSession", + mock_cls, + ): + result = controller._gateway_post("/v1/test", "api-key", {"x": 1}) + + assert result == {"ok": True} + _, kwargs = mock_session.post.call_args + assert kwargs["headers"] == {"Authorization": "Bearer api-key"} + + def test_gateway_post_sends_json_payload(self): + controller = DataController(DataServiceConfig(), MagicMock()) + controller._gateway_addr = "http://gateway" + + mock_cls, mock_session = _make_mock_aiohttp(status=200, json_data={"ok": True}) + + with patch( + "areal.infra.data_service.controller.controller.aiohttp.ClientSession", + mock_cls, + ): + controller._gateway_post("/v1/test", "api-key", {"payload": "value"}) + + _, kwargs = mock_session.post.call_args + assert kwargs["json"] == {"payload": "value"} + + def test_gateway_post_raises_on_error(self): + controller = DataController(DataServiceConfig(), MagicMock()) + controller._gateway_addr = "http://gateway" + + mock_cls, _ = _make_mock_aiohttp(status=500, text_data="boom") + + with patch( + "areal.infra.data_service.controller.controller.aiohttp.ClientSession", + mock_cls, + ): + try: + controller._gateway_post("/v1/test", "api-key", {}) + except RuntimeError as exc: + assert "returned 500" in str(exc) + else: + raise AssertionError("Expected RuntimeError for gateway error response") + + +class TestDataControllerRegisterDataset: + def test_register_returns_dataset_metadata(self): + controller = DataController(DataServiceConfig(), MagicMock()) + + payload = { + "api_key": "ds-key", + "dataset_id": "test-ds", + "dataset_size": 32, + "num_workers": 4, + } + + with patch( + "areal.infra.utils.concurrent.run_async_task", + return_value=payload, + ) as mock_run: + result = controller.register_dataset( + dataset_id="test-ds", + dataset_path="dummy", + dataset_type="rl", + drop_last=True, + ) + + assert mock_run.called + assert result["api_key"] == "ds-key" + assert result["dataset_id"] == "test-ds" + assert result["total_samples"] == 32 + assert result["num_workers"] == 4 + assert controller._datasets["ds-key"]["dataset_id"] == "test-ds" + + def test_register_stores_drop_last_flag(self): + controller = DataController(DataServiceConfig(), MagicMock()) + + payload = { + "api_key": "ds-key", + "dataset_id": "test-ds", + "dataset_size": 30, + "num_workers": 4, + } + with patch( + "areal.infra.utils.concurrent.run_async_task", + return_value=payload, + ): + controller.register_dataset( + dataset_id="test-ds", + dataset_path="dummy", + dataset_type="rl", + drop_last=False, + ) + assert controller._datasets["ds-key"]["drop_last"] is False + + def test_unregister_removes_local_dataset_cache(self): + controller = DataController(DataServiceConfig(), MagicMock()) + controller._datasets["key-1"] = {"dataset_id": "a"} + controller._datasets["key-2"] = {"dataset_id": "b"} + + with patch("areal.infra.utils.concurrent.run_async_task"): + controller.unregister_dataset("a") + + assert "key-1" not in controller._datasets + assert "key-2" in controller._datasets diff --git a/tests/infra/data_service/test_data_service_e2e.py b/tests/infra/data_service/test_data_service_e2e.py new file mode 100644 index 0000000000..3d0fab725c --- /dev/null +++ b/tests/infra/data_service/test_data_service_e2e.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +# pyright: reportMissingImports=false +import os +import socket +import threading +import time +import uuid +from pathlib import Path +from typing import Any + +import httpx +import pytest +import uvicorn +from huggingface_hub import snapshot_download + +from areal.infra.data_service.gateway.app import create_gateway_app +from areal.infra.data_service.gateway.config import GatewayConfig +from areal.infra.data_service.router.app import create_router_app +from areal.infra.data_service.router.config import RouterConfig +from areal.infra.data_service.worker.app import create_worker_app +from areal.infra.data_service.worker.config import DataWorkerConfig + +pytestmark = pytest.mark.slow + +ADMIN_KEY = "areal-data-admin" +BATCH_SIZE = 256 + + +def _resolve_path(local: str, hf_id: str, repo_type: str = "dataset") -> str: + if os.path.exists(local): + return local + try: + return snapshot_download(repo_id=hf_id, repo_type=repo_type) + except Exception as exc: + pytest.skip(f"Required test artifact unavailable: {hf_id} ({exc})") + raise AssertionError("unreachable") + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _start_uvicorn( + app: object, host: str, port: int +) -> tuple[uvicorn.Server, threading.Thread]: + config = uvicorn.Config(app, host=host, port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + return server, thread + + +def _wait_healthy(base_url: str, timeout: float = 10.0) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = httpx.get(f"{base_url}/health", timeout=1.0) + if resp.status_code == 200: + return + except httpx.HTTPError: + pass + time.sleep(0.05) + raise TimeoutError(f"Service did not become healthy: {base_url}") + + +def _admin_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +def _dataset_headers(api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {api_key}"} + + +def _register_dataset( + client: httpx.Client, + *, + dataset_id: str, + dataset_path: str, + dataset_type: str = "rl", + tokenizer_or_processor_path: str = "", +) -> dict[str, Any]: + resp = client.post( + "/v1/datasets/register", + headers=_admin_headers(), + json={ + "dataset_id": dataset_id, + "dataset_path": dataset_path, + "dataset_type": dataset_type, + "split": "test", + "seed": 42, + "shuffle": False, + "tokenizer_or_processor_path": tokenizer_or_processor_path, + }, + timeout=120.0, + ) + assert resp.status_code == 200, f"register failed: {resp.text}" + payload = resp.json() + assert payload["dataset_id"] == dataset_id + return payload + + +def _unique_dataset_id(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +@pytest.fixture(scope="module") +def data_service_stack(tmp_path_factory: pytest.TempPathFactory): + dataset_path = _resolve_path( + "/storage/openpsi/data/gsm8k", "openai/gsm8k", repo_type="dataset" + ) + geometry3k_path = _resolve_path( + "/storage/openpsi/data/hiyouga__geometry3k/", + "hiyouga/geometry3k", + repo_type="dataset", + ) + model_path = _resolve_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B", + "Qwen/Qwen3-0.6B", + repo_type="model", + ) + vlm_model_path = _resolve_path( + "/storage/openpsi/models/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + repo_type="model", + ) + + worker_port = _free_port() + router_port = _free_port() + gateway_port = _free_port() + + worker_addr = f"http://127.0.0.1:{worker_port}" + router_addr = f"http://127.0.0.1:{router_port}" + gateway_addr = f"http://127.0.0.1:{gateway_port}" + state_dir = tmp_path_factory.mktemp("data-service-state") + + servers: list[tuple[uvicorn.Server, threading.Thread]] = [] + + worker_app = create_worker_app( + DataWorkerConfig( + host="127.0.0.1", + port=worker_port, + rank=0, + world_size=1, + dataloader_num_workers=0, + ) + ) + servers.append(_start_uvicorn(worker_app, "127.0.0.1", worker_port)) + _wait_healthy(worker_addr) + + router_app = create_router_app( + RouterConfig( + host="127.0.0.1", + port=router_port, + admin_api_key=ADMIN_KEY, + poll_interval=0.2, + ) + ) + servers.append(_start_uvicorn(router_app, "127.0.0.1", router_port)) + _wait_healthy(router_addr) + + register_resp = httpx.post( + f"{router_addr}/register", + headers=_admin_headers(), + json={"worker_addr": worker_addr}, + timeout=3.0, + ) + assert register_resp.status_code == 200 + + gateway_app = create_gateway_app( + GatewayConfig( + host="127.0.0.1", + port=gateway_port, + admin_api_key=ADMIN_KEY, + router_addr=router_addr, + ) + ) + servers.append(_start_uvicorn(gateway_app, "127.0.0.1", gateway_port)) + _wait_healthy(gateway_addr) + + yield { + "worker_addr": worker_addr, + "router_addr": router_addr, + "gateway_addr": gateway_addr, + "state_dir": state_dir, + "dataset_path": dataset_path, + "geometry3k_path": geometry3k_path, + "model_path": model_path, + "vlm_model_path": vlm_model_path, + } + + try: + httpx.post( + f"{gateway_addr}/v1/shutdown", + headers=_admin_headers(), + timeout=3.0, + ) + except httpx.HTTPError: + pass + + for server, thread in reversed(servers): + server.should_exit = True + thread.join(timeout=5) + + +@pytest.fixture +def gateway_client(data_service_stack: dict[str, Any]): + gateway_addr = str(data_service_stack["gateway_addr"]) + with httpx.Client(base_url=gateway_addr, timeout=30.0) as client: + yield client + + +class TestServiceHealth: + def test_all_services_healthy(self, data_service_stack: dict[str, Any]): + worker_addr = str(data_service_stack["worker_addr"]) + router_addr = str(data_service_stack["router_addr"]) + gateway_addr = str(data_service_stack["gateway_addr"]) + + worker = httpx.get(f"{worker_addr}/health", timeout=3.0) + router = httpx.get(f"{router_addr}/health", timeout=3.0) + gateway = httpx.get(f"{gateway_addr}/health", timeout=3.0) + + assert worker.status_code == 200 + assert router.status_code == 200 + assert gateway.status_code == 200 + + def test_router_shows_registered_worker(self, data_service_stack: dict[str, Any]): + router_addr = str(data_service_stack["router_addr"]) + worker_addr = str(data_service_stack["worker_addr"]) + + resp = httpx.get( + f"{router_addr}/workers", + headers=_admin_headers(), + timeout=3.0, + ) + assert resp.status_code == 200 + workers = resp.json()["workers"] + assert workers == [{"addr": worker_addr, "healthy": True}] + + +class TestDatasetRegistration: + def test_register_rl_dataset_returns_key( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + payload = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("register-rl"), + dataset_path=str(data_service_stack["dataset_path"]), + dataset_type="rl", + ) + assert str(payload["api_key"]).startswith("ds-") + assert payload["dataset_size"] > 0 + + def test_register_sft_dataset_returns_key( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + payload = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("register-sft"), + dataset_path=str(data_service_stack["dataset_path"]), + dataset_type="sft", + tokenizer_or_processor_path=str(data_service_stack["model_path"]), + ) + assert str(payload["api_key"]).startswith("ds-") + assert payload["dataset_size"] > 0 + + def test_register_returns_dataset_size( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + payload = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("register-steps"), + dataset_path=str(data_service_stack["dataset_path"]), + ) + assert payload["dataset_size"] > 0 + assert "steps_per_epoch" not in payload + + def test_register_geometry3k_rl_dataset( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + payload = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("register-geo3k"), + dataset_path=str(data_service_stack["geometry3k_path"]), + dataset_type="rl", + tokenizer_or_processor_path=str(data_service_stack["vlm_model_path"]), + ) + assert str(payload["api_key"]).startswith("ds-") + assert payload["dataset_size"] > 0 + + +class TestBatchFetching: + def test_fetch_batch_returns_data( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + reg = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("fetch-one"), + dataset_path=str(data_service_stack["dataset_path"]), + ) + api_key = str(reg["api_key"]) + resp = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert resp.status_code == 200 + payload = resp.json()["samples"] + assert len(payload) == 1 + assert payload[0] + + def test_fetch_multiple_batches( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + reg = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("fetch-multi"), + dataset_path=str(data_service_stack["dataset_path"]), + ) + api_key = str(reg["api_key"]) + + batches = [] + for idx in range(3): + resp = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [idx]}, + ) + assert resp.status_code == 200 + batches.append(resp.json()["samples"][0]) + + assert batches[0] != batches[1] + assert batches[1] != batches[2] + + def test_fetch_after_epoch_advance( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + reg = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("fetch-epoch"), + dataset_path=str(data_service_stack["dataset_path"]), + ) + api_key = str(reg["api_key"]) + before = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert before.status_code == 200 + assert len(before.json()["samples"]) == 1 + + reset = gateway_client.post( + "/v1/epochs/advance", + headers=_dataset_headers(api_key), + json={"epoch": 1}, + ) + assert reset.status_code == 200 + assert reset.json()["workers_reset"] == 1 + + after = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert after.status_code == 200 + body = after.json()["samples"] + assert len(body) == 1 + assert body[0] + + +class TestStatePersistence: + def test_state_save_and_load( + self, + gateway_client: httpx.Client, + data_service_stack: dict[str, Any], + ): + reg = _register_dataset( + gateway_client, + dataset_id=_unique_dataset_id("state"), + dataset_path=str(data_service_stack["dataset_path"]), + ) + api_key = str(reg["api_key"]) + state_dir = Path(str(data_service_stack["state_dir"])) + + save = gateway_client.post( + "/v1/state/save", + headers=_dataset_headers(api_key), + json={"path": str(state_dir)}, + ) + assert save.status_code == 200 + assert save.json()["status"] == "ok" + assert (state_dir / "worker_0.pkl").exists() + + load = gateway_client.post( + "/v1/state/load", + headers=_dataset_headers(api_key), + json={"path": str(state_dir)}, + ) + assert load.status_code == 200 + assert load.json()["status"] == "ok" + + +class TestDatasetUnregistration: + def test_unregister_revokes_key( + self, gateway_client: httpx.Client, data_service_stack: dict[str, Any] + ): + dataset_id = _unique_dataset_id("unregister") + reg = _register_dataset( + gateway_client, + dataset_id=dataset_id, + dataset_path=str(data_service_stack["dataset_path"]), + ) + api_key = str(reg["api_key"]) + + unreg = gateway_client.post( + "/v1/datasets/unregister", + headers=_admin_headers(), + json={"dataset_id": dataset_id}, + ) + assert unreg.status_code == 200 + + rejected = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert rejected.status_code == 401 + + +class TestFullLifecycle: + @pytest.mark.parametrize("dataset_type", ["rl", "sft"]) + def test_complete_lifecycle( + self, + gateway_client: httpx.Client, + data_service_stack: dict[str, Any], + dataset_type: str, + ): + dataset_id = _unique_dataset_id(f"full-{dataset_type}") + tokenizer_or_processor_path = ( + str(data_service_stack["model_path"]) if dataset_type == "sft" else "" + ) + reg = _register_dataset( + gateway_client, + dataset_id=dataset_id, + dataset_path=str(data_service_stack["dataset_path"]), + dataset_type=dataset_type, + tokenizer_or_processor_path=tokenizer_or_processor_path, + ) + api_key = str(reg["api_key"]) + state_dir = Path(str(data_service_stack["state_dir"])) / f"{dataset_id}-state" + + for _ in range(3): + resp = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert resp.status_code == 200 + assert resp.json()["samples"][0] is not None + + reset = gateway_client.post( + "/v1/epochs/advance", + headers=_dataset_headers(api_key), + json={"epoch": 2}, + ) + assert reset.status_code == 200 + assert reset.json()["workers_reset"] == 1 + + after_reset = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert after_reset.status_code == 200 + batch_item = after_reset.json()["samples"][0] + assert batch_item is not None + + save = gateway_client.post( + "/v1/state/save", + headers=_dataset_headers(api_key), + json={"path": str(state_dir)}, + ) + assert save.status_code == 200 + + load = gateway_client.post( + "/v1/state/load", + headers=_dataset_headers(api_key), + json={"path": str(state_dir)}, + ) + assert load.status_code == 200 + + unreg = gateway_client.post( + "/v1/datasets/unregister", + headers=_admin_headers(), + json={"dataset_id": dataset_id}, + ) + assert unreg.status_code == 200 + + rejected = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert rejected.status_code == 401 + + def test_geometry3k_rl_lifecycle( + self, + gateway_client: httpx.Client, + data_service_stack: dict[str, Any], + ): + dataset_id = _unique_dataset_id("full-geo3k") + reg = _register_dataset( + gateway_client, + dataset_id=dataset_id, + dataset_path=str(data_service_stack["geometry3k_path"]), + dataset_type="rl", + tokenizer_or_processor_path=str(data_service_stack["vlm_model_path"]), + ) + api_key = str(reg["api_key"]) + steps = int(reg["dataset_size"]) + assert steps > 0 + + resp = gateway_client.post( + "/v1/samples/fetch", + headers=_dataset_headers(api_key), + json={"indices": [0]}, + ) + assert resp.status_code == 200 + assert resp.json()["samples"][0] + + gateway_client.post( + "/v1/datasets/unregister", + headers=_admin_headers(), + json={"dataset_id": dataset_id}, + ) diff --git a/tests/infra/data_service/test_epoch_crossing.py b/tests/infra/data_service/test_epoch_crossing.py new file mode 100644 index 0000000000..0b5e6e93dd --- /dev/null +++ b/tests/infra/data_service/test_epoch_crossing.py @@ -0,0 +1,612 @@ +"""Epoch crossing comparison tests. + +Verifies that DatasetHandle (data service) produces the same iteration +behavior as a local StatefulDataLoader for various dataset sizes, batch +sizes, and worker counts. + +Specifically checks: +- __len__() matches (steps_per_epoch) +- Number of yielded batches per epoch matches +- batch_size consistency +- Epoch boundary behavior (exhaustion → reset → new epoch) +- Multi-worker round-robin yields all data +""" + +from __future__ import annotations + +import socket +import threading +import time +import uuid +from typing import Any + +import httpx +import pytest +import uvicorn +from datasets import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.cli_args import TrainDatasetConfig +from areal.infra.data_service.gateway.app import create_gateway_app +from areal.infra.data_service.gateway.config import GatewayConfig +from areal.infra.data_service.rdataset import RDataset +from areal.infra.data_service.router.app import create_router_app +from areal.infra.data_service.router.config import RouterConfig +from areal.infra.data_service.worker.app import create_worker_app +from areal.infra.data_service.worker.config import DataWorkerConfig +from areal.utils.dataloader import create_dataloader + +pytestmark = pytest.mark.slow + +ADMIN_KEY = "areal-data-admin" + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _start_uvicorn(app, host, port): + config = uvicorn.Config(app, host=host, port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + return server, thread + + +def _wait_healthy(base_url: str, timeout: float = 10.0): + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = httpx.get(f"{base_url}/health", timeout=1.0) + if resp.status_code == 200: + return + except httpx.HTTPError: + pass + time.sleep(0.05) + raise TimeoutError(f"Service did not become healthy: {base_url}") + + +def _make_dataset(n_samples: int) -> Dataset: + return Dataset.from_dict( + { + "idx": list(range(n_samples)), + "text": [f"sample_{i}" for i in range(n_samples)], + } + ) + + +def _identity_collate(samples): + return samples + + +def _local_dataloader_info( + dataset: Dataset, + batch_size: int, + num_workers: int, + drop_last: bool = True, + shuffle: bool = False, +) -> dict: + """Simulate new data-service behavior: workers use batch_size=1, + controller accumulates into batches of batch_size with drop_last. + """ + from torch.utils.data import DistributedSampler + + from areal.utils.dataloader import EvalDistributedSampler + + all_samples: list = [] + shard_sizes: list[int] = [] + + for rank in range(num_workers): + sampler_cls = DistributedSampler if drop_last else EvalDistributedSampler + sampler = sampler_cls( + dataset, + num_replicas=num_workers, + rank=rank, + shuffle=shuffle, + drop_last=drop_last, + ) + shard_sizes.append(sampler.num_samples) + dl = StatefulDataLoader( + dataset, + batch_size=1, + sampler=sampler, + collate_fn=_identity_collate, + drop_last=False, + ) + for item in dl: + all_samples.append( + item[0] if isinstance(item, list) and len(item) == 1 else item + ) + + total_samples = len(all_samples) + + all_batches: list[list] = [] + for i in range(0, total_samples, batch_size): + chunk = all_samples[i : i + batch_size] + if len(chunk) == batch_size: + all_batches.append(chunk) + elif not drop_last: + all_batches.append(chunk) + + if drop_last: + steps = total_samples // batch_size + else: + steps = (total_samples + batch_size - 1) // batch_size + + return { + "total_steps": steps, + "shard_sizes": shard_sizes, + "all_batches": all_batches, + "total_samples": total_samples, + } + + +class _DataServiceStack: + """Starts worker(s), router, gateway in-process for testing.""" + + def __init__(self, num_workers: int, batch_size: int): + self.num_workers = num_workers + self.batch_size = batch_size + self.servers: list[uvicorn.Server] = [] + self.worker_urls: list[str] = [] + self.router_url = "" + self.gateway_url = "" + + def start(self): + host = "127.0.0.1" + + for rank in range(self.num_workers): + port = _free_port() + app = create_worker_app( + DataWorkerConfig( + rank=rank, + world_size=self.num_workers, + dataloader_num_workers=0, + ) + ) + server, _ = _start_uvicorn(app, host, port) + self.servers.append(server) + self.worker_urls.append(f"http://{host}:{port}") + + for url in self.worker_urls: + _wait_healthy(url) + + router_port = _free_port() + router_app = create_router_app( + RouterConfig( + host=host, + port=router_port, + admin_api_key=ADMIN_KEY, + ) + ) + router_server, _ = _start_uvicorn(router_app, host, router_port) + self.servers.append(router_server) + self.router_url = f"http://{host}:{router_port}" + _wait_healthy(self.router_url) + + gw_port = _free_port() + gw_app = create_gateway_app( + GatewayConfig( + host=host, + port=gw_port, + router_addr=self.router_url, + admin_api_key=ADMIN_KEY, + forward_timeout=30.0, + router_timeout=5.0, + ) + ) + gw_server, _ = _start_uvicorn(gw_app, host, gw_port) + self.servers.append(gw_server) + self.gateway_url = f"http://{host}:{gw_port}" + _wait_healthy(self.gateway_url) + + with httpx.Client(timeout=5.0) as client: + for wurl in self.worker_urls: + resp = client.post( + f"{self.router_url}/register", + json={"worker_addr": wurl}, + headers={"Authorization": f"Bearer {ADMIN_KEY}"}, + ) + assert resp.status_code == 200 + + def stop(self): + for server in reversed(self.servers): + server.should_exit = True + + def register_dataset( + self, + client: httpx.Client, + dataset_path: str, + dataset_id: str = "test", + dataset_type: str = "rl", + shuffle: bool = False, + ) -> dict: + resp = client.post( + f"{self.gateway_url}/v1/datasets/register", + json={ + "dataset_id": dataset_id, + "dataset_path": dataset_path, + "dataset_type": dataset_type, + "seed": 42, + "shuffle": shuffle, + }, + headers={"Authorization": f"Bearer {ADMIN_KEY}"}, + timeout=30.0, + ) + assert resp.status_code == 200, resp.text + return resp.json() + + def advance_epoch(self, client: httpx.Client, api_key: str, epoch: int): + resp = client.post( + f"{self.gateway_url}/v1/epochs/advance", + json={"epoch": epoch}, + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10.0, + ) + assert resp.status_code == 200, resp.text + return resp.json() + + +class _GatewayControllerAdapter: + def __init__(self, gateway_url: str, admin_key: str): + self._gateway_url = gateway_url + self._admin_key = admin_key + + def register_dataset( + self, + dataset_id: str, + dataset_path: str, + dataset_type: str, + dataset_kwargs: dict | None = None, + tokenizer_or_processor_path: str = "", + split: str = "train", + seed: int = 42, + shuffle: bool = False, + drop_last: bool = True, + max_length: int | None = None, + ) -> dict: + with httpx.Client(timeout=30.0) as client: + resp = client.post( + f"{self._gateway_url}/v1/datasets/register", + headers={"Authorization": f"Bearer {self._admin_key}"}, + json={ + "dataset_id": dataset_id, + "dataset_path": dataset_path, + "dataset_type": dataset_type, + "dataset_kwargs": dataset_kwargs or {}, + "tokenizer_or_processor_path": tokenizer_or_processor_path, + "split": split, + "seed": seed, + "shuffle": shuffle, + "drop_last": drop_last, + "max_length": max_length, + }, + ) + assert resp.status_code == 200, resp.text + payload = resp.json() + payload["total_samples"] = payload["dataset_size"] + return payload + + def unregister_dataset(self, dataset_id: str) -> None: + with httpx.Client(timeout=10.0) as client: + resp = client.post( + f"{self._gateway_url}/v1/datasets/unregister", + headers={"Authorization": f"Bearer {self._admin_key}"}, + json={"dataset_id": dataset_id}, + ) + assert resp.status_code == 200, resp.text + + def _gateway_post(self, endpoint: str, api_key: str, payload: dict): + with httpx.Client(timeout=30.0) as client: + resp = client.post( + f"{self._gateway_url}{endpoint}", + headers={"Authorization": f"Bearer {api_key}"}, + json=payload, + ) + assert resp.status_code == 200, resp.text + return resp.json() + + +def _collect_epoch_indices(dl: StatefulDataLoader, epoch: int) -> list[int]: + if hasattr(dl, "sampler") and hasattr(dl.sampler, "set_epoch"): + dl.sampler.set_epoch(epoch) + indices: list[int] = [] + for batch in dl: + for item in batch: + indices.append(int(item["idx"])) + return indices + + +@pytest.fixture +def gsm8k_path(tmp_path): + """Create a minimal synthetic dataset mimicking GSM8K structure.""" + ds = _make_dataset(100) + path = str(tmp_path / "test_dataset") + ds.save_to_disk(path) + return path + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (10, 3, 1), + (10, 3, 2), + (10, 5, 1), + (10, 5, 2), + (100, 32, 1), + (100, 32, 4), + (7, 3, 2), + (15, 4, 3), + ], +) +def test_steps_per_epoch_matches_local( + n_samples: int, batch_size: int, num_workers: int +): + """steps_per_epoch = total_samples // batch_size (drop_last=True default).""" + dataset = _make_dataset(n_samples) + local_info = _local_dataloader_info(dataset, batch_size, num_workers) + + assert local_info["total_steps"] == local_info["total_samples"] // batch_size + assert local_info["total_steps"] == len(local_info["all_batches"]) + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (10, 3, 1), + (10, 3, 2), + (100, 32, 1), + (100, 32, 4), + ], +) +def test_local_dataloader_epoch_iteration( + n_samples: int, batch_size: int, num_workers: int +): + """Local DataLoader yields exactly len(dl) batches per epoch across all shards.""" + + dataset = _make_dataset(n_samples) + local_info = _local_dataloader_info(dataset, batch_size, num_workers) + + for batch in local_info["all_batches"]: + assert len(batch) == batch_size + + assert len(local_info["all_batches"]) == local_info["total_steps"] + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (20, 5, 1), + (20, 5, 2), + (50, 10, 3), + ], +) +def test_data_service_epoch_matches_local( + tmp_path, n_samples: int, batch_size: int, num_workers: int +): + dataset = _make_dataset(n_samples) + ds_path = str(tmp_path / "ds") + dataset.save_to_disk(ds_path) + + stack = _DataServiceStack(num_workers=num_workers, batch_size=batch_size) + stack.start() + + try: + cfg = TrainDatasetConfig( + path=ds_path, + type="rl", + batch_size=batch_size, + shuffle=False, + drop_last=True, + num_workers=0, + ) + local_dl = create_dataloader(dataset, rank=0, world_size=1, dataset_config=cfg) + + controller: Any = _GatewayControllerAdapter(stack.gateway_url, ADMIN_KEY) + rdataset = RDataset(path=ds_path, type="rl", split="train") + rdataset.connect( + controller, + dataset_id=f"epoch-local-{uuid.uuid4().hex[:8]}", + shuffle=False, + drop_last=True, + ) + remote_dl = create_dataloader( + rdataset, rank=0, world_size=1, dataset_config=cfg + ) + + try: + remote_indices = _collect_epoch_indices(remote_dl, 0) + local_indices = _collect_epoch_indices(local_dl, 0) + assert remote_indices == local_indices[: len(remote_indices)] + finally: + rdataset.close() + finally: + stack.stop() + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers,num_epochs", + [ + (20, 5, 1, 3), + (20, 5, 2, 2), + ], +) +def test_data_service_multi_epoch( + tmp_path, n_samples: int, batch_size: int, num_workers: int, num_epochs: int +): + dataset = _make_dataset(n_samples) + ds_path = str(tmp_path / "ds") + dataset.save_to_disk(ds_path) + + stack = _DataServiceStack(num_workers=num_workers, batch_size=batch_size) + stack.start() + + try: + cfg = TrainDatasetConfig( + path=ds_path, + type="rl", + batch_size=batch_size, + shuffle=True, + drop_last=True, + num_workers=0, + ) + local_dl = create_dataloader(dataset, rank=0, world_size=1, dataset_config=cfg) + + controller: Any = _GatewayControllerAdapter(stack.gateway_url, ADMIN_KEY) + rdataset = RDataset(path=ds_path, type="rl", split="train") + rdataset.connect( + controller, + dataset_id=f"epoch-multi-{uuid.uuid4().hex[:8]}", + shuffle=True, + drop_last=True, + ) + remote_dl = create_dataloader( + rdataset, rank=0, world_size=1, dataset_config=cfg + ) + + try: + for epoch in range(num_epochs): + remote_epoch = _collect_epoch_indices(remote_dl, epoch) + local_epoch = _collect_epoch_indices(local_dl, epoch) + assert remote_epoch == local_epoch + finally: + rdataset.close() + finally: + stack.stop() + + +# --- Unit tests for sharding math (no service needed) --- + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (10, 3, 1), + (10, 3, 2), + (10, 3, 3), + (7, 3, 2), + (7, 3, 3), + (100, 32, 4), + (1, 1, 1), + (5, 10, 1), + (5, 10, 2), + ], +) +def test_train_drop_last_no_incomplete_batches( + n_samples: int, batch_size: int, num_workers: int +): + """With drop_last=True, every batch has exactly batch_size samples.""" + dataset = _make_dataset(n_samples) + info = _local_dataloader_info(dataset, batch_size, num_workers, drop_last=True) + for batch in info["all_batches"]: + assert len(batch) == batch_size + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (10, 3, 1), + (10, 3, 2), + (10, 3, 3), + (7, 3, 2), + (7, 3, 3), + (100, 32, 4), + (1, 1, 1), + (5, 10, 1), + (5, 10, 2), + ], +) +def test_valid_drop_last_false_preserves_all_data( + n_samples: int, batch_size: int, num_workers: int +): + """With drop_last=False, ALL samples appear in the output batches.""" + dataset = _make_dataset(n_samples) + info = _local_dataloader_info(dataset, batch_size, num_workers, drop_last=False) + + total_yielded = sum(len(b) for b in info["all_batches"]) + assert total_yielded == n_samples, ( + f"Expected all {n_samples} samples, got {total_yielded}. " + f"shard_sizes={info['shard_sizes']}" + ) + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (7, 3, 2), + (7, 3, 3), + (10, 4, 3), + (11, 5, 3), + ], +) +def test_uneven_shard_steps_sum_is_correct( + n_samples: int, batch_size: int, num_workers: int +): + """Computed total_steps == actual number of batches yielded.""" + dataset = _make_dataset(n_samples) + for drop_last in (True, False): + info = _local_dataloader_info( + dataset, batch_size, num_workers, drop_last=drop_last + ) + assert info["total_steps"] == len(info["all_batches"]), ( + f"drop_last={drop_last}: total_steps={info['total_steps']} " + f"but got {len(info['all_batches'])} batches. " + f"shard_sizes={info['shard_sizes']}" + ) + + +@pytest.mark.parametrize("n_samples", [10, 7, 15, 100]) +@pytest.mark.parametrize("batch_size", [3, 5, 32]) +def test_single_worker_matches_single_process_dataloader( + n_samples: int, batch_size: int +): + """With 1 worker, data service steps == single-process DataLoader steps.""" + dataset = _make_dataset(n_samples) + for drop_last in (True, False): + info = _local_dataloader_info( + dataset, batch_size, num_workers=1, drop_last=drop_last + ) + dl = StatefulDataLoader( + dataset, + batch_size=batch_size, + drop_last=drop_last, + collate_fn=_identity_collate, + ) + assert info["total_steps"] == len(dl), ( + f"drop_last={drop_last}: 1-worker service steps={info['total_steps']} " + f"!= single-process steps={len(dl)}" + ) + + +@pytest.mark.parametrize( + "n_samples,batch_size,num_workers", + [ + (12, 2, 4), + (10, 3, 2), + (10, 3, 3), + (7, 3, 2), + (100, 32, 4), + ], +) +def test_multi_worker_matches_single_process( + n_samples: int, batch_size: int, num_workers: int +): + """Multi-worker data service yields same steps as single-process DataLoader.""" + dataset = _make_dataset(n_samples) + for drop_last in (True, False): + multi = _local_dataloader_info( + dataset, batch_size, num_workers, drop_last=drop_last + ) + dl = StatefulDataLoader( + dataset, + batch_size=batch_size, + drop_last=drop_last, + collate_fn=_identity_collate, + ) + assert multi["total_steps"] == len(dl), ( + f"drop_last={drop_last}, workers={num_workers}: " + f"multi-worker steps={multi['total_steps']} != " + f"single-process steps={len(dl)}" + ) diff --git a/tests/infra/data_service/test_gateway.py b/tests/infra/data_service/test_gateway.py new file mode 100644 index 0000000000..833fdd4578 --- /dev/null +++ b/tests/infra/data_service/test_gateway.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +# pyright: reportMissingImports=false +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +import pytest_asyncio + +from areal.infra.data_service.gateway.app import create_gateway_app +from areal.infra.data_service.gateway.config import GatewayConfig + +ADMIN_KEY = "test-admin-key" +WORKER_ADDR = "http://worker-1:8000" +WORKER_ADDR_2 = "http://worker-2:8000" +MODULE = "areal.infra.data_service.gateway.app" + + +@pytest.fixture +def config(): + return GatewayConfig( + host="127.0.0.1", + port=18090, + admin_api_key=ADMIN_KEY, + router_addr="http://mock-router:8091", + ) + + +@pytest_asyncio.fixture +async def client(config): + app = create_gateway_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +def admin_headers(): + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +async def _register_dataset(client, dataset_id: str = "train-sample") -> dict: + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR] + mock_broadcast.return_value = [ + { + "addr": WORKER_ADDR, + "status": 200, + "data": {"steps_per_epoch": 10, "dataset_size": 100}, + } + ] + resp = await client.post( + "/v1/datasets/register", + json={"dataset_id": dataset_id, "dataset_path": "/tmp/sample.jsonl"}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + return resp.json() + + +class TestGatewayHealth: + @pytest.mark.asyncio + async def test_health_no_auth_required(self, client): + resp = await client.get("/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + @pytest.mark.asyncio + async def test_health_returns_router_addr(self, client, config): + resp = await client.get("/health") + assert resp.status_code == 200 + assert resp.json()["router_addr"] == config.router_addr + + +class TestGatewayAuth: + @pytest.mark.asyncio + async def test_fetch_samples_no_auth_401(self, client): + resp = await client.post("/v1/samples/fetch", json={"indices": [0]}) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_fetch_samples_bad_key_401(self, client): + resp = await client.post( + "/v1/samples/fetch", + json={"indices": [0]}, + headers={"Authorization": "Bearer unknown-key"}, + ) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_admin_endpoint_with_dataset_key_403(self, client): + resp = await client.post( + "/v1/datasets/register", + json={"dataset_id": "d1", "dataset_path": "/tmp/data.jsonl"}, + headers={"Authorization": "Bearer ds-not-admin"}, + ) + assert resp.status_code == 403 + + +class TestDatasetRegistration: + @pytest.mark.asyncio + async def test_register_dataset_returns_api_key(self, client): + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.return_value = [ + { + "addr": WORKER_ADDR, + "status": 200, + "data": {"steps_per_epoch": 12, "dataset_size": 120}, + }, + { + "addr": WORKER_ADDR_2, + "status": 200, + "data": {"steps_per_epoch": 12, "dataset_size": 120}, + }, + ] + + resp = await client.post( + "/v1/datasets/register", + json={ + "dataset_id": "dataset-a", + "dataset_path": "/tmp/a.jsonl", + }, + headers=admin_headers(), + ) + + assert resp.status_code == 200 + payload = resp.json() + assert payload["api_key"].startswith("ds-") + assert payload["dataset_id"] == "dataset-a" + assert payload["dataset_size"] == 240 + + @pytest.mark.asyncio + async def test_register_then_fetch_uses_dataset_key(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-b") + dataset_key = reg_payload["api_key"] + + mock_client = AsyncMock() + mock_client.post.return_value = httpx.Response( + 200, + json={"samples": [{"text": "hello"}]}, + ) + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_client + + with ( + patch(f"{MODULE}._query_router", new_callable=AsyncMock) as mock_route, + patch(f"{MODULE}.httpx.AsyncClient", return_value=mock_cm), + ): + mock_route.return_value = WORKER_ADDR + resp = await client.post( + "/v1/samples/fetch", + json={"indices": [0]}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 200 + call_args = mock_client.post.await_args + assert call_args[0][0] == f"{WORKER_ADDR}/v1/samples/fetch" + assert call_args[1]["json"]["dataset_id"] == "dataset-b" + + @pytest.mark.asyncio + async def test_unregister_revokes_key(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-c") + dataset_key = reg_payload["api_key"] + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR] + mock_broadcast.return_value = [ + {"addr": WORKER_ADDR, "status": 200, "data": {}} + ] + resp = await client.post( + "/v1/datasets/unregister", + json={"dataset_id": "dataset-c"}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + + after_revoke = await client.post( + "/v1/samples/fetch", + json={"indices": [0]}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + assert after_revoke.status_code == 401 + + @pytest.mark.asyncio + async def test_register_failure_rolls_back_successful_workers(self, client): + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.side_effect = [ + [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + {"addr": WORKER_ADDR_2, "status": 500, "error": "boom"}, + ], + [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + ], + ] + + resp = await client.post( + "/v1/datasets/register", + json={ + "dataset_id": "dataset-rollback", + "dataset_path": "/tmp/a.jsonl", + }, + headers=admin_headers(), + ) + + assert resp.status_code == 502 + assert mock_broadcast.await_count == 2 + rollback_call = mock_broadcast.await_args_list[1] + assert rollback_call.args[0] == [WORKER_ADDR] + assert rollback_call.args[1] == "/datasets/unload" + assert rollback_call.args[2] == {"dataset_id": "dataset-rollback"} + + @pytest.mark.asyncio + async def test_reregister_revokes_old_key(self, client): + first = await _register_dataset(client, dataset_id="dataset-rekey") + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR] + mock_broadcast.return_value = [ + { + "addr": WORKER_ADDR, + "status": 200, + "data": {"steps_per_epoch": 10, "dataset_size": 100}, + } + ] + second_resp = await client.post( + "/v1/datasets/register", + json={ + "dataset_id": "dataset-rekey", + "dataset_path": "/tmp/sample.jsonl", + }, + headers=admin_headers(), + ) + + assert second_resp.status_code == 200 + second = second_resp.json() + assert first["api_key"] != second["api_key"] + + old_key_resp = await client.post( + "/v1/samples/fetch", + json={"indices": [0]}, + headers={"Authorization": f"Bearer {first['api_key']}"}, + ) + assert old_key_resp.status_code == 401 + + +class TestBroadcastEndpoints: + @pytest.mark.asyncio + async def test_epoch_advance_broadcasts_to_all_workers(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-f") + dataset_key = reg_payload["api_key"] + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.return_value = [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + {"addr": WORKER_ADDR_2, "status": 200, "data": {}}, + ] + resp = await client.post( + "/v1/epochs/advance", + json={"epoch": 7}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 200 + assert resp.json()["workers_reset"] == 2 + + @pytest.mark.asyncio + async def test_epoch_advance_no_workers_returns_503(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-noworkers") + dataset_key = reg_payload["api_key"] + + with patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers: + mock_workers.return_value = [] + resp = await client.post( + "/v1/epochs/advance", + json={"epoch": 7}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 503 + + @pytest.mark.asyncio + async def test_epoch_advance_worker_failure_returns_502(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-partial") + dataset_key = reg_payload["api_key"] + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.return_value = [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + {"addr": WORKER_ADDR_2, "status": 500, "error": "boom"}, + ] + resp = await client.post( + "/v1/epochs/advance", + json={"epoch": 7}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 502 + + @pytest.mark.asyncio + async def test_state_save_broadcasts_to_all_workers(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-g") + dataset_key = reg_payload["api_key"] + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.return_value = [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + {"addr": WORKER_ADDR_2, "status": 200, "data": {}}, + ] + resp = await client.post( + "/v1/state/save", + json={"path": "/tmp/ckpt"}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 200 + assert resp.json() == {"status": "ok", "path": "/tmp/ckpt"} + + @pytest.mark.asyncio + async def test_state_load_broadcasts_to_all_workers(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-h") + dataset_key = reg_payload["api_key"] + + with ( + patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers, + patch( + f"{MODULE}._broadcast_to_workers", new_callable=AsyncMock + ) as mock_broadcast, + ): + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + mock_broadcast.return_value = [ + {"addr": WORKER_ADDR, "status": 200, "data": {}}, + {"addr": WORKER_ADDR_2, "status": 200, "data": {}}, + ] + resp = await client.post( + "/v1/state/load", + json={"path": "/tmp/ckpt"}, + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + +class TestStatusAndWorkers: + @pytest.mark.asyncio + async def test_workers_returns_router_workers(self, client): + with patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers: + mock_workers.return_value = [WORKER_ADDR, WORKER_ADDR_2] + resp = await client.get("/v1/workers", headers=admin_headers()) + + assert resp.status_code == 200 + assert resp.json() == { + "workers": [{"addr": WORKER_ADDR}, {"addr": WORKER_ADDR_2}] + } + + @pytest.mark.asyncio + async def test_status_returns_dataset_id(self, client): + reg_payload = await _register_dataset(client, dataset_id="dataset-status") + dataset_key = reg_payload["api_key"] + + with patch( + f"{MODULE}._query_router", new_callable=AsyncMock + ) as mock_query_router: + mock_query_router.side_effect = RuntimeError("router unavailable") + resp = await client.get( + "/v1/status", + headers={"Authorization": f"Bearer {dataset_key}"}, + ) + + assert resp.status_code == 200 + payload = resp.json() + assert payload["status"] == "ok" + assert payload["dataset_id"] == "dataset-status" + + +class TestShutdown: + @pytest.mark.asyncio + async def test_shutdown_requires_admin_key(self, client): + resp = await client.post( + "/v1/shutdown", + headers={"Authorization": "Bearer ds-not-admin"}, + ) + assert resp.status_code == 403 + + @pytest.mark.asyncio + async def test_shutdown_returns_ok(self, client): + with patch( + f"{MODULE}._get_all_worker_addrs", new_callable=AsyncMock + ) as mock_workers: + mock_workers.return_value = [] + resp = await client.post("/v1/shutdown", headers=admin_headers()) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/tests/infra/data_service/test_guard.py b/tests/infra/data_service/test_guard.py new file mode 100644 index 0000000000..90bc816f72 --- /dev/null +++ b/tests/infra/data_service/test_guard.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + +from areal.infra.data_service.guard.app import ( + GuardState, + cleanup_forked_children, + create_app, +) + + +@pytest.fixture() +def state() -> GuardState: + s = GuardState() + s.server_host = "10.0.0.1" + s.experiment_name = "test-exp" + s.trial_name = "test-trial" + s.role = "test-role" + s.worker_index = 0 + return s + + +@pytest.fixture() +def client(state: GuardState): + app = create_app(state) + app.config["TESTING"] = True + with app.test_client() as c: + yield c + + +def _make_mock_process(pid: int = 12345, running: bool = True) -> MagicMock: + proc = MagicMock(spec=subprocess.Popen) + proc.pid = pid + proc.poll.return_value = None if running else 0 + return proc + + +def test_health_returns_200(client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "healthy" + assert data["forked_children"] == 0 + + +@patch("areal.infra.rpc.guard.app.find_free_ports") +def test_alloc_ports_success(mock_find, client, state: GuardState): + mock_find.return_value = [9001, 9002] + resp = client.post("/alloc_ports", json={"count": 2}) + assert resp.status_code == 200 + data = resp.get_json() + assert data["ports"] == [9001, 9002] + assert data["host"] == "10.0.0.1" + assert state.allocated_ports == {9001, 9002} + + +@patch("areal.infra.rpc.guard.app.run_with_streaming_logs") +def test_fork_raw_command_success(mock_run, client, state: GuardState): + mock_proc = _make_mock_process(pid=42) + mock_run.return_value = mock_proc + + resp = client.post( + "/fork", + json={ + "role": "worker", + "worker_index": 1, + "raw_cmd": ["python", "-m", "module", "--port", "8001"], + }, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "success" + assert data["host"] == "10.0.0.1" + assert data["pid"] == 42 + assert ("worker", 1) in state.forked_children_map + + +@patch("areal.infra.rpc.guard.app.kill_process_tree") +def test_kill_known_worker(mock_kill, client, state: GuardState): + mock_proc = _make_mock_process(pid=123) + state.forked_children.append(mock_proc) + state.forked_children_map[("test", 0)] = mock_proc + + resp = client.post("/kill_forked_worker", json={"role": "test", "worker_index": 0}) + assert resp.status_code == 200 + assert ("test", 0) not in state.forked_children_map + mock_kill.assert_called_once_with(123, timeout=3, graceful=True) + + +@patch("areal.infra.rpc.guard.app.kill_process_tree") +def test_cleanup_kills_all_running_children(mock_kill, state: GuardState): + proc1 = _make_mock_process(pid=100) + proc2 = _make_mock_process(pid=200) + state.forked_children = [proc1, proc2] + state.forked_children_map = {("a", 0): proc1, ("b", 0): proc2} + + cleanup_forked_children(state) + + assert mock_kill.call_count == 2 + assert state.forked_children == [] + assert state.forked_children_map == {} diff --git a/tests/infra/data_service/test_performance.py b/tests/infra/data_service/test_performance.py new file mode 100644 index 0000000000..4adce1908e --- /dev/null +++ b/tests/infra/data_service/test_performance.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import time +from unittest.mock import patch + +import httpx +import pytest + +from areal.infra.data_service.worker.app import create_worker_app +from areal.infra.data_service.worker.config import DataWorkerConfig + +DATASET_ID = "perf-test" +WORKER_CONFIG = DataWorkerConfig( + host="127.0.0.1", + port=0, + rank=0, + world_size=1, + dataloader_num_workers=0, +) + + +def _make_mock_dataset(n: int): + from datasets import Dataset + + return Dataset.from_dict( + {"text": [f"sample_{i}" for i in range(n)], "label": list(range(n))} + ) + + +def _load_payload(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "dataset_id": DATASET_ID, + "dataset_path": "test/dataset", + "dataset_type": "rl", + "seed": 42, + "shuffle": False, + } + payload.update(overrides) + return payload + + +@pytest.mark.asyncio +class TestWorkerSampleFetchPerformance: + async def test_sample_fetch_throughput(self): + n = 100 + with ( + patch( + "areal.infra.data_service.worker.app._get_custom_dataset" + ) as mock_get, + patch( + "areal.infra.data_service.worker.app.load_hf_processor_and_tokenizer" + ) as mock_load, + ): + mock_get.return_value = _make_mock_dataset(n) + mock_load.return_value = (None, None) + + app = create_worker_app(WORKER_CONFIG) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test" + ) as client: + resp = await client.post("/datasets/load", json=_load_payload()) + assert resp.status_code == 200 + + batch_size = 10 + t0 = time.perf_counter() + for start in range(0, n, batch_size): + indices = list(range(start, min(start + batch_size, n))) + resp = await client.post( + "/v1/samples/fetch", + json={"dataset_id": DATASET_ID, "indices": indices}, + ) + assert resp.status_code == 200 + assert len(resp.json()["samples"]) == len(indices) + elapsed = time.perf_counter() - t0 + + per_item_ms = (elapsed / n) * 1000 + assert per_item_ms < 50, ( + f"Worker sample fetch: {per_item_ms:.1f}ms per item " + f"(expected < 50ms for in-memory mock data via ASGI)" + ) + + async def test_single_large_sample_fetch(self): + n = 50 + with ( + patch( + "areal.infra.data_service.worker.app._get_custom_dataset" + ) as mock_get, + patch( + "areal.infra.data_service.worker.app.load_hf_processor_and_tokenizer" + ) as mock_load, + ): + mock_get.return_value = _make_mock_dataset(n) + mock_load.return_value = (None, None) + + app = create_worker_app(WORKER_CONFIG) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test" + ) as client: + resp = await client.post("/datasets/load", json=_load_payload()) + assert resp.status_code == 200 + + indices = list(range(n)) + t0 = time.perf_counter() + resp = await client.post( + "/v1/samples/fetch", + json={"dataset_id": DATASET_ID, "indices": indices}, + ) + elapsed = time.perf_counter() - t0 + + assert resp.status_code == 200 + assert len(resp.json()["samples"]) == n + + assert elapsed < 5.0, ( + f"Single sample fetch of {n} items took {elapsed:.2f}s (expected < 5s)" + ) diff --git a/tests/infra/data_service/test_router.py b/tests/infra/data_service/test_router.py new file mode 100644 index 0000000000..b9cd74a498 --- /dev/null +++ b/tests/infra/data_service/test_router.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import asyncio + +import httpx +import pytest +import pytest_asyncio + +from areal.infra.data_service.router.app import create_router_app +from areal.infra.data_service.router.config import RouterConfig + +ADMIN_KEY = "test-admin-key" +WORKER_1 = "http://worker-1:8000" +WORKER_2 = "http://worker-2:8000" +WORKER_3 = "http://worker-3:8000" + + +@pytest.fixture +def config(): + return RouterConfig( + host="127.0.0.1", + port=18091, + admin_api_key=ADMIN_KEY, + poll_interval=999, + ) + + +@pytest_asyncio.fixture +async def client(config): + app = create_router_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +def admin_headers(): + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +class TestHealth: + @pytest.mark.asyncio + async def test_health_returns_200_with_worker_count(self, client): + resp = await client.get("/health") + assert resp.status_code == 200 + payload = resp.json() + assert payload["status"] == "ok" + assert payload["workers"] == 0 + + @pytest.mark.asyncio + async def test_health_shows_healthy_count(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + await client.post( + "/register", json={"worker_addr": WORKER_2}, headers=admin_headers() + ) + + resp = await client.get("/health") + assert resp.status_code == 200 + payload = resp.json() + assert payload["workers"] == 2 + assert payload["healthy"] == 2 + + +class TestWorkerRegistration: + @pytest.mark.asyncio + async def test_register_worker_success(self, client): + resp = await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + health = await client.get("/health") + assert health.json()["workers"] == 1 + + @pytest.mark.asyncio + async def test_register_duplicate_noop(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + resp = await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + assert resp.status_code == 200 + + health = await client.get("/health") + assert health.json()["workers"] == 1 + + @pytest.mark.asyncio + async def test_register_no_auth_401(self, client): + resp = await client.post("/register", json={"worker_addr": WORKER_1}) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_register_wrong_key_403(self, client): + resp = await client.post( + "/register", + json={"worker_addr": WORKER_1}, + headers={"Authorization": "Bearer wrong-key"}, + ) + assert resp.status_code == 403 + + @pytest.mark.asyncio + async def test_unregister_worker_removes(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + resp = await client.post( + "/unregister", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + health = await client.get("/health") + assert health.json()["workers"] == 0 + + +class TestRouting: + @pytest.mark.asyncio + async def test_route_round_robin_cycles(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + await client.post( + "/register", json={"worker_addr": WORKER_2}, headers=admin_headers() + ) + + picks = [] + for _ in range(4): + resp = await client.post("/route", headers=admin_headers()) + assert resp.status_code == 200 + picks.append(resp.json()["worker_addr"]) + + assert picks == [WORKER_1, WORKER_2, WORKER_1, WORKER_2] + + @pytest.mark.asyncio + async def test_route_no_workers_503(self, client): + resp = await client.post("/route", headers=admin_headers()) + assert resp.status_code == 503 + + @pytest.mark.asyncio + async def test_route_skips_unhealthy(self, client): + await client.post( + "/register", + json={"worker_addr": WORKER_1}, + headers=admin_headers(), + ) + await client.post( + "/register", + json={"worker_addr": WORKER_2}, + headers=admin_headers(), + ) + + app = client._transport.app + app.state.worker_healthy[WORKER_2] = False + + for _ in range(3): + resp = await client.post("/route", headers=admin_headers()) + assert resp.status_code == 200 + assert resp.json()["worker_addr"] == WORKER_1 + + @pytest.mark.asyncio + async def test_route_no_auth_401(self, client): + resp = await client.post("/route") + assert resp.status_code == 401 + + +class TestWorkersList: + @pytest.mark.asyncio + async def test_workers_returns_all_registered(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + await client.post( + "/register", json={"worker_addr": WORKER_2}, headers=admin_headers() + ) + await client.post( + "/register", json={"worker_addr": WORKER_3}, headers=admin_headers() + ) + + resp = await client.get("/workers", headers=admin_headers()) + assert resp.status_code == 200 + workers = resp.json()["workers"] + assert {w["addr"] for w in workers} == {WORKER_1, WORKER_2, WORKER_3} + + @pytest.mark.asyncio + async def test_workers_shows_health_status(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + + resp = await client.get("/workers", headers=admin_headers()) + assert resp.status_code == 200 + workers = resp.json()["workers"] + assert workers == [{"addr": WORKER_1, "healthy": True}] + + @pytest.mark.asyncio + async def test_workers_no_auth_401(self, client): + resp = await client.get("/workers") + assert resp.status_code == 401 + + +class TestConcurrentRouting: + @pytest.mark.asyncio + async def test_concurrent_routes_distribute_evenly(self, client): + await client.post( + "/register", json={"worker_addr": WORKER_1}, headers=admin_headers() + ) + await client.post( + "/register", json={"worker_addr": WORKER_2}, headers=admin_headers() + ) + + async def route_once() -> str: + resp = await client.post("/route", headers=admin_headers()) + assert resp.status_code == 200 + return resp.json()["worker_addr"] + + routes = await asyncio.gather(*(route_once() for _ in range(10))) + assert routes.count(WORKER_1) == 5 + assert routes.count(WORKER_2) == 5 diff --git a/tests/infra/data_service/test_trainer_compat.py b/tests/infra/data_service/test_trainer_compat.py new file mode 100644 index 0000000000..5a895f5809 --- /dev/null +++ b/tests/infra/data_service/test_trainer_compat.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import inspect +from pathlib import Path + +import pytest + +from areal.infra.data_service.controller.config import DataServiceConfig +from areal.utils.data import cycle_dataloader + + +class _SamplerWithSetEpoch: + def __init__(self): + self.epochs: list[int] = [] + + def set_epoch(self, epoch: int) -> None: + self.epochs.append(epoch) + + +class _FiniteDataloader: + def __init__(self, batches_per_epoch: int = 2, with_sampler: bool = True): + self.batches_per_epoch = batches_per_epoch + if with_sampler: + self.sampler = _SamplerWithSetEpoch() + + def __iter__(self): + for i in range(self.batches_per_epoch): + yield {"batch": i} + + +class _DataloaderWithSamplerNoSetEpoch: + class _Sampler: + pass + + def __init__(self): + self.sampler = self._Sampler() + + def __iter__(self): + yield {"batch": 0} + + +class TestCycleDataloader: + def test_cycle_yields_correct_batches(self): + dl = _FiniteDataloader(batches_per_epoch=3) + gen = cycle_dataloader(dl, num_cycles=2) + all_batches = list(gen) + assert len(all_batches) == 6 + + def test_cycle_calls_set_epoch(self): + dl = _FiniteDataloader(batches_per_epoch=2) + gen = cycle_dataloader(dl, num_cycles=3) + list(gen) + assert dl.sampler.epochs == [0, 1, 2] + + def test_cycle_no_sampler_does_not_crash(self): + dl = _FiniteDataloader(batches_per_epoch=2, with_sampler=False) + gen = cycle_dataloader(dl, num_cycles=1) + all_batches = list(gen) + assert len(all_batches) == 2 + + def test_cycle_sampler_without_set_epoch(self): + dl = _DataloaderWithSamplerNoSetEpoch() + gen = cycle_dataloader(dl, num_cycles=1) + all_batches = list(gen) + assert len(all_batches) == 1 + + def test_infinite_cycle_generates_beyond_one_epoch(self): + dl = _FiniteDataloader(batches_per_epoch=2) + gen = cycle_dataloader(dl) + collected = [] + for i, batch in enumerate(gen): + collected.append(batch) + if i >= 5: + break + assert len(collected) == 6 + assert dl.sampler.epochs == [0, 1, 2] + + +class TestDataServiceConfig: + def test_from_dataset_config_defaults(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="sft") + svc_cfg = DataServiceConfig.from_dataset_config(ds_cfg) + + assert svc_cfg.num_workers >= 1 + assert svc_cfg.num_workers == 1 + assert svc_cfg.dataloader_num_workers == 0 + assert svc_cfg.scheduling_strategy.type == "separation" + + def test_from_dataset_config_custom_workers(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig( + path="dummy", + type="sft", + num_workers=4, + num_dataset_workers=3, + ) + svc_cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert svc_cfg.num_workers == 3 + assert svc_cfg.dataloader_num_workers == 4 + + def test_from_dataset_config_zero_workers_defaults_to_one(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="sft", num_workers=0) + svc_cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert svc_cfg.num_workers == 1 + assert svc_cfg.dataloader_num_workers == 0 + + def test_from_dataset_config_allows_disabling_data_service(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="sft", scheduling_spec=None) + svc_cfg = DataServiceConfig.from_dataset_config(ds_cfg) + assert svc_cfg.scheduling_spec is None + + def test_dataset_config_has_scheduling_spec(self): + from areal.api.cli_args import TrainDatasetConfig + + ds_cfg = TrainDatasetConfig(path="dummy", type="sft") + assert ds_cfg.scheduling_spec is not None + + +class TestTrainerDataServicePath: + def test_data_controller_importable(self): + from areal.infra.data_service import DataController, RDataset + + assert DataController is not None + assert RDataset is not None + + def test_data_controller_config_importable(self): + assert DataServiceConfig is not None + + def test_rdataset_has_required_protocol(self): + from areal.infra.data_service import RDataset + + dataset = RDataset(path="dummy", type="rl", split="train") + + assert hasattr(dataset, "connect") + assert hasattr(dataset, "close") + assert hasattr(dataset, "__len__") + assert hasattr(dataset, "__getitem__") + assert dataset.connected is False + + def test_get_custom_dataset_respects_scheduling_spec_none(self, monkeypatch): + from areal.api.cli_args import TrainDatasetConfig + from areal.dataset import get_custom_dataset + + sentinel = object() + + def _fake_custom_dataset(**_kwargs): + return sentinel + + monkeypatch.setenv("AREAL_SPMD_MODE", "0") + monkeypatch.setattr("areal.dataset._get_custom_dataset", _fake_custom_dataset) + + cfg = TrainDatasetConfig(path="dummy", type="sft", scheduling_spec=None) + dataset = get_custom_dataset(split="train", dataset_config=cfg) + + assert dataset is sentinel + + +class TestGenericDatasetFallback: + def test_none_split_uses_first_available_split(self, tmp_path: Path): + from datasets import Dataset, DatasetDict + + from areal.dataset import _get_custom_dataset + + dataset_path = tmp_path / "dict-ds-none-split" + dataset = DatasetDict({"train": Dataset.from_dict({"x": [1, 2]})}) + dataset.save_to_disk(str(dataset_path)) + + loaded = _get_custom_dataset( + path=str(dataset_path), + type="sft", + split=None, + ) + assert len(loaded) == 2 + + def test_explicit_missing_split_raises_error(self, tmp_path: Path): + from datasets import Dataset, DatasetDict + + from areal.dataset import _get_custom_dataset + + dataset_path = tmp_path / "dict-ds" + dataset = DatasetDict({"train": Dataset.from_dict({"x": [1, 2]})}) + dataset.save_to_disk(str(dataset_path)) + + with pytest.raises(ValueError, match="Requested split 'test' not found"): + _get_custom_dataset( + path=str(dataset_path), + type="sft", + split="test", + ) + + +class TestEmptyDataLoaderCompat: + def test_empty_dataloader_still_works(self): + from areal.trainer.rl_trainer import _EmptyDataLoader + + dataloader = _EmptyDataLoader(batch_size=2, steps_per_epoch=3) + + assert len(dataloader) == 3 + assert dataloader.batch_size == 2 + + batches = [] + for batch in dataloader: + batches.append(batch) + if len(batches) >= 3: + break + assert len(batches) == 3 + + def test_empty_dataloader_state_dict(self): + from areal.trainer.rl_trainer import _EmptyDataLoader + + dataloader = _EmptyDataLoader() + assert dataloader.state_dict() == {} + dataloader.load_state_dict({"some": "state"}) + + +class TestTrainerSignature: + def test_sft_trainer_accepts_dataset_params(self): + from areal.trainer.sft_trainer import SFTTrainer + + sig = inspect.signature(SFTTrainer.__init__) + params = list(sig.parameters.keys()) + assert "train_dataset" in params + assert "valid_dataset" in params + assert "config" in params + + def test_rw_trainer_accepts_dataset_params(self): + from areal.trainer.rw_trainer import RWTrainer + + sig = inspect.signature(RWTrainer.__init__) + params = list(sig.parameters.keys()) + assert "train_dataset" in params + assert "valid_dataset" in params + assert "config" in params + + def test_ppo_trainer_accepts_dataset_params(self): + from areal.trainer.rl_trainer import PPOTrainer + + sig = inspect.signature(PPOTrainer.__init__) + params = list(sig.parameters.keys()) + assert "train_dataset" in params + assert "valid_dataset" in params + assert "config" in params diff --git a/tests/infra/data_service/test_worker.py b/tests/infra/data_service/test_worker.py new file mode 100644 index 0000000000..b75fe33ebd --- /dev/null +++ b/tests/infra/data_service/test_worker.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +# pyright: reportMissingImports=false +from pathlib import Path +from unittest.mock import patch + +import httpx +import pytest +import pytest_asyncio +from datasets import Dataset + +from areal.infra.data_service.worker.app import create_worker_app +from areal.infra.data_service.worker.config import DataWorkerConfig + +DATASET_ID = "test-train" + + +@pytest.fixture +def config() -> DataWorkerConfig: + return DataWorkerConfig( + host="127.0.0.1", + port=0, + rank=0, + world_size=1, + dataloader_num_workers=0, + ) + + +def _make_mock_dataset(n: int = 20) -> Dataset: + return Dataset.from_dict( + { + "text": [f"sample_{i}" for i in range(n)], + "label": list(range(n)), + } + ) + + +def _load_payload(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "dataset_id": DATASET_ID, + "dataset_path": "test/dataset", + "dataset_type": "rl", + "seed": 42, + "shuffle": False, + } + payload.update(overrides) + return payload + + +@pytest_asyncio.fixture +async def client(config: DataWorkerConfig): + app = create_worker_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +@pytest_asyncio.fixture +async def loaded_client(config: DataWorkerConfig): + with ( + patch("areal.infra.data_service.worker.app._get_custom_dataset") as mock_get, + patch( + "areal.infra.data_service.worker.app.load_hf_processor_and_tokenizer" + ) as mock_load, + ): + ds = _make_mock_dataset(8) + mock_get.return_value = ds + mock_load.return_value = (None, None) + + app = create_worker_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + resp = await c.post("/datasets/load", json=_load_payload()) + assert resp.status_code == 200 + yield c + + +@pytest.mark.asyncio +class TestWorkerHealth: + async def test_health_returns_200(self, client: httpx.AsyncClient): + resp = await client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["datasets"] == 0 + + async def test_health_shows_dataset_count(self, loaded_client: httpx.AsyncClient): + resp = await loaded_client.get("/health") + assert resp.status_code == 200 + assert resp.json()["datasets"] == 1 + + +@pytest.mark.asyncio +class TestDatasetLoading: + async def test_load_dataset_returns_steps_per_epoch(self, config: DataWorkerConfig): + with ( + patch( + "areal.infra.data_service.worker.app._get_custom_dataset" + ) as mock_get, + ): + ds = _make_mock_dataset(20) + mock_get.return_value = ds + + app = create_worker_app(config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test" + ) as c: + resp = await c.post("/datasets/load", json=_load_payload()) + + assert resp.status_code == 200 + data = resp.json() + assert data["steps_per_epoch"] > 0 + assert data["dataset_size"] == 20 + + async def test_load_dataset_duplicate_409(self, loaded_client: httpx.AsyncClient): + resp = await loaded_client.post("/datasets/load", json=_load_payload()) + assert resp.status_code == 409 + + async def test_unload_dataset_removes(self, loaded_client: httpx.AsyncClient): + resp = await loaded_client.post( + "/datasets/unload", json={"dataset_id": DATASET_ID} + ) + assert resp.status_code == 200 + + health = await loaded_client.get("/health") + assert health.status_code == 200 + assert health.json()["datasets"] == 0 + + async def test_unload_unknown_dataset_404(self, client: httpx.AsyncClient): + resp = await client.post("/datasets/unload", json={"dataset_id": "unknown"}) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestSampleFetch: + async def test_fetch_samples_returns_data(self, loaded_client: httpx.AsyncClient): + resp = await loaded_client.post( + "/v1/samples/fetch", + json={"dataset_id": DATASET_ID, "indices": [0, 1]}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["samples"]) == 2 + + async def test_fetch_samples_unknown_dataset_404(self, client: httpx.AsyncClient): + resp = await client.post( + "/v1/samples/fetch", + json={"dataset_id": "unknown", "indices": [0]}, + ) + assert resp.status_code == 404 + + async def test_fetch_samples_returns_distinct_items( + self, loaded_client: httpx.AsyncClient + ): + resp = await loaded_client.post( + "/v1/samples/fetch", + json={"dataset_id": DATASET_ID, "indices": [0, 1, 2]}, + ) + assert resp.status_code == 200 + samples = resp.json()["samples"] + assert len(samples) == 3 + assert samples[0] != samples[1] + + +@pytest.mark.asyncio +class TestEpochReset: + async def test_epoch_reset_updates_epoch(self, loaded_client: httpx.AsyncClient): + reset = await loaded_client.post( + "/epoch/reset", json={"dataset_id": DATASET_ID, "epoch": 1} + ) + assert reset.status_code == 200 + assert reset.json()["epoch"] == 1 + + async def test_epoch_reset_unknown_dataset_404(self, client: httpx.AsyncClient): + resp = await client.post( + "/epoch/reset", json={"dataset_id": "unknown", "epoch": 1} + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestStatePersistence: + async def test_state_save_creates_file( + self, loaded_client: httpx.AsyncClient, tmp_path: Path + ): + resp = await loaded_client.post( + "/state/save", json={"dataset_id": DATASET_ID, "path": str(tmp_path)} + ) + assert resp.status_code == 200 + out = resp.json() + assert out["status"] == "ok" + assert (tmp_path / "worker_0.pkl").exists() + + async def test_state_load_restores( + self, loaded_client: httpx.AsyncClient, tmp_path: Path + ): + save = await loaded_client.post( + "/state/save", json={"dataset_id": DATASET_ID, "path": str(tmp_path)} + ) + assert save.status_code == 200 + + load = await loaded_client.post( + "/state/load", json={"dataset_id": DATASET_ID, "path": str(tmp_path)} + ) + assert load.status_code == 200 + assert load.json()["status"] == "ok" + + async def test_state_load_missing_file_404( + self, loaded_client: httpx.AsyncClient, tmp_path: Path + ): + missing = tmp_path / "does-not-exist" + resp = await loaded_client.post( + "/state/load", json={"dataset_id": DATASET_ID, "path": str(missing)} + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestTensorShardEndpoints: + async def test_data_clear_returns_ok(self, client: httpx.AsyncClient): + resp = await client.delete("/data/clear") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["tensor_shards"] == 0 + + async def test_data_shard_not_found_404(self, client: httpx.AsyncClient): + resp = await client.get("/data/nonexistent") + assert resp.status_code == 404 diff --git a/tests/sft/entrypoint.py b/tests/sft/entrypoint.py index e4d2f7b224..4b6a2f731b 100644 --- a/tests/sft/entrypoint.py +++ b/tests/sft/entrypoint.py @@ -7,6 +7,7 @@ from areal import SFTTrainer from areal.api.cli_args import SFTConfig, load_expr_config from areal.dataset import get_custom_dataset +from areal.infra.data_service.rdataset import RDataset from areal.utils.hf_utils import load_hf_tokenizer @@ -24,6 +25,8 @@ def _export_and_commit_stats(self, epoch, epoch_step, global_step): def main() -> None: + os.environ["AREAL_SPMD_MODE"] = "0" + config, _ = load_expr_config(sys.argv[1:], SFTConfig) tokenizer = load_hf_tokenizer(config.tokenizer_path) @@ -32,6 +35,10 @@ def main() -> None: dataset_config=config.train_dataset, tokenizer=tokenizer, ) + assert isinstance(train_dataset, RDataset), ( + "SFT integration test expects get_custom_dataset() to return RDataset " + "(single-controller data service path)." + ) with MinimalSFTTrainer( config,