diff --git a/areal/infra/data_service/controller/controller.py b/areal/infra/data_service/controller/controller.py index 3624aa3226..198fab753e 100644 --- a/areal/infra/data_service/controller/controller.py +++ b/areal/infra/data_service/controller/controller.py @@ -123,7 +123,7 @@ async def _async_initialize( guard_addr_0 = guard_addrs[0] try: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=False) as session: # Wave 1: Fork all DataWorkers + Router in parallel worker_tasks = [ self._async_fork_on_guard( @@ -211,6 +211,7 @@ async def _register_workers() -> None: "DataController initialization failed, rolling back", exc_info=True, ) + # Kill forked services concurrently if self._forked_services: await self._async_kill_forked_services( list(reversed(self._forked_services)) @@ -340,7 +341,7 @@ async def _clear_one(session: aiohttp.ClientSession, addr: str) -> None: except Exception: logger.debug("Failed to clear batches on %s", addr) - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=False) as session: await asyncio.gather( *(_clear_one(session, addr) for addr in self._worker_addrs), return_exceptions=True, @@ -367,7 +368,6 @@ def destroy(self) -> None: exc, ) - # Kill forked services concurrently if self._forked_services: run_async_task( self._async_kill_forked_services, @@ -488,7 +488,7 @@ async def _kill_one( exc, ) - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=False) as session: await asyncio.gather( *(_kill_one(session, *svc) for svc in services), return_exceptions=True, @@ -514,7 +514,7 @@ async def _async_gateway_post( url = f"{self._gateway_addr}{endpoint}" try: async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=timeout) + timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False ) as session: async with session.post( url, diff --git a/areal/infra/data_service/gateway/app.py b/areal/infra/data_service/gateway/app.py index 33b15e650f..46b1472b59 100644 --- a/areal/infra/data_service/gateway/app.py +++ b/areal/infra/data_service/gateway/app.py @@ -4,7 +4,9 @@ from __future__ import annotations -import httpx +import asyncio + +import aiohttp from fastapi import FastAPI, HTTPException, Request from areal.infra.data_service.gateway.auth import ( @@ -19,48 +21,58 @@ async def _query_router(router_addr: str, admin_key: str, timeout: float) -> str: - """Get a worker address from the router via round-robin.""" - async with httpx.AsyncClient(timeout=timeout) as client: - resp = await client.post( + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False + ) as session: + async with session.post( f"{router_addr}/route", + json={}, headers={"Authorization": f"Bearer {admin_key}"}, - ) - if resp.status_code != 200: - raise HTTPException(status_code=502, detail=f"Router error: {resp.text}") - return resp.json()["worker_addr"] + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=502, detail=f"Router error: {await resp.text()}" + ) + return (await resp.json())["worker_addr"] async def _get_all_worker_addrs( router_addr: str, admin_key: str, timeout: float ) -> list[str]: - """Get all worker addresses from the router.""" - async with httpx.AsyncClient(timeout=timeout) as client: - resp = await client.get( + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False + ) as session: + async with session.get( f"{router_addr}/workers", headers={"Authorization": f"Bearer {admin_key}"}, - ) - if resp.status_code != 200: - raise HTTPException(status_code=502, detail=f"Router error: {resp.text}") - return [w["addr"] for w in resp.json()["workers"]] + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=502, detail=f"Router error: {await resp.text()}" + ) + return [w["addr"] for w in (await resp.json())["workers"]] async def _broadcast_to_workers( worker_addrs: list[str], endpoint: str, payload: dict, timeout: float ) -> list[dict]: - """Broadcast a POST request to all workers and collect responses.""" - results: list[dict] = [] - async with httpx.AsyncClient(timeout=timeout) as client: - for addr in worker_addrs: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), trust_env=False + ) as session: + + async def _send_post(addr: str) -> dict: try: - resp = await client.post(f"{addr}{endpoint}", json=payload) - try: - data = resp.json() - except Exception: - data = {"raw": resp.text} - results.append({"addr": addr, "status": resp.status_code, "data": data}) + async with session.post(f"{addr}{endpoint}", json=payload) as resp: + try: + data = await resp.json() + except Exception: + data = {"raw": await resp.text()} + return {"addr": addr, "status": resp.status, "data": data} except Exception as exc: - results.append({"addr": addr, "status": 500, "error": str(exc)}) - return results + return {"addr": addr, "status": 500, "error": str(exc)} + + results = await asyncio.gather(*(_send_post(addr) for addr in worker_addrs)) + return list(results) def create_gateway_app(config: GatewayConfig) -> FastAPI: @@ -75,7 +87,6 @@ def _resolve_dataset_key(token: str) -> str: return dataset_id def _check_broadcast_results(results: list[dict], operation: str) -> None: - """Raise HTTPException if any worker failed during a broadcast.""" failed = [r for r in results if r["status"] != 200] if failed: details = ", ".join( @@ -104,7 +115,7 @@ async def register_dataset(request: Request): "dataset_id", f"{body.get('split', 'train')}-{body.get('dataset_path', 'unknown').split('/')[-1]}", ) - # Broadcast /datasets/load to all workers + worker_addrs = await _get_all_worker_addrs( config.router_addr, config.admin_api_key, @@ -236,17 +247,20 @@ async def fetch_samples(request: Request): config.admin_api_key, config.router_timeout, ) - async with httpx.AsyncClient(timeout=config.forward_timeout) as client: - resp = await client.post( + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=config.forward_timeout), + trust_env=False, + ) as session: + async with session.post( f"{worker_addr}/v1/samples/fetch", json={"dataset_id": dataset_id, "indices": indices}, - ) - if resp.status_code != 200: - raise HTTPException( - status_code=502, - detail=f"Worker fetch_samples error: {resp.text}", - ) - return resp.json() + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=502, + detail=f"Worker fetch_samples error: {await resp.text()}", + ) + return await resp.json() # ===== Consumer: Epoch Advance ===== @app.post("/v1/epochs/advance") @@ -331,12 +345,15 @@ async def status(request: Request): config.admin_api_key, config.router_timeout, ) - async with httpx.AsyncClient(timeout=config.forward_timeout) as client: - resp = await client.get(f"{worker_addr}/health") - if resp.status_code == 200: - payload = resp.json() - payload["dataset_id"] = dataset_id - return payload + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=config.forward_timeout), + trust_env=False, + ) as session: + async with session.get(f"{worker_addr}/health") as resp: + if resp.status == 200: + payload = await resp.json() + payload["dataset_id"] = dataset_id + return payload except Exception: pass return {"status": "ok", "dataset_id": dataset_id} diff --git a/areal/infra/data_service/router/app.py b/areal/infra/data_service/router/app.py index 6338589b4a..709dc2b43d 100644 --- a/areal/infra/data_service/router/app.py +++ b/areal/infra/data_service/router/app.py @@ -7,10 +7,11 @@ import importlib from contextlib import asynccontextmanager +import aiohttp + from areal.infra.data_service.router.config import RouterConfig from areal.utils import logging -httpx = importlib.import_module("httpx") _fastapi = importlib.import_module("fastapi") FastAPI = _fastapi.FastAPI HTTPException = _fastapi.HTTPException @@ -52,17 +53,23 @@ def create_router_app(config: RouterConfig) -> FastAPI: lock = asyncio.Lock() async def _poll_workers() -> None: - while True: - for addr in list(registered_workers): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=config.worker_health_timeout), + trust_env=False, + ) as session: + + async def _check_health(addr: str) -> None: try: - async with httpx.AsyncClient( - timeout=config.worker_health_timeout - ) as client: - resp = await client.get(f"{addr}/health") - worker_healthy[addr] = resp.status_code == 200 + async with session.get(f"{addr}/health") as resp: + worker_healthy[addr] = resp.status == 200 except Exception: worker_healthy[addr] = False - await asyncio.sleep(config.poll_interval) + + while True: + await asyncio.gather( + *(_check_health(addr) for addr in list(registered_workers)) + ) + await asyncio.sleep(config.poll_interval) @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/areal/infra/data_service/worker/app.py b/areal/infra/data_service/worker/app.py index b103850ec1..b83d00c377 100644 --- a/areal/infra/data_service/worker/app.py +++ b/areal/infra/data_service/worker/app.py @@ -87,45 +87,51 @@ async def load_dataset(body: WorkerLoadDatasetRequest): detail=f"Dataset {body.dataset_id} is already loaded", ) - tokenizer = None - processor = None - if body.tokenizer_or_processor_path: - processor, tokenizer = load_hf_processor_and_tokenizer( - body.tokenizer_or_processor_path + def _load_sync(): + _tokenizer = None + _processor = None + if body.tokenizer_or_processor_path: + _processor, _tokenizer = load_hf_processor_and_tokenizer( + body.tokenizer_or_processor_path + ) + + seeding.set_random_seed(body.seed, key=f"data_worker_{config.rank}") + + # Workers must load real datasets, not RDataset proxies. + # Call _get_custom_dataset directly to bypass the is_single_controller() + # gate in get_custom_dataset() that would create an RDataset. + _dataset = _get_custom_dataset( + path=body.dataset_path, + type=body.dataset_type, + split=body.split, + max_length=body.max_length, + tokenizer=_tokenizer, + processor=_processor, + **body.dataset_kwargs, ) - seeding.set_random_seed(body.seed, key=f"data_worker_{config.rank}") - - # Workers must load real datasets, not RDataset proxies. - # Call _get_custom_dataset directly to bypass the is_single_controller() - # gate in get_custom_dataset() that would create an RDataset. - dataset = _get_custom_dataset( - path=body.dataset_path, - type=body.dataset_type, - split=body.split, - max_length=body.max_length, - tokenizer=tokenizer, - processor=processor, - **body.dataset_kwargs, - ) + _sampler_cls = ( + DistributedSampler if body.drop_last else EvalDistributedSampler + ) + _sampler = _sampler_cls( + _dataset, + num_replicas=config.world_size, + rank=config.rank, + shuffle=body.shuffle, + drop_last=body.drop_last, + ) - sampler_cls = DistributedSampler if body.drop_last else EvalDistributedSampler - sampler = sampler_cls( - dataset, - num_replicas=config.world_size, - rank=config.rank, - shuffle=body.shuffle, - drop_last=body.drop_last, - ) + _dataloader = StatefulDataLoader( + _dataset, + batch_size=1, + num_workers=config.dataloader_num_workers, + sampler=_sampler, + drop_last=False, + collate_fn=_identity_collate, + ) + return _dataset, _sampler, _dataloader - dataloader = StatefulDataLoader( - dataset, - batch_size=1, - num_workers=config.dataloader_num_workers, - sampler=sampler, - drop_last=False, - collate_fn=_identity_collate, - ) + dataset, sampler, dataloader = await asyncio.to_thread(_load_sync) datasets[body.dataset_id] = _DatasetState( dataset_id=body.dataset_id,