Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-areal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ jobs:
VIRTUAL_ENV: /AReaL/.venv
run: |
export PATH="/AReaL/.venv/bin:$PATH"
pytest -m "(not slow or ci) and not ${EXCLUDE_BACKEND}" --durations=20 -s -vv tests/test_*.py tests/experimental/
pytest -m "(not slow or ci) and not ${EXCLUDE_BACKEND}" --durations=20 -s -vv tests/test_*.py tests/experimental/ tests/infra/

- name: Run SFT integration tests
env:
Expand Down
16 changes: 15 additions & 1 deletion areal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,21 @@
current_platform,
workflow_context,
)
from .trainer import PPOTrainer, RWTrainer, SFTTrainer


def __getattr__(name: str):
if name in ("PPOTrainer", "RWTrainer", "SFTTrainer"):
from .trainer import PPOTrainer, RWTrainer, SFTTrainer

_map = {
"PPOTrainer": PPOTrainer,
"RWTrainer": RWTrainer,
"SFTTrainer": SFTTrainer,
}
globals().update(_map)
return _map[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = [
"PPOTrainer",
Expand Down
30 changes: 30 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,10 @@ class SchedulerConfig:
class _DatasetConfig:
"""Configuration for dataset loading and preprocessing."""

split: str = field(
default="train",
metadata={"help": "Dataset split to use, e.g., 'train', 'test'."},
)
path: str = field(
default=MISSING,
metadata={
Expand All @@ -2248,6 +2252,12 @@ class _DatasetConfig:
num_workers: int = field(
default=0, metadata={"help": "Number of worker processes for data loading"}
)
num_dataset_workers: int = field(
default=1,
metadata={
"help": "Number of remote data-service worker processes to launch when using scheduling_spec."
},
)
drop_last: bool = field(
default=True, metadata={"help": "Drop the last incomplete batch"}
)
Expand All @@ -2257,6 +2267,22 @@ class _DatasetConfig:
"help": "Maximum token length of sequences in dataset. Longer sequences are filtered out."
},
)
dataset_kwargs: dict[str, Any] = field(
default_factory=dict,
metadata={
"help": "Additional keyword arguments for dataset loading. "
"These are passed to the dataset loading function `get_custom_dataset`."
},
)
scheduling_spec: SchedulingSpec | None = field(
default_factory=lambda: SchedulingSpec(
cpu=1, gpu=0, mem=10, cmd="python3 -m areal.infra.rpc.guard"
),
metadata={
"help": "Scheduling spec for remote data loading workers. "
"If set, dataset loading will be offloaded to a data service with remote workers."
},
)


@dataclass
Expand All @@ -2272,6 +2298,10 @@ class ValidDatasetConfig(_DatasetConfig):
`shuffle` and `drop_last` default to False.
"""

split: str = field(
default="test",
metadata={"help": "Dataset split to use, e.g., 'train', 'test'."},
)
shuffle: bool = field(
default=False, metadata={"help": "Whether to shuffle the dataset"}
)
Expand Down
2 changes: 1 addition & 1 deletion areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class LocalInfServerInfo:

host: str
port: int
process: subprocess.Popen
process: subprocess.Popen | None


@dataclass
Expand Down
51 changes: 46 additions & 5 deletions areal/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from areal.infra.data_service.rdataset import RDataset

VALID_DATASETS = [
"gsm8k",
"clevr_count_70k",
Expand Down Expand Up @@ -120,10 +122,32 @@ def _get_custom_dataset(
**kwargs,
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)
# Fallback: try loading as a generic HuggingFace dataset from disk.
# This supports arbitrary datasets saved via dataset.save_to_disk().
try:
from datasets import DatasetDict, load_from_disk

dataset = load_from_disk(path)
if isinstance(dataset, DatasetDict):
if split is not None:
if split in dataset:
return dataset[split]
available = list(dataset.keys())
raise ValueError(
f"Requested split '{split}' not found in DatasetDict at {path}. "
f"Available splits: {available}"
)
available = list(dataset.keys())
if available:
return dataset[available[0]]
raise ValueError(f"Empty DatasetDict at {path}")
return dataset
except Exception as load_err:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
f"Also failed to load from disk: {load_err}"
)


def get_custom_dataset(
Expand All @@ -132,7 +156,24 @@ def get_custom_dataset(
tokenizer: Optional["PreTrainedTokenizerFast"] = None,
processor: Optional["ProcessorMixin"] = None,
**kwargs,
) -> "Dataset":
) -> "Dataset | RDataset":
from areal.utils.environ import is_single_controller

if (
is_single_controller()
and dataset_config is not None
and dataset_config.scheduling_spec is not None
):
from areal.infra.data_service.rdataset import RDataset

return RDataset(
path=dataset_config.path,
type=dataset_config.type,
split=split,
max_length=dataset_config.max_length,
dataset_kwargs=getattr(dataset_config, "dataset_kwargs", None),
)

if dataset_config is not None:
return _get_custom_dataset(
path=dataset_config.path,
Expand Down
9 changes: 9 additions & 0 deletions areal/infra/data_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from areal.infra.data_service.controller.config import DataServiceConfig
from areal.infra.data_service.controller.controller import DataController
from areal.infra.data_service.rdataset import RDataset

__all__ = [
"DataController",
"DataServiceConfig",
"RDataset",
]
Empty file.
39 changes: 39 additions & 0 deletions areal/infra/data_service/controller/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from dataclasses import dataclass, field

from areal.api.cli_args import SchedulingSpec, SchedulingStrategy


@dataclass
class DataServiceConfig:
"""Internal config for the data service controller.

Constructed from ``_DatasetConfig`` fields by the trainer.
Not exposed to end users directly.
"""

num_workers: int = 1
scheduling_spec: SchedulingSpec = field(
default_factory=lambda: SchedulingSpec(
cpu=1, gpu=0, mem=10, cmd="python3 -m areal.infra.rpc.guard"
),
)
# Always separation — data controller starts before engines.
scheduling_strategy: SchedulingStrategy = field(
default_factory=lambda: SchedulingStrategy(type="separation"),
)
setup_timeout: float = 120.0
dataloader_num_workers: int = 4

@staticmethod
def from_dataset_config(dataset_config) -> DataServiceConfig:
"""Build from a ``_DatasetConfig`` instance."""
return DataServiceConfig(
num_workers=max(dataset_config.num_dataset_workers, 1),
scheduling_spec=dataset_config.scheduling_spec,
dataloader_num_workers=max(dataset_config.num_workers, 0),
)


__all__ = ["DataServiceConfig"]
Loading
Loading