From 8fbfc8b8bff17205e5ddd64a3be2343a80dd9366 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 May 2026 15:55:20 +0800 Subject: [PATCH] refactor rollout weight update flow --- xtuner/v1/rl/trainer/controller.py | 3 +- xtuner/v1/rl/trainer/update_weighter.py | 951 ------------------ xtuner/v1/rl/trainer/worker.py | 23 +- xtuner/v1/rl/weight_update/__init__.py | 46 + xtuner/v1/rl/weight_update/client.py | 31 + xtuner/v1/rl/weight_update/data.py | 41 + xtuner/v1/rl/weight_update/exporter.py | 300 ++++++ xtuner/v1/rl/weight_update/transport.py | 714 +++++++++++++ xtuner/v1/rl/weight_update/update_weighter.py | 249 +++++ xtuner/v1/train/rl_trainer.py | 4 + 10 files changed, 1407 insertions(+), 955 deletions(-) delete mode 100644 xtuner/v1/rl/trainer/update_weighter.py create mode 100644 xtuner/v1/rl/weight_update/__init__.py create mode 100644 xtuner/v1/rl/weight_update/client.py create mode 100644 xtuner/v1/rl/weight_update/data.py create mode 100644 xtuner/v1/rl/weight_update/exporter.py create mode 100644 xtuner/v1/rl/weight_update/transport.py create mode 100644 xtuner/v1/rl/weight_update/update_weighter.py diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index e3dcf1917..8f0c629f6 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -7,6 +7,7 @@ from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.weight_update import TrainRolloutMode from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import get_logger @@ -292,7 +293,7 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): def update_rollout_info(self, info_dict): ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined] - def set_train_rollout_mode(self, train_rollout_mode: str): + def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode): ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers]) def update_weights(self): diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py deleted file mode 100644 index 23cede80d..000000000 --- a/xtuner/v1/rl/trainer/update_weighter.py +++ /dev/null @@ -1,951 +0,0 @@ -import json -import os -import socket -from concurrent.futures import ThreadPoolExecutor -from datetime import timedelta -from itertools import chain -from threading import Lock -from typing import Any, Dict, List, TypeAlias, cast - -import requests -import torch -import torch.distributed as dist -import tqdm -from packaging.version import parse as parse_version -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.distributed_c10d import ( - Backend, - PrefixStore, - Store, - _new_process_group_helper, - _world, - default_pg_timeout, - rendezvous, -) -from torch.distributed.tensor import DTensor - -from xtuner.v1.model.compose.base import BaseComposeConfig -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration -from xtuner.v1.model.moe.moe import MoE -from xtuner.v1.rl.rollout.worker import RolloutConfig -from xtuner.v1.utils import ( - get_device, - get_torch_device_module, - monkey_unpatch_torch_reductions, - ray_method, -) -from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec - - -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs -RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) -DEVICE = get_device() -DEVICE_MODULE = get_torch_device_module() - - -class UpdateWeighter: - rank: int - logger: Any - config: Any - - def _init_update_weighter(self): - # Used to update weight to rollout engine - self.rollout_device_mesh: DeviceMesh | None = None - self.rollout_url: str | None = None - self.rollout_cfg_info: dict = dict() - self.endpoints: dict[str, str] = dict() - self.endpoints["update_weights"] = "update_weights" - - self.rollout_engine_rank_mesh_array: DeviceMeshRaw = [] - self.rollout_server_url_dict: ServiceUrlMap = {} - self.worker_server_urls_status: dict[str, bool] = {} - - self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() - self._default_ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - self._ipc_tensor_bytes_dict_by_dtype: dict[torch.dtype, int] = {} - self._update_params_ipc_tensor_dict_by_dtype: dict[torch.dtype, torch.Tensor] = {} - self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None - self._update_params_ipc_event = None - self._sglang_disagg_group: dist.ProcessGroup | None = None - self._sglang_disagg_group_name: str | None = None - self._sglang_disagg_engine_urls: list[str] = [] - self._sglang_disagg_executor: ThreadPoolExecutor | None = None - self._train_update_sync_group: dist.ProcessGroup | None = None - self._sglang_disagg_update_lock = Lock() - self.use_fake_weight_update = ( - False # 仅在 lmdeploy 后端的 disaggregated 模式下使用,表示是否使用 fake 接口进行权重更新 - ) - - def _hook_compare_test_sent_and_received_weight_hash( - self, - result: dict[str, Any], - *, - bucket_idx: int | None = None, - names: list[str] | None = None, - ) -> None: - """Test hook for comparing sent and received weight hashes. - - This hook is intentionally a no-op in production code and is expected to be overridden in unit tests that need - to compare training-side sent hashes with rollout-side received hashes returned by SGLang. - """ - return - - @ray_method - def update_rollout_info( - self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - rollout_config: RolloutConfig, - worker_server_urls_status: Dict[str, bool], - api_server_url: str | None = None, - ): - """Update the rollout information for the training worker.""" - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - if self.rollout_device_mesh is None: - self.rollout_device_mesh = DeviceMesh( - "cpu", - mesh=engine_rank_mesh_array, - mesh_dim_names=("engine_instance", "engine_parallel"), - ) - rollout_server_url = server_url_dict.get(self.rank, "") - if worker_server_urls_status.get(rollout_server_url, "False") is False: - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_url = None - else: - self.rollout_url = rollout_server_url - - self.rollout_engine_rank_mesh_array = [[int(rank) for rank in ranks] for ranks in engine_rank_mesh_array] - self.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} - self.worker_server_urls_status = worker_server_urls_status - - self.rollout_cfg_info["tp"] = tp - self.rollout_cfg_info["ep"] = ep - self.rollout_cfg_info["api_key"] = rollout_config.api_key - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - self.rollout_cfg_info["backend"] = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - self.rollout_cfg_info["backend"] = "vllm" - else: - self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( - "lmdeploy_backend", "pytorch" - ) - - @ray_method - def set_train_rollout_mode(self, train_rollout_mode: str): - mode = train_rollout_mode.lower() - if mode == "colocate": - self.is_train_rollout_colocated = True - elif mode == "disaggregated": - self.is_train_rollout_colocated = False - - backend = self.rollout_cfg_info.get("backend", "").lower() - if backend == "vllm": - raise NotImplementedError("Disaggregated train-rollout mode is not supported for vLLM backend.") - - elif backend == "pytorch" or backend == "turbomind": - self.logger.warning( - "Disaggregated train-rollout mode for lmdeploy backend is not fully supported yet. " - "A fake no-op interface will be used temporarily.", - ) - self.use_fake_weight_update = True # 后续 fake 接口可根据这个标志跳过实际同步 - - elif backend == "sglang": - self.use_fake_weight_update = False - else: - raise ValueError( - f"Unsupported rollout backend for disaggregated mode: {backend!r}. " - "Expected 'vllm', 'pytorch', 'turbomind' or 'sglang'." - ) - - else: - raise ValueError( - f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." - ) - - if self.is_train_rollout_colocated: - self._reset_sglang_disagg_group() - - def _reset_sglang_disagg_group(self): - if self._sglang_disagg_executor is not None: - self._sglang_disagg_executor.shutdown(wait=False, cancel_futures=True) - try: - if self._sglang_disagg_group is not None: - dist.destroy_process_group(self._sglang_disagg_group) - except Exception: - pass - self._sglang_disagg_group = None - self._sglang_disagg_group_name = None - self._sglang_disagg_engine_urls = [] - self._sglang_disagg_executor = None - - def _get_train_update_sync_group(self) -> dist.ProcessGroup: - if self._train_update_sync_group is None: - ranks = list(range(dist.get_world_size())) - self._train_update_sync_group = dist.new_group(ranks=ranks, backend="gloo") - return self._train_update_sync_group - - @ray_method - def update_weights(self): - """Update the model weights.""" - if not hasattr(self, "is_train_rollout_colocated"): - raise RuntimeError( - "train/rollout mode is not set. Please call set_train_rollout_mode() before update_weights()." - ) - - if self.is_train_rollout_colocated: - self._update_weights_colocated() - else: - self._update_weights_disaggregated() - - def _update_weights_colocated(self): - DEVICE_MODULE.empty_cache() - self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) - if self.rollout_cfg_info.get("backend") == "turbomind": - self._update_weights_by_layer() - else: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="language_model", final_update=False) - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) - else: - self._update_weights_hf_generator(final_update=True) - self._update_params_ipc_tensor_dict_by_dtype = {} - self._last_update_params_ipc_tensor_dtype = None - self._update_params_ipc_event = None - - DEVICE_MODULE.empty_cache() - - def _update_weights_disaggregated(self): - if self.use_fake_weight_update: - self.logger.warning( - "Using fake weight update interface, no actual weight synchronization will happen. This is only for testing purposes and should not be used in production." - ) - return - - DEVICE_MODULE.empty_cache() - try: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="language_model", final_update=False) - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) - else: - self._update_weights_hf_generator(final_update=True) - finally: - DEVICE_MODULE.empty_cache() - - def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): - fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) - model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size - if not fused_param_groups: - return - - def _get_hf_params( - fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], - ) -> tuple[list[torch.Tensor], list[str]]: - hf_keys_list: list[str] = [] - hf_tensor_list: list[torch.Tensor] = [] - - for fsdp_tensor, load_spec in fsdp_tensor_list: - hf_keys = load_spec.hf_keys - if model_ep_size > 1 and model.ep_mesh is not None: - if load_spec.name not in self._global_hf_keys_mapping_cache: - global_hf_keys: list[list[str] | None] = [None] * model_ep_size - dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) - global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) - self._global_hf_keys_mapping_cache[load_spec.name] = list( - chain.from_iterable(global_hf_keys_gathered) - ) - hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] - - fused_full_tensor = fsdp_tensor.bfloat16() - if isinstance(fused_full_tensor, DTensor): - fused_full_tensor = fused_full_tensor.full_tensor() - dim = cast(int, load_spec.dim) - num_split = len(hf_keys) - hf_tensor_size = fused_full_tensor.shape[dim] / num_split - assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" - hf_tensor_size = int(hf_tensor_size) - - hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) - assert num_split % target_ep_size == 0, ( - f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" - ) - start_idx = (num_split // target_ep_size) * target_ep_rank - end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) - - hf_keys_list.extend(hf_keys[start_idx:end_idx]) - hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) - - hf_tensor_list = [ - model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) - ] - - return hf_tensor_list, hf_keys_list - - safetensor_size = 0 - dtype = torch.bfloat16 - tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - - for param, load_spec in fused_param_groups: - tensor_size = dtype.itemsize * param.numel() // target_ep_size - if safetensor_size + tensor_size > bucket_size and tensor_list: - hf_params, name_list = _get_hf_params(tensor_list) - yield name_list, hf_params - safetensor_size = tensor_size - name_list = load_spec.hf_keys.copy() - tensor_list = [(param, load_spec)] - continue - safetensor_size += tensor_size - tensor_list.append((param, load_spec)) - - if tensor_list: - hf_params, name_list = _get_hf_params(tensor_list) - yield name_list, hf_params - - @torch.no_grad() - def _update_weights_hf_generator(self, submodule=None, final_update=False): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - if submodule: - model = getattr(model, submodule) - - dtype = torch.bfloat16 - bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - same_gen = model._get_same_hf_param( - model._group_param_by_load_spec(LoadEnum.SAME), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - ) - - train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 - if train_enable_ep: - if self.rollout_cfg_info["ep"] > 1: - fused_gen = self._rl_get_fused_ep_hf_param( - model, - target_ep_rank=self.rollout_device_mesh["engine_parallel"].get_coordinate()[0], - target_ep_size=self.rollout_device_mesh["engine_parallel"].size(), - bucket_size=bucket_size, - ) - else: - fused_gen = self._rl_get_fused_ep_hf_param( - model, - target_ep_rank=0, - target_ep_size=1, - bucket_size=bucket_size, - ) - else: - fused_gen = model._get_fused_hf_param( - model._group_param_by_load_spec(LoadEnum.FUSED), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - update_weights_for_rl=True, - ) - shard_gen = model._get_shard_hf_param( - model._group_param_by_load_spec(LoadEnum.SHARD), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - ) - - for name_list, fused_param_list in fused_gen: - state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} - self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) - del state_dict, name_list, fused_param_list - - for name_list, param_list in chain(same_gen, shard_gen): - state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} - self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) - del state_dict, name_list, param_list - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: - self.request_update_params({}, train_enable_ep=train_enable_ep, finished=True) - - if self.is_train_rollout_colocated: - dist.barrier() - else: - dist.barrier(group=self._get_train_update_sync_group()) - DEVICE_MODULE.empty_cache() - return - - def _update_weights_by_layer(self): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - DEVICE_MODULE.empty_cache() - - if isinstance(model.config, BaseComposeConfig): - # TODO: support float8 for vision compose model - dtype = torch.bfloat16 - else: - if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - dtype = torch.float8_e4m3fn - else: - dtype = torch.bfloat16 - - def get_params(tensor_list, name_list, save_dtype): - _tensor_list, _spec_list = list(zip(*tensor_list)) - fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - if save_dtype == torch.float8_e4m3fn: - fsdp_unshard_tensor_list, name_list = model._to_float8( - fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype - ) - return fsdp_unshard_tensor_list, name_list - - saved_list = [] - is_qwen3vl = False - if isinstance(model.config, BaseComposeConfig): - language_model = model.language_model - if isinstance(model, Qwen3VLForConditionalGeneration): - is_qwen3vl = True - else: - language_model = model - - if is_qwen3vl: - vision_hf_prefix = "model.visual." - projector_hf_prefix = "model.visual." - else: - vision_hf_prefix = "model.vision_tower." - projector_hf_prefix = "model.multi_modal_projector." - - for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): - tensor_list = [] - name_list = [] - for sub_name, param in layer.state_dict().items(): - if isinstance(model.config, BaseComposeConfig): - saved_list.append(f"language_model.layers.{i}.{sub_name}") - else: - saved_list.append(f"layers.{i}.{sub_name}") - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") - - if isinstance(model.config, BaseComposeConfig): - name = f"model.language_model.layers.{i}.{sub_name}" - else: - name = f"model.layers.{i}.{sub_name}" - - if ".experts." in name and ".mlp.experts." not in name: - name = name.replace(".experts.", ".mlp.experts.") - if ".gate." in name and ".mlp.gate." not in name: - name = name.replace(".gate.", ".mlp.gate.") - name_list.append(name) - tensor_list.append((local_tensor, load_spec)) - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - for name, param in model.state_dict().items(): - if name in saved_list: - continue - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = model.load_spec_mapping.get(name) - - if isinstance(model.config, BaseComposeConfig): - if "vision_tower." in name: - name = name.replace("vision_tower.", vision_hf_prefix) - elif "multi_modal_projector." in name: - name = name.replace("multi_modal_projector.", projector_hf_prefix) - elif name == "language_model.norm.weight": - name = "model.language_model.norm.weight" - elif name == "language_model.embed_tokens.weight": - name = "model.language_model.embed_tokens.weight" - elif name == "language_model.lm_head.weight": - name = "lm_head.weight" - else: - if name == "norm.weight": - name = "model.norm.weight" - elif name == "embed_tokens.weight": - name = "model.embed_tokens.weight" - tensor_list = [(local_tensor, load_spec)] - name_list = [name] - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - @staticmethod - def _compute_state_dict_bytes(state_dict: Dict[str, torch.Tensor]) -> int: - total_bytes = 0 - for tensor in state_dict.values(): - total_bytes += tensor.numel() * tensor.element_size() - return total_bytes - - @staticmethod - def _init_external_process_group( - backend: str | Backend | None = None, - init_method: str | None = None, - timeout: timedelta | None = None, - world_size: int = -1, - rank: int = -1, - store: Store | None = None, - group_name: str | None = None, - pg_options: Any | None = None, - ) -> dist.ProcessGroup: - assert (store is None) or (init_method is None), "Cannot specify both store and init_method." - if store is not None: - assert world_size > 0, "world_size must be positive if using store" - assert rank >= 0, "rank must be non-negative if using store" - elif init_method is None: - init_method = "env://" - - backend = Backend(backend) if backend else Backend("undefined") - if timeout is None: - timeout = default_pg_timeout - - if store is None: - assert init_method is not None - rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) - store, rank, world_size = next(rendezvous_iterator) - store.set_timeout(timeout) - if group_name is not None: - store = PrefixStore(group_name, store) - - pg_options_param_name = ( - "backend_options" if parse_version(torch.__version__) >= parse_version("2.6") else "pg_options" - ) - pg, _ = _new_process_group_helper( - world_size, - rank, - [], - backend, - store, - group_name=group_name, - **{pg_options_param_name: pg_options}, - timeout=timeout, - ) - _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} - return pg - - @staticmethod - def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): - return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) - - def _build_lmdeploy_flattened_tensor_data(self, state_dict: dict, flattened_tensor_bucket_cls) -> dict: - # LMDeploy flattened buckets require all tensors in one bucket to share a dtype. - state_dict_dtype = state_dict[next(iter(state_dict))].dtype - update_params_ipc_tensor = self._update_params_ipc_tensor_dict_by_dtype.get(state_dict_dtype, None) - state_dict_bytes = self._compute_state_dict_bytes(state_dict) - ipc_tensor_bytes = self._ipc_tensor_bytes_dict_by_dtype.get( - state_dict_dtype, - self._default_ipc_tensor_bytes, - ) - dtype_changed = ( - self._last_update_params_ipc_tensor_dtype is not None - and state_dict_dtype != self._last_update_params_ipc_tensor_dtype - ) - need_resize = state_dict_bytes > ipc_tensor_bytes - send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None - - if update_params_ipc_tensor is not None: - self._update_params_ipc_event.wait() - if need_resize: - torch.cuda.synchronize() - - if update_params_ipc_tensor is None or need_resize: - ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes) - self._ipc_tensor_bytes_dict_by_dtype[state_dict_dtype] = ipc_tensor_bytes - update_params_ipc_tensor = self._create_ipc_tensor( - ipc_tensor_bytes, - state_dict_dtype, - ) - self._update_params_ipc_tensor_dict_by_dtype[state_dict_dtype] = update_params_ipc_tensor - - flattened_tensor_bucket = flattened_tensor_bucket_cls( - named_tensors=list(state_dict.items()), - flattened_tensor=update_params_ipc_tensor, - ) - flattened_tensor_data = { - "metadata": flattened_tensor_bucket.get_metadata(), - "require_clone": False, - } - self._update_params_ipc_event.record() - self._last_update_params_ipc_tensor_dtype = state_dict_dtype - - if send_ipc_tensor: - flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() - flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() - return flattened_tensor_data - - def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo: - engine_info: RolloutEngineInfo = [] - seen_urls: set[str] = set() - rank_to_engine_size: dict[int, int] = {} - for engine_ranks in self.rollout_engine_rank_mesh_array: - engine_size = len(engine_ranks) - for rank in engine_ranks: - rank_to_engine_size[int(rank)] = engine_size - - for rank, url in sorted(self.rollout_server_url_dict.items(), key=lambda item: int(item[0])): - rank = int(rank) - if not url or url in seen_urls: - continue - if self.worker_server_urls_status.get(url, False) is False: - continue - seen_urls.add(url) - engine_info.append( - ( - rank, - url, - rank_to_engine_size.get( - rank, - max(self.rollout_cfg_info["tp"], self.rollout_cfg_info["ep"]), - ), - ) - ) - return engine_info - - def _ensure_sglang_disagg_group(self): - if self._sglang_disagg_group is not None: - return - engine_info = self._get_sglang_disagg_engine_info() - if not engine_info: - self.logger.error("No active rollout engine url, cannot init sglang weight update group") - return - - os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" - backend = "nccl" - - master_address = None - master_port = None - # get address and port for weight-update - try: - import ray - - master_address = ray.util.get_node_ip_address() - except Exception: - master_address = socket.gethostbyname(socket.gethostname()) - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - master_port = int(sock.getsockname()[1]) - - group_name = f"xtuner_sglang_weight_update_{self.rank}" - world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 - - self._sglang_disagg_executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) - init_futures = [] - rank_offset = 1 - for _, url, engine_size in engine_info: - payload = { - "master_address": master_address, - "master_port": master_port, - "rank_offset": rank_offset, - "world_size": world_size, - "group_name": group_name, - "backend": backend, - } - init_futures.append( - self._sglang_disagg_executor.submit( - requests.post, - f"{url}/init_weights_update_group", - json=payload, - ) - ) - rank_offset += engine_size - - self._sglang_disagg_group = self._init_external_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) - - for init_future in init_futures: - response = init_future.result() - response.raise_for_status() - result = response.json() - assert result.get("success", True), ( - f"SGLang init_weights_update_group failed: {result.get('message', result)}" - ) - - self._sglang_disagg_group_name = group_name - self._sglang_disagg_engine_urls = [url for _, url, _ in engine_info] - - def _request_update_params_sglang_disaggregated(self, state_dict): - if not state_dict: - return - - train_sync_group = self._get_train_update_sync_group() - head_rank = 0 - if dist.get_rank() != head_rank: - dist.barrier(group=train_sync_group) - return - - self._ensure_sglang_disagg_group() - if self._sglang_disagg_group is None: - dist.barrier(group=train_sync_group) - return - - assert self._sglang_disagg_executor is not None - assert self._sglang_disagg_group_name is not None - with self._sglang_disagg_update_lock: - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - except Exception as e: - raise RuntimeError( - "Disaggregated update_weights currently only supports sglang builds " - "that provide `sglang.srt.model_executor.model_runner.FlattenedTensorBucket`." - ) from e - - names = list(state_dict.keys()) - tensors = [ - tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() - ] - payload = { - "names": names, - "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], - "shapes": [list(tensor.shape) for tensor in tensors], - "group_name": self._sglang_disagg_group_name, - "load_format": "flattened_bucket", - } - update_futures = [ - self._sglang_disagg_executor.submit( - requests.post, - f"{url}/update_weights_from_distributed", - json=payload, - ) - for url in self._sglang_disagg_engine_urls - ] - assert self._sglang_disagg_group is not None - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) - flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() - - dist.broadcast(flattened_tensor, src=0, group=self._sglang_disagg_group) - DEVICE_MODULE.synchronize() - for update_future in update_futures: - response = update_future.result() - response.raise_for_status() - result = response.json() - self._hook_compare_test_sent_and_received_weight_hash( - result, - names=names, - ) - assert result.get("success", True), ( - f"SGLang update_weights_from_distributed failed: {result.get('message', result)}" - ) - dist.barrier(group=train_sync_group) - - @ray_method - def request_update_params(self, state_dict, train_enable_ep=False, finished=False): - """Send a request to update the parameters on the rollout workers. - - This method serializes the state dictionary and sends it to the - appropriate rollout worker via an HTTP request. - - Args: - state_dict (dict | list): The state dictionary containing the model - parameters to update. - train_enable_ep (bool): Whether the training engine enables expert parallelism. - Defaults to False. - finished (bool): A flag indicating whether this is the final - batch of updates. Defaults to False. - """ - - if self.rollout_cfg_info["backend"] == "sglang" and not self.is_train_rollout_colocated: - self._request_update_params_sglang_disaggregated(state_dict) - return - - cpu_mesh = self.rollout_device_mesh["engine_parallel"] - cpu_group = cpu_mesh.get_group() - head_rank = cpu_mesh.mesh[0].item() - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") - return - - if self.rollout_cfg_info["backend"] == "vllm": - - def serialize_state_dict(state_dict: dict) -> str: - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - - serialized_data = [None] * self.rollout_cfg_info["tp"] - dist.gather_object( - serialize_state_dict(state_dict), - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - } - data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) - data = dict(method="update_weight_npu_ipc", args=[data_]) - response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - return - - if self.rollout_cfg_info["backend"] == "pytorch": - # TODO(chenchiyu): remove lmdeploy related code - from lmdeploy.utils import serialize_state_dict - - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket and state_dict: - flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( - state_dict, - FlattenedTensorBucket, - ) - tp_serialized_data = serialize_state_dict(flattened_tensor_data) - else: - tp_serialized_data = serialize_state_dict(state_dict) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - elif self.rollout_cfg_info["backend"] == "pytorch": - if use_flattened_tensor_bucket and state_dict: - flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( - state_dict, - FlattenedTensorBucket, - ) - serialized_data = serialize_state_dict(flattened_tensor_data) - else: - serialized_data = serialize_state_dict(state_dict) - else: - # for turbomind backend, only head_rank should serialize data - serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None - else: - # sglang - from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions - - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - # NOTE: xtuner目前去掉sglang的patch也不会出问题,但为了保险起见,还是保留patch逻辑,并且在update_weights结束后unpatch - monkey_patch_torch_reductions() - state_dict = state_dict.items() - if self.rollout_cfg_info["tp"] == 1: - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - else: - serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - - serialized_data = [serialized_data] - else: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - else: - tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", - } - if self.rollout_cfg_info["backend"] == "sglang": - payload = { - "serialized_named_tensors": serialized_data, - "flush_cache": False, - } - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - if use_flattened_tensor_bucket: - payload["load_format"] = "flattened_bucket" - - url = f"{self.rollout_url}/update_weights_from_tensor" - response = requests.post(url, json=payload or {}) - response.raise_for_status() - else: - data = dict(serialized_named_tensors=serialized_data, finished=finished) - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if use_flattened_tensor_bucket and state_dict: - data["load_format"] = "flattened_bucket" - response = requests.post( - f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data - ) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - # TODO(chenchiyu): narrow this condition - if finished or ( - self.rollout_cfg_info["backend"] == "pytorch" and train_enable_ep and self.rollout_cfg_info["tp"] > 1 - ): - # This barrier is aim to make each tp head rank sync with other ranks in engine_parallel group - # which could not be barrier by `fsdp_foreach_allgather` of the next state dict. (Happens in same_gen, shard not tested) - # Without barrier, some ranks in engine_parallel group would not wait for current iter data ipc event recording in lmdeploy. - # They would write next iter state_dict into the ipc tensor before lmdeploy load current iter weight. - dist.barrier(group=cpu_group) - - monkey_unpatch_torch_reductions() - return diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index ebbb5c755..ddccd29d3 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -48,6 +48,7 @@ from xtuner.v1.profiler import profiling_memory, profiling_time from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, kl_penalty from xtuner.v1.rl.utils import SingleAcceleratorWorker +from xtuner.v1.rl.weight_update import TrainRolloutMode, UpdateWeighter from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -59,7 +60,6 @@ ) from ..rollout_is import merge_rollout_is_metrics -from .update_weighter import UpdateWeighter DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices @@ -200,7 +200,7 @@ class WorkerLogItem(TypedDict): sft_train_metrics: NotRequired[dict[str, float]] -class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter): +class TrainingWorker(SingleAcceleratorWorker): _SAVE_WEIGHTS_DIR = "weights" _SAVE_SFT_DATALOADER_DIR = "sft_dataloader" _SAVE_SFT_TRAIN_STATE_PATH = "sft_train_state.json" @@ -268,7 +268,24 @@ def __init__( if hasattr(worker_cfg.model_cfg.text_config, "mtp_config"): self.mtp_config = worker_cfg.model_cfg.text_config.mtp_config - self._init_update_weighter() + self.update_weighter = UpdateWeighter( + rank=self.rank, + logger=self.logger, + config=self.config, + engine=self._engine, + ) + + @ray_method + def update_rollout_info(self, *args, **kwargs): + return self.update_weighter.update_rollout_info(*args, **kwargs) + + @ray_method + def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode): + return self.update_weighter.set_train_rollout_mode(train_rollout_mode) + + @ray_method + def update_weights(self): + return self.update_weighter.update_weights() def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py new file mode 100644 index 000000000..36d497d5a --- /dev/null +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -0,0 +1,46 @@ +from .client import RolloutWeightUpdateClient +from .data import ( + DeviceMeshRaw, + RolloutBackend, + RolloutEngineInfo, + RolloutWeightUpdateInfo, + ServiceUrlMap, + TrainRolloutMode, + WeightTransportType, + WeightUpdateBatch, +) +from .exporter import WeightExporter +from .transport import ( + IPCBackendAdapter, + IPCWeightTransport, + LMDeployIPCBackendAdapter, + NCCLBackendAdapter, + NCCLWeightTransport, + SGLangIPCBackendAdapter, + SGLangNCCLBackendAdapter, + WeightTransport, +) +from .update_weighter import UpdateWeighter + + +__all__ = [ + "DeviceMeshRaw", + "IPCBackendAdapter", + "IPCWeightTransport", + "LMDeployIPCBackendAdapter", + "NCCLBackendAdapter", + "NCCLWeightTransport", + "RolloutBackend", + "RolloutEngineInfo", + "RolloutWeightUpdateInfo", + "RolloutWeightUpdateClient", + "SGLangIPCBackendAdapter", + "SGLangNCCLBackendAdapter", + "ServiceUrlMap", + "TrainRolloutMode", + "UpdateWeighter", + "WeightExporter", + "WeightTransportType", + "WeightUpdateBatch", + "WeightTransport", +] diff --git a/xtuner/v1/rl/weight_update/client.py b/xtuner/v1/rl/weight_update/client.py new file mode 100644 index 000000000..85ed9414e --- /dev/null +++ b/xtuner/v1/rl/weight_update/client.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any + +import requests + + +class RolloutWeightUpdateClient: + def __init__(self, api_key: list[str] | str | None): + self.api_key = api_key + + def _headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self.api_key is not None: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def collective_rpc(self, url: str, payload: dict[str, Any]): + return requests.post(f"{url}/collective_rpc", headers=self._headers(), json=payload) + + def update_weights(self, url: str, endpoint: str, payload: dict[str, Any]): + return requests.post(f"{url}/{endpoint}", headers=self._headers(), json=payload) + + def update_weights_from_tensor(self, url: str, payload: dict[str, Any]): + return requests.post(f"{url}/update_weights_from_tensor", json=payload or {}) + + def init_weights_update_group(self, url: str, payload: dict[str, Any]): + return requests.post(f"{url}/init_weights_update_group", json=payload) + + def update_weights_from_distributed(self, url: str, payload: dict[str, Any]): + return requests.post(f"{url}/update_weights_from_distributed", json=payload) diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py new file mode 100644 index 000000000..320a1dc63 --- /dev/null +++ b/xtuner/v1/rl/weight_update/data.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Literal, TypeAlias + +import torch +from torch.distributed.device_mesh import DeviceMesh + + +DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices. +ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs. +RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) +TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode. +RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend. +WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. + + +@dataclass +class RolloutWeightUpdateInfo: + api_key: list[str] | str | None = None + rollout_device_mesh: DeviceMesh | None = None + rollout_url: str | None = None + backend: RolloutBackend | None = None + tp: int = 1 + ep: int = 1 + train_rollout_mode: TrainRolloutMode | None = None + transport_type: WeightTransportType | None = None + rollout_cfg_info: dict = field(default_factory=dict) + endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"}) + rollout_engine_rank_mesh_array: DeviceMeshRaw = field(default_factory=list) + rollout_server_url_dict: ServiceUrlMap = field(default_factory=dict) + worker_server_urls_status: dict[str, bool] = field(default_factory=dict) + + +@dataclass +class WeightUpdateBatch: + """A single bucket of weights to send to rollout workers.""" + + state_dict: dict[str, torch.Tensor] + train_enable_ep: bool = False + finished: bool = False diff --git a/xtuner/v1/rl/weight_update/exporter.py b/xtuner/v1/rl/weight_update/exporter.py new file mode 100644 index 000000000..bb6df2834 --- /dev/null +++ b/xtuner/v1/rl/weight_update/exporter.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from itertools import chain +from typing import Any, cast + +import torch +import torch.distributed as dist +import tqdm +from torch.distributed.tensor import DTensor + +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.moe.moe import MoE +from xtuner.v1.utils import get_device, get_torch_device_module +from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec + +from .data import RolloutWeightUpdateInfo, WeightUpdateBatch + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class WeightExporter: + def __init__( + self, + *, + config: Any, + engine: Any, + rollout_info: RolloutWeightUpdateInfo, + global_hf_keys_mapping_cache: dict[str, list[str]], + ): + self.config = config + self._engine = engine + self.rollout_info = rollout_info + self._global_hf_keys_mapping_cache = global_hf_keys_mapping_cache + + def _get_hf_params( + self, + model, + model_ep_size: int, + target_ep_size: int, + target_ep_rank: int, + fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], + ) -> tuple[list[torch.Tensor], list[str]]: + hf_keys_list: list[str] = [] + hf_tensor_list: list[torch.Tensor] = [] + + for fsdp_tensor, load_spec in fsdp_tensor_list: + hf_keys = load_spec.hf_keys + if model_ep_size > 1 and model.ep_mesh is not None: + # Each train EP rank owns only part of the HF key list; gather the global + # mapping once so rollout EP ranks can receive the right slice. + if load_spec.name not in self._global_hf_keys_mapping_cache: + global_hf_keys: list[list[str] | None] = [None] * model_ep_size + dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) + global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) + self._global_hf_keys_mapping_cache[load_spec.name] = list( + chain.from_iterable(global_hf_keys_gathered) + ) + hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] + + fused_full_tensor = fsdp_tensor.bfloat16() + if isinstance(fused_full_tensor, DTensor): + fused_full_tensor = fused_full_tensor.full_tensor() + # FUSED load specs pack multiple HF tensors along load_spec.dim; split them + # back into HF tensors before selecting the target rollout EP shard. + dim = cast(int, load_spec.dim) + num_split = len(hf_keys) + hf_tensor_size = fused_full_tensor.shape[dim] / num_split + assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" + hf_tensor_size = int(hf_tensor_size) + + hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) + assert num_split % target_ep_size == 0, ( + f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" + ) + start_idx = (num_split // target_ep_size) * target_ep_rank + end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) + + hf_keys_list.extend(hf_keys[start_idx:end_idx]) + hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) + + hf_tensor_list = [ + model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) + ] + + return hf_tensor_list, hf_keys_list + + def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): + fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) + model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size + if not fused_param_groups: + return + + safetensor_size = 0 + dtype = torch.bfloat16 + tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] + + for param, load_spec in fused_param_groups: + tensor_size = dtype.itemsize * param.numel() // target_ep_size + if safetensor_size + tensor_size > bucket_size and tensor_list: + hf_params, name_list = self._get_hf_params( + model, + model_ep_size=model_ep_size, + target_ep_size=target_ep_size, + target_ep_rank=target_ep_rank, + fsdp_tensor_list=tensor_list, + ) + yield name_list, hf_params + safetensor_size = tensor_size + # Kept to mirror the legacy generator layout; the next iteration rebuilds + # name_list from tensor_list before yielding. + name_list = load_spec.hf_keys.copy() + tensor_list = [(param, load_spec)] + continue + safetensor_size += tensor_size + tensor_list.append((param, load_spec)) + + if tensor_list: + hf_params, name_list = self._get_hf_params( + model=model, + model_ep_size=model_ep_size, + target_ep_size=target_ep_size, + target_ep_rank=target_ep_rank, + fsdp_tensor_list=tensor_list, + ) + yield name_list, hf_params + + @torch.no_grad() + def iter_hf_batches(self, submodule=None, final_update=False): + """Update the model weights.""" + rollout_device_mesh = self.rollout_info.rollout_device_mesh + assert rollout_device_mesh is not None + + model = self._engine.model + if submodule: + model = getattr(model, submodule) + + dtype = torch.bfloat16 + bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + same_gen = model._get_same_hf_param( + model._group_param_by_load_spec(LoadEnum.SAME), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + ) + + train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 + if train_enable_ep: + # Remap train EP shards to the rollout EP topology. Non-EP rollout receives + # the full fused tensor slice as target_ep_size=1. + if self.rollout_info.ep > 1: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=rollout_device_mesh["engine_parallel"].get_coordinate()[0], + target_ep_size=rollout_device_mesh["engine_parallel"].size(), + bucket_size=bucket_size, + ) + else: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=0, + target_ep_size=1, + bucket_size=bucket_size, + ) + else: + fused_gen = model._get_fused_hf_param( + model._group_param_by_load_spec(LoadEnum.FUSED), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + update_weights_for_rl=True, + ) + shard_gen = model._get_shard_hf_param( + model._group_param_by_load_spec(LoadEnum.SHARD), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + ) + + for name_list, fused_param_list in fused_gen: + state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} + yield WeightUpdateBatch(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, fused_param_list + + for name_list, param_list in chain(same_gen, shard_gen): + state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} + yield WeightUpdateBatch(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, param_list + + # pytorch and vLLM use an empty final update as an end marker; SGLang and + # turbomind do not consume this marker. + if self.rollout_info.backend in ("pytorch", "vllm") and final_update: + yield WeightUpdateBatch({}, train_enable_ep=train_enable_ep, finished=True) + + DEVICE_MODULE.empty_cache() + + @torch.no_grad() + def iter_layer_batches(self): + """Update the model weights.""" + assert self.rollout_info.rollout_device_mesh is not None + + model = self._engine.model + DEVICE_MODULE.empty_cache() + + if isinstance(model.config, BaseComposeConfig): + # TODO: support float8 for vision compose model. + dtype = torch.bfloat16 + else: + if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): + dtype = torch.float8_e4m3fn + else: + dtype = torch.bfloat16 + + def get_params(tensor_list, name_list, save_dtype): + _tensor_list, _spec_list = list(zip(*tensor_list)) + fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) + if save_dtype == torch.float8_e4m3fn: + fsdp_unshard_tensor_list, name_list = model._to_float8( + fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype + ) + return fsdp_unshard_tensor_list, name_list + + saved_list = [] + is_qwen3vl = False + if isinstance(model.config, BaseComposeConfig): + language_model = model.language_model + if isinstance(model, Qwen3VLForConditionalGeneration): + is_qwen3vl = True + else: + language_model = model + + if is_qwen3vl: + vision_hf_prefix = "model.visual." + projector_hf_prefix = "model.visual." + else: + vision_hf_prefix = "model.vision_tower." + projector_hf_prefix = "model.multi_modal_projector." + + for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): + tensor_list = [] + name_list = [] + for sub_name, param in layer.state_dict().items(): + if isinstance(model.config, BaseComposeConfig): + saved_list.append(f"language_model.layers.{i}.{sub_name}") + else: + saved_list.append(f"layers.{i}.{sub_name}") + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") + + if isinstance(model.config, BaseComposeConfig): + name = f"model.language_model.layers.{i}.{sub_name}" + else: + name = f"model.layers.{i}.{sub_name}" + + if ".experts." in name and ".mlp.experts." not in name: + name = name.replace(".experts.", ".mlp.experts.") + if ".gate." in name and ".mlp.gate." not in name: + name = name.replace(".gate.", ".mlp.gate.") + name_list.append(name) + tensor_list.append((local_tensor, load_spec)) + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + yield WeightUpdateBatch(state_dict) + + for name, param in model.state_dict().items(): + if name in saved_list: + continue + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = model.load_spec_mapping.get(name) + + if isinstance(model.config, BaseComposeConfig): + if "vision_tower." in name: + name = name.replace("vision_tower.", vision_hf_prefix) + elif "multi_modal_projector." in name: + name = name.replace("multi_modal_projector.", projector_hf_prefix) + elif name == "language_model.norm.weight": + name = "model.language_model.norm.weight" + elif name == "language_model.embed_tokens.weight": + name = "model.language_model.embed_tokens.weight" + elif name == "language_model.lm_head.weight": + name = "lm_head.weight" + else: + if name == "norm.weight": + name = "model.norm.weight" + elif name == "embed_tokens.weight": + name = "model.embed_tokens.weight" + tensor_list = [(local_tensor, load_spec)] + name_list = [name] + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + yield WeightUpdateBatch(state_dict) + + if self.rollout_info.backend in ("pytorch", "vllm"): + yield WeightUpdateBatch({}, finished=True) + + DEVICE_MODULE.empty_cache() diff --git a/xtuner/v1/rl/weight_update/transport.py b/xtuner/v1/rl/weight_update/transport.py new file mode 100644 index 000000000..b36d82ce3 --- /dev/null +++ b/xtuner/v1/rl/weight_update/transport.py @@ -0,0 +1,714 @@ +from __future__ import annotations + +import json +import os +import socket +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from threading import Lock +from typing import Any, Callable + +import torch +import torch.distributed as dist +from packaging.version import parse as parse_version +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) + +from xtuner.v1.utils import ( + get_device, + get_torch_device_module, + monkey_unpatch_torch_reductions, +) + +from .client import RolloutWeightUpdateClient +from .data import RolloutEngineInfo, RolloutWeightUpdateInfo, WeightUpdateBatch + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class WeightTransport(ABC): + def before_update(self) -> None: + return + + @abstractmethod + def send(self, batch: WeightUpdateBatch) -> None: + raise NotImplementedError + + def after_update(self) -> None: + return + + def teardown(self) -> None: + return + + +class IPCBackendAdapter: + def before_serialize(self, transport: IPCWeightTransport, batch: WeightUpdateBatch) -> None: + return + + def serialize( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> Any: + raise NotImplementedError + + def send_request( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + serialized_data: Any, + ) -> None: + raise NotImplementedError + + def postprocess( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + ) -> None: + return + + def after_serialize(self, transport: IPCWeightTransport, batch: WeightUpdateBatch) -> None: + return + + def send(self, transport: IPCWeightTransport, batch: WeightUpdateBatch) -> None: + """Send a request to update the parameters on the rollout workers. + + This method serializes the state dictionary and sends it to the appropriate rollout worker through the backend- + specific IPC adapter. + """ + cpu_mesh = transport.cpu_mesh + cpu_group = cpu_mesh.get_group() + head_rank = cpu_mesh.mesh[0].item() + + # Template method for IPC updates: all ranks serialize/gather, only the + # engine-parallel head rank sends the rollout HTTP request. + self.before_serialize(transport, batch) + try: + serialized_data = self.serialize(transport, batch, cpu_group, head_rank) + if dist.get_rank() == head_rank: + self.send_request(transport, batch, serialized_data) + self.postprocess(transport, batch, cpu_group) + finally: + self.after_serialize(transport, batch) + + +class VLLMIPCBackendAdapter(IPCBackendAdapter): + @staticmethod + def _serialize_state_dict(state_dict: dict) -> str: + import base64 + from io import BytesIO + from multiprocessing.reduction import ForkingPickler + + from torch.multiprocessing.reductions import reduce_tensor + + data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] + buf = BytesIO() + ForkingPickler(buf).dump(data) + buf.seek(0) + return base64.b64encode(buf.read()).decode("utf-8") + + def serialize( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> list[Any]: + info = transport.rollout_info + serialized_data = [None] * info.tp + dist.gather_object( + self._serialize_state_dict(batch.state_dict), + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + return serialized_data + + def send_request( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + serialized_data: list[Any], + ) -> None: + info = transport.rollout_info + data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=batch.finished)) + data = dict(method="update_weight_npu_ipc", args=[data_]) + assert info.rollout_url is not None + response = transport.client.collective_rpc(info.rollout_url, data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + def postprocess( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + ) -> None: + if batch.finished: + dist.barrier(group=cpu_group) + + +class LMDeployIPCBackendAdapter(IPCBackendAdapter): + @staticmethod + def _compute_state_dict_bytes(state_dict: dict[str, torch.Tensor]) -> int: + total_bytes = 0 + for tensor in state_dict.values(): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + @staticmethod + def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): + return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) + + def build_flattened_tensor_data( + self, + transport: IPCWeightTransport, + state_dict: dict, + flattened_tensor_bucket_cls, + ) -> dict: + assert transport._update_params_ipc_event is not None + # LMDeploy flattened buckets require all tensors in one bucket to share a dtype. + state_dict_dtype = state_dict[next(iter(state_dict))].dtype + # LMDeploy can reuse the same IPC tensor across batches. A new handle is + # sent only when dtype changes, capacity is insufficient, or this is the first batch. + update_params_ipc_tensor = transport._update_params_ipc_tensor_dict_by_dtype.get(state_dict_dtype, None) + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + ipc_tensor_bytes = transport._ipc_tensor_bytes_dict_by_dtype.get( + state_dict_dtype, + transport._default_ipc_tensor_bytes, + ) + dtype_changed = ( + transport._last_update_params_ipc_tensor_dtype is not None + and state_dict_dtype != transport._last_update_params_ipc_tensor_dtype + ) + need_resize = state_dict_bytes > ipc_tensor_bytes + send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None + + if update_params_ipc_tensor is not None: + # Wait until rollout has consumed the previous IPC tensor before reusing it. + transport._update_params_ipc_event.wait() + if need_resize: + # Synchronize before replacing a too-small IPC tensor to avoid freeing + # storage that may still be referenced by the rollout process. + DEVICE_MODULE.synchronize() + + if update_params_ipc_tensor is None or need_resize: + ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes) + transport._ipc_tensor_bytes_dict_by_dtype[state_dict_dtype] = ipc_tensor_bytes + update_params_ipc_tensor = self._create_ipc_tensor( + ipc_tensor_bytes, + state_dict_dtype, + ) + transport._update_params_ipc_tensor_dict_by_dtype[state_dict_dtype] = update_params_ipc_tensor + + flattened_tensor_bucket = flattened_tensor_bucket_cls( + named_tensors=list(state_dict.items()), + flattened_tensor=update_params_ipc_tensor, + ) + flattened_tensor_data = { + "metadata": flattened_tensor_bucket.get_metadata(), + "require_clone": False, + } + transport._update_params_ipc_event.record() + transport._last_update_params_ipc_tensor_dtype = state_dict_dtype + + if send_ipc_tensor: + # Subsequent batches with the same cached IPC tensor only need metadata; the + # tensor handle and event handle are resent only when the cached buffer changes. + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = transport._update_params_ipc_event.ipc_handle() + return flattened_tensor_data + + def serialize( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> Any: + from lmdeploy.utils import serialize_state_dict + + info = transport.rollout_info + state_dict = batch.state_dict + + try: + from lmdeploy.utils import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + FlattenedTensorBucket = None + + if info.tp > 1: + serialized_data = [None] * info.tp + if use_flattened_tensor_bucket and state_dict: + flattened_tensor_data = self.build_flattened_tensor_data( + transport, + state_dict, + FlattenedTensorBucket, + ) + tp_serialized_data = serialize_state_dict(flattened_tensor_data) + else: + tp_serialized_data = serialize_state_dict(state_dict) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + else: + if use_flattened_tensor_bucket and state_dict: + flattened_tensor_data = self.build_flattened_tensor_data( + transport, + state_dict, + FlattenedTensorBucket, + ) + serialized_data = serialize_state_dict(flattened_tensor_data) + else: + serialized_data = serialize_state_dict(state_dict) + return serialized_data, use_flattened_tensor_bucket + + def send_request( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + serialized_data: tuple[Any, bool], + ) -> None: + info = transport.rollout_info + state_dict = batch.state_dict + serialized_named_tensors, use_flattened_tensor_bucket = serialized_data + data = dict(serialized_named_tensors=serialized_named_tensors, finished=batch.finished) + if use_flattened_tensor_bucket and state_dict: + data["load_format"] = "flattened_bucket" + assert info.rollout_url is not None + response = transport.client.update_weights(info.rollout_url, info.endpoints["update_weights"], data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + def postprocess( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + ) -> None: + info = transport.rollout_info + # TODO(chenchiyu): narrow this condition. + if batch.finished or (batch.train_enable_ep and info.tp > 1): + # Make each TP head rank sync with other ranks in engine_parallel group. + # FSDP all-gather of the next state_dict cannot cover this case, so without + # this barrier some ranks could overwrite the IPC tensor before LMDeploy loads it. + dist.barrier(group=cpu_group) + + +class SGLangIPCBackendAdapter(IPCBackendAdapter): + def before_serialize(self, transport: IPCWeightTransport, batch: WeightUpdateBatch) -> None: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + + # NOTE: XTuner currently also works without the SGLang patch in some cases, + # but keep the patch/unpatch pair for compatibility with SGLang serialization. + # SGLang overrides torch tensor reduction for multiprocessing serialization. + monkey_patch_torch_reductions() + + def serialize( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> tuple[list[Any], bool]: + from sglang.srt.utils import MultiprocessingSerializer + + info = transport.rollout_info + state_dict = batch.state_dict + + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + FlattenedTensorBucket = None + + state_items = state_dict.items() + if info.tp == 1: + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_items) + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": flattened_tensor_bucket.get_metadata(), + } + serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + serialized_data = MultiprocessingSerializer.serialize(state_items, output_str=True) + serialized_data = [serialized_data] + else: + serialized_data = [None] * info.tp + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_items) + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": flattened_tensor_bucket.get_metadata(), + } + tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + tp_serialized_data = MultiprocessingSerializer.serialize(state_items, output_str=True) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + return serialized_data, use_flattened_tensor_bucket + + def send_request( + self, + transport: IPCWeightTransport, + batch: WeightUpdateBatch, + serialized_data: tuple[list[Any], bool], + ) -> None: + info = transport.rollout_info + serialized_named_tensors, use_flattened_tensor_bucket = serialized_data + payload = { + "serialized_named_tensors": serialized_named_tensors, + "flush_cache": False, + } + if use_flattened_tensor_bucket: + payload["load_format"] = "flattened_bucket" + + assert info.rollout_url is not None + response = transport.client.update_weights_from_tensor(info.rollout_url, payload) + response.raise_for_status() + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + def after_serialize(self, transport: IPCWeightTransport, batch: WeightUpdateBatch) -> None: + monkey_unpatch_torch_reductions() + + +class IPCWeightTransport(WeightTransport): + def __init__( + self, + *, + rank: int, + logger: Any, + config: Any, + rollout_info: RolloutWeightUpdateInfo, + ): + self.rank = rank + self.logger = logger + self.config = config + self.rollout_info = rollout_info + self.client = RolloutWeightUpdateClient(rollout_info.api_key) + self._default_ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + self._ipc_tensor_bytes_dict_by_dtype: dict[torch.dtype, int] = {} + self._update_params_ipc_tensor_dict_by_dtype: dict[torch.dtype, torch.Tensor] = {} + self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None + self._update_params_ipc_event = None + self._adapter = self._build_adapter() + + @property + def cpu_mesh(self): + assert self.rollout_info.rollout_device_mesh is not None + return self.rollout_info.rollout_device_mesh["engine_parallel"] + + def _build_adapter(self) -> IPCBackendAdapter: + backend = self.rollout_info.backend + if backend == "vllm": + return VLLMIPCBackendAdapter() + if backend == "sglang": + return SGLangIPCBackendAdapter() + return LMDeployIPCBackendAdapter() + + def before_update(self) -> None: + DEVICE_MODULE.empty_cache() + self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) + + def after_update(self) -> None: + self._update_params_ipc_tensor_dict_by_dtype = {} + self._last_update_params_ipc_tensor_dtype = None + self._update_params_ipc_event = None + DEVICE_MODULE.empty_cache() + + def send(self, batch: WeightUpdateBatch) -> None: + if self.rollout_info.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + return + self._adapter.send(self, batch) + + +class NCCLBackendAdapter: + def send(self, transport: NCCLWeightTransport, batch: WeightUpdateBatch) -> None: + raise NotImplementedError + + +class SGLangNCCLBackendAdapter(NCCLBackendAdapter): + def send(self, transport: NCCLWeightTransport, batch: WeightUpdateBatch) -> None: + state_dict = batch.state_dict + if not state_dict: + return + + train_sync_group = transport.get_train_update_sync_group() + head_rank = 0 + # Disaggregated SGLang update is driven by train rank 0. Other train ranks + # only wait so optimizer/rollout steps stay aligned. + if dist.get_rank() != head_rank: + dist.barrier(group=train_sync_group) + return + + transport.ensure_group() + if transport.group is None: + dist.barrier(group=train_sync_group) + return + + assert transport.executor is not None + assert transport.group_name is not None + with transport.update_lock: + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + except Exception as e: + raise RuntimeError( + "Disaggregated update_weights currently only supports sglang builds " + "that provide `sglang.srt.model_executor.model_runner.FlattenedTensorBucket`." + ) from e + + names = list(state_dict.keys()) + tensors = [ + tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() + ] + payload = { + "names": names, + "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], + "shapes": [list(tensor.shape) for tensor in tensors], + "group_name": transport.group_name, + "load_format": "flattened_bucket", + } + # Notify rollout engines first so they can join the external NCCL group and + # prepare receive buffers described by names/dtypes/shapes. + update_futures = [ + transport.executor.submit( + transport.client.update_weights_from_distributed, + url, + payload, + ) + for url in transport.engine_urls + ] + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) + flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() + + dist.broadcast(flattened_tensor, src=0, group=transport.group) + DEVICE_MODULE.synchronize() + for update_future in update_futures: + response = update_future.result() + response.raise_for_status() + result = response.json() + transport.hook_compare_test_sent_and_received_weight_hash( + result, + names=names, + ) + assert result.get("success", True), ( + f"SGLang update_weights_from_distributed failed: {result.get('message', result)}" + ) + dist.barrier(group=train_sync_group) + + +class NCCLWeightTransport(WeightTransport): + def __init__(self, *, rank: int, logger: Any, rollout_info: RolloutWeightUpdateInfo): + self.rank = rank + self.logger = logger + self.rollout_info = rollout_info + self.client = RolloutWeightUpdateClient(rollout_info.api_key) + self.group: dist.ProcessGroup | None = None + self.group_name: str | None = None + self.engine_urls: list[str] = [] + self.executor: ThreadPoolExecutor | None = None + self.train_update_sync_group: dist.ProcessGroup | None = None + self.update_lock = Lock() + self.hook_compare_test_sent_and_received_weight_hash: Callable[..., None] = lambda result, **kwargs: None + self._adapter = self._build_adapter() + + def _build_adapter(self) -> NCCLBackendAdapter: + backend = self.rollout_info.backend + if backend == "sglang": + return SGLangNCCLBackendAdapter() + raise ValueError(f"Unsupported NCCL weight update backend: {backend!r}") + + def get_train_update_sync_group(self) -> dist.ProcessGroup: + if self.train_update_sync_group is None: + ranks = list(range(dist.get_world_size())) + self.train_update_sync_group = dist.new_group(ranks=ranks, backend="gloo") + return self.train_update_sync_group + + def get_engine_info(self) -> RolloutEngineInfo: + engine_info: RolloutEngineInfo = [] + seen_urls: set[str] = set() + rank_to_engine_size: dict[int, int] = {} + for engine_ranks in self.rollout_info.rollout_engine_rank_mesh_array: + engine_size = len(engine_ranks) + for rank in engine_ranks: + rank_to_engine_size[int(rank)] = engine_size + + for rank, url in sorted( + self.rollout_info.rollout_server_url_dict.items(), + key=lambda item: int(item[0]), + ): + rank = int(rank) + # Active server URLs are engine entrypoints, not one endpoint per rollout rank. + # Deduplicate URLs and skip workers marked unhealthy by the rollout controller. + if not url or url in seen_urls: + continue + if self.rollout_info.worker_server_urls_status.get(url, False) is False: + continue + seen_urls.add(url) + engine_info.append( + ( + rank, + url, + rank_to_engine_size.get( + rank, + max(self.rollout_info.tp, self.rollout_info.ep), + ), + ) + ) + return engine_info + + def ensure_group(self): + if self.group is not None: + return + engine_info = self.get_engine_info() + if not engine_info: + self.logger.error("No active rollout engine url, cannot init sglang weight update group") + return + + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + backend = "nccl" + + # Get address and port for the external weight-update process group. + try: + import ray + + master_address = ray.util.get_node_ip_address() + except Exception: + master_address = socket.gethostbyname(socket.gethostname()) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + master_port = int(sock.getsockname()[1]) + + group_name = f"xtuner_sglang_weight_update_{self.rank}" + # Train rank 0 is external group rank 0. Rollout engine ranks are assigned + # contiguous offsets starting from rank 1. + world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 + + self.executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) + init_futures = [] + rank_offset = 1 + for _, url, engine_size in engine_info: + payload = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + } + init_futures.append( + self.executor.submit( + self.client.init_weights_update_group, + url, + payload, + ) + ) + rank_offset += engine_size + + self.group = self._init_external_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + + for init_future in init_futures: + response = init_future.result() + response.raise_for_status() + result = response.json() + assert result.get("success", True), ( + f"SGLang init_weights_update_group failed: {result.get('message', result)}" + ) + + self.group_name = group_name + self.engine_urls = [url for _, url, _ in engine_info] + + def send(self, batch: WeightUpdateBatch) -> None: + self._adapter.send(self, batch) + + def teardown(self) -> None: + if self.executor is not None: + self.executor.shutdown(wait=False, cancel_futures=True) + try: + if self.group is not None: + dist.destroy_process_group(self.group) + except Exception: + pass + self.group = None + self.group_name = None + self.engine_urls = [] + self.executor = None + + @staticmethod + def _init_external_process_group( + backend: str | Backend | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str | None = None, + pg_options: Any | None = None, + ) -> dist.ProcessGroup: + # Build a process group that includes external rollout processes, which + # cannot be represented by dist.new_group over the current training world. + assert (store is None) or (init_method is None), "Cannot specify both store and init_method." + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + if timeout is None: + timeout = default_pg_timeout + + if store is None: + assert init_method is not None + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + if group_name is not None: + store = PrefixStore(group_name, store) + + pg_options_param_name = ( + "backend_options" if parse_version(torch.__version__) >= parse_version("2.6") else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py new file mode 100644 index 000000000..225fa1c6d --- /dev/null +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.utils import get_torch_device_module + +from .data import DeviceMeshRaw, RolloutBackend, RolloutWeightUpdateInfo, ServiceUrlMap, TrainRolloutMode +from .exporter import WeightExporter +from .transport import IPCWeightTransport, NCCLWeightTransport, WeightTransport + + +DEVICE_MODULE = get_torch_device_module() + + +class UpdateWeighter: + def __init__(self, *, rank: int, logger: Any, config: Any, engine: Any): + self.rank = rank + self.logger = logger + self.config = config + self._engine = engine + # Used to update weight to rollout engine. + self.rollout_info = RolloutWeightUpdateInfo() + self._global_hf_keys_mapping_cache: dict[str, list[str]] = {} + self.is_train_rollout_colocated: bool | None = None + # Only used by currently unsupported LMDeploy disaggregated modes. + self.use_fake_weight_update = False + self._transport: WeightTransport | None = None + + @staticmethod + def _normalize_rollout_backend(backend: str) -> RolloutBackend: + backend = backend.lower() + if backend not in ("sglang", "vllm", "pytorch", "turbomind"): + raise ValueError( + f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." + ) + return cast(RolloutBackend, backend) + + @staticmethod + def _normalize_train_rollout_mode(train_rollout_mode: str) -> TrainRolloutMode: + mode = train_rollout_mode.lower() + if mode not in ("colocate", "disaggregated"): + raise ValueError( + f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." + ) + return cast(TrainRolloutMode, mode) + + def _hook_compare_test_sent_and_received_weight_hash( + self, + result: dict[str, Any], + *, + bucket_idx: int | None = None, + names: list[str] | None = None, + ) -> None: + """Test hook for comparing sent and received weight hashes. + + This hook is intentionally a no-op in production code and is expected to be overridden in unit tests that need + to compare training-side sent hashes with rollout-side received hashes returned by SGLang. + """ + return + + def update_rollout_info( + self, + engine_rank_mesh_array: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + rollout_config: RolloutConfig, + worker_server_urls_status: dict[str, bool], + api_server_url: str | None = None, + ): + """Update the rollout information for the training worker.""" + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + if self.rollout_info.rollout_device_mesh is None: + self.rollout_info.rollout_device_mesh = DeviceMesh( + "cpu", + mesh=engine_rank_mesh_array, + mesh_dim_names=("engine_instance", "engine_parallel"), + ) + rollout_server_url = server_url_dict.get(self.rank, "") + if worker_server_urls_status.get(rollout_server_url, "False") is False: + self.logger.error(f"Rollout server url {rollout_server_url} is not available.") + self.rollout_info.rollout_url = None + else: + self.rollout_info.rollout_url = rollout_server_url + + self.rollout_info.rollout_engine_rank_mesh_array = [ + [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array + ] + self.rollout_info.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} + self.rollout_info.worker_server_urls_status = worker_server_urls_status + + # Backend selection follows rollout launcher precedence: explicit SGLang/vLLM env vars win, + # otherwise the LMDeploy backend decides between pytorch and turbomind. + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + else: + backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") + + self.rollout_info.tp = tp + self.rollout_info.ep = ep + self.rollout_info.api_key = rollout_config.api_key + self.rollout_info.backend = self._normalize_rollout_backend(backend) + + # Keep the legacy dict synchronized while downstream code migrates to typed fields. + self.rollout_info.rollout_cfg_info["tp"] = self.rollout_info.tp + self.rollout_info.rollout_cfg_info["ep"] = self.rollout_info.ep + self.rollout_info.rollout_cfg_info["api_key"] = self.rollout_info.api_key + self.rollout_info.rollout_cfg_info["backend"] = self.rollout_info.backend + + def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): + mode = self._normalize_train_rollout_mode(train_rollout_mode) + self.rollout_info.train_rollout_mode = mode + if mode == "colocate": + self.is_train_rollout_colocated = True + self.use_fake_weight_update = False + self.rollout_info.transport_type = "ipc" + elif mode == "disaggregated": + self.is_train_rollout_colocated = False + self.rollout_info.transport_type = "nccl" + + backend = self.rollout_info.backend + if backend == "vllm": + raise NotImplementedError("Disaggregated train-rollout mode is not supported for vLLM backend.") + if backend == "pytorch" or backend == "turbomind": + self.logger.warning( + "Disaggregated train-rollout mode for lmdeploy backend is not fully supported yet. " + "A fake no-op interface will be used temporarily.", + ) + # Fake update lets the training loop skip real synchronization for unsupported modes. + self.use_fake_weight_update = True + elif backend == "sglang": + self.use_fake_weight_update = False + else: + raise ValueError( + f"Unsupported rollout backend for disaggregated mode: {backend!r}. " + "Expected 'vllm', 'pytorch', 'turbomind' or 'sglang'." + ) + + # IPC transports are per-update and cheap to recreate, while NCCL transports keep an + # external process group alive for disaggregated updates. + if self.is_train_rollout_colocated: + self._reset_transport() + + def update_weights(self): + """Update the model weights.""" + if self.is_train_rollout_colocated is None: + raise RuntimeError( + "train/rollout mode is not set. Please call set_train_rollout_mode() before update_weights()." + ) + + if self.use_fake_weight_update: + train_rollout_mode = self.rollout_info.train_rollout_mode or ( + "colocate" if self.is_train_rollout_colocated else "disaggregated" + ) + backend = self.rollout_info.backend or "unknown" + self.logger.warning( + "Using fake weight update interface, no actual weight synchronization will happen. " + "This is only for testing purposes and should not be used in production. " + f"train_rollout_mode={train_rollout_mode}, backend={backend}." + ) + return + + transport = self._get_transport() + exporter = WeightExporter( + config=self.config, + engine=self._engine, + rollout_info=self.rollout_info, + global_hf_keys_mapping_cache=self._global_hf_keys_mapping_cache, + ) + + transport.before_update() + DEVICE_MODULE.empty_cache() + try: + for batches, sync_group in self._iter_export_batch_groups(exporter): + self._send_exported_batches(transport, batches, sync_group=sync_group) + finally: + transport.after_update() + DEVICE_MODULE.empty_cache() + + def _iter_export_batch_groups(self, exporter: WeightExporter): + # Export path depends on rollout protocol: turbomind consumes layer-wise batches, + # compose models update submodules in order, and plain models use HF-style batches. + if self.is_train_rollout_colocated and self.rollout_info.backend == "turbomind": + yield exporter.iter_layer_batches(), "colocated" + return + + if isinstance(self.config.model_cfg, BaseComposeConfig): + # Only the last compose submodule sends the final update marker. + submodules = ( + ("language_model", False), + ("vision_tower", False), + ("multi_modal_projector", True), + ) + for submodule, final_update in submodules: + yield exporter.iter_hf_batches(submodule=submodule, final_update=final_update), "current" + return + + yield exporter.iter_hf_batches(final_update=True), "current" + + def _send_exported_batches(self, transport: WeightTransport, batches, *, sync_group: str) -> None: + for batch in batches: + transport.send(batch) + self._barrier_after_export(transport, sync_group=sync_group) + DEVICE_MODULE.empty_cache() + + def _barrier_after_export(self, transport: WeightTransport, *, sync_group: str) -> None: + # Colocated IPC synchronizes all training ranks, while disaggregated NCCL uses a + # dedicated CPU sync group to avoid coupling with the external NCCL group. + if self.is_train_rollout_colocated or sync_group == "colocated": + dist.barrier() + return + if isinstance(transport, NCCLWeightTransport): + dist.barrier(group=transport.get_train_update_sync_group()) + return + dist.barrier() + + def _get_transport(self) -> WeightTransport: + if self.rollout_info.transport_type == "ipc": + return IPCWeightTransport( + rank=self.rank, + logger=self.logger, + config=self.config, + rollout_info=self.rollout_info, + ) + + if self.rollout_info.transport_type == "nccl" and self._transport is None: + transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=self.rollout_info) + transport.hook_compare_test_sent_and_received_weight_hash = ( + self._hook_compare_test_sent_and_received_weight_hash + ) + self._transport = transport + if self._transport is None: + raise RuntimeError( + f"Weight transport is not initialized. transport_type={self.rollout_info.transport_type!r}." + ) + return self._transport + + def _reset_transport(self): + if self._transport is not None: + self._transport.teardown() + self._transport = None diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 25d8eb173..223c0a9be 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -1533,6 +1533,10 @@ async def _fit(self): if self._enable_initial_evaluate: await self._run_initial_evaluate() + # _run_initial_evaluate() 在最后会执行 rollout controller 的 pause_generation() + # SGLang 下会把生成暂停/abort,导致下面的正式生产卡住。 + # 而共卡会在每次生产前先continue_produce。 + await self.agent_loop_manager.continue_produce(model_step=self._cur_step) self._benchmark_start_time_s = time.perf_counter() self._benchmark_training_samples = 0