Skip to content

refactor rollout weight update flow#1828

Open
PengchengShi00 wants to merge 1 commit into
InternLM:mainfrom
PengchengShi00:refactor-update-weight
Open

refactor rollout weight update flow#1828
PengchengShi00 wants to merge 1 commit into
InternLM:mainfrom
PengchengShi00:refactor-update-weight

Conversation

@PengchengShi00
Copy link
Copy Markdown
Contributor

重构了权重更新流程,重构新增了下面几个文件

  • data.py:定义共享数据结构、类型别名和` update batch。
  • client.py:封装 rollout engine 的 HTTP update 接口。
  • exporter.py:从训练模型中导出 HuggingFace 风格权重 batch。
  • transport.py:实现 IPC/NCCL 传输,以及不同 rollout backend 的适配器。
  • update_weighter.py:对外提供 update weight 编排逻辑,是上游主要调用入口。
  • __init__.py:导出公共接口。

后续新增RDMA的更新方式需要新增 RDMAWeightTransport,新增对LMdeploy支持需要增加LMdeployNCCLBackendAdapter

@jayhenry
Copy link
Copy Markdown
Collaborator

jayhenry commented May 27, 2026

Transport / Adapter 重构草图

当前问题

当前 PR 里 Adapter 的名字是对的,但 Interface 还偏浅:

  • IPCBackendAdapter 接收完整 IPCWeightTransport,能访问 IPC event、buffer cache、rollout info、client、process group。
  • SGLangNCCLBackendAdapter 依赖完整 NCCLWeightTransport,能访问 external group、executor、group name、engine urls、client。
  • RolloutWeightUpdateClient 混合了不同推理引擎接口:collective_rpcupdate_weightsupdate_weights_from_tensorinit_weights_update_groupupdate_weights_from_distributed

这些都让 Adapter 知道太多 Transport 信息。改法是:该上升到 Transport 的通用传输信息上升;该下沉到 Adapter 的推理引擎协议信息下沉。

Module 分工

UpdateWeighter

UpdateWeighter 只负责编排:

  1. 用现有 rollout_info 创建 WeightExporter
  2. 用同一个 rollout_info 创建 Transport
  3. 把 exporter 产出的 batch iterable 交给 transport

示例:

class UpdateWeighter:
    def update_weights(self) -> None:
        exporter = WeightExporter(
            config=self.config,
            engine=self.engine,
            rollout_info=self.rollout_info,
        )
        transport = self._get_transport()

        # 不要 list(exporter.iter_batches()),权重 batch 需要边导出边发送。
        transport.send_all(exporter.iter_batches())

UpdateWeighter 不知道:

  • LMDeploy IPC event 怎么复用
  • SGLang NCCL group 怎么初始化
  • HTTP endpoint 名字是什么
  • payload 怎么组织

WeightExporter

WeightExporter 仍然可以接收当前 rollout_info,减少重构面。

示例:

class WeightExporter:
    def __init__(self, *, config, engine, rollout_info):
        self.config = config
        self.engine = engine
        self.rollout_info = rollout_info

    def iter_batches(self) -> Iterable[WeightUpdateBatch]:
        ...

重点是保持 iter_batches() 流式产出。这样 transport 可以边拿 batch 边发送,避免权重 batch 提前全部驻留显存/内存。

Transport

Transport 持有通用传输生命周期,可以持有完整 rollout_info。但 Transport 调 Adapter 时,只传 Adapter 需要的字段,不把完整 rollout_info 传下去。

IPCWeightTransport 负责:

  • 调 adapter 的 update 生命周期 hook
  • 每个 batch 调 adapter 生成本地 payload
  • TP gather
  • head rank post
  • engine parallel barrier

示例:

class IPCWeightTransport:
    def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
        self.adapter.before_update()
        try:
            for batch in batches:
                local_payload = self.adapter.build_local_payload(batch)
                tensors = self._gather_if_needed(local_payload.data)
                request = self.adapter.build_request(
                    batch,
                    tp=self.rollout_info.tp,
                    serialized_named_tensors=tensors,
                    load_format=local_payload.load_format,
                )
                if self._is_engine_parallel_head():
                    self._post_local_rollout(request)
                if request.needs_engine_parallel_barrier:
                    self._barrier_engine_parallel()
        finally:
            self.adapter.after_update()

NCCLWeightTransport 负责:

  • train head rank 判断
  • external NCCL group 生命周期
  • 调 adapter 完成 backend-specific broadcast/request
  • 向 rollout engine post
  • train update barrier

示例:

class NCCLWeightTransport:
    def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
        for batch in batches:
            if not batch.state_dict:
                continue
            if not self._is_train_head_rank():
                self._barrier_train_update_group()
                continue

            self.nccl_group.ensure_started()
            request = self.adapter.build_request(batch, group=self.nccl_group)
            for endpoint in self.rollout_info.active_engine_endpoints:
                post_json(endpoint.url, request.endpoint, request.body, api_key=self.rollout_info.api_key)
            self._barrier_train_update_group()

Adapter

Adapter 只表达推理引擎协议差异。

LMDeployIPCAdapter

LMDeploy 的特殊点是 flattened bucket 和 IPC tensor cache。这些不应该放在通用 IPC transport 里,而应该下沉到 LMDeployIPCAdapter

示例:

class LMDeployIPCAdapter(IPCBackendAdapter):
    def __init__(self, bucket_size_bytes: int):
        self.flattened_bucket_cache = LMDeployFlattenedBucketCache(bucket_size_bytes)

    def before_update(self) -> None:
        self.flattened_bucket_cache.open()

    def after_update(self) -> None:
        self.flattened_bucket_cache.close()

    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
        if batch.state_dict and lmdeploy_supports_flattened_bucket():
            data = self.flattened_bucket_cache.flatten(batch.state_dict)
            return LocalPayload(
                data=lmdeploy_serialize_state_dict(data),
                load_format="flattened_bucket",
            )
        return LocalPayload(data=lmdeploy_serialize_state_dict(batch.state_dict))

    def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
        body = {"serialized_named_tensors": serialized_named_tensors, "finished": batch.finished}
        if load_format is not None:
            body["load_format"] = load_format
        return IPCRequest(
            endpoint="update_weights",
            body=body,
            needs_engine_parallel_barrier=batch.finished or (batch.train_enable_ep and tp > 1),
        )

这里 LMDeployFlattenedBucketCache 承接旧实现里的 _update_params_ipc_event、per-dtype buffer cache、event wait/record、ipc handle 发送策略。通用 IPC transport 不知道这些细节。

SGLangIPCAdapter

SGLang colocate IPC 的特殊点是 torch reductions patch、SGLang serializer、update_weights_from_tensor payload。

示例:

class SGLangIPCAdapter(IPCBackendAdapter):
    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
        with patched_sglang_torch_reductions():
            if sglang_supports_flattened_bucket() and batch.state_dict:
                flattened = sglang_flattened_bucket(batch.state_dict.items())
                return LocalPayload(
                    data=sglang_serialize({
                        "flattened_tensor": flattened.tensor,
                        "metadata": flattened.metadata,
                    }),
                    load_format="flattened_bucket",
                )
            return LocalPayload(data=sglang_serialize(batch.state_dict.items()))

    def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
        if tp == 1:
            serialized_named_tensors = [serialized_named_tensors]
        body = {"serialized_named_tensors": serialized_named_tensors, "flush_cache": False}
        if load_format is not None:
            body["load_format"] = load_format
        return IPCRequest(endpoint="update_weights_from_tensor", body=body)

SGLangNCCLAdapter

SGLang disaggregated NCCL 的特殊点有两个:

  • rollout side group init endpoint 是 init_weights_update_group
  • weight update endpoint 是 update_weights_from_distributed

这些 endpoint 不应该放在 client 或通用 group helper 里,而应该下沉到 SGLang adapter。

示例:

class SGLangNCCLAdapter(NCCLBackendAdapter):
    def build_group_init_requests(self, *, active_engine_endpoints, rendezvous):
        rank_offset = 1
        requests = []
        for endpoint in active_engine_endpoints:
            requests.append(NCCLGroupInitRequest(
                url=endpoint.url,
                endpoint="init_weights_update_group",
                body={
                    "master_address": rendezvous.master_address,
                    "master_port": rendezvous.master_port,
                    "rank_offset": rank_offset,
                    "world_size": rendezvous.world_size,
                    "group_name": rendezvous.group_name,
                    "backend": rendezvous.backend,
                },
            ))
            rank_offset += endpoint.engine_size
        return requests

    def build_request(self, batch, *, group) -> NCCLRequest:
        flattened = sglang_flattened_bucket(batch.state_dict.items())
        group.broadcast_tensor(flattened.tensor)
        return NCCLRequest(
            endpoint="update_weights_from_distributed",
            body={
                "names": flattened.names,
                "dtypes": flattened.dtypes,
                "shapes": flattened.shapes,
                "group_name": group.group_name,
                "load_format": "flattened_bucket",
            },
        )

HTTP helper 原则

不要让 client 提供这些方法:

client.collective_rpc(...)
client.update_weights(...)
client.update_weights_from_tensor(...)
client.init_weights_update_group(...)
client.update_weights_from_distributed(...)

这些方法名全是推理引擎协议,放在 client 会混合 vLLM、LMDeploy、SGLang 的 Interface。

更简洁的是低层 helper:

def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:
    headers = {"Content-Type": "application/json"}
    if api_key is not None:
        headers["Authorization"] = f"Bearer {api_key}"
    response = requests.post(f"{url}/{endpoint}", headers=headers, json=payload)
    response.raise_for_status()
    return response.json()

endpoint 和 body 由 Adapter 产出:

request = adapter.build_request(...)
post_json(url, request.endpoint, request.body, api_key=rollout_info.api_key)

这样 helper 只负责 HTTP,Adapter 负责推理引擎协议。

上升 / 下沉规则

  • 上升到 Transport:TP gather、head rank post、barrier、external NCCL group 生命周期、HTTP post 执行。
  • 留在 rollout_info:backend、transport type、tp/ep、api key、rollout URL、engine endpoints、device mesh。
  • 下沉到 Adapter:推理引擎 endpoint、payload schema、序列化格式、flattened bucket、SGLang group init 请求。
  • 下沉到 backend-specific helper:LMDeploy IPC event/cache、SGLang torch patch、SGLang FlattenedTensorBucket。

最终的 Interface 更小:

class IPCBackendAdapter:
    def before_update(self) -> None: ...
    def after_update(self) -> None: ...
    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload: ...
    def build_request(..., tp, serialized_named_tensors, load_format) -> IPCRequest: ...


class NCCLBackendAdapter:
    def build_group_init_requests(..., active_engine_endpoints, rendezvous) -> list[NCCLGroupInitRequest]: ...
    def check_group_init_result(self, result: dict) -> None: ...
    def build_request(self, batch, *, group) -> NCCLRequest: ...

这能让 Adapter 只关注不同推理引擎的差异点,同时让 Transport 保持通用传输语义。


with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
master_port = int(sock.getsockname()[1])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

通过写入明确的配置来设定端口,这种动态拿端口很容易产生 address already in use的报错

[Errno 98] error while attempting to bind on address ('10.102.241.42', 42169): address already in use

return

transport.ensure_group()
if transport.group is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该改为 assert transport.group is not None ?


assert transport.executor is not None
assert transport.group_name is not None
with transport.update_lock:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删掉 update_lock ?

DEVICE_MODULE = get_torch_device_module()


class WeightExporter:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename WeightIterator

)
return engine_info

def ensure_group(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL group 是持久化的,但 update_rollout_info() 可能被反复调用更新 rollout metadata。如果 rollout worker fail & recover 后 URL/status 变化,已有 group/engine_urls 不会重建。旧代码也有类似问题,但这次重构可以顺手把“metadata 变化时 reset NCCL transport”这个语义补上。

self.rollout_info.rollout_cfg_info["api_key"] = self.rollout_info.api_key
self.rollout_info.rollout_cfg_info["backend"] = self.rollout_info.backend

def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

简化接口:去掉整个调用链路的 set_train_rollout_mode 接口,合并到 update_rollout_info 中。两者都是设置 rollout_info

@@ -0,0 +1,249 @@
from __future__ import annotations
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

更新对应单测

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants