Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions areal/infra/data_service/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _async_initialize(
guard_addr_0 = guard_addrs[0]

try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=False) as session:
# Wave 1: Fork all DataWorkers + Router in parallel
worker_tasks = [
self._async_fork_on_guard(
Expand Down Expand Up @@ -211,6 +211,7 @@ async def _register_workers() -> None:
"DataController initialization failed, rolling back",
exc_info=True,
)
# Kill forked services concurrently
if self._forked_services:
await self._async_kill_forked_services(
list(reversed(self._forked_services))
Expand Down Expand Up @@ -340,7 +341,7 @@ async def _clear_one(session: aiohttp.ClientSession, addr: str) -> None:
except Exception:
logger.debug("Failed to clear batches on %s", addr)

async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=False) as session:
await asyncio.gather(
*(_clear_one(session, addr) for addr in self._worker_addrs),
return_exceptions=True,
Expand All @@ -367,7 +368,6 @@ def destroy(self) -> None:
exc,
)

# Kill forked services concurrently
if self._forked_services:
run_async_task(
self._async_kill_forked_services,
Expand Down Expand Up @@ -488,7 +488,7 @@ async def _kill_one(
exc,
)

async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=False) as session:
await asyncio.gather(
*(_kill_one(session, *svc) for svc in services),
return_exceptions=True,
Expand All @@ -514,7 +514,7 @@ async def _async_gateway_post(
url = f"{self._gateway_addr}{endpoint}"
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout)
timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False
) as session:
async with session.post(
url,
Expand Down
105 changes: 61 additions & 44 deletions areal/infra/data_service/gateway/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from __future__ import annotations

import httpx
import asyncio

import aiohttp
from fastapi import FastAPI, HTTPException, Request

from areal.infra.data_service.gateway.auth import (
Expand All @@ -19,48 +21,58 @@


async def _query_router(router_addr: str, admin_key: str, timeout: float) -> str:
"""Get a worker address from the router via round-robin."""
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False
) as session:
async with session.post(
f"{router_addr}/route",
json={},
headers={"Authorization": f"Bearer {admin_key}"},
)
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Router error: {resp.text}")
return resp.json()["worker_addr"]
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=502, detail=f"Router error: {await resp.text()}"
)
return (await resp.json())["worker_addr"]


async def _get_all_worker_addrs(
router_addr: str, admin_key: str, timeout: float
) -> list[str]:
"""Get all worker addresses from the router."""
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.get(
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False
) as session:
async with session.get(
f"{router_addr}/workers",
headers={"Authorization": f"Bearer {admin_key}"},
)
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Router error: {resp.text}")
return [w["addr"] for w in resp.json()["workers"]]
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=502, detail=f"Router error: {await resp.text()}"
)
return [w["addr"] for w in (await resp.json())["workers"]]


async def _broadcast_to_workers(
worker_addrs: list[str], endpoint: str, payload: dict, timeout: float
) -> list[dict]:
"""Broadcast a POST request to all workers and collect responses."""
results: list[dict] = []
async with httpx.AsyncClient(timeout=timeout) as client:
for addr in worker_addrs:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False
) as session:

async def _send_post(addr: str) -> dict:
try:
resp = await client.post(f"{addr}{endpoint}", json=payload)
try:
data = resp.json()
except Exception:
data = {"raw": resp.text}
results.append({"addr": addr, "status": resp.status_code, "data": data})
async with session.post(f"{addr}{endpoint}", json=payload) as resp:
try:
data = await resp.json()
except Exception:
data = {"raw": await resp.text()}
return {"addr": addr, "status": resp.status, "data": data}
except Exception as exc:
results.append({"addr": addr, "status": 500, "error": str(exc)})
return results
return {"addr": addr, "status": 500, "error": str(exc)}

results = await asyncio.gather(*(_send_post(addr) for addr in worker_addrs))
return list(results)


def create_gateway_app(config: GatewayConfig) -> FastAPI:
Expand All @@ -75,7 +87,6 @@ def _resolve_dataset_key(token: str) -> str:
return dataset_id

def _check_broadcast_results(results: list[dict], operation: str) -> None:
"""Raise HTTPException if any worker failed during a broadcast."""
failed = [r for r in results if r["status"] != 200]
if failed:
details = ", ".join(
Expand Down Expand Up @@ -104,7 +115,7 @@ async def register_dataset(request: Request):
"dataset_id",
f"{body.get('split', 'train')}-{body.get('dataset_path', 'unknown').split('/')[-1]}",
)
# Broadcast /datasets/load to all workers

worker_addrs = await _get_all_worker_addrs(
config.router_addr,
config.admin_api_key,
Expand Down Expand Up @@ -236,17 +247,20 @@ async def fetch_samples(request: Request):
config.admin_api_key,
config.router_timeout,
)
async with httpx.AsyncClient(timeout=config.forward_timeout) as client:
resp = await client.post(
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=config.forward_timeout),
trust_env=False,
) as session:
async with session.post(
f"{worker_addr}/v1/samples/fetch",
json={"dataset_id": dataset_id, "indices": indices},
)
if resp.status_code != 200:
raise HTTPException(
status_code=502,
detail=f"Worker fetch_samples error: {resp.text}",
)
return resp.json()
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=502,
detail=f"Worker fetch_samples error: {await resp.text()}",
)
return await resp.json()

# ===== Consumer: Epoch Advance =====
@app.post("/v1/epochs/advance")
Expand Down Expand Up @@ -331,12 +345,15 @@ async def status(request: Request):
config.admin_api_key,
config.router_timeout,
)
async with httpx.AsyncClient(timeout=config.forward_timeout) as client:
resp = await client.get(f"{worker_addr}/health")
if resp.status_code == 200:
payload = resp.json()
payload["dataset_id"] = dataset_id
return payload
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=config.forward_timeout),
trust_env=False,
) as session:
async with session.get(f"{worker_addr}/health") as resp:
if resp.status == 200:
payload = await resp.json()
payload["dataset_id"] = dataset_id
return payload
except Exception:
pass
return {"status": "ok", "dataset_id": dataset_id}
Expand Down
25 changes: 16 additions & 9 deletions areal/infra/data_service/router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import importlib
from contextlib import asynccontextmanager

import aiohttp

from areal.infra.data_service.router.config import RouterConfig
from areal.utils import logging

httpx = importlib.import_module("httpx")
_fastapi = importlib.import_module("fastapi")
FastAPI = _fastapi.FastAPI
HTTPException = _fastapi.HTTPException
Expand Down Expand Up @@ -52,17 +53,23 @@ def create_router_app(config: RouterConfig) -> FastAPI:
lock = asyncio.Lock()

async def _poll_workers() -> None:
while True:
for addr in list(registered_workers):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=config.worker_health_timeout),
trust_env=False,
) as session:

async def _check_health(addr: str) -> None:
try:
async with httpx.AsyncClient(
timeout=config.worker_health_timeout
) as client:
resp = await client.get(f"{addr}/health")
worker_healthy[addr] = resp.status_code == 200
async with session.get(f"{addr}/health") as resp:
worker_healthy[addr] = resp.status == 200
except Exception:
worker_healthy[addr] = False
await asyncio.sleep(config.poll_interval)

while True:
await asyncio.gather(
*(_check_health(addr) for addr in list(registered_workers))
)
await asyncio.sleep(config.poll_interval)

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand Down
76 changes: 41 additions & 35 deletions areal/infra/data_service/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,45 +87,51 @@ async def load_dataset(body: WorkerLoadDatasetRequest):
detail=f"Dataset {body.dataset_id} is already loaded",
)

tokenizer = None
processor = None
if body.tokenizer_or_processor_path:
processor, tokenizer = load_hf_processor_and_tokenizer(
body.tokenizer_or_processor_path
def _load_sync():
_tokenizer = None
_processor = None
if body.tokenizer_or_processor_path:
_processor, _tokenizer = load_hf_processor_and_tokenizer(
body.tokenizer_or_processor_path
)

seeding.set_random_seed(body.seed, key=f"data_worker_{config.rank}")

# Workers must load real datasets, not RDataset proxies.
# Call _get_custom_dataset directly to bypass the is_single_controller()
# gate in get_custom_dataset() that would create an RDataset.
_dataset = _get_custom_dataset(
path=body.dataset_path,
type=body.dataset_type,
split=body.split,
max_length=body.max_length,
tokenizer=_tokenizer,
processor=_processor,
**body.dataset_kwargs,
)

seeding.set_random_seed(body.seed, key=f"data_worker_{config.rank}")

# Workers must load real datasets, not RDataset proxies.
# Call _get_custom_dataset directly to bypass the is_single_controller()
# gate in get_custom_dataset() that would create an RDataset.
dataset = _get_custom_dataset(
path=body.dataset_path,
type=body.dataset_type,
split=body.split,
max_length=body.max_length,
tokenizer=tokenizer,
processor=processor,
**body.dataset_kwargs,
)
_sampler_cls = (
DistributedSampler if body.drop_last else EvalDistributedSampler
)
_sampler = _sampler_cls(
_dataset,
num_replicas=config.world_size,
rank=config.rank,
shuffle=body.shuffle,
drop_last=body.drop_last,
)

sampler_cls = DistributedSampler if body.drop_last else EvalDistributedSampler
sampler = sampler_cls(
dataset,
num_replicas=config.world_size,
rank=config.rank,
shuffle=body.shuffle,
drop_last=body.drop_last,
)
_dataloader = StatefulDataLoader(
_dataset,
batch_size=1,
num_workers=config.dataloader_num_workers,
sampler=_sampler,
drop_last=False,
collate_fn=_identity_collate,
)
return _dataset, _sampler, _dataloader

dataloader = StatefulDataLoader(
dataset,
batch_size=1,
num_workers=config.dataloader_num_workers,
sampler=sampler,
drop_last=False,
collate_fn=_identity_collate,
)
dataset, sampler, dataloader = await asyncio.to_thread(_load_sync)

datasets[body.dataset_id] = _DatasetState(
dataset_id=body.dataset_id,
Expand Down
Loading