Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions areal/engine/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def is_gemma3_model(model_type: str) -> bool:

VALID_MOE_MODELS = [
"qwen3_moe",
"qwen3_5_moe",
"qwen3_5_moe_text",
"bailing_moe_v2",
"bailing_moe_linear",
"bailing_hybrid",
Expand All @@ -49,6 +51,10 @@ def is_qwen3_moe_model(model_type: str) -> bool:
return model_type in ["qwen3_moe"]


def is_qwen3_5_model(model_type: str) -> bool:
return model_type in ["qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"]


# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
Expand Down
7 changes: 5 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from areal.engine.core.model import (
disable_dropout_in_model,
is_gemma3_model,
is_qwen3_5_model,
is_qwen3_moe_model,
is_qwen3_vl_model,
is_qwen_vl_model,
Expand Down Expand Up @@ -1521,8 +1522,10 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
]
mb["use_cache"] = False
padded_mb["use_cache"] = False
if is_qwen3_moe_model(self.model_config.model_type) or is_qwen3_vl_model(
self.model_config.model_type
if (
is_qwen3_moe_model(self.model_config.model_type)
or is_qwen3_vl_model(self.model_config.model_type)
or is_qwen3_5_model(self.model_config.model_type)
):
mb["attention_mask"] = None
padded_mb["attention_mask"] = None
Expand Down
50 changes: 17 additions & 33 deletions areal/experimental/inference_service/data_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import httpx
import orjson
from fastapi import FastAPI, HTTPException, Request
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import Response as RawResponse
from openai.types.chat.completion_create_params import CompletionCreateParams
Expand Down Expand Up @@ -36,6 +36,10 @@
from areal.experimental.openai.client import ArealOpenAI
from areal.experimental.openai.proxy.server import serialize_interactions
from areal.infra.rpc import rtensor as rtensor_storage
from areal.infra.rpc.guard.data_blueprint import (
BatchShardRequest,
ClearShardRequest,
)
from areal.infra.rpc.serialization import deserialize_value, serialize_value
from areal.utils import logging

Expand Down Expand Up @@ -485,34 +489,18 @@ async def export_trajectories(
# =========================================================================

@app.post("/data/batch")
async def retrieve_data_shard_batch(request: Request):
async def retrieve_data_shard_batch(payload: BatchShardRequest):
"""Retrieve multiple tensor shards in one request.

Mirrors the ``POST /data/batch`` endpoint on the Flask RPC server
(``rpc_server.py``) so that ``HttpRTensorBackend._fetch_shard_group``
works against data-proxy addresses.
"""
try:
try:
payload = (await request.json()) or {}
except Exception:
payload = {}
if not isinstance(payload, dict):
payload = {}
shard_ids = payload.get("shard_ids", [])
if not isinstance(shard_ids, list) or not all(
isinstance(sid, str) for sid in shard_ids
):
return JSONResponse(
status_code=400,
content={
"status": "error",
"message": "Expected JSON body with string list field 'shard_ids'",
},
)

shard_ids = payload.shard_ids
data = []
missing: list[str] = []

for sid in shard_ids:
try:
data.append(rtensor_storage.fetch(sid))
Expand Down Expand Up @@ -547,15 +535,14 @@ async def retrieve_data_shard_batch(request: Request):
)

@app.put("/data/{shard_id}")
async def store_data_shard(shard_id: str, request: Request):
async def store_data_shard(shard_id: str, data: Any = Body(...)):
"""Store a tensor shard in local RTensor storage."""
data_bytes = await request.body()
serialized_data = orjson.loads(data_bytes)
data = deserialize_value(serialized_data)
rtensor_storage.store(shard_id, data)
logger.debug(
"Stored RTensor shard %s (size=%d bytes)", shard_id, len(data_bytes)
)
# FastAPI already parsed the JSON into the 'data' variable
deserialized_data = deserialize_value(data)
rtensor_storage.store(shard_id, deserialized_data)

# We don't have 'data_bytes' anymore, but we can log that it's stored
logger.debug("Stored RTensor shard %s", shard_id)
return {"status": "ok", "shard_id": shard_id}

@app.get("/data/{shard_id}")
Expand All @@ -573,12 +560,9 @@ async def retrieve_data_shard(shard_id: str):
return RawResponse(content=data_bytes, media_type="application/octet-stream")

@app.delete("/data/clear")
async def clear_data_shards(request: Request):
async def clear_data_shards(payload: ClearShardRequest):
"""Clear specified tensor shards from local RTensor storage."""
body = await request.json()
shard_ids = body.get("shard_ids", [])
if not isinstance(shard_ids, list):
raise HTTPException(status_code=400, detail="'shard_ids' must be a list")
shard_ids = payload.shard_ids
cleared_count = sum(rtensor_storage.remove(sid) for sid in shard_ids)
stats = dict(cleared_count=cleared_count, **rtensor_storage.storage_stats())
logger.info("Cleared %d RTensor shards. Stats: %s", cleared_count, stats)
Expand Down
65 changes: 47 additions & 18 deletions areal/infra/rpc/guard/data_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import orjson
from flask import Blueprint, Response, jsonify, request
from pydantic import BaseModel, ValidationError

from areal.infra.rpc import rtensor
from areal.infra.rpc.serialization import deserialize_value, serialize_value
Expand All @@ -28,6 +29,30 @@
data_bp = Blueprint("data", __name__)


# ================================================================================
# Pydantic models for Data API
# ================================================================================


class ShardListRequest(BaseModel):
"""Base model for requests containing a list of shard IDs."""

shard_ids: list[str]


class BatchShardRequest(ShardListRequest):
"""Request to retrieve multiple tensor shards."""


class ClearShardRequest(ShardListRequest):
"""Request to clear specific tensor shards."""


# ================================================================================
# Flask Blueprint Definition
# ================================================================================


@data_bp.route("/data/<shard_id>", methods=["PUT"])
def store_batch_data(shard_id: str):
"""Store batch data shard."""
Expand Down Expand Up @@ -81,22 +106,22 @@ def retrieve_batch_data(shard_id: str):
def retrieve_batch_data_many():
"""Retrieve multiple batch data shards in one request."""
try:
payload = request.get_json(silent=True) or {}
shard_ids = payload.get("shard_ids", [])
if not isinstance(shard_ids, list) or not all(
isinstance(shard_id, str) for shard_id in shard_ids
):
raw_payload = request.get_json(silent=True) or {}

# USE PYDANTIC MODEL FOR VALIDATION
try:
payload_model = BatchShardRequest(**raw_payload)
except ValidationError:
return (
jsonify(
{
"status": "error",
"message": (
"Expected JSON body with string list field 'shard_ids'"
),
"message": "Expected JSON body with string list field 'shard_ids'",
}
),
400,
)
shard_ids = payload_model.shard_ids # use the validated data

data = []
missing_shard_ids = []
Expand Down Expand Up @@ -134,20 +159,24 @@ def retrieve_batch_data_many():

@data_bp.route("/data/clear", methods=["DELETE"])
def clear_batch_data():
"""Clear specified batch data shards.

Expected JSON payload::

{"shard_ids": ["id1", "id2", ...]}
"""
"""Clear specified batch data shards."""
try:
data = request.get_json(silent=True) or {}
shard_ids = data.get("shard_ids", [])
if not isinstance(shard_ids, list):
raw_data = request.get_json(silent=True) or {}

# USE PYDANTIC MODEL FOR VALIDATION
try:
payload_model = ClearShardRequest(**raw_data)
except ValidationError:
return (
jsonify({"status": "error", "message": "'shard_ids' must be a list"}),
jsonify(
{
"status": "error",
"message": "'shard_ids' must be a list",
}
),
400,
)
shard_ids = payload_model.shard_ids # use the validated data

cleared_count = sum(rtensor.remove(sid) for sid in shard_ids)
storage = rtensor.storage_stats()
Expand Down
57 changes: 25 additions & 32 deletions tests/experimental/inference_service/test_data_proxy_rtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,19 @@ async def test_post_batch_missing_shard_returns_400(self, client):

@pytest.mark.asyncio
async def test_post_batch_invalid_body_returns_400(self, client):
"""POST /data/batch with shard_ids as a non-list → 400 with error message."""
"""POST /data/batch with shard_ids as a non-list → 400/422 validation error."""
batch_resp = await client.post(
"/data/batch",
json={"shard_ids": "not-a-list"},
)
assert batch_resp.status_code == 400

# If calling FastAPI, it might return 422. If calling Flask Blueprint, 400.
assert batch_resp.status_code in (400, 422)

# Pydantic errors are formatted differently in FastAPI vs Flask.
# Check for the presence of an error rather than the exact old string.
data = batch_resp.json()
assert data["status"] == "error"
assert (
data["message"] == "Expected JSON body with string list field 'shard_ids'"
)
assert "detail" in data or "message" in data

@pytest.mark.asyncio
async def test_delete_clear_shards(self, client):
Expand Down Expand Up @@ -237,21 +239,20 @@ async def test_delete_clear_shards(self, client):
assert get_resp.status_code == 404

@pytest.mark.asyncio
async def test_post_batch_malformed_json_returns_200_empty(self, client):
"""POST /data/batch with non-JSON body → graceful fallback (empty batch).
async def test_post_batch_malformed_json_returns_error(self, client):
"""POST /data/batch with non-JSON body → Now returns an error.

Mirrors Flask's ``get_json(silent=True) or {}`` which silently
returns an empty dict on parse failure, yielding an empty 200.
Previous behavior was a 200 with an empty list, but Pydantic
requires a valid BatchShardRequest object.
"""
resp = await client.post(
"/data/batch",
content=b"this is not json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200
assert resp.headers["content-type"] == "application/octet-stream"
batch = _deserialize_batch_response_bytes(resp.content)
assert batch == []
assert resp.status_code in (400, 422)
data = resp.json()
assert "detail" in data or "message" in data

@pytest.mark.asyncio
async def test_post_batch_serialization_error_returns_500(
Expand Down Expand Up @@ -285,36 +286,28 @@ def _boom(data):
assert "serialization kaboom" in data["message"]

@pytest.mark.asyncio
async def test_post_batch_null_json_returns_200_empty(self, client):
"""POST /data/batch with ``null`` JSON body → empty batch 200.

Flask ``get_json(silent=True) or {}`` normalises falsy values to ``{}``.
"""
async def test_post_batch_null_json_returns_error(self, client):
"""POST /data/batch with ``null`` JSON body → error."""
resp = await client.post(
"/data/batch",
content=b"null",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200
assert resp.headers["content-type"] == "application/octet-stream"
batch = _deserialize_batch_response_bytes(resp.content)
assert batch == []
assert resp.status_code in (400, 422)
data = resp.json()
assert "detail" in data or "message" in data

@pytest.mark.asyncio
async def test_post_batch_non_dict_json_returns_200_empty(self, client):
"""POST /data/batch with a JSON array → empty batch 200.

Truthy non-dict payloads are normalised to ``{}`` to match Flask.
"""
async def test_post_batch_non_dict_json_returns_error(self, client):
"""POST /data/batch with a JSON array → error."""
resp = await client.post(
"/data/batch",
content=b"[1, 2, 3]",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200
assert resp.headers["content-type"] == "application/octet-stream"
batch = _deserialize_batch_response_bytes(resp.content)
assert batch == []
assert resp.status_code in (400, 422)
data = resp.json()
assert "detail" in data or "message" in data

@pytest.mark.asyncio
async def test_post_batch_fetch_runtime_error_returns_500(
Expand Down
Loading