From 1982c407e87e10fb2450f36a9d197b1c58d57234 Mon Sep 17 00:00:00 2001 From: iflow-cli Date: Wed, 19 Nov 2025 09:52:15 +0000 Subject: [PATCH] feat: completed task XLUT20251119310498000001598006 --- roll/configs/base_config.py | 10 ++++++-- roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py | 16 +++++++----- roll/pipeline/rlvr/rlvr_pipeline.py | 26 +++++++++++--------- roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 23 +++++++++++------ 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 350f69cc5..d2d5e1cfd 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -309,6 +309,11 @@ class PPOConfig(BaseConfig): default_factory=WorkerConfig, metadata={"help": "Configuration for the reference role."} ) + use_reference: bool = field( + default=True, + metadata={"help": "Whether to use reference model for KL divergence computation. If False, reference model will not be initialized."} + ) + async_generation_ratio: float = field( default=0, metadata={ @@ -389,11 +394,12 @@ def __post_init__(self): if ( self.actor_train.model_args.model_name_or_path is None or self.actor_infer.model_args.model_name_or_path is None - or self.reference.model_args.model_name_or_path is None + or (self.use_reference and self.reference.model_args.model_name_or_path is None) ): self.actor_train.model_args.model_name_or_path = self.pretrain self.actor_infer.model_args.model_name_or_path = self.pretrain - self.reference.model_args.model_name_or_path = self.pretrain + if self.use_reference: + self.reference.model_args.model_name_or_path = self.pretrain if self.critic.model_args.model_name_or_path is None: self.critic.model_args.model_name_or_path = self.reward_pretrain diff --git a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py index de38bd157..cf69dc7f3 100644 --- a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py @@ -222,7 +222,7 @@ def __init__(self, pipeline_config: RLVRConfig): worker_config=self.pipeline_config.actor_infer, ) # use unwrapped model as reference for lora training - if not self.is_lora: + if not self.is_lora and self.pipeline_config.use_reference: self.reference: Any = Cluster( name=self.pipeline_config.reference.name, worker_cls=self.pipeline_config.reference.worker_cls, @@ -361,15 +361,19 @@ def run(self): with Timer(name="cal_ref_log_probs_reward", logger=None) as cal_timer: if self.is_lora: - batch.meta_info["disable_adapter"] = True - batch.meta_info["is_offload_states"] = False ref_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( batch, blocking=False ) else: - ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs( - batch, blocking=False - ) + if self.pipeline_config.use_reference: + ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs( + batch, blocking=False + ) + else: + # When reference model is disabled, use actor's own log probabilities as reference + ref_log_probs_refs: List[ray.ObjectRef] = self.actor_infer.compute_log_probs( + batch, blocking=False + ) rewards_refs: List[ray.ObjectRef] = self.reward.compute_rewards(batch, blocking=False) ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 4f702f0b6..6ae715ea9 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -216,7 +216,7 @@ def __init__(self, pipeline_config: RLVRConfig): ) download_clusters = [self.actor_train, self.actor_infer] # use unwrapped model as reference for lora training - if not self.is_lora: + if not self.is_lora and self.pipeline_config.use_reference: self.reference: Any = Cluster( name=self.pipeline_config.reference.name, worker_cls=self.pipeline_config.reference.worker_cls, @@ -544,16 +544,20 @@ def run(self): batch.meta_info["is_offload_states"] = False ref_log_probs = self.actor_train.compute_log_probs(batch, blocking=True) else: - if self.pipeline_config.reference.use_dynamic_batching_in_infer: - batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, - self.reference.dp_size, - self.pipeline_config.reference.max_tokens_per_microbatch_in_infer, - self.pipeline_config.reference.sequence_length_round_in_infer, - "reference/compute_log_probs", - ) - metrics_mgr.add_metrics(dynamic_batching_metrics) - ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + if self.pipeline_config.use_reference: + if self.pipeline_config.reference.use_dynamic_batching_in_infer: + batch, dynamic_batching_metrics = dynamic_batching_shard( + batch, + self.reference.dp_size, + self.pipeline_config.reference.max_tokens_per_microbatch_in_infer, + self.pipeline_config.reference.sequence_length_round_in_infer, + "reference/compute_log_probs", + ) + metrics_mgr.add_metrics(dynamic_batching_metrics) + ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + else: + # When reference model is disabled, use actor's own log probabilities as reference + ref_log_probs = self.actor_infer.compute_log_probs(batch, blocking=True) metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") batch = batch.union(ref_log_probs) diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index eb689c1c8..88f5ac725 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -313,13 +313,16 @@ def __init__(self, pipeline_config: RLVRConfig): resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_infer, ) - self.reference: Any = Cluster( - name=self.pipeline_config.reference.name, - worker_cls=self.pipeline_config.reference.worker_cls, - resource_manager=self.resource_manager, - worker_config=self.pipeline_config.reference, - ) - download_clusters = [self.actor_train, self.actor_infer, self.reference] + if self.pipeline_config.use_reference: + self.reference: Any = Cluster( + name=self.pipeline_config.reference.name, + worker_cls=self.pipeline_config.reference.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.reference, + ) + download_clusters = [self.actor_train, self.actor_infer, self.reference] + else: + download_clusters = [self.actor_train, self.actor_infer] if self.pipeline_config.adv_estimator == "gae": self.critic: Any = Cluster( name=self.pipeline_config.critic.name, @@ -542,7 +545,11 @@ def run(self): batch.meta_info["_broadcast_non_tensor_batch"]= True with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: - ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + if self.pipeline_config.use_reference: + ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + else: + # When reference model is disabled, use actor's own log probabilities as reference + ref_log_probs = self.actor_infer.compute_log_probs(batch, blocking=True) metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") batch = batch.union(ref_log_probs)