Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/rl/trainer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.model.compose.base import BaseComposeConfig
from xtuner.v1.rl.weight_update import TrainRolloutMode
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import get_logger

Expand Down Expand Up @@ -292,7 +293,7 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"):
def update_rollout_info(self, info_dict):
ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined]

def set_train_rollout_mode(self, train_rollout_mode: str):
def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode):
ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers])

def update_weights(self):
Expand Down
951 changes: 0 additions & 951 deletions xtuner/v1/rl/trainer/update_weighter.py

This file was deleted.

23 changes: 20 additions & 3 deletions xtuner/v1/rl/trainer/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from xtuner.v1.profiler import profiling_memory, profiling_time
from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, kl_penalty
from xtuner.v1.rl.utils import SingleAcceleratorWorker
from xtuner.v1.rl.weight_update import TrainRolloutMode, UpdateWeighter
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import (
XTUNER_DETERMINISTIC,
Expand All @@ -59,7 +60,6 @@
)

from ..rollout_is import merge_rollout_is_metrics
from .update_weighter import UpdateWeighter


DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices
Expand Down Expand Up @@ -200,7 +200,7 @@ class WorkerLogItem(TypedDict):
sft_train_metrics: NotRequired[dict[str, float]]


class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter):
class TrainingWorker(SingleAcceleratorWorker):
_SAVE_WEIGHTS_DIR = "weights"
_SAVE_SFT_DATALOADER_DIR = "sft_dataloader"
_SAVE_SFT_TRAIN_STATE_PATH = "sft_train_state.json"
Expand Down Expand Up @@ -268,7 +268,24 @@ def __init__(
if hasattr(worker_cfg.model_cfg.text_config, "mtp_config"):
self.mtp_config = worker_cfg.model_cfg.text_config.mtp_config

self._init_update_weighter()
self.update_weighter = UpdateWeighter(
rank=self.rank,
logger=self.logger,
config=self.config,
engine=self._engine,
)

@ray_method
def update_rollout_info(self, *args, **kwargs):
return self.update_weighter.update_rollout_info(*args, **kwargs)

@ray_method
def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode):
return self.update_weighter.set_train_rollout_mode(train_rollout_mode)

@ray_method
def update_weights(self):
return self.update_weighter.update_weights()

def _init_sft(self, worker_cfg: WorkerConfig):
self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg
Expand Down
46 changes: 46 additions & 0 deletions xtuner/v1/rl/weight_update/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from .client import RolloutWeightUpdateClient
from .data import (
DeviceMeshRaw,
RolloutBackend,
RolloutEngineInfo,
RolloutWeightUpdateInfo,
ServiceUrlMap,
TrainRolloutMode,
WeightTransportType,
WeightUpdateBatch,
)
from .exporter import WeightExporter
from .transport import (
IPCBackendAdapter,
IPCWeightTransport,
LMDeployIPCBackendAdapter,
NCCLBackendAdapter,
NCCLWeightTransport,
SGLangIPCBackendAdapter,
SGLangNCCLBackendAdapter,
WeightTransport,
)
from .update_weighter import UpdateWeighter


__all__ = [
"DeviceMeshRaw",
"IPCBackendAdapter",
"IPCWeightTransport",
"LMDeployIPCBackendAdapter",
"NCCLBackendAdapter",
"NCCLWeightTransport",
"RolloutBackend",
"RolloutEngineInfo",
"RolloutWeightUpdateInfo",
"RolloutWeightUpdateClient",
"SGLangIPCBackendAdapter",
"SGLangNCCLBackendAdapter",
"ServiceUrlMap",
"TrainRolloutMode",
"UpdateWeighter",
"WeightExporter",
"WeightTransportType",
"WeightUpdateBatch",
"WeightTransport",
]
31 changes: 31 additions & 0 deletions xtuner/v1/rl/weight_update/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import Any

import requests


class RolloutWeightUpdateClient:
def __init__(self, api_key: list[str] | str | None):
self.api_key = api_key

def _headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers

def collective_rpc(self, url: str, payload: dict[str, Any]):
return requests.post(f"{url}/collective_rpc", headers=self._headers(), json=payload)

def update_weights(self, url: str, endpoint: str, payload: dict[str, Any]):
return requests.post(f"{url}/{endpoint}", headers=self._headers(), json=payload)

def update_weights_from_tensor(self, url: str, payload: dict[str, Any]):
return requests.post(f"{url}/update_weights_from_tensor", json=payload or {})

def init_weights_update_group(self, url: str, payload: dict[str, Any]):
return requests.post(f"{url}/init_weights_update_group", json=payload)

def update_weights_from_distributed(self, url: str, payload: dict[str, Any]):
return requests.post(f"{url}/update_weights_from_distributed", json=payload)
41 changes: 41 additions & 0 deletions xtuner/v1/rl/weight_update/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Literal, TypeAlias

import torch
from torch.distributed.device_mesh import DeviceMesh


DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices.
ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs.
RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count)
TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode.
RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend.
WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types.


@dataclass
class RolloutWeightUpdateInfo:
api_key: list[str] | str | None = None
rollout_device_mesh: DeviceMesh | None = None
rollout_url: str | None = None
backend: RolloutBackend | None = None
tp: int = 1
ep: int = 1
train_rollout_mode: TrainRolloutMode | None = None
transport_type: WeightTransportType | None = None
rollout_cfg_info: dict = field(default_factory=dict)
endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"})
rollout_engine_rank_mesh_array: DeviceMeshRaw = field(default_factory=list)
rollout_server_url_dict: ServiceUrlMap = field(default_factory=dict)
worker_server_urls_status: dict[str, bool] = field(default_factory=dict)


@dataclass
class WeightUpdateBatch:
"""A single bucket of weights to send to rollout workers."""

state_dict: dict[str, torch.Tensor]
train_enable_ep: bool = False
finished: bool = False
Loading
Loading