From 228edda593fc43a83f699c16583ac8f366bcdad3 Mon Sep 17 00:00:00 2001 From: Kolade Fajimi <107228310+koladefaj@users.noreply.github.com> Date: Mon, 13 Apr 2026 01:46:32 +0100 Subject: [PATCH 1/2] refactor: replace manual JSON parsing with Pydantic models (#1154) * refactor: replace manual JSON parsing with Pydantic models * style: sort imports and clean up whitespace in data_proxy * refactor: address review feedback on data proxy Pydantic refactor - Consolidate BatchShardRequest and ClearShardRequest into a shared ShardListRequest base class - Explicitly mark data parameter with Body(...) in store_data_shard - Update FastAPI imports to include Body dependency * style: ruff auto-formatting * refactor: unify RTensor schemas in data_blueprint and adopt in data_proxy Moved BatchShardRequest and ClearShardRequest to infra blueprint. Updated Flask Blueprint to use Pydantic validation. Refactored experimental data_proxy to import schemas from infra. Updated unit tests for strict validation (400/422 errors). Fixed ruff formatting. --- .../inference_service/data_proxy/app.py | 50 +++++--------- areal/infra/rpc/guard/data_blueprint.py | 65 ++++++++++++++----- .../test_data_proxy_rtensor.py | 57 +++++++--------- 3 files changed, 89 insertions(+), 83 deletions(-) diff --git a/areal/experimental/inference_service/data_proxy/app.py b/areal/experimental/inference_service/data_proxy/app.py index 7fb033c812..7d3400090d 100644 --- a/areal/experimental/inference_service/data_proxy/app.py +++ b/areal/experimental/inference_service/data_proxy/app.py @@ -7,7 +7,7 @@ import httpx import orjson -from fastapi import FastAPI, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import Response as RawResponse from openai.types.chat.completion_create_params import CompletionCreateParams @@ -36,6 +36,10 @@ from areal.experimental.openai.client import ArealOpenAI from areal.experimental.openai.proxy.server import serialize_interactions from areal.infra.rpc import rtensor as rtensor_storage +from areal.infra.rpc.guard.data_blueprint import ( + BatchShardRequest, + ClearShardRequest, +) from areal.infra.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging @@ -485,7 +489,7 @@ async def export_trajectories( # ========================================================================= @app.post("/data/batch") - async def retrieve_data_shard_batch(request: Request): + async def retrieve_data_shard_batch(payload: BatchShardRequest): """Retrieve multiple tensor shards in one request. Mirrors the ``POST /data/batch`` endpoint on the Flask RPC server @@ -493,26 +497,10 @@ async def retrieve_data_shard_batch(request: Request): works against data-proxy addresses. """ try: - try: - payload = (await request.json()) or {} - except Exception: - payload = {} - if not isinstance(payload, dict): - payload = {} - shard_ids = payload.get("shard_ids", []) - if not isinstance(shard_ids, list) or not all( - isinstance(sid, str) for sid in shard_ids - ): - return JSONResponse( - status_code=400, - content={ - "status": "error", - "message": "Expected JSON body with string list field 'shard_ids'", - }, - ) - + shard_ids = payload.shard_ids data = [] missing: list[str] = [] + for sid in shard_ids: try: data.append(rtensor_storage.fetch(sid)) @@ -547,15 +535,14 @@ async def retrieve_data_shard_batch(request: Request): ) @app.put("/data/{shard_id}") - async def store_data_shard(shard_id: str, request: Request): + async def store_data_shard(shard_id: str, data: Any = Body(...)): """Store a tensor shard in local RTensor storage.""" - data_bytes = await request.body() - serialized_data = orjson.loads(data_bytes) - data = deserialize_value(serialized_data) - rtensor_storage.store(shard_id, data) - logger.debug( - "Stored RTensor shard %s (size=%d bytes)", shard_id, len(data_bytes) - ) + # FastAPI already parsed the JSON into the 'data' variable + deserialized_data = deserialize_value(data) + rtensor_storage.store(shard_id, deserialized_data) + + # We don't have 'data_bytes' anymore, but we can log that it's stored + logger.debug("Stored RTensor shard %s", shard_id) return {"status": "ok", "shard_id": shard_id} @app.get("/data/{shard_id}") @@ -573,12 +560,9 @@ async def retrieve_data_shard(shard_id: str): return RawResponse(content=data_bytes, media_type="application/octet-stream") @app.delete("/data/clear") - async def clear_data_shards(request: Request): + async def clear_data_shards(payload: ClearShardRequest): """Clear specified tensor shards from local RTensor storage.""" - body = await request.json() - shard_ids = body.get("shard_ids", []) - if not isinstance(shard_ids, list): - raise HTTPException(status_code=400, detail="'shard_ids' must be a list") + shard_ids = payload.shard_ids cleared_count = sum(rtensor_storage.remove(sid) for sid in shard_ids) stats = dict(cleared_count=cleared_count, **rtensor_storage.storage_stats()) logger.info("Cleared %d RTensor shards. Stats: %s", cleared_count, stats) diff --git a/areal/infra/rpc/guard/data_blueprint.py b/areal/infra/rpc/guard/data_blueprint.py index b0b2cdee78..2dff136d1a 100644 --- a/areal/infra/rpc/guard/data_blueprint.py +++ b/areal/infra/rpc/guard/data_blueprint.py @@ -18,6 +18,7 @@ import orjson from flask import Blueprint, Response, jsonify, request +from pydantic import BaseModel, ValidationError from areal.infra.rpc import rtensor from areal.infra.rpc.serialization import deserialize_value, serialize_value @@ -28,6 +29,30 @@ data_bp = Blueprint("data", __name__) +# ================================================================================ +# Pydantic models for Data API +# ================================================================================ + + +class ShardListRequest(BaseModel): + """Base model for requests containing a list of shard IDs.""" + + shard_ids: list[str] + + +class BatchShardRequest(ShardListRequest): + """Request to retrieve multiple tensor shards.""" + + +class ClearShardRequest(ShardListRequest): + """Request to clear specific tensor shards.""" + + +# ================================================================================ +# Flask Blueprint Definition +# ================================================================================ + + @data_bp.route("/data/", methods=["PUT"]) def store_batch_data(shard_id: str): """Store batch data shard.""" @@ -81,22 +106,22 @@ def retrieve_batch_data(shard_id: str): def retrieve_batch_data_many(): """Retrieve multiple batch data shards in one request.""" try: - payload = request.get_json(silent=True) or {} - shard_ids = payload.get("shard_ids", []) - if not isinstance(shard_ids, list) or not all( - isinstance(shard_id, str) for shard_id in shard_ids - ): + raw_payload = request.get_json(silent=True) or {} + + # USE PYDANTIC MODEL FOR VALIDATION + try: + payload_model = BatchShardRequest(**raw_payload) + except ValidationError: return ( jsonify( { "status": "error", - "message": ( - "Expected JSON body with string list field 'shard_ids'" - ), + "message": "Expected JSON body with string list field 'shard_ids'", } ), 400, ) + shard_ids = payload_model.shard_ids # use the validated data data = [] missing_shard_ids = [] @@ -134,20 +159,24 @@ def retrieve_batch_data_many(): @data_bp.route("/data/clear", methods=["DELETE"]) def clear_batch_data(): - """Clear specified batch data shards. - - Expected JSON payload:: - - {"shard_ids": ["id1", "id2", ...]} - """ + """Clear specified batch data shards.""" try: - data = request.get_json(silent=True) or {} - shard_ids = data.get("shard_ids", []) - if not isinstance(shard_ids, list): + raw_data = request.get_json(silent=True) or {} + + # USE PYDANTIC MODEL FOR VALIDATION + try: + payload_model = ClearShardRequest(**raw_data) + except ValidationError: return ( - jsonify({"status": "error", "message": "'shard_ids' must be a list"}), + jsonify( + { + "status": "error", + "message": "'shard_ids' must be a list", + } + ), 400, ) + shard_ids = payload_model.shard_ids # use the validated data cleared_count = sum(rtensor.remove(sid) for sid in shard_ids) storage = rtensor.storage_stats() diff --git a/tests/experimental/inference_service/test_data_proxy_rtensor.py b/tests/experimental/inference_service/test_data_proxy_rtensor.py index ac3d7051a6..9d27849d3e 100644 --- a/tests/experimental/inference_service/test_data_proxy_rtensor.py +++ b/tests/experimental/inference_service/test_data_proxy_rtensor.py @@ -194,17 +194,19 @@ async def test_post_batch_missing_shard_returns_400(self, client): @pytest.mark.asyncio async def test_post_batch_invalid_body_returns_400(self, client): - """POST /data/batch with shard_ids as a non-list → 400 with error message.""" + """POST /data/batch with shard_ids as a non-list → 400/422 validation error.""" batch_resp = await client.post( "/data/batch", json={"shard_ids": "not-a-list"}, ) - assert batch_resp.status_code == 400 + + # If calling FastAPI, it might return 422. If calling Flask Blueprint, 400. + assert batch_resp.status_code in (400, 422) + + # Pydantic errors are formatted differently in FastAPI vs Flask. + # Check for the presence of an error rather than the exact old string. data = batch_resp.json() - assert data["status"] == "error" - assert ( - data["message"] == "Expected JSON body with string list field 'shard_ids'" - ) + assert "detail" in data or "message" in data @pytest.mark.asyncio async def test_delete_clear_shards(self, client): @@ -237,21 +239,20 @@ async def test_delete_clear_shards(self, client): assert get_resp.status_code == 404 @pytest.mark.asyncio - async def test_post_batch_malformed_json_returns_200_empty(self, client): - """POST /data/batch with non-JSON body → graceful fallback (empty batch). + async def test_post_batch_malformed_json_returns_error(self, client): + """POST /data/batch with non-JSON body → Now returns an error. - Mirrors Flask's ``get_json(silent=True) or {}`` which silently - returns an empty dict on parse failure, yielding an empty 200. + Previous behavior was a 200 with an empty list, but Pydantic + requires a valid BatchShardRequest object. """ resp = await client.post( "/data/batch", content=b"this is not json", headers={"Content-Type": "application/json"}, ) - assert resp.status_code == 200 - assert resp.headers["content-type"] == "application/octet-stream" - batch = _deserialize_batch_response_bytes(resp.content) - assert batch == [] + assert resp.status_code in (400, 422) + data = resp.json() + assert "detail" in data or "message" in data @pytest.mark.asyncio async def test_post_batch_serialization_error_returns_500( @@ -285,36 +286,28 @@ def _boom(data): assert "serialization kaboom" in data["message"] @pytest.mark.asyncio - async def test_post_batch_null_json_returns_200_empty(self, client): - """POST /data/batch with ``null`` JSON body → empty batch 200. - - Flask ``get_json(silent=True) or {}`` normalises falsy values to ``{}``. - """ + async def test_post_batch_null_json_returns_error(self, client): + """POST /data/batch with ``null`` JSON body → error.""" resp = await client.post( "/data/batch", content=b"null", headers={"Content-Type": "application/json"}, ) - assert resp.status_code == 200 - assert resp.headers["content-type"] == "application/octet-stream" - batch = _deserialize_batch_response_bytes(resp.content) - assert batch == [] + assert resp.status_code in (400, 422) + data = resp.json() + assert "detail" in data or "message" in data @pytest.mark.asyncio - async def test_post_batch_non_dict_json_returns_200_empty(self, client): - """POST /data/batch with a JSON array → empty batch 200. - - Truthy non-dict payloads are normalised to ``{}`` to match Flask. - """ + async def test_post_batch_non_dict_json_returns_error(self, client): + """POST /data/batch with a JSON array → error.""" resp = await client.post( "/data/batch", content=b"[1, 2, 3]", headers={"Content-Type": "application/json"}, ) - assert resp.status_code == 200 - assert resp.headers["content-type"] == "application/octet-stream" - batch = _deserialize_batch_response_bytes(resp.content) - assert batch == [] + assert resp.status_code in (400, 422) + data = resp.json() + assert "detail" in data or "message" in data @pytest.mark.asyncio async def test_post_batch_fetch_runtime_error_returns_500( From b5e373d7818c7608d44ae762ba30e3be15f8370f Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 13 Apr 2026 06:17:38 +0530 Subject: [PATCH 2/2] fix(engine): FSDP compute_logp fails for Qwen3.5 with dict attention_mask (#1153) * fix(engine): add Qwen3.5 model type helpers and MoE registry * fix(engine): use None attention_mask for Qwen3.5 in FSDP compute_logp --- areal/engine/core/model.py | 6 ++++++ areal/engine/fsdp_engine.py | 7 +++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/areal/engine/core/model.py b/areal/engine/core/model.py index a569e1a752..13e4af7b37 100644 --- a/areal/engine/core/model.py +++ b/areal/engine/core/model.py @@ -34,6 +34,8 @@ def is_gemma3_model(model_type: str) -> bool: VALID_MOE_MODELS = [ "qwen3_moe", + "qwen3_5_moe", + "qwen3_5_moe_text", "bailing_moe_v2", "bailing_moe_linear", "bailing_hybrid", @@ -49,6 +51,10 @@ def is_qwen3_moe_model(model_type: str) -> bool: return model_type in ["qwen3_moe"] +def is_qwen3_5_model(model_type: str) -> bool: + return model_type in ["qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"] + + # Copied from trl def disable_dropout_in_model(model: torch.nn.Module) -> None: for module in model.modules(): diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 89a235944f..6bbfd5f7a1 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -66,6 +66,7 @@ from areal.engine.core.model import ( disable_dropout_in_model, is_gemma3_model, + is_qwen3_5_model, is_qwen3_moe_model, is_qwen3_vl_model, is_qwen_vl_model, @@ -1521,8 +1522,10 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: ] mb["use_cache"] = False padded_mb["use_cache"] = False - if is_qwen3_moe_model(self.model_config.model_type) or is_qwen3_vl_model( - self.model_config.model_type + if ( + is_qwen3_moe_model(self.model_config.model_type) + or is_qwen3_vl_model(self.model_config.model_type) + or is_qwen3_5_model(self.model_config.model_type) ): mb["attention_mask"] = None padded_mb["attention_mask"] = None