refactor rollout weight update flow#1828
Conversation
Transport / Adapter 重构草图当前问题当前 PR 里 Adapter 的名字是对的,但 Interface 还偏浅:
这些都让 Adapter 知道太多 Transport 信息。改法是:该上升到 Transport 的通用传输信息上升;该下沉到 Adapter 的推理引擎协议信息下沉。 Module 分工UpdateWeighter
示例: class UpdateWeighter:
def update_weights(self) -> None:
exporter = WeightExporter(
config=self.config,
engine=self.engine,
rollout_info=self.rollout_info,
)
transport = self._get_transport()
# 不要 list(exporter.iter_batches()),权重 batch 需要边导出边发送。
transport.send_all(exporter.iter_batches())
WeightExporter
示例: class WeightExporter:
def __init__(self, *, config, engine, rollout_info):
self.config = config
self.engine = engine
self.rollout_info = rollout_info
def iter_batches(self) -> Iterable[WeightUpdateBatch]:
...重点是保持 TransportTransport 持有通用传输生命周期,可以持有完整
示例: class IPCWeightTransport:
def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
self.adapter.before_update()
try:
for batch in batches:
local_payload = self.adapter.build_local_payload(batch)
tensors = self._gather_if_needed(local_payload.data)
request = self.adapter.build_request(
batch,
tp=self.rollout_info.tp,
serialized_named_tensors=tensors,
load_format=local_payload.load_format,
)
if self._is_engine_parallel_head():
self._post_local_rollout(request)
if request.needs_engine_parallel_barrier:
self._barrier_engine_parallel()
finally:
self.adapter.after_update()
示例: class NCCLWeightTransport:
def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
for batch in batches:
if not batch.state_dict:
continue
if not self._is_train_head_rank():
self._barrier_train_update_group()
continue
self.nccl_group.ensure_started()
request = self.adapter.build_request(batch, group=self.nccl_group)
for endpoint in self.rollout_info.active_engine_endpoints:
post_json(endpoint.url, request.endpoint, request.body, api_key=self.rollout_info.api_key)
self._barrier_train_update_group()AdapterAdapter 只表达推理引擎协议差异。 LMDeployIPCAdapterLMDeploy 的特殊点是 flattened bucket 和 IPC tensor cache。这些不应该放在通用 IPC transport 里,而应该下沉到 示例: class LMDeployIPCAdapter(IPCBackendAdapter):
def __init__(self, bucket_size_bytes: int):
self.flattened_bucket_cache = LMDeployFlattenedBucketCache(bucket_size_bytes)
def before_update(self) -> None:
self.flattened_bucket_cache.open()
def after_update(self) -> None:
self.flattened_bucket_cache.close()
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
if batch.state_dict and lmdeploy_supports_flattened_bucket():
data = self.flattened_bucket_cache.flatten(batch.state_dict)
return LocalPayload(
data=lmdeploy_serialize_state_dict(data),
load_format="flattened_bucket",
)
return LocalPayload(data=lmdeploy_serialize_state_dict(batch.state_dict))
def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
body = {"serialized_named_tensors": serialized_named_tensors, "finished": batch.finished}
if load_format is not None:
body["load_format"] = load_format
return IPCRequest(
endpoint="update_weights",
body=body,
needs_engine_parallel_barrier=batch.finished or (batch.train_enable_ep and tp > 1),
)这里 SGLangIPCAdapterSGLang colocate IPC 的特殊点是 torch reductions patch、SGLang serializer、 示例: class SGLangIPCAdapter(IPCBackendAdapter):
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
with patched_sglang_torch_reductions():
if sglang_supports_flattened_bucket() and batch.state_dict:
flattened = sglang_flattened_bucket(batch.state_dict.items())
return LocalPayload(
data=sglang_serialize({
"flattened_tensor": flattened.tensor,
"metadata": flattened.metadata,
}),
load_format="flattened_bucket",
)
return LocalPayload(data=sglang_serialize(batch.state_dict.items()))
def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
if tp == 1:
serialized_named_tensors = [serialized_named_tensors]
body = {"serialized_named_tensors": serialized_named_tensors, "flush_cache": False}
if load_format is not None:
body["load_format"] = load_format
return IPCRequest(endpoint="update_weights_from_tensor", body=body)SGLangNCCLAdapterSGLang disaggregated NCCL 的特殊点有两个:
这些 endpoint 不应该放在 client 或通用 group helper 里,而应该下沉到 SGLang adapter。 示例: class SGLangNCCLAdapter(NCCLBackendAdapter):
def build_group_init_requests(self, *, active_engine_endpoints, rendezvous):
rank_offset = 1
requests = []
for endpoint in active_engine_endpoints:
requests.append(NCCLGroupInitRequest(
url=endpoint.url,
endpoint="init_weights_update_group",
body={
"master_address": rendezvous.master_address,
"master_port": rendezvous.master_port,
"rank_offset": rank_offset,
"world_size": rendezvous.world_size,
"group_name": rendezvous.group_name,
"backend": rendezvous.backend,
},
))
rank_offset += endpoint.engine_size
return requests
def build_request(self, batch, *, group) -> NCCLRequest:
flattened = sglang_flattened_bucket(batch.state_dict.items())
group.broadcast_tensor(flattened.tensor)
return NCCLRequest(
endpoint="update_weights_from_distributed",
body={
"names": flattened.names,
"dtypes": flattened.dtypes,
"shapes": flattened.shapes,
"group_name": group.group_name,
"load_format": "flattened_bucket",
},
)HTTP helper 原则不要让 client 提供这些方法: client.collective_rpc(...)
client.update_weights(...)
client.update_weights_from_tensor(...)
client.init_weights_update_group(...)
client.update_weights_from_distributed(...)这些方法名全是推理引擎协议,放在 client 会混合 vLLM、LMDeploy、SGLang 的 Interface。 更简洁的是低层 helper: def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
response = requests.post(f"{url}/{endpoint}", headers=headers, json=payload)
response.raise_for_status()
return response.json()endpoint 和 body 由 Adapter 产出: request = adapter.build_request(...)
post_json(url, request.endpoint, request.body, api_key=rollout_info.api_key)这样 helper 只负责 HTTP,Adapter 负责推理引擎协议。 上升 / 下沉规则
最终的 Interface 更小: class IPCBackendAdapter:
def before_update(self) -> None: ...
def after_update(self) -> None: ...
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload: ...
def build_request(..., tp, serialized_named_tensors, load_format) -> IPCRequest: ...
class NCCLBackendAdapter:
def build_group_init_requests(..., active_engine_endpoints, rendezvous) -> list[NCCLGroupInitRequest]: ...
def check_group_init_result(self, result: dict) -> None: ...
def build_request(self, batch, *, group) -> NCCLRequest: ...这能让 Adapter 只关注不同推理引擎的差异点,同时让 Transport 保持通用传输语义。 |
|
|
||
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | ||
| sock.bind(("", 0)) | ||
| master_port = int(sock.getsockname()[1]) |
There was a problem hiding this comment.
通过写入明确的配置来设定端口,这种动态拿端口很容易产生 address already in use的报错
[Errno 98] error while attempting to bind on address ('10.102.241.42', 42169): address already in use
| return | ||
|
|
||
| transport.ensure_group() | ||
| if transport.group is None: |
There was a problem hiding this comment.
应该改为 assert transport.group is not None ?
|
|
||
| assert transport.executor is not None | ||
| assert transport.group_name is not None | ||
| with transport.update_lock: |
| DEVICE_MODULE = get_torch_device_module() | ||
|
|
||
|
|
||
| class WeightExporter: |
| ) | ||
| return engine_info | ||
|
|
||
| def ensure_group(self): |
There was a problem hiding this comment.
NCCL group 是持久化的,但 update_rollout_info() 可能被反复调用更新 rollout metadata。如果 rollout worker fail & recover 后 URL/status 变化,已有 group/engine_urls 不会重建。旧代码也有类似问题,但这次重构可以顺手把“metadata 变化时 reset NCCL transport”这个语义补上。
| self.rollout_info.rollout_cfg_info["api_key"] = self.rollout_info.api_key | ||
| self.rollout_info.rollout_cfg_info["backend"] = self.rollout_info.backend | ||
|
|
||
| def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): |
There was a problem hiding this comment.
简化接口:去掉整个调用链路的 set_train_rollout_mode 接口,合并到 update_rollout_info 中。两者都是设置 rollout_info
| @@ -0,0 +1,249 @@ | |||
| from __future__ import annotations | |||
重构了权重更新流程,重构新增了下面几个文件
data.py:定义共享数据结构、类型别名和` update batch。client.py:封装 rollout engine 的 HTTP update 接口。exporter.py:从训练模型中导出 HuggingFace 风格权重 batch。transport.py:实现 IPC/NCCL 传输,以及不同 rollout backend 的适配器。update_weighter.py:对外提供 update weight 编排逻辑,是上游主要调用入口。__init__.py:导出公共接口。后续新增RDMA的更新方式需要新增 RDMAWeightTransport,新增对LMdeploy支持需要增加LMdeployNCCLBackendAdapter