From e55311993c58eab98642096e1bd89eca598006fc Mon Sep 17 00:00:00 2001 From: Gursimran Date: Fri, 17 Apr 2026 00:45:09 -0700 Subject: [PATCH 1/4] feat(engine): lora support for MoE models (single node/ cross node) (#1159) * feat(engine): lora support for MoE models; single node + cross node support * fix(engine): resolve Ruff formatting in vLLM worker extension Keep the vLLM worker extension aligned with Ruff formatting so formatting checks stop failing for a non-functional import spacing issue. --------- Co-authored-by: Wentai Zhang --- areal/engine/megatron_engine.py | 4 +- areal/engine/megatron_utils/megatron.py | 6 +- areal/engine/megatron_utils/megatron_lora.py | 102 +++++++++- .../engine/vllm_ext/vllm_worker_extension.py | 4 +- docs/en/reference/lora.md | 17 +- docs/zh/reference/lora.md | 21 +- .../math/gsm8k_grpo_megatron_lora_moe.yaml | 191 ++++++++++++++++++ 7 files changed, 325 insertions(+), 20 deletions(-) create mode 100644 examples/math/gsm8k_grpo_megatron_lora_moe.yaml diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8e12f8d4ae..6f733cf750 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1456,7 +1456,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: converted_named_tensors = [] for name, param in get_named_parameters(self.model, num_moe_experts): - if ".experts." in name: + if ".experts." in name and not self.config.use_lora: continue if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) @@ -1474,7 +1474,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # Only pipeline parallel heads CAN contain named tensors here if converted_named_tensors: self._update_bucket_weights_from_distributed(meta, converted_named_tensors) - elif self.is_pipeline_parallel_head() and not self.config.use_lora: + elif self.config.use_lora and self.is_pipeline_parallel_head(): self.logger.warning( "No tensors were collected for distributed update at version %s.", meta.version, diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index c5261f9fb6..8f5d8b7718 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -17,7 +17,10 @@ get_block_size_from_config, quantize_params, ) -from areal.engine.megatron_utils.megatron_lora import convert_qwen3_lora_to_hf +from areal.engine.megatron_utils.megatron_lora import ( + convert_qwen3_lora_to_hf, + convert_qwen3_moe_lora_to_hf, +) def _all_gather_and_concat( @@ -763,6 +766,7 @@ def convert_bailingmoe_to_hf( _CONVERSION_FN_REGISTRY = { "qwen3_lora": convert_qwen3_lora_to_hf, "qwen2_lora": convert_qwen3_lora_to_hf, + "qwen3_moe_lora": convert_qwen3_moe_lora_to_hf, "qwen3_moe": convert_qwen3moe_to_hf, "qwen2": convert_qwen2_to_hf, "qwen3": convert_qwen2_to_hf, diff --git a/areal/engine/megatron_utils/megatron_lora.py b/areal/engine/megatron_utils/megatron_lora.py index 89387181e2..fb1d538d22 100644 --- a/areal/engine/megatron_utils/megatron_lora.py +++ b/areal/engine/megatron_utils/megatron_lora.py @@ -121,6 +121,103 @@ def convert_qwen3_lora_to_hf( return [] +def convert_qwen3_moe_lora_to_hf( + tf_config, + name: str, + tensor: torch.Tensor, +) -> list[tuple[str, torch.Tensor]]: + # Reuse non-MoE conversion for attention and dense MLP paths. + converted = convert_qwen3_lora_to_hf(tf_config, name, tensor) + if converted: + return converted + + grouped_expert_pattern = ( + r"(?:^|.*\.)decoder\.layers\.(\d+)\.mlp\.experts\." + r"(linear_fc1|linear_fc2)\.adapter\.(linear_in|linear_out)\.weight$" + ) + match = re.match(grouped_expert_pattern, name) + if match is not None: + layer_idx, module_name, adapter_part = match.groups() + num_experts = getattr(tf_config, "num_moe_experts", None) + if num_experts is None: + num_experts = getattr(tf_config, "num_experts", None) + if num_experts is None: + return [] + + outputs: list[tuple[str, torch.Tensor]] = [] + for expert_idx in range(num_experts): + base_prefix = ( + f"base_model.model.model.layers.{layer_idx}.mlp.experts.{expert_idx}" + ) + + if module_name == "linear_fc2": + hf_base = f"{base_prefix}.down_proj" + suffix = ( + "lora_A.default.weight" + if adapter_part == "linear_in" + else "lora_B.default.weight" + ) + outputs.append((f"{hf_base}.{suffix}", tensor)) + continue + + gate_base = f"{base_prefix}.gate_proj" + up_base = f"{base_prefix}.up_proj" + if adapter_part == "linear_in": + outputs.extend( + [ + (f"{gate_base}.lora_A.default.weight", tensor), + (f"{up_base}.lora_A.default.weight", tensor), + ] + ) + continue + + gate_b, up_b = tensor.chunk(2, dim=0) + outputs.extend( + [ + (f"{gate_base}.lora_B.default.weight", gate_b.contiguous()), + (f"{up_base}.lora_B.default.weight", up_b.contiguous()), + ] + ) + + return outputs + + expert_pattern = ( + r"(?:^|.*\.)decoder\.layers\.(\d+)\.mlp\.experts\." + r"(linear_fc1|linear_fc2)\.adapter\.(linear_in|linear_out)\.weight(\d+)$" + ) + match = re.match(expert_pattern, name) + if match is None: + return [] + + layer_idx, module_name, adapter_part, expert_idx = match.groups() + base_prefix = f"base_model.model.model.layers.{layer_idx}.mlp.experts.{expert_idx}" + + if module_name == "linear_fc2": + hf_base = f"{base_prefix}.down_proj" + suffix = ( + "lora_A.default.weight" + if adapter_part == "linear_in" + else "lora_B.default.weight" + ) + return [(f"{hf_base}.{suffix}", tensor)] + + if module_name == "linear_fc1": + gate_base = f"{base_prefix}.gate_proj" + up_base = f"{base_prefix}.up_proj" + if adapter_part == "linear_in": + return [ + (f"{gate_base}.lora_A.default.weight", tensor), + (f"{up_base}.lora_A.default.weight", tensor), + ] + gate_b, up_b = tensor.chunk(2, dim=0) + return [ + (f"{gate_base}.lora_B.default.weight", gate_b.contiguous()), + (f"{up_base}.lora_B.default.weight", up_b.contiguous()), + ] + + return [] + + def _infer_target_modules_from_adapter_weights(weight_keys: Iterable[str]) -> list[str]: """ Infer PEFT target_modules from adapter weight parameter names. @@ -235,7 +332,10 @@ def save_hf_adapter( # Export adapter weights adapter_state: dict[str, torch.Tensor] = {} for name, tensor in self.export_adapter_weights( - model, cpu=True, show_progress=show_progress + # cpu=True may reduce memory pressure but hangs for MoE models using slurm + model, + cpu=False, + show_progress=False, ): adapter_state[f"base_model.model.{name}"] = tensor.clone().float() diff --git a/areal/engine/vllm_ext/vllm_worker_extension.py b/areal/engine/vllm_ext/vllm_worker_extension.py index c773839f12..30f762329c 100644 --- a/areal/engine/vllm_ext/vllm_worker_extension.py +++ b/areal/engine/vllm_ext/vllm_worker_extension.py @@ -209,7 +209,7 @@ def areal_update_weight_lora_xccl(self): async_op=False, ) - received_weights[name] = tensor + received_weights[name] = tensor.cpu() logger.info(f"Received {len(received_weights)} LoRA parameters via XCCL") @@ -259,7 +259,7 @@ def areal_update_weight_lora_xccl(self): lora_model_id=self.areal_lora_int_id, tensors=merged_weights, peft_helper=peft_helper, - device=self.model_runner.device, + device="cpu", dtype=self.model_runner.lora_manager.lora_config.lora_dtype, model_vocab_size=model_vocab_size, weights_mapper=getattr( diff --git a/docs/en/reference/lora.md b/docs/en/reference/lora.md index 99e7dfd1ae..d8c00c3eb6 100644 --- a/docs/en/reference/lora.md +++ b/docs/en/reference/lora.md @@ -29,12 +29,18 @@ The current LoRA support matrix in AReaL is: | Megatron | ✅ | ❌ | | Archon | ❌ | ❌ | -Example scripts: +**Example scripts:** -| Engine | Example script | -| -------- | --------------------------------------------- | -| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | -| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | +| Engine | Example script | +| ------------ | ------------------------------------------------- | +| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | +| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | +| Megatron MoE | `examples/math/gsm8k_grpo_megatron_lora_moe.yaml` | + +For Megatron + vLLM, AReaL now supports: + +- LoRA fine-tuning on MoE architectures such as Qwen3 MoE with XCCL-based LoRA weight. +- Cross-node LoRA training when the Megatron and rollout groups span multiple nodes. ## Core LoRA Parameters @@ -50,5 +56,4 @@ Example scripts: - Start with `r=16` or `r=32` for most models, then tune upward only if needed. - Keep `target_modules` consistent with your model architecture naming. -- Currently only dense models (non MoE) are supported. - For Megatron backend, LoRA requires `megatron-bridge` instead of `mbridge`. diff --git a/docs/zh/reference/lora.md b/docs/zh/reference/lora.md index 71b8344ef6..dba6abe6d4 100644 --- a/docs/zh/reference/lora.md +++ b/docs/zh/reference/lora.md @@ -1,6 +1,6 @@ # LoRA 参考 -LoRA 是一种参数高效的微调技术,会在预训练权重中注入可训练的低秩矩阵, 通常作用在线性层附近。与全参数微调相比,LoRA 可以显著降低显存占用和 计算开销,从而让大模型的 +LoRA 是一种参数高效的微调技术,会在预训练权重中注入可训练的低秩矩阵, 通常作用在线性层附近。与全参数微调相比,LoRA 可以显著降低显存占用和计算开销, 从而让大模型的 RL 微调在硬件资源有限的条件下也更具可行性。 在 AReaL 中,LoRA 尤其适用于以下场景: @@ -8,7 +8,7 @@ RL 微调在硬件资源有限的条件下也更具可行性。 - 在相对有限的硬件条件下进行超大模型的强化学习训练,例如使用 8 x 80 GB GPU 训练 70B+ 规模模型, - 由于显存压力更低,可以支持更大的 batch size, - 模型迁移与部署更加简单,因为只需要保存和分发 LoRA adapter, -- \[Future\] 更高效地并行微调多个 LoRA adapter,以提升硬件利用率 (参见 RFC +- \[Future\] 更高效地并行微调多个 LoRA adapter,以提升硬件利用率(参见 RFC [#609](https://github.com/inclusionAI/AReaL/issues/609))。 本文档说明如何在 RL 训练中启用 LoRA,并配置相关参数。 @@ -23,12 +23,18 @@ AReaL 当前的 LoRA 支持矩阵如下: | Megatron | ✅ | ❌ | | Archon | ❌ | ❌ | -示例脚本: +**示例脚本:** -| Engine | Example script | -| -------- | --------------------------------------------- | -| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | -| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | +| Engine | Example script | +| ------------ | ------------------------------------------------- | +| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | +| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | +| Megatron MoE | `examples/math/gsm8k_grpo_megatron_lora_moe.yaml` | + +对于 Megatron + vLLM,AReaL 现在支持: + +- 在 Qwen3 MoE 等 MoE 架构上进行 LoRA 微调,并通过 XCCL 更新 LoRA 权重。 +- 当 Megatron 与 rollout group 横跨多个节点时进行跨节点 LoRA 训练。 ## 核心 LoRA 参数 @@ -44,5 +50,4 @@ AReaL 当前的 LoRA 支持矩阵如下: - 可先从 `r=16` 或 `r=32` 开始,再按效果和资源逐步调参。 - `target_modules` 需与具体模型的层命名保持一致。 -- 当前仅支持 dense 模型(非 MoE)。 - 对于 Megatron 后端,LoRA 需要使用 `megatron-bridge`,而不是 `mbridge`。 diff --git a/examples/math/gsm8k_grpo_megatron_lora_moe.yaml b/examples/math/gsm8k_grpo_megatron_lora_moe.yaml new file mode 100644 index 0000000000..8ce555d901 --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_lora_moe.yaml @@ -0,0 +1,191 @@ +experiment_name: gsm8k-grpo-megatron-lora-moe +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "vllm:d1p1t2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + use_lora: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + lora_name: "lora-gsm8k" + +actor: + backend: "megatron:(attn:d1p6t1c1|ffn:d1p6t1e1)" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-30B-A3B-Base + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 3e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: cosine + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + bridge_type: megatron-bridge + weight_update_mode: xccl + use_lora: ${rollout.use_lora} + peft_type: lora + lora_rank: 32 + lora_alpha: 32 + target_modules: [linear_qkv, linear_proj, linear_fc1, linear_fc2] + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + enforce_eager: true + enable_lora: ${rollout.use_lora} + max_lora_rank: ${actor.lora_rank} + +# datasets +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 16 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From e8c1e1fd9b336370834091ced4c4351a28c213ce Mon Sep 17 00:00:00 2001 From: xiao <102247755+Wangxiaoxiaoa@users.noreply.github.com> Date: Fri, 17 Apr 2026 18:30:21 +0800 Subject: [PATCH 2/4] fix: handle integer device ids in ray rpc server (#1199) --- areal/infra/rpc/ray_rpc_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/areal/infra/rpc/ray_rpc_server.py b/areal/infra/rpc/ray_rpc_server.py index 12f28eab1e..ce6d9ed892 100644 --- a/areal/infra/rpc/ray_rpc_server.py +++ b/areal/infra/rpc/ray_rpc_server.py @@ -45,6 +45,11 @@ def _get_device(self): return current_platform.current_device() + def _get_device_type(self) -> str: + from areal.infra.platforms import current_platform + + return current_platform.device_type + def _should_broadcast_payload( self, engine: TrainEngine | InferenceEngine, @@ -181,7 +186,7 @@ def call( if ( isinstance(engine, TrainEngine) and engine.initialized - and self._get_device().type != "cpu" + and self._get_device_type() != "cpu" ): from areal.infra.platforms import current_platform From 8965973baa2c158f482d45d3e06e2c50fb7f4918 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Fri, 17 Apr 2026 18:57:41 +0800 Subject: [PATCH 3/4] chore: add @CormickKneey as maintainer and agent_service codeowner (#1201) Add Han Jiang (@CormickKneey) to the GOVERNANCE.md maintainers table and assign codeowner for /areal/experimental/agent_service/ in CODEOWNERS. --- .github/CODEOWNERS | 1 + GOVERNANCE.md | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f987027ff7..b5c9ac2c55 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,6 +5,7 @@ /areal/api/ @garrett4wade /areal/engine/ @rchardx /areal/experimental/inference_service @nuzant +/areal/experimental/agent_service/ @CormickKneey /areal/infra/ @garrett4wade /areal/models/ @rchardx /areal/trainer/ @garrett4wade diff --git a/GOVERNANCE.md b/GOVERNANCE.md index 191e6f87e9..50e8074206 100644 --- a/GOVERNANCE.md +++ b/GOVERNANCE.md @@ -23,6 +23,7 @@ project. | Zhiyu Mei | AReaL Team, Ant Group | @nuzant | | Xujie Shen | AReaL Team, Ant Group | @fishcrap | | Tongkai Yang | AReaL Team, Ant Group | @fredy12 | +| Han Jiang | AReaL Team, Ant Group | @CormickKneey | ### Lead Maintainer (BDFL) From f3d7e50ac81970633e3f03b7541d42c067a29844 Mon Sep 17 00:00:00 2001 From: xiao <102247755+Wangxiaoxiaoa@users.noreply.github.com> Date: Fri, 17 Apr 2026 18:58:30 +0800 Subject: [PATCH 4/4] fix: serialize ray object refs in rpc payloads (#1198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: serialize ray object refs in rpc payloads * fix(infra): use inline ray imports in rpc serialization Ray is a mandatory dependency -- remove unnecessary try-except guards and move imports to runtime/inline per review feedback. Key changes: - Remove top-level try-except import ray in serialization.py - Add inline import ray.cloudpickle in SerializedRayObjectRef methods - Add inline import ray in serialize_value for ObjectRef isinstance check - Remove skipif decorator and try-except in test_serialization.py Refs: #1198 --------- Co-authored-by: 博惟 --- areal/infra/rpc/serialization.py | 37 ++++++++++++++++++++++++++++++++ tests/test_serialization.py | 14 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/areal/infra/rpc/serialization.py b/areal/infra/rpc/serialization.py index 8a5a572120..5b460fcfe8 100644 --- a/areal/infra/rpc/serialization.py +++ b/areal/infra/rpc/serialization.py @@ -248,6 +248,26 @@ def to_image(self) -> "ImageObject": return image +class SerializedRayObjectRef(BaseModel): + """Pydantic model for serialized ray.ObjectRef handles.""" + + type: Literal["ray_object_ref"] = Field(default="ray_object_ref") + data: str + + @classmethod + def from_object_ref(cls, ref: Any) -> "SerializedRayObjectRef": + import ray.cloudpickle + + payload = ray.cloudpickle.dumps(ref) + return cls(data=base64.b64encode(payload).decode("utf-8")) + + def to_object_ref(self) -> Any: + import ray.cloudpickle + + payload = base64.b64decode(self.data.encode("utf-8")) + return ray.cloudpickle.loads(payload) + + class SerializedDataclass(BaseModel): """Pydantic model for serialized dataclass with metadata. @@ -569,6 +589,13 @@ def serialize_value(value: Any) -> Any: if ImageObject is not None and isinstance(value, ImageObject): return SerializedPILImage.from_image(value).model_dump() + # Handle Ray object references when HTTP RPC needs to carry RTensor shard + # handles across processes. + import ray + + if isinstance(value, ray.ObjectRef): + return SerializedRayObjectRef.from_object_ref(value).model_dump() + # Handle dataclass instances (check before dict, as dataclasses can be dict-like) # Note: is_dataclass returns True for both classes and instances, so check it's not a type if is_dataclass(value) and not isinstance(value, type): @@ -698,6 +725,16 @@ def deserialize_value(value: Any) -> Any: f"Failed to deserialize PIL image, treating as regular dict: {e}" ) + # Check for SerializedRayObjectRef marker + if value.get("type") == "ray_object_ref": + try: + serialized_ref = SerializedRayObjectRef.model_validate(value) + return serialized_ref.to_object_ref() + except Exception as e: + logger.warning( + f"Failed to deserialize ray.ObjectRef, treating as regular dict: {e}" + ) + # Check for SerializedTensor marker if value.get("type") == "tensor": try: diff --git a/tests/test_serialization.py b/tests/test_serialization.py index e1006c8186..bcea605163 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import ray import torch from PIL import Image from transformers import AutoTokenizer @@ -176,6 +177,19 @@ def test_nested_structure(self): assert torch.equal(deserialized["list"][0], payload["list"][0]) assert deserialized["meta"]["text"] == "value" + def test_ray_object_ref_roundtrip(self): + """Ray ObjectRef handles should round-trip through RPC serialization.""" + ray.init(local_mode=True, ignore_reinit_error=True) + try: + original = ray.put({"value": 123}) + serialized = serialize_value({"ref": original}) + assert serialized["ref"]["type"] == "ray_object_ref" + + deserialized = deserialize_value(serialized) + assert ray.get(deserialized["ref"]) == {"value": 123} + finally: + ray.shutdown() + @pytest.mark.skipif( not hasattr(torch, "cuda") or not torch.cuda.is_available(), reason="CUDA not available",