From 5a95cab127a42dde308cdbd162d7e48635dab535 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 18 Jun 2025 10:14:02 -0400 Subject: [PATCH 001/195] apo initial --- compose_rl/algorithms/online/callback.py | 221 ++++++++++-------- compose_rl/algorithms/online/model.py | 5 + compose_rl/algorithms/online/model_methods.py | 66 +++++- compose_rl/data/prompt_data.py | 9 + 4 files changed, 200 insertions(+), 101 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 0eef5b4a..4dca6023 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -589,6 +589,7 @@ def iteration_start(self, state: State, logger: Logger): del logger # unused batch = self._get_next_iter_prompts() + batch = state.device.batch_to_device(batch) if self.vllm_engines is not None: @@ -648,7 +649,7 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id']: + if key in ['prompt_len', 'verified_answer', 'prompt_id', 'vstar']: curr_values.append(batch[key]) continue @@ -678,6 +679,8 @@ def _get_next_iter_prompts(self): else: if key == 'verified_answer': ret_batch[key] = list(flatten(curr_values)) + elif key == 'vstar': + ret_batch[key] = list(flatten(curr_values)) else: # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values @@ -870,109 +873,137 @@ def _resolve_outputs( env_outs['right_padded_attn_mask'] = torch.logical_not( torch.eq(env_outs['obs'], self.pad_token_idx), # type: ignore ) + if self.actor_critic.loss_type != OnPolicyEnum.APO: + # Now that rewards are resolved, we can compute advantages + if self.actor_critic.loss_type == OnPolicyEnum.PPO: + env_outs['advantages'] = compute_advantages( + rewards=env_outs['rewards'], + values=env_outs['values'], + gamma=self.gamma, + lambda_gae=self.lambda_gae, + ) + elif self.actor_critic.loss_type == OnPolicyEnum.GRPO: + # compute GRPO advantages + prompt_id = env_outs['prompt_id'] + rewards = env_outs['rewards'] + + # Flatten the rewards by summing on sequence length/action_mask + flat_rewards = masked_sum( + rewards, + env_outs['action_mask'], + dim=-1, + ) - # Now that rewards are resolved, we can compute advantages - if self.actor_critic.loss_type == OnPolicyEnum.PPO: - env_outs['advantages'] = compute_advantages( - rewards=env_outs['rewards'], - values=env_outs['values'], - gamma=self.gamma, - lambda_gae=self.lambda_gae, - ) - elif self.actor_critic.loss_type == OnPolicyEnum.GRPO: - # compute GRPO advantages - prompt_id = env_outs['prompt_id'] - rewards = env_outs['rewards'] - - # Flatten the rewards by summing on sequence length/action_mask - flat_rewards = masked_sum( - rewards, - env_outs['action_mask'], - dim=-1, - ) + # Get unique prompt IDs and their indices + unique_prompt_ids, inverse_indices = torch.unique( + prompt_id, + return_inverse=True, + ) - # Get unique prompt IDs and their indices - unique_prompt_ids, inverse_indices = torch.unique( - prompt_id, - return_inverse=True, - ) + # Use scatter to compute means and standard deviations + # First, we'll create a tensor to track counts, sums, and sum of squares + n_unique = len(unique_prompt_ids) + counts = torch.zeros(n_unique, device=prompt_id.device) + sums = torch.zeros(n_unique, device=prompt_id.device) + sum_squares = torch.zeros(n_unique, device=prompt_id.device) + + # Use scatter_add to accumulate values + counts.scatter_add_( + 0, + inverse_indices, + torch.ones_like(flat_rewards), + ) + sums.scatter_add_(0, inverse_indices, flat_rewards) + sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2) + + # Compute means and standard deviations + means = sums / counts + variances = (sum_squares / counts) - (means**2) + stds = torch.sqrt(variances) + + # Map back to original tensor shape + mean_rewards = means[inverse_indices] + std_rewards = stds[inverse_indices] + + # Calculate GRPO advantage + grpo_advantage = (flat_rewards - mean_rewards) + # Only normalize the advantage if flag is set + if self.actor_critic.normalize_advantage: + grpo_advantage /= (std_rewards + 1e-4) + + # Create advantages of the same shape as original rewards + advantages = torch.zeros_like(rewards) + # Copy the flat grpo_advantage according to action_mask + expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( + env_outs['action_mask'], + ) + advantages = torch.where( + env_outs['action_mask'].bool(), + expanded_advantages, + advantages, + ) + env_outs['advantages'] = advantages + else: + raise ValueError( + f'Invalid loss type: {self.actor_critic.loss_type}. ' + + 'Valid options are: ppo, grpo.', + ) - # Use scatter to compute means and standard deviations - # First, we'll create a tensor to track counts, sums, and sum of squares - n_unique = len(unique_prompt_ids) - counts = torch.zeros(n_unique, device=prompt_id.device) - sums = torch.zeros(n_unique, device=prompt_id.device) - sum_squares = torch.zeros(n_unique, device=prompt_id.device) - - # Use scatter_add to accumulate values - counts.scatter_add_( - 0, - inverse_indices, - torch.ones_like(flat_rewards), - ) - sums.scatter_add_(0, inverse_indices, flat_rewards) - sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2) - - # Compute means and standard deviations - means = sums / counts - variances = (sum_squares / counts) - (means**2) - stds = torch.sqrt(variances) - - # Map back to original tensor shape - mean_rewards = means[inverse_indices] - std_rewards = stds[inverse_indices] - - # Calculate GRPO advantage - grpo_advantage = (flat_rewards - mean_rewards) - # Only normalize the advantage if flag is set - if self.actor_critic.normalize_advantage: - grpo_advantage /= (std_rewards + 1e-4) - - # Create advantages of the same shape as original rewards - advantages = torch.zeros_like(rewards) - # Copy the flat grpo_advantage according to action_mask - expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( + batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( + env_outs['advantages'], env_outs['action_mask'], ) - advantages = torch.where( - env_outs['action_mask'].bool(), - expanded_advantages, - advantages, + + mean_ift = masked_mean( + env_outs['ift_kl'], + env_outs['action_mask'], ) - env_outs['advantages'] = advantages + self.kl_ift.append(mean_ift.cpu()) + + iter_batch.update(env_outs) + + iter_batch.update({ + 'max_gen_len': + torch.ones(self.iter_batch_size).to(torch.int32) * + self.max_gen_len, + 'adv_masked_mean': + torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(), + 'adv_masked_var': + torch.ones(self.iter_batch_size) * batch_adv_var.cpu(), + 'ift_kl_scalar': + torch.ones(self.iter_batch_size) * self.kl_ctl.value, + 'reward_std': + torch.ones(self.iter_batch_size) * + env_outs['rewards'].std().to('cpu'), + }) else: - raise ValueError( - f'Invalid loss type: {self.actor_critic.loss_type}. ' + - 'Valid options are: ppo, grpo.', - ) + # Adding dummy advantages + env_outs['advantages'] = torch.ones_like(env_outs['action_mask']) + env_outs - batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( - env_outs['advantages'], - env_outs['action_mask'], - ) + mean_ift = masked_mean( + env_outs['ift_kl'], + env_outs['action_mask'], + ) + self.kl_ift.append(mean_ift.cpu()) + + iter_batch.update(env_outs) + + iter_batch.update({ + 'max_gen_len': + torch.ones(self.iter_batch_size).to(torch.int32) * + self.max_gen_len, + 'adv_masked_mean': + torch.ones(self.iter_batch_size), + 'adv_masked_var': + torch.ones(self.iter_batch_size), + 'ift_kl_scalar': + torch.ones(self.iter_batch_size) * self.kl_ctl.value, + 'reward_std': + torch.ones(self.iter_batch_size) * + env_outs['rewards'].std().to('cpu'), + }) - mean_ift = masked_mean( - env_outs['ift_kl'], - env_outs['action_mask'], - ) - self.kl_ift.append(mean_ift.cpu()) - - iter_batch.update(env_outs) - - iter_batch.update({ - 'max_gen_len': - torch.ones(self.iter_batch_size).to(torch.int32) * - self.max_gen_len, - 'adv_masked_mean': - torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(), - 'adv_masked_var': - torch.ones(self.iter_batch_size) * batch_adv_var.cpu(), - 'ift_kl_scalar': - torch.ones(self.iter_batch_size) * self.kl_ctl.value, - 'reward_std': - torch.ones(self.iter_batch_size) * - env_outs['rewards'].std().to('cpu'), - }) # Moving minibatches to CPU to not take additional GPU memory for k, v in iter_batch.items(): diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 9e7acda5..db908ea1 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -104,6 +104,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, + beta = self.config.beta, #added beta add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, @@ -217,6 +218,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, + beta = self.config.beta, #added beta parameter add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, @@ -255,6 +257,7 @@ def __init__( length_normalize_policy_loss: bool = True, policy_clip_ratio: float = 0.15, policy_clip_high_ratio: float | None = None, + beta: float = 1e-3, #added beta compute_kl_loss: bool = True, target_kl: float = 0.1, kl_estimator: str = 'k3', @@ -283,6 +286,7 @@ def __init__( self.policy_clip_high_ratio = policy_clip_high_ratio self.compute_kl_loss = compute_kl_loss self.target_kl = target_kl + self.beta = beta self.kl_estimator = kl_estimator self.kl_clip_range = kl_clip_range @@ -306,6 +310,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): loss_type=self.loss_type, policy_clip_ratio=self.policy_clip_ratio, policy_clip_high_ratio=self.policy_clip_high_ratio, + beta = self.beta, #added beta length_normalize_policy_loss=self.length_normalize_policy_loss, add_direct_kl_loss=self.compute_kl_loss, kl_estimator=self.kl_estimator, diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index aba0f5b4..73bf802a 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -14,12 +14,14 @@ class OnPolicyEnum(Enum): PPO = 'ppo' GRPO = 'grpo' + APO = 'apo' #add A-star PO class ALGORITHM_TYPE(set, Enum): - CRITIC_FREE = {OnPolicyEnum.GRPO} + CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO} ACTOR_CRITIC = {OnPolicyEnum.PPO} CLIPPED_PG = {OnPolicyEnum.PPO, OnPolicyEnum.GRPO} + REGRESSION = {OnPolicyEnum.APO} #regression based loss, maybe REBEL also here? @dataclass @@ -88,6 +90,7 @@ def prepare_critic_values_for_training( ) values *= action_mask + #TODO: why zero padding? if zero_pad: zero_pad_tensor = torch.zeros((bs, 1), device=values.device, @@ -153,7 +156,7 @@ def critic_loss( ) -> MutableMapping: if loss_type == OnPolicyEnum.PPO: advantages = batch['advantages'] - v_preds = outputs['values'][:, :-1] * batch['action_mask'] + v_preds = outputs['values'][:, :-1] * batch['action_mask'] #TODO: why shift -1? eos token? why no shift in log_probs? v_preds = v_preds.to(advantages.dtype) values = batch['values'][:, :-1] * batch['action_mask'] @@ -215,9 +218,10 @@ def critic_loss( def policy_loss( advantages: torch.Tensor, - outputs: MutableMapping, - batch: MutableMapping, + outputs: MutableMapping, #outputs contain things that are from forward pass with gradient graph turned on + batch: MutableMapping, #batch contains data w/o gradient, e.g., it can contain old_log_probs ; todo: check if it contains ref_log_probs loss_type: OnPolicyEnum, + beta: float = 1e-3, #beta used in rebel and APO (assume it is beta *( log pi - log pi_ref ), so smaller beta means less KL to pi_ref) policy_clip_ratio: float = 0.15, policy_clip_high_ratio: float | None = None, length_normalize_policy_loss: bool = True, @@ -236,7 +240,7 @@ def policy_loss( kl_clip_range=kl_clip_range, ) online_ift_kl_dict = utils.approx_kl( - log_p=batch['ift_log_probs'], + log_p=batch['ift_log_probs'], log_q=outputs['online_log_probs'], kl_clip_range=kl_clip_range, ) @@ -261,7 +265,7 @@ def policy_loss( batch['action_mask'], ) - ratio = torch.exp(online_log_probs - old_log_probs) + ratio = torch.exp(online_log_probs - old_log_probs) #pi/pi_old policy_loss_1 = -advantages * ratio # Use the same clip ratio for both sides if clip high ratio is not provided @@ -342,6 +346,54 @@ def policy_loss( utils.sample_wise_masked_mean(advantages, batch['action_mask']), } return policy_dict + + elif loss_type in ALGORITHM_TYPE.REGRESSION: + print("="* 10 + "Inside Model Methods Loss" + "="*10) + print(f"Batch Keys: {batch.keys()}") + print(batch['vstar'].shape) + + print(f"Output Keys: {outputs.keys()}") + for k, v in outputs.items(): + print(f"{k}: {v.shape}") + + #assume batch contains (1) V-star values (key 'vstar), (2) rewards (key 'rewards'), (3) ref_log_probs + online_log_probs = outputs['online_log_probs'] + ref_log_probs = batch['ift_log_probs'] + log_probs_diff = online_log_probs - ref_log_probs + old_entropies = batch['old_entropies'] + + #compute KL to pi_ref to keep track the divergence to \pi_ref + policy_kl_dict = utils.approx_kl( + log_p= ref_log_probs, + log_q= online_log_probs, #log_q - log_p = log pi - log pi_ref + kl_clip_range=kl_clip_range, + ) + with torch.no_grad(): + policy_kl = utils.masked_mean(policy_kl_dict[kl_estimator],batch['action_mask']) #plain average over all tokens (KL to pi_ref) + + #compute the policy class + maksed_log_probs_diff = utils.masked_sum(log_probs_diff, batch['action_mask'], dim = -1) #size: (batch_size,) + vstars = batch['vstar'] #TODO: (batch_size, ) + rewards = utils.masked_sum(batch['rewards'], batch['action_mask'], dim = -1) #TODO: check, (Batch_size, ) + assert vstars.size() == rewards.size() == maksed_log_probs_diff.size() #should have the same shape which is (batch_size, ) + + policy_loss = ((beta*maksed_log_probs_diff - (rewards - vstars))**2).mean() + policy_dict = { + 'loss/policy_loss': + policy_loss, + 'kl/policy_kl': #TODO: add more KLs + policy_kl, + 'gen/gen_length': + batch['action_mask'].sum(dim=1).to(torch.float32), + 'gen/entropy': + old_entropies, + 'rewards/mean': + torch.mean(rewards), #compute the average reward of the current batch + 'vstars/mean': + torch.mean(vstars), #compute the average of the vstar of the current batch + } + return policy_dict + else: raise ValueError(f'Policy loss not implemented for {loss_type}') @@ -353,6 +405,7 @@ def online_rl_loss( value_clip_range: float = 0.2, value_loss_weight: float = 0.2, policy_clip_ratio: float = 0.15, + beta: float = 1e-3, #added a beta parameter here for APO and REBEL policy_clip_high_ratio: float | None = None, length_normalize_policy_loss: bool = True, add_direct_kl_loss: bool = False, @@ -423,6 +476,7 @@ def online_rl_loss( outputs=outputs, batch=batch, loss_type=loss_type, + beta = beta, policy_clip_ratio=policy_clip_ratio, policy_clip_high_ratio=policy_clip_high_ratio, length_normalize_policy_loss=length_normalize_policy_loss, diff --git a/compose_rl/data/prompt_data.py b/compose_rl/data/prompt_data.py index 07b8622d..686a0ddb 100644 --- a/compose_rl/data/prompt_data.py +++ b/compose_rl/data/prompt_data.py @@ -40,6 +40,7 @@ def prompt_dataset_collate_fn( mlm_probability=0.0, ) + keys = batch[0].keys() collated_batch: dict[str, torch.Tensor] = {} for key in keys: @@ -50,6 +51,9 @@ def prompt_dataset_collate_fn( if key == 'prompt_id': collated_batch[key] = torch.tensor(cur_values) continue + if key == 'vstar': + collated_batch[key] = torch.tensor(cur_values) + continue if key in ['verified_answer']: collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] utils.flatten(cur_values), @@ -124,5 +128,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]: _answer = '' item_dict['verified_answer'] = _answer # type: ignore + + #vstar + vstar = sample.get('vstar', None) + if vstar is not None: + item_dict['vstar'] = torch.Tensor([vstar]) return item_dict From d89cd64b18c439bed242761fb54155a1f2a1c538 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 18 Jun 2025 10:14:47 -0400 Subject: [PATCH 002/195] . --- yamls/local_apo.yaml | 171 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 yamls/local_apo.yaml diff --git a/yamls/local_apo.yaml b/yamls/local_apo.yaml new file mode 100644 index 00000000..5d50da8f --- /dev/null +++ b/yamls/local_apo.yaml @@ -0,0 +1,171 @@ +seed: 4667 + +max_seq_len: 10240 +model: + name: hf_critic_free_lm + pretrained: true + use_auth_token: true + loss_type: apo + beta: 1e-3 # TODO: need to + normalize_advantage: true + length_normalize_policy_loss: true + policy_clip_ratio: 0.2 + compute_kl_loss: false # turn off kl in loss + target_kl: 0.1 + kl_estimator: k3 + kl_clip_range: 40.0 + use_flash_attention_2: true + pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct + +#loggers: +# mlflow: +# tags: +# run: apo_test +# project: apo +# tracking_uri: databricks +# experiment_name: null # TODO: add mlflow experiment name + +callbacks: + on_policy_rl: {} + lr_monitor: {} + memory_monitor: {} + +optimizer: + lr: 7.0e-07 + name: decoupled_adamw + betas: + - 0.9 + - 0.95 + weight_decay: 0 + +precision: amp_bf16 +scheduler: + name: cosine_with_warmup + alpha_f: 0.01 + t_warmup: 10iter + +tokenizer: + name: ${variables.tokenizer_name} + kwargs: + padding: longest + pad_token: <|endoftext|> + truncation: true + padding_side: left + model_max_length: ${max_seq_len} + trust_remote_code: true + +variables: + tokenizer_name: Qwen/Qwen2.5-1.5B-Instruct + + reference_model: + precision: amp_bf16 + pretrained: true + model_config: + name: hf_causal_lm + pretrained: true + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct + + kl_controller: # Turn off reward KL + kl_ctl_type: fixed + init_kl_coef: 0.0 + + # The non-train-FSDP config + non_train_fsdp_config: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: DEFAULT + activation_checkpointing: true + activation_cpu_offload: false + verbose: false + limit_all_gathers: true + state_dict_type: sharded + use_orig_params: true + + generation_kwargs: + top_p: 1 + top_k: 0 + do_sample: true + use_cache: true + temperature: 1 + + gamma: 1 + lambda_gae: 1 + + num_train_nodes: 1 + num_batches_per_update: 4 + generations_per_prompt: 1 + device_generate_batch_size: 1 + epoch_per_iteration: 1 + + buffer: + name: MinibatchRolloutBuffer + max_buffer_size: ${variables.num_batches_per_update} + + rewards: + bad_generation_end: + reward: -1 + eos_penalty: true + reward_type: bad_generation_end + math_verifier: + reward_type: math_verifier + reward: 4 + math_format_verifier: + reward_type: math_format_verifier + reward: 1 + + global_seed: 17 + max_gen_len: 32 + eos_token_ids: + - 151643 + - 151645 + + +algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0e-02 + +autoresume: false +save_folder: null # TODO: fill this in +save_overwrite: true +save_interval: 25iter +save_num_checkpoints_to_keep: 1 + +fsdp_config: + verbose: false + cpu_offload: false + mixed_precision: PURE + state_dict_type: sharded + use_orig_params: true + forward_prefetch: true + backward_prefetch: BACKWARD_PRE + sharding_strategy: FULL_SHARD + activation_checkpointing: true + activation_cpu_offload: false + activation_checkpointing_reentrant: false + +train_loader: + name: prompt + dataset: + local: data/gsm8k + split: train + # remote: + shuffle: true + max_gen_len: ${variables.max_gen_len} + max_seq_len: ${max_seq_len} + shuffle_seed: ${variables.global_seed} + download_timeout: 3600 + drop_last: true + num_workers: 8 + +log_config: true +dist_timeout: 3600 +progress_bar: false +eval_interval: 25iter +max_duration: 1000iter +log_to_console: true +python_log_level: debug +console_log_interval: 1ba +global_train_batch_size: 8 +device_train_microbatch_size: 1 From 1d925d21a5c68271e6f9c205dd1d06b66290b048 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 18 Jun 2025 13:07:27 -0400 Subject: [PATCH 003/195] vllm fix --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index b844480d..6f5cb4ff 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -45,7 +45,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor -from compose_rl.algorithms.online.model_methods import OnPolicyEnum +from compose_rl.algorithms.online.model_methods import ALGORITHM_TYPE, OnPolicyEnum log = logging.getLogger(__name__) @@ -362,7 +362,7 @@ def should_update_torch_module( if parsed_module_name not in valid_non_leaf_module_names: return False - if loss_type == OnPolicyEnum.GRPO: + if loss_type in ALGORITHM_TYPE.CRITIC_FREE: return True if loss_type == OnPolicyEnum.PPO and 'lm_backbone' in full_param_name: From e9dbad9b45e05bd624fc55c663e4e963693822f6 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 18 Jun 2025 13:08:17 -0400 Subject: [PATCH 004/195] critic free --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 6f5cb4ff..5262dc0c 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -394,7 +394,7 @@ def broadcast_to_vllm( count, num_params = 0, len( list(model.model.lm_backbone.named_parameters()), # type: ignore ) - elif loss_type == OnPolicyEnum.GRPO: + elif loss_type in ALGORITHM_TYPE.CRITIC_FREE: # Directly use the model params count, num_params = 0, len( list(model.model.named_parameters()), # type: ignore From 934e864247f9f91f402cfe1fb5f7214762a18e0c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Fri, 20 Jun 2025 10:13:46 -0400 Subject: [PATCH 005/195] cleanup --- compose_rl/algorithms/online/callback.py | 7 +++---- compose_rl/algorithms/online/model_methods.py | 21 +++++++------------ 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 4dca6023..3dfdabca 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -41,6 +41,7 @@ ComposerMPTPolicyLM, ) from compose_rl.algorithms.online.model_methods import ( + ALGORITHM_TYPE, OnPolicyEnum, ) from compose_rl.algorithms.online.reward_manager import ( @@ -873,7 +874,7 @@ def _resolve_outputs( env_outs['right_padded_attn_mask'] = torch.logical_not( torch.eq(env_outs['obs'], self.pad_token_idx), # type: ignore ) - if self.actor_critic.loss_type != OnPolicyEnum.APO: + if self.actor_critic.loss_type not in ALGORITHM_TYPE.REGRESSION: # Now that rewards are resolved, we can compute advantages if self.actor_critic.loss_type == OnPolicyEnum.PPO: env_outs['advantages'] = compute_advantages( @@ -977,9 +978,7 @@ def _resolve_outputs( env_outs['rewards'].std().to('cpu'), }) else: - # Adding dummy advantages - env_outs['advantages'] = torch.ones_like(env_outs['action_mask']) - env_outs + # APO and REBEL mean_ift = masked_mean( env_outs['ift_kl'], diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 73bf802a..9f11786c 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -348,14 +348,6 @@ def policy_loss( return policy_dict elif loss_type in ALGORITHM_TYPE.REGRESSION: - print("="* 10 + "Inside Model Methods Loss" + "="*10) - print(f"Batch Keys: {batch.keys()}") - print(batch['vstar'].shape) - - print(f"Output Keys: {outputs.keys()}") - for k, v in outputs.items(): - print(f"{k}: {v.shape}") - #assume batch contains (1) V-star values (key 'vstar), (2) rewards (key 'rewards'), (3) ref_log_probs online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] @@ -373,9 +365,9 @@ def policy_loss( #compute the policy class maksed_log_probs_diff = utils.masked_sum(log_probs_diff, batch['action_mask'], dim = -1) #size: (batch_size,) - vstars = batch['vstar'] #TODO: (batch_size, ) - rewards = utils.masked_sum(batch['rewards'], batch['action_mask'], dim = -1) #TODO: check, (Batch_size, ) - assert vstars.size() == rewards.size() == maksed_log_probs_diff.size() #should have the same shape which is (batch_size, ) + vstars = batch['vstar'] + rewards = utils.masked_sum(batch['rewards'], batch['action_mask'], dim = -1) + assert vstars.size() == rewards.size() == maksed_log_probs_diff.size() # should have the same shape which is (batch_size, ) policy_loss = ((beta*maksed_log_probs_diff - (rewards - vstars))**2).mean() policy_dict = { @@ -438,7 +430,9 @@ def online_rl_loss( # tensors in `outputs` are recomputed at the start of each step in the epoch. return_dict = {} - advantages = batch['advantages'] + advantages = None + if loss_type not in ALGORITHM_TYPE.REGRESSION: + advantages = batch['advantages'] # 1. Critic Loss if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: @@ -468,7 +462,8 @@ def online_rl_loss( return_dict.update(**value_dict) - advantages = advantages.detach() + if advantages is not None: + advantages = advantages.detach() # 2. Policy Loss policy_dict = policy_loss( From b7cb34e9bc28f989677e880406105c07f1ae9679 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Fri, 20 Jun 2025 10:16:53 -0400 Subject: [PATCH 006/195] local run --- yamls/local_apo.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/yamls/local_apo.yaml b/yamls/local_apo.yaml index 5d50da8f..e89cd8bd 100644 --- a/yamls/local_apo.yaml +++ b/yamls/local_apo.yaml @@ -11,7 +11,7 @@ model: length_normalize_policy_loss: true policy_clip_ratio: 0.2 compute_kl_loss: false # turn off kl in loss - target_kl: 0.1 + target_kl: 100000 # turn off early stopping kl_estimator: k3 kl_clip_range: 40.0 use_flash_attention_2: true @@ -31,18 +31,18 @@ callbacks: memory_monitor: {} optimizer: - lr: 7.0e-07 + lr: 1.0e-6 name: decoupled_adamw betas: - 0.9 - 0.95 - weight_decay: 0 + weight_decay: 1.0e-7 precision: amp_bf16 scheduler: - name: cosine_with_warmup - alpha_f: 0.01 - t_warmup: 10iter + name: constant_with_warmup + t_max: 12800ba + t_warmup: 0ba tokenizer: name: ${variables.tokenizer_name} @@ -124,7 +124,7 @@ variables: algorithms: gradient_clipping: clipping_type: norm - clipping_threshold: 1.0e-02 + clipping_threshold: 1.0 autoresume: false save_folder: null # TODO: fill this in From 28014a9835b4ebd4945fd0ff49eb6f3d4fa69bbb Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 21 Jun 2025 11:09:23 -0400 Subject: [PATCH 007/195] added timeout in rlvr utils --- compose_rl/utils/rlvr_utils.py | 72 ++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index 4d4c888b..3bce864c 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -4,6 +4,7 @@ import logging import re from typing import Any +import signal import sympy from sympy.parsing.latex import parse_latex @@ -70,34 +71,55 @@ def remove_boxed(s: str) -> str: return s.strip('{}') +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + def is_equiv(x1: str, x2: str) -> bool: """Checks mathematical equivalence between two normalized LaTeX strings.""" try: - try: - parsed_x1 = parse_latex(x1) - parsed_x2 = parse_latex(x2) - except ( - sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] - errors.LaTeXParsingError, - sympy.SympifyError, - TypeError, - ): - log.debug(f"couldn't parse one of {x1} or {x2}") - return False - - try: - diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] - except TypeError: - log.debug(f"couldn't subtract {x1} and {x2}") - return False - - try: - return sympy.simplify(diff) == 0 - except ValueError: - log.debug( - f'Had some trouble simplifying when comparing {x1} and {x2}', - ) - return False + with timeout(seconds = 5): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] + errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + log.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] + except TypeError: + log.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + return sympy.simplify(diff) == 0 + except ValueError: + log.debug( + f'Had some trouble simplifying when comparing {x1} and {x2}', + ) + return False + + except TimeoutError: + log.debug(f"Timed out comparing {x1} and {x2}") + return False except ImportError as e: log.error(e) raise From ed4e554cde8424389e8008c594fa4b67fcb3c448 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 21 Jun 2025 11:10:08 -0400 Subject: [PATCH 008/195] added timeout in rlvr utils --- compose_rl/utils/rlvr_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index 3bce864c..ec84edd8 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -88,6 +88,12 @@ def __exit__(self, type, value, traceback): def is_equiv(x1: str, x2: str) -> bool: + print("##########################") + print("############################") + print("############################") + print("############################") + print("############################") + """Checks mathematical equivalence between two normalized LaTeX strings.""" try: with timeout(seconds = 5): From de6400c7e3144753fd9eea4e811f15e70f492b0a Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 21 Jun 2025 11:20:53 -0400 Subject: [PATCH 009/195] . --- compose_rl/utils/rlvr_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index ec84edd8..c003ed2f 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -88,12 +88,7 @@ def __exit__(self, type, value, traceback): def is_equiv(x1: str, x2: str) -> bool: - print("##########################") - print("############################") - print("############################") - print("############################") - print("############################") - + """Checks mathematical equivalence between two normalized LaTeX strings.""" try: with timeout(seconds = 5): From 6b4c385b8ac39965ff5b152ef2965e8062f085d1 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 23 Jun 2025 13:57:59 -0400 Subject: [PATCH 010/195] fix some comments issues in local yaml --- yamls/local_apo.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/yamls/local_apo.yaml b/yamls/local_apo.yaml index e89cd8bd..43317353 100644 --- a/yamls/local_apo.yaml +++ b/yamls/local_apo.yaml @@ -6,18 +6,18 @@ model: pretrained: true use_auth_token: true loss_type: apo - beta: 1e-3 # TODO: need to + beta: 1e-3 # TODO: need to normalize_advantage: true length_normalize_policy_loss: true policy_clip_ratio: 0.2 - compute_kl_loss: false # turn off kl in loss - target_kl: 100000 # turn off early stopping + compute_kl_loss: false # turn off kl in loss + target_kl: 100000 # turn off early stopping kl_estimator: k3 kl_clip_range: 40.0 use_flash_attention_2: true pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct -#loggers: +# loggers: # mlflow: # tags: # run: apo_test @@ -67,7 +67,7 @@ variables: use_flash_attention_2: true pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct - kl_controller: # Turn off reward KL + kl_controller: # Turn off reward KL kl_ctl_type: fixed init_kl_coef: 0.0 From b2ae469872fa25bd810128815b8a85f68debe6ad Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 23 Jun 2025 14:56:53 -0400 Subject: [PATCH 011/195] . --- compose_rl/algorithms/online/callback.py | 8 +- .../online/generation_utils/vllm_utils.py | 5 +- compose_rl/algorithms/online/model.py | 5 +- compose_rl/algorithms/online/model_methods.py | 79 +++++++++++-------- compose_rl/data/prompt_data.py | 3 +- compose_rl/utils/rlvr_utils.py | 10 +-- 6 files changed, 67 insertions(+), 43 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 3dfdabca..b9a706e7 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -650,7 +650,12 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id', 'vstar']: + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'vstar', + ]: curr_values.append(batch[key]) continue @@ -1003,7 +1008,6 @@ def _resolve_outputs( env_outs['rewards'].std().to('cpu'), }) - # Moving minibatches to CPU to not take additional GPU memory for k, v in iter_batch.items(): if hasattr(v, 'cpu'): diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5262dc0c..c681e7f4 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -45,7 +45,10 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor -from compose_rl.algorithms.online.model_methods import ALGORITHM_TYPE, OnPolicyEnum +from compose_rl.algorithms.online.model_methods import ( + ALGORITHM_TYPE, + OnPolicyEnum, +) log = logging.getLogger(__name__) diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index db908ea1..4196bac8 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -104,7 +104,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, - beta = self.config.beta, #added beta + beta=self.config.beta, add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, @@ -218,7 +218,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, - beta = self.config.beta, #added beta parameter + beta = self.config.beta, #added beta parameter add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, @@ -276,6 +276,7 @@ def __init__( target_kl (float): The target KL value. Default: ``0.1``. kl_estimator (str): The KL estimator to use. Default: ``'k3'``. kl_clip_range (float): The KL clip range. Default: ``40.0``. + beta (float): pi_ref KL hyperparameter for APO. Default: ``1e-3`` """ super().__init__(**kwargs) self.policy_kl = [] diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 9f11786c..066c306a 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -14,14 +14,16 @@ class OnPolicyEnum(Enum): PPO = 'ppo' GRPO = 'grpo' - APO = 'apo' #add A-star PO + APO = 'apo' #add A-star PO class ALGORITHM_TYPE(set, Enum): CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO} ACTOR_CRITIC = {OnPolicyEnum.PPO} CLIPPED_PG = {OnPolicyEnum.PPO, OnPolicyEnum.GRPO} - REGRESSION = {OnPolicyEnum.APO} #regression based loss, maybe REBEL also here? + REGRESSION = { + OnPolicyEnum.APO, + } #regression based loss, maybe REBEL also here? @dataclass @@ -156,7 +158,9 @@ def critic_loss( ) -> MutableMapping: if loss_type == OnPolicyEnum.PPO: advantages = batch['advantages'] - v_preds = outputs['values'][:, :-1] * batch['action_mask'] #TODO: why shift -1? eos token? why no shift in log_probs? + v_preds = outputs['values'][:, :-1] * batch[ + 'action_mask' + ] #TODO: why shift -1? eos token? why no shift in log_probs? v_preds = v_preds.to(advantages.dtype) values = batch['values'][:, :-1] * batch['action_mask'] @@ -218,10 +222,10 @@ def critic_loss( def policy_loss( advantages: torch.Tensor, - outputs: MutableMapping, #outputs contain things that are from forward pass with gradient graph turned on - batch: MutableMapping, #batch contains data w/o gradient, e.g., it can contain old_log_probs ; todo: check if it contains ref_log_probs + outputs: MutableMapping, + batch: MutableMapping, loss_type: OnPolicyEnum, - beta: float = 1e-3, #beta used in rebel and APO (assume it is beta *( log pi - log pi_ref ), so smaller beta means less KL to pi_ref) + beta: float = 1e-3, policy_clip_ratio: float = 0.15, policy_clip_high_ratio: float | None = None, length_normalize_policy_loss: bool = True, @@ -240,7 +244,7 @@ def policy_loss( kl_clip_range=kl_clip_range, ) online_ift_kl_dict = utils.approx_kl( - log_p=batch['ift_log_probs'], + log_p=batch['ift_log_probs'], log_q=outputs['online_log_probs'], kl_clip_range=kl_clip_range, ) @@ -265,7 +269,7 @@ def policy_loss( batch['action_mask'], ) - ratio = torch.exp(online_log_probs - old_log_probs) #pi/pi_old + ratio = torch.exp(online_log_probs - old_log_probs) #pi/pi_old policy_loss_1 = -advantages * ratio # Use the same clip ratio for both sides if clip high ratio is not provided @@ -346,11 +350,11 @@ def policy_loss( utils.sample_wise_masked_mean(advantages, batch['action_mask']), } return policy_dict - + elif loss_type in ALGORITHM_TYPE.REGRESSION: #assume batch contains (1) V-star values (key 'vstar), (2) rewards (key 'rewards'), (3) ref_log_probs online_log_probs = outputs['online_log_probs'] - ref_log_probs = batch['ift_log_probs'] + ref_log_probs = batch['ift_log_probs'] log_probs_diff = online_log_probs - ref_log_probs old_entropies = batch['old_entropies'] @@ -361,30 +365,42 @@ def policy_loss( kl_clip_range=kl_clip_range, ) with torch.no_grad(): - policy_kl = utils.masked_mean(policy_kl_dict[kl_estimator],batch['action_mask']) #plain average over all tokens (KL to pi_ref) + policy_kl = utils.masked_mean( + policy_kl_dict[kl_estimator], + batch['action_mask'], + ) #plain average over all tokens (KL to pi_ref) #compute the policy class - maksed_log_probs_diff = utils.masked_sum(log_probs_diff, batch['action_mask'], dim = -1) #size: (batch_size,) + maksed_log_probs_diff = utils.masked_sum( + log_probs_diff, + batch['action_mask'], + dim=-1, + ) #size: (batch_size,) vstars = batch['vstar'] - rewards = utils.masked_sum(batch['rewards'], batch['action_mask'], dim = -1) - assert vstars.size() == rewards.size() == maksed_log_probs_diff.size() # should have the same shape which is (batch_size, ) - - policy_loss = ((beta*maksed_log_probs_diff - (rewards - vstars))**2).mean() + rewards = utils.masked_sum( + batch['rewards'], + batch['action_mask'], + dim=-1, + ) + assert vstars.size() == rewards.size() == maksed_log_probs_diff.size( + ) # should have the same shape which is (batch_size, ) + + policy_loss = ((beta * maksed_log_probs_diff - + (rewards - vstars))**2).mean() policy_dict = { - 'loss/policy_loss': - policy_loss, - 'kl/policy_kl': #TODO: add more KLs - policy_kl, - 'gen/gen_length': - batch['action_mask'].sum(dim=1).to(torch.float32), - 'gen/entropy': - old_entropies, - 'rewards/mean': - torch.mean(rewards), #compute the average reward of the current batch - 'vstars/mean': - torch.mean(vstars), #compute the average of the vstar of the current batch + 'loss/policy_loss': policy_loss, + 'kl/policy_kl': #TODO: add more KLs + policy_kl, + 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), + 'gen/entropy': old_entropies, + 'rewards/mean': torch.mean( + rewards, + ), #compute the average reward of the current batch + 'vstars/mean': torch.mean( + vstars, + ), #compute the average of the vstar of the current batch } - return policy_dict + return policy_dict else: raise ValueError(f'Policy loss not implemented for {loss_type}') @@ -397,7 +413,7 @@ def online_rl_loss( value_clip_range: float = 0.2, value_loss_weight: float = 0.2, policy_clip_ratio: float = 0.15, - beta: float = 1e-3, #added a beta parameter here for APO and REBEL + beta: float = 1e-3, policy_clip_high_ratio: float | None = None, length_normalize_policy_loss: bool = True, add_direct_kl_loss: bool = False, @@ -418,6 +434,7 @@ def online_rl_loss( add_direct_kl_loss (bool): Whether to add the KL loss directly to the loss. Default: ``False``. kl_estimator (str): The KL estimator to use. Default: ``'k1'``. kl_clip_range (float): The clip range for the KL divergence. Default: ``40.0``. + beta (float): pi_ref KL hyperparameter for APO. Default: ``1e-3``. """ # log_probs: [bs, gen_len] log probability of each action # action_mask: [bs, gen_len] action mask @@ -471,7 +488,7 @@ def online_rl_loss( outputs=outputs, batch=batch, loss_type=loss_type, - beta = beta, + beta=beta, policy_clip_ratio=policy_clip_ratio, policy_clip_high_ratio=policy_clip_high_ratio, length_normalize_policy_loss=length_normalize_policy_loss, diff --git a/compose_rl/data/prompt_data.py b/compose_rl/data/prompt_data.py index 686a0ddb..22da7501 100644 --- a/compose_rl/data/prompt_data.py +++ b/compose_rl/data/prompt_data.py @@ -40,7 +40,6 @@ def prompt_dataset_collate_fn( mlm_probability=0.0, ) - keys = batch[0].keys() collated_batch: dict[str, torch.Tensor] = {} for key in keys: @@ -128,7 +127,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: _answer = '' item_dict['verified_answer'] = _answer # type: ignore - + #vstar vstar = sample.get('vstar', None) if vstar is not None: diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index c003ed2f..b7012816 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -3,8 +3,8 @@ import logging import re -from typing import Any import signal +from typing import Any import sympy from sympy.parsing.latex import parse_latex @@ -72,7 +72,8 @@ def remove_boxed(s: str) -> str: class timeout: - def __init__(self, seconds=1, error_message="Timeout"): + + def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds self.error_message = error_message @@ -88,10 +89,9 @@ def __exit__(self, type, value, traceback): def is_equiv(x1: str, x2: str) -> bool: - """Checks mathematical equivalence between two normalized LaTeX strings.""" try: - with timeout(seconds = 5): + with timeout(seconds=5): try: parsed_x1 = parse_latex(x1) parsed_x2 = parse_latex(x2) @@ -117,7 +117,7 @@ def is_equiv(x1: str, x2: str) -> bool: f'Had some trouble simplifying when comparing {x1} and {x2}', ) return False - + except TimeoutError: log.debug(f"Timed out comparing {x1} and {x2}") return False From 4b514378a396b4aa68c8c3e5b2094e67e8d88b06 Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Mon, 23 Jun 2025 15:08:35 -0400 Subject: [PATCH 012/195] Update compose_rl/algorithms/online/callback.py Co-authored-by: bcui-db <141345999+bcui-db@users.noreply.github.com> --- compose_rl/algorithms/online/callback.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index b9a706e7..64db8b65 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -683,9 +683,7 @@ def _get_next_iter_prompts(self): if isinstance(curr_values[0], torch.Tensor): ret_batch[key] = torch.cat(curr_values) else: - if key == 'verified_answer': - ret_batch[key] = list(flatten(curr_values)) - elif key == 'vstar': + if key in ['verified_answer', 'vstar']: ret_batch[key] = list(flatten(curr_values)) else: # this is an edge case that we will not hit currently, but just handling it as needed From d4ffc5981956644a74cec18d9d4db1767974dbf4 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 23 Jun 2025 15:09:33 -0400 Subject: [PATCH 013/195] clean comments --- compose_rl/algorithms/online/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 4196bac8..ed1055c9 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -218,7 +218,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, - beta = self.config.beta, #added beta parameter + beta = self.config.beta, add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, @@ -257,7 +257,7 @@ def __init__( length_normalize_policy_loss: bool = True, policy_clip_ratio: float = 0.15, policy_clip_high_ratio: float | None = None, - beta: float = 1e-3, #added beta + beta: float = 1e-3, compute_kl_loss: bool = True, target_kl: float = 0.1, kl_estimator: str = 'k3', @@ -311,7 +311,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): loss_type=self.loss_type, policy_clip_ratio=self.policy_clip_ratio, policy_clip_high_ratio=self.policy_clip_high_ratio, - beta = self.beta, #added beta + beta = self.beta, length_normalize_policy_loss=self.length_normalize_policy_loss, add_direct_kl_loss=self.compute_kl_loss, kl_estimator=self.kl_estimator, From 1fd7e2d056945bde33737f7aacdad52aa59c35dd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 23 Jun 2025 15:10:11 -0400 Subject: [PATCH 014/195] comment cleanup --- compose_rl/algorithms/online/model_methods.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 066c306a..8eedca50 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -92,7 +92,6 @@ def prepare_critic_values_for_training( ) values *= action_mask - #TODO: why zero padding? if zero_pad: zero_pad_tensor = torch.zeros((bs, 1), device=values.device, From c769975b953c504648202aa05c1c38763759e5d7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 23 Jun 2025 15:18:56 -0400 Subject: [PATCH 015/195] cleanup of callback --- compose_rl/algorithms/online/callback.py | 59 ++++++++---------------- compose_rl/algorithms/online/model.py | 2 +- compose_rl/utils/rlvr_utils.py | 2 +- 3 files changed, 20 insertions(+), 43 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 64db8b65..198b235b 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -958,54 +958,31 @@ def _resolve_outputs( env_outs['action_mask'], ) - mean_ift = masked_mean( - env_outs['ift_kl'], - env_outs['action_mask'], - ) - self.kl_ift.append(mean_ift.cpu()) - - iter_batch.update(env_outs) - iter_batch.update({ - 'max_gen_len': - torch.ones(self.iter_batch_size).to(torch.int32) * - self.max_gen_len, 'adv_masked_mean': torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(), 'adv_masked_var': torch.ones(self.iter_batch_size) * batch_adv_var.cpu(), - 'ift_kl_scalar': - torch.ones(self.iter_batch_size) * self.kl_ctl.value, - 'reward_std': - torch.ones(self.iter_batch_size) * - env_outs['rewards'].std().to('cpu'), - }) - else: - # APO and REBEL - - mean_ift = masked_mean( - env_outs['ift_kl'], - env_outs['action_mask'], - ) - self.kl_ift.append(mean_ift.cpu()) - - iter_batch.update(env_outs) - - iter_batch.update({ - 'max_gen_len': - torch.ones(self.iter_batch_size).to(torch.int32) * - self.max_gen_len, - 'adv_masked_mean': - torch.ones(self.iter_batch_size), - 'adv_masked_var': - torch.ones(self.iter_batch_size), - 'ift_kl_scalar': - torch.ones(self.iter_batch_size) * self.kl_ctl.value, - 'reward_std': - torch.ones(self.iter_batch_size) * - env_outs['rewards'].std().to('cpu'), }) + mean_ift = masked_mean( + env_outs['ift_kl'], + env_outs['action_mask'], + ) + self.kl_ift.append(mean_ift.cpu()) + + iter_batch.update(env_outs) + + iter_batch.update({ + 'max_gen_len': + torch.ones(self.iter_batch_size).to(torch.int32) * + self.max_gen_len, + 'ift_kl_scalar': + torch.ones(self.iter_batch_size) * self.kl_ctl.value, + 'reward_std': + torch.ones(self.iter_batch_size) * + env_outs['rewards'].std().to('cpu'), + }) # Moving minibatches to CPU to not take additional GPU memory for k, v in iter_batch.items(): if hasattr(v, 'cpu'): diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index ed1055c9..ea0ea534 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -311,7 +311,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): loss_type=self.loss_type, policy_clip_ratio=self.policy_clip_ratio, policy_clip_high_ratio=self.policy_clip_high_ratio, - beta = self.beta, + beta=self.beta, length_normalize_policy_loss=self.length_normalize_policy_loss, add_direct_kl_loss=self.compute_kl_loss, kl_estimator=self.kl_estimator, diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index b7012816..ce725f91 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -119,7 +119,7 @@ def is_equiv(x1: str, x2: str) -> bool: return False except TimeoutError: - log.debug(f"Timed out comparing {x1} and {x2}") + log.debug(f'Timed out comparing {x1} and {x2}') return False except ImportError as e: log.error(e) From 0d275ec5773011f8ec0defb463b4a84a2b6d1be7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 23 Jun 2025 15:22:15 -0400 Subject: [PATCH 016/195] undo rlvr update --- compose_rl/utils/rlvr_utils.py | 73 ++++++++++++---------------------- 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index ce725f91..4d4c888b 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -3,7 +3,6 @@ import logging import re -import signal from typing import Any import sympy @@ -71,56 +70,34 @@ def remove_boxed(s: str) -> str: return s.strip('{}') -class timeout: - - def __init__(self, seconds=1, error_message='Timeout'): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - def is_equiv(x1: str, x2: str) -> bool: """Checks mathematical equivalence between two normalized LaTeX strings.""" try: - with timeout(seconds=5): - try: - parsed_x1 = parse_latex(x1) - parsed_x2 = parse_latex(x2) - except ( - sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] - errors.LaTeXParsingError, - sympy.SympifyError, - TypeError, - ): - log.debug(f"couldn't parse one of {x1} or {x2}") - return False - - try: - diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] - except TypeError: - log.debug(f"couldn't subtract {x1} and {x2}") - return False - - try: - return sympy.simplify(diff) == 0 - except ValueError: - log.debug( - f'Had some trouble simplifying when comparing {x1} and {x2}', - ) - return False - - except TimeoutError: - log.debug(f'Timed out comparing {x1} and {x2}') - return False + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] + errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + log.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] + except TypeError: + log.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + return sympy.simplify(diff) == 0 + except ValueError: + log.debug( + f'Had some trouble simplifying when comparing {x1} and {x2}', + ) + return False except ImportError as e: log.error(e) raise From 202b0a2dfd773e2d2221db0b0e38591ad77e82bd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 23 Jun 2025 15:33:35 -0400 Subject: [PATCH 017/195] fix --- compose_rl/algorithms/online/model_methods.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 8eedca50..801728c8 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -220,7 +220,7 @@ def critic_loss( def policy_loss( - advantages: torch.Tensor, + advantages: torch.Tensor | None, outputs: MutableMapping, batch: MutableMapping, loss_type: OnPolicyEnum, @@ -233,6 +233,7 @@ def policy_loss( ) -> MutableMapping: if loss_type in ALGORITHM_TYPE.CLIPPED_PG: + assert advantages is not None online_log_probs, old_log_probs = outputs['online_log_probs'], batch[ 'old_log_probs'] old_entropies = batch['old_entropies'] @@ -365,7 +366,7 @@ def policy_loss( ) with torch.no_grad(): policy_kl = utils.masked_mean( - policy_kl_dict[kl_estimator], + policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) @@ -388,8 +389,7 @@ def policy_loss( (rewards - vstars))**2).mean() policy_dict = { 'loss/policy_loss': policy_loss, - 'kl/policy_kl': #TODO: add more KLs - policy_kl, + 'kl/policy_kl': policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), 'gen/entropy': old_entropies, 'rewards/mean': torch.mean( From d6d81ca851d2c4bb89b37be2c6e90e1fdfcbff86 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 16:07:32 -0400 Subject: [PATCH 018/195] apo offline initial implementaiton --- .../algorithms/offline/model_methods.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 32aa3a01..14407f83 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -33,6 +33,7 @@ class PairwiseOfflineEnum(Enum): REBEL = 'rebel' IPO = 'ipo' KTO = 'kto' + APO = 'apo' # Not a pair-wise preference algorithm def pairwise_offline_forward( @@ -188,8 +189,8 @@ def pairwise_offline_loss( sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to be added to DPO type loss. """ - policy_chosen_logp = outputs['policy_chosen_logp'] - policy_rejected_logp = outputs['policy_rejected_logp'] + policy_chosen_logp = outputs['policy_chosen_logp'] # (batch_size, ) + policy_rejected_logp = outputs['policy_rejected_logp'] # (batch_size, ) ref_chosen_logp = batch.get( 'ref_chosen', torch.zeros_like(policy_chosen_logp), @@ -210,6 +211,23 @@ def pairwise_offline_loss( -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing ) + elif loss_type == PairwiseOfflineEnum.APO: + # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 + # APO is not a pair-wise loss function. + # We assume the dataset contains two responses per prompt. + # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. + # The chosen and reject do not mean anything in APO + # Similar to REBEL, we assume each response has a reward in the batch. + # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch + vstars = batch['vstar'] # (batch_size, ) + loss_1 = ( + beta*(policy_chosen_logp - ref_chosen_logp) - (outputs['chosen_reward'] - vstars) + )**2 + loss_2 = ( + beta*(policy_rejected_logp - ref_rejected_logp) - (outputs['rejected_reward'] -vstars ) + )**2 + losses = (loss_1 + loss_2) / 2 + elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob chosen_reward = outputs['chosen_reward'] From dd04bffcb2842b21504551ac0f257a0d3b510071 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 16:07:41 -0400 Subject: [PATCH 019/195] apo offline initial implementaiton --- .../algorithms/offline/model_methods.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 14407f83..8de77d53 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -212,22 +212,24 @@ def pairwise_offline_loss( F.logsigmoid(-beta * logits) * label_smoothing ) elif loss_type == PairwiseOfflineEnum.APO: - # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 - # APO is not a pair-wise loss function. - # We assume the dataset contains two responses per prompt. - # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. + # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 + # APO is not a pair-wise loss function. + # We assume the dataset contains two responses per prompt. + # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. # The chosen and reject do not mean anything in APO - # Similar to REBEL, we assume each response has a reward in the batch. + # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - vstars = batch['vstar'] # (batch_size, ) + vstars = batch['vstar'] # (batch_size, ) loss_1 = ( - beta*(policy_chosen_logp - ref_chosen_logp) - (outputs['chosen_reward'] - vstars) + beta * (policy_chosen_logp - ref_chosen_logp) - + (outputs['chosen_reward'] - vstars) )**2 loss_2 = ( - beta*(policy_rejected_logp - ref_rejected_logp) - (outputs['rejected_reward'] -vstars ) + beta * (policy_rejected_logp - ref_rejected_logp) - + (outputs['rejected_reward'] - vstars) )**2 losses = (loss_1 + loss_2) / 2 - + elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob chosen_reward = outputs['chosen_reward'] From 01d28c78a942dab9a15934c86c40c1ac673f51f2 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 20:35:37 -0400 Subject: [PATCH 020/195] updates --- compose_rl/algorithms/offline/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 8de77d53..fdcaa7c5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -319,6 +319,7 @@ def pairwise_offline_loss( PairwiseOfflineEnum.RPO, PairwiseOfflineEnum.RCDPO, PairwiseOfflineEnum.REBEL, + PairwiseOfflineEnum.APO, ]: # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore From 1c62fdb30c0e92ea132b68c927de2dd992c306ad Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 20:52:42 -0400 Subject: [PATCH 021/195] add vstar entries --- compose_rl/data/preference_data.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..942cf0a5 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -63,6 +63,8 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] + vstars = sample['vstar'] + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -141,6 +143,7 @@ def pairwise_preference_dataset_collate_fn( 'input_ids': input_ids, 'attention_mask': attention_masks, 'sequence_id': sequence_id, + 'vstar': vstars } if len(chosen_rewards) > 0: chosen_rewards = torch.stack(chosen_rewards) @@ -263,6 +266,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward + + if 'vstar' in sample: + return_dict['vstar'] = torch.Tensor([sample['vstar']]) + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): From a5fa39161a70c2c1cf51354c03794f75b5efa445 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 21:00:27 -0400 Subject: [PATCH 022/195] . --- compose_rl/data/preference_data.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 942cf0a5..b11c9ab4 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -54,6 +54,7 @@ def pairwise_preference_dataset_collate_fn( prompt_lens = [] sequence_id = [] chosen_rewards = [] + vstars = [] rejected_rewards = [] for sample in data: @@ -128,6 +129,8 @@ def pairwise_preference_dataset_collate_fn( if 'chosen_reward' in sample: chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) + if 'vstar' in sample: + vstars.append(sample['vstar']) input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) @@ -143,13 +146,16 @@ def pairwise_preference_dataset_collate_fn( 'input_ids': input_ids, 'attention_mask': attention_masks, 'sequence_id': sequence_id, - 'vstar': vstars } if len(chosen_rewards) > 0: chosen_rewards = torch.stack(chosen_rewards) rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards + if len(vstars) > 0: + vstars = torch.stack(vstars) + return_dict['vstar'] = vstars + return return_dict @@ -266,10 +272,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward - - if 'vstar' in sample: + + if 'vstar' in sample: return_dict['vstar'] = torch.Tensor([sample['vstar']]) - + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): From 725e3a295d0548def4b99f1ce75243bbd475a5b0 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 21:04:17 -0400 Subject: [PATCH 023/195] . --- compose_rl/data/preference_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index b11c9ab4..1af2f983 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -64,8 +64,6 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] - vstars = sample['vstar'] - # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 From a0b0209ce9d579342a5dc753052b58e3a894c8d6 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 21:08:30 -0400 Subject: [PATCH 024/195] . --- compose_rl/algorithms/offline/model_methods.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index fdcaa7c5..8de77d53 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -319,7 +319,6 @@ def pairwise_offline_loss( PairwiseOfflineEnum.RPO, PairwiseOfflineEnum.RCDPO, PairwiseOfflineEnum.REBEL, - PairwiseOfflineEnum.APO, ]: # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore From 39377aef5284adcc4fbf2980aa7d6e5a59b89dd1 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 29 Jun 2025 21:10:32 -0400 Subject: [PATCH 025/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 8de77d53..a19bb378 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -172,7 +172,7 @@ def pairwise_offline_loss( loss_type: PairwiseOfflineEnum, beta: float, label_smoothing: float, - sft_alpha: float, + sft_alpha: float = 0., ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. From ce6d18aa8a5a46fc0eebbfe11c9e6fdce273df15 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 30 Jun 2025 00:19:55 -0400 Subject: [PATCH 026/195] added ope estimator in apo --- compose_rl/algorithms/offline/model_methods.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index a19bb378..71e512d7 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -172,7 +172,7 @@ def pairwise_offline_loss( loss_type: PairwiseOfflineEnum, beta: float, label_smoothing: float, - sft_alpha: float = 0., + sft_alpha: float = 0.0, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -230,6 +230,14 @@ def pairwise_offline_loss( )**2 losses = (loss_1 + loss_2) / 2 + # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) + # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref + # use clip to ensure the output from exp is valid + with torch.no_grad(): + estimated_rewards = torch.exp(torch.clip(policy_chosen_logp - ref_chosen_logp, max = 10.)) * outputs['chosen_reward'] + estimated_rewards += torch.exp(torch.clip(policy_rejected_logp - ref_rejected_logp, max = 10.)) * outputs['rejected_reward'] + estimated_reward = torch.mean(estimated_rewards) + elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob chosen_reward = outputs['chosen_reward'] @@ -322,6 +330,11 @@ def pairwise_offline_loss( ]: # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore + if loss_type == PairwiseOfflineEnum.APO: + forward_kl = ((ref_chosen_logp - policy_chosen_logp) + (ref_rejected_logp - policy_rejected_logp)).detach() + loss_dict['forward_kl'] = forward_kl + loss_dict['estimated_reward'] = estimated_reward + if sft_alpha > 0: # sft_losses_normalized is always defined if sft_alpha>0 snl = sft_losses_normalized.detach() # type: ignore From 4f0456e3b80f1c401f141a9412c82de3a2d15432 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 30 Jun 2025 09:33:01 -0400 Subject: [PATCH 027/195] . --- compose_rl/algorithms/offline/model_methods.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 71e512d7..45900bae 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -228,15 +228,15 @@ def pairwise_offline_loss( beta * (policy_rejected_logp - ref_rejected_logp) - (outputs['rejected_reward'] - vstars) )**2 - losses = (loss_1 + loss_2) / 2 + losses = (loss_1 + loss_2) / 2. # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref # use clip to ensure the output from exp is valid with torch.no_grad(): - estimated_rewards = torch.exp(torch.clip(policy_chosen_logp - ref_chosen_logp, max = 10.)) * outputs['chosen_reward'] - estimated_rewards += torch.exp(torch.clip(policy_rejected_logp - ref_rejected_logp, max = 10.)) * outputs['rejected_reward'] - estimated_reward = torch.mean(estimated_rewards) + estimated_rewards = torch.exp(torch.clip(policy_chosen_logp - ref_chosen_logp, max = 5.)) * outputs['chosen_reward'] + estimated_rewards += torch.exp(torch.clip(policy_rejected_logp - ref_rejected_logp, max = 5.)) * outputs['rejected_reward'] + estimated_reward = torch.mean(estimated_rewards)/2. elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob @@ -332,7 +332,7 @@ def pairwise_offline_loss( loss_dict['reward_diff'] = reward_diff.detach() # type: ignore if loss_type == PairwiseOfflineEnum.APO: forward_kl = ((ref_chosen_logp - policy_chosen_logp) + (ref_rejected_logp - policy_rejected_logp)).detach() - loss_dict['forward_kl'] = forward_kl + loss_dict['forward_kl'] = forward_kl/2. loss_dict['estimated_reward'] = estimated_reward if sft_alpha > 0: From 06ceb475bbf6cc594469be42a5e67ca0376618fe Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 30 Jun 2025 14:07:54 -0400 Subject: [PATCH 028/195] . --- compose_rl/algorithms/offline/model_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 45900bae..193d8ee9 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -154,6 +154,9 @@ def pairwise_offline_forward( if 'chosen_reward' in batch: outputs['chosen_reward'] = batch['chosen_reward'] outputs['rejected_reward'] = batch['rejected_reward'] + + if 'vstar' in batch: + outputs['vstar'] = batch['vstar'] if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( @@ -219,7 +222,7 @@ def pairwise_offline_loss( # The chosen and reject do not mean anything in APO # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - vstars = batch['vstar'] # (batch_size, ) + vstars = outputs['vstar'] # (batch_size, ) loss_1 = ( beta * (policy_chosen_logp - ref_chosen_logp) - (outputs['chosen_reward'] - vstars) From ea680a575e0d6e2980dbde63a61f92403d26caf5 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 11:47:41 -0400 Subject: [PATCH 029/195] add multimodal support for models --- .../algorithms/offline/model_methods.py | 42 ++++++++++++++----- compose_rl/algorithms/online/model_methods.py | 4 ++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 32aa3a01..d934f675 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -65,6 +65,10 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') + is_multimodal = "pixel_values" in batch.keys() + if is_multimodal and use_attention_sequence_id: + raise NotImplementedError("Using Sequence ID is not implemented for VLMs") + # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` if use_attention_sequence_id: @@ -102,18 +106,36 @@ def pairwise_offline_forward( pad_token_id=0, ) - batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0) - batch_attn_mask = torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ) + inputs = { + "input_ids": torch.cat([chosen_inputs, rejected_inputs], dim=0), + "attention_mask": torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), + } + + if is_multimodal: + chosen_token_type_ids, rejected_token_type_ids = extract_packed_chosen_rejected( + batch['token_type_ids'], + batch['chosen_len'], + batch['rejected_len'], + concat_seq_len, + pad_token_id=0, + ) + + # TODO: Ask if assuming same pixel inputs is ok? + multimodal_inputs = { + "token_type_ids": torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), + "pixel_values": torch.cat([batch['pixel_values'], batch['pixel_values']], dim=0), + } + + inputs.update(multimodal_inputs) output_logits = model( - batch_cat_inputs, - attention_mask=batch_attn_mask, + **inputs ).logits # Extract out the chosen and rejected logits along the batch dimension diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index bdcec748..76072d1d 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -118,6 +118,10 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] + if "pixel_values" in batch.keys(): + model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] + model_forward_kwargs['pixel_values'] = batch['pixel_values'] + actor_output = model(batch['obs'], **model_forward_kwargs) logits = actor_output.logits From a7e86cc89020c520ebfc981ef138655d7c56ad9c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 11:48:42 -0400 Subject: [PATCH 030/195] linting --- .../algorithms/offline/model_methods.py | 34 ++++++++++++------- compose_rl/algorithms/online/model_methods.py | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index d934f675..95bc1567 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -65,9 +65,11 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') - is_multimodal = "pixel_values" in batch.keys() + is_multimodal = 'pixel_values' in batch.keys() if is_multimodal and use_attention_sequence_id: - raise NotImplementedError("Using Sequence ID is not implemented for VLMs") + raise NotImplementedError( + 'Using Sequence ID is not implemented for VLMs', + ) # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` @@ -107,14 +109,16 @@ def pairwise_offline_forward( ) inputs = { - "input_ids": torch.cat([chosen_inputs, rejected_inputs], dim=0), - "attention_mask": torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ), + 'input_ids': + torch.cat([chosen_inputs, rejected_inputs], dim=0), + 'attention_mask': + torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), } if is_multimodal: @@ -128,14 +132,18 @@ def pairwise_offline_forward( # TODO: Ask if assuming same pixel inputs is ok? multimodal_inputs = { - "token_type_ids": torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - "pixel_values": torch.cat([batch['pixel_values'], batch['pixel_values']], dim=0), + 'token_type_ids': + torch.cat([chosen_token_type_ids, rejected_token_type_ids], + dim=0), + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), } inputs.update(multimodal_inputs) output_logits = model( - **inputs + **inputs, ).logits # Extract out the chosen and rejected logits along the batch dimension diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 76072d1d..f45abb62 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -118,7 +118,7 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] - if "pixel_values" in batch.keys(): + if 'pixel_values' in batch.keys(): model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] model_forward_kwargs['pixel_values'] = batch['pixel_values'] From 08952f7d978a4f41f9c7cb7e0f80e9d1c3f6f2f3 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 12:15:52 -0400 Subject: [PATCH 031/195] multimodal handling for gemma3 --- compose_rl/data/preference_data.py | 51 ++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..6a4fa50b 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -56,6 +56,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards = [] rejected_rewards = [] + # For VLMs + token_type_ids = [] + pixel_values = [] + for sample in data: chosen = sample['chosen'] rejected = sample['rejected'] @@ -63,6 +67,17 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] + is_multimodal = "pixel_values" in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + chosen_token_type_ids = sample['chosen_token_type_ids'] + rejected_token_type_ids = sample['rejected_token_type_ids'] + else: + pixel_vals = None + chosen_token_type_ids = None + rejected_token_type_ids = None + cat_token_type_ids = None + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -75,6 +90,9 @@ def pairwise_preference_dataset_collate_fn( pad_len = max_seq_len - chosen_len - rejected_len cat_batch = torch.cat([chosen, rejected], dim=-1) + if is_multimodal: + cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + if pad_len < 0: # We should truncate chosen and rejected by the same amount truncate_len = abs(pad_len // 2) + 1 @@ -92,6 +110,15 @@ def pairwise_preference_dataset_collate_fn( rejected = rejected[:-truncate_len] rejected[-1] = tokenizer.eos_token_id # type: ignore + if is_multimodal: + chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] + rejected_token_type_ids = rejected_token_type_ids[:-truncate_len] + + # NOTE: GEMMA specific: 0 == text token + chosen_token_type_ids[-1] = 0 + rejected_token_type_ids[-1] = 0 + cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_batch = torch.cat([chosen, rejected], dim=-1) chosen_len = torch.tensor([len(chosen)]) @@ -108,6 +135,11 @@ def pairwise_preference_dataset_collate_fn( ], dim=-1, # type: ignore ) + if is_multimodal: + cat_token_type_ids = torch.cat([ + cat_token_type_ids, + torch.zeros(int(pad_len.item()), dtype=cat_token_type_ids.dtype), + ], dim=-1) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -127,6 +159,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) + if is_multimodal: + token_type_ids.append(cat_token_type_ids) + pixel_values.append(pixel_vals) + input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) sequence_id = torch.stack(sequence_id) @@ -147,6 +183,11 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards + + if is_multimodal: + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + return return_dict @@ -263,6 +304,16 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward + + if 'pixel_values' in sample: + pixel_values = self._read_binary_tokenized_sample(sample['pixel_values'], 'pixel_values') + chosen_token_type_ids = self._read_binary_tokenized_sample(sample['chosen_token_type_ids'], 'chosen_token_type_ids') + rejected_token_type_ids = self._read_binary_tokenized_sample((sample['rejected_token_type_ids']), 'rejected_token_type_ids') + + return_dict['pixel_values'] = pixel_values + return_dict['chosen_token_type_ids'] = chosen_token_type_ids + return_dict['rejected_token_type_ids'] = rejected_token_type_ids + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): From bcc273b625be2f0cc2c1ce2793513700f4914c89 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 12:16:33 -0400 Subject: [PATCH 032/195] lint --- compose_rl/data/preference_data.py | 42 +++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 6a4fa50b..990df1fb 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -67,7 +67,7 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] - is_multimodal = "pixel_values" in sample.keys() + is_multimodal = 'pixel_values' in sample.keys() if is_multimodal: pixel_vals = sample['pixel_values'] chosen_token_type_ids = sample['chosen_token_type_ids'] @@ -91,7 +91,11 @@ def pairwise_preference_dataset_collate_fn( cat_batch = torch.cat([chosen, rejected], dim=-1) if is_multimodal: - cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) if pad_len < 0: # We should truncate chosen and rejected by the same amount @@ -112,12 +116,17 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] - rejected_token_type_ids = rejected_token_type_ids[:-truncate_len] + rejected_token_type_ids = rejected_token_type_ids[:-truncate_len + ] # NOTE: GEMMA specific: 0 == text token chosen_token_type_ids[-1] = 0 rejected_token_type_ids[-1] = 0 - cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) cat_batch = torch.cat([chosen, rejected], dim=-1) @@ -138,8 +147,12 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: cat_token_type_ids = torch.cat([ cat_token_type_ids, - torch.zeros(int(pad_len.item()), dtype=cat_token_type_ids.dtype), - ], dim=-1) + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, + ), + ], + dim=-1) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -184,7 +197,7 @@ def pairwise_preference_dataset_collate_fn( return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards - if is_multimodal: + if is_multimodal: return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -306,9 +319,18 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['rejected_reward'] = rejected_reward if 'pixel_values' in sample: - pixel_values = self._read_binary_tokenized_sample(sample['pixel_values'], 'pixel_values') - chosen_token_type_ids = self._read_binary_tokenized_sample(sample['chosen_token_type_ids'], 'chosen_token_type_ids') - rejected_token_type_ids = self._read_binary_tokenized_sample((sample['rejected_token_type_ids']), 'rejected_token_type_ids') + pixel_values = self._read_binary_tokenized_sample( + sample['pixel_values'], + 'pixel_values', + ) + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample['chosen_token_type_ids'], + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + (sample['rejected_token_type_ids']), + 'rejected_token_type_ids', + ) return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids From 949a8bdf714cce59ed60dfa4b0da86335b2c0563 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 2 Jul 2025 11:37:05 -0400 Subject: [PATCH 033/195] bce loss --- .../algorithms/offline/model_methods.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 193d8ee9..ee34efa9 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -175,7 +175,8 @@ def pairwise_offline_loss( loss_type: PairwiseOfflineEnum, beta: float, label_smoothing: float, - sft_alpha: float = 0.0, + sft_alpha: float = 0.0, + bce: bool = False, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -209,6 +210,7 @@ def pairwise_offline_loss( logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l} losses = torch.zeros_like(logits) + if loss_type == PairwiseOfflineEnum.DPO: losses = ( -F.logsigmoid(beta * logits) * (1 - label_smoothing) - @@ -223,15 +225,25 @@ def pairwise_offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch vstars = outputs['vstar'] # (batch_size, ) - loss_1 = ( - beta * (policy_chosen_logp - ref_chosen_logp) - - (outputs['chosen_reward'] - vstars) - )**2 - loss_2 = ( - beta * (policy_rejected_logp - ref_rejected_logp) - - (outputs['rejected_reward'] - vstars) - )**2 - losses = (loss_1 + loss_2) / 2. + + if bce == False: + loss_1 = ( + beta * (policy_chosen_logp - ref_chosen_logp) - + (outputs['chosen_reward'] - vstars) + )**2 + loss_2 = ( + beta * (policy_rejected_logp - ref_rejected_logp) - + (outputs['rejected_reward'] - vstars) + )**2 + losses = (loss_1 + loss_2) / 2. + else: + normalized_adv_chosen = torch.sigmoid(outputs['chosen_reward'] - vstars) # put it into [0,1] + normalized_adv_reject = torch.sigmoid(outputs['rejected_reward'] - vstars) # put it into [0,1] + prob_chosen = torch.sigmoid(beta*(policy_chosen_logp - ref_chosen_logp)) # turn prediction into prob + prob_reject = torch.sigmoid(beta*(policy_rejected_logp - ref_rejected_logp)) # turn prediction into prob + loss_1 = torch.log(prob_chosen) * normalized_adv_chosen + (1. - normalized_adv_chosen) * torch.log(1 - prob_chosen) + loss_1 = torch.log(prob_reject) * normalized_adv_reject + (1. - normalized_adv_reject) * torch.log(1 - prob_reject) + losses = -(loss_1 + loss_2) / 2. # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref From 0368483bdfbf0129b7d503f26493e0c93726d8c7 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 2 Jul 2025 16:15:46 -0400 Subject: [PATCH 034/195] bce --- compose_rl/algorithms/offline/model_methods.py | 4 +++- yamls/local_dpo.yaml | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index ee34efa9..4b4bd7e4 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -176,7 +176,7 @@ def pairwise_offline_loss( beta: float, label_smoothing: float, sft_alpha: float = 0.0, - bce: bool = False, + bce: bool = False, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -192,6 +192,8 @@ def pairwise_offline_loss( preferences as noisy (preferences are flipped with probability label_smoothing). sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to be added to DPO type loss. + bce (bool): loss type that is alternative to the squared loss. It is in APO, potentially can be + used for REBEL and IPO. """ policy_chosen_logp = outputs['policy_chosen_logp'] # (batch_size, ) policy_rejected_logp = outputs['policy_rejected_logp'] # (batch_size, ) diff --git a/yamls/local_dpo.yaml b/yamls/local_dpo.yaml index c142dde2..3c6aef27 100644 --- a/yamls/local_dpo.yaml +++ b/yamls/local_dpo.yaml @@ -7,11 +7,11 @@ model: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: meta-llama/Llama-3.1-8B-Instruct + pretrained_model_name_or_path: Qwen/Qwen2.5-7B #meta-llama/Llama-3.1-8B-Instruct loggers: mlflow: - experiment_name: brandon_dpo_test + experiment_name: wensun_dpo_test callbacks: offline_rl: {} @@ -39,7 +39,7 @@ scheduler: t_warmup: 0.1dur tokenizer: - name: meta-llama/Llama-3.1-8B-Instruct + name: Qwen/Qwen2.5-7B #meta-llama/Llama-3.1-8B-Instruct kwargs: model_max_length: ${max_seq_len} trust_remote_code: true @@ -59,7 +59,7 @@ fsdp_config: sharding_strategy: FULL_SHARD activation_cpu_offload: false -max_seq_len: 2048 +max_seq_len: 4096 #2048 save_folder: /tmp/dpo_model # TODO: update for a proper save path dist_timeout: 600 max_duration: 1ep From 634e1513036c4ee891fb42819d6f2e8b5f54e740 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 14:57:06 -0400 Subject: [PATCH 035/195] fix multimodal preference loading --- compose_rl/data/preference_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 990df1fb..9abe3e75 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -320,15 +320,15 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: pixel_values = self._read_binary_tokenized_sample( - sample['pixel_values'], + sample, 'pixel_values', ) chosen_token_type_ids = self._read_binary_tokenized_sample( - sample['chosen_token_type_ids'], + sample, 'chosen_token_type_ids', ) rejected_token_type_ids = self._read_binary_tokenized_sample( - (sample['rejected_token_type_ids']), + sample, 'rejected_token_type_ids', ) From d35ed90b19832961e5304d26eb3b52d7fffa3b92 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:01:19 -0400 Subject: [PATCH 036/195] fix collator --- compose_rl/data/preference_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 9abe3e75..d669f346 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -198,6 +198,8 @@ def pairwise_preference_dataset_collate_fn( return_dict['rejected_reward'] = rejected_rewards if is_multimodal: + token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From 8bea0e3c7d7d5fc8e9b66bb1b09529df7fc097e3 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:15:31 -0400 Subject: [PATCH 037/195] debug --- compose_rl/algorithms/offline/model_methods.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 95bc1567..4d7c56cc 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,6 +140,10 @@ def pairwise_offline_forward( dim=0), } + print("MULTIMODAL INPUTS") + for k, v in multimodal_inptus.itmes(): + print(f"{k}: {v.shape}") + inputs.update(multimodal_inputs) output_logits = model( From b305fded93838af269baee93ad9cc444497b4974 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:18:36 -0400 Subject: [PATCH 038/195] debug --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4d7c56cc..00b74f73 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -141,7 +141,7 @@ def pairwise_offline_forward( } print("MULTIMODAL INPUTS") - for k, v in multimodal_inptus.itmes(): + for k, v in multimodal_inputs.items(): print(f"{k}: {v.shape}") inputs.update(multimodal_inputs) From b631cccd4ed58cc2a029a8017157fb37a7c3dc7f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:29:34 -0400 Subject: [PATCH 039/195] debug --- compose_rl/data/preference_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index d669f346..b157e184 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -200,6 +200,8 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) + print('HIIIIIII') + print(pixel_values[0].shape) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -269,6 +271,10 @@ def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) + if key == 'pixel_values': + print('I AM INSIDE READ BINARY') + print(temp_sample.shape) + print(len(temp_sample)) if len(temp_sample) > self.max_seq_len: log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') self.num_truncated += 1 From fabb4a82f14a6cca39ef21ae368d85329124102b Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 16:23:05 -0400 Subject: [PATCH 040/195] change pixel values from being bytes to ndarray or pil --- compose_rl/data/preference_data.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index b157e184..d08c53dc 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -11,6 +11,9 @@ from streaming import StreamingDataset from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer +from PIL import Image +from torchvision import transforms + log = logging.getLogger(__name__) @@ -200,8 +203,6 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) - print('HIIIIIII') - print(pixel_values[0].shape) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -271,10 +272,6 @@ def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) - if key == 'pixel_values': - print('I AM INSIDE READ BINARY') - print(temp_sample.shape) - print(len(temp_sample)) if len(temp_sample) > self.max_seq_len: log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') self.num_truncated += 1 @@ -327,10 +324,17 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['rejected_reward'] = rejected_reward if 'pixel_values' in sample: - pixel_values = self._read_binary_tokenized_sample( - sample, - 'pixel_values', - ) + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.Tensor(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + chosen_token_type_ids = self._read_binary_tokenized_sample( sample, 'chosen_token_type_ids', From e2880b609b2bfb96cd5b32d2ce81a02ab01f22fc Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 09:37:30 -0400 Subject: [PATCH 041/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4b4bd7e4..0d0eecd4 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -36,6 +36,8 @@ class PairwiseOfflineEnum(Enum): APO = 'apo' # Not a pair-wise preference algorithm + + def pairwise_offline_forward( model: nn.Module, tokenizer: Tokenizer, From b040daae03a0e5ff4c36487fb4cb3f326328fccb Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 16:44:58 -0400 Subject: [PATCH 042/195] . --- compose_rl/algorithms/offline/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 44a51fff..faf9aa2c 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,6 +34,7 @@ def __init__( label_smoothing: float = 0, sft_alpha: float = 0.0, average_log_prob: bool = False, + bce: bool = False **kwargs: Any, ): self.loss_type = PairwiseOfflineEnum(loss_type) @@ -41,6 +42,8 @@ def __init__( self.label_smoothing = label_smoothing self.sft_alpha = sft_alpha self.average_log_prob = average_log_prob + + self.bce = bce super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward @@ -73,6 +76,7 @@ def loss(self, outputs: CausalLMOutputWithPast, self.beta, self.label_smoothing, self.sft_alpha, + self.bce ) From 62769ed625cc72f67fabbc4ad0c23027aa0e0712 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 16:46:09 -0400 Subject: [PATCH 043/195] . --- compose_rl/algorithms/offline/model_methods.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 0d0eecd4..267b3903 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -241,6 +241,9 @@ def pairwise_offline_loss( )**2 losses = (loss_1 + loss_2) / 2. else: + print("#####################") + print("using BCE loss") + print("######################") normalized_adv_chosen = torch.sigmoid(outputs['chosen_reward'] - vstars) # put it into [0,1] normalized_adv_reject = torch.sigmoid(outputs['rejected_reward'] - vstars) # put it into [0,1] prob_chosen = torch.sigmoid(beta*(policy_chosen_logp - ref_chosen_logp)) # turn prediction into prob From 649a1138f11e725e4788c12558981471e4744419 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 16:49:51 -0400 Subject: [PATCH 044/195] . --- compose_rl/algorithms/offline/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index faf9aa2c..3bb6468e 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,7 +34,7 @@ def __init__( label_smoothing: float = 0, sft_alpha: float = 0.0, average_log_prob: bool = False, - bce: bool = False + bce: bool = False, **kwargs: Any, ): self.loss_type = PairwiseOfflineEnum(loss_type) From d49d2a2f916b995daa24634e23f8cf701d292e82 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 17:01:29 -0400 Subject: [PATCH 045/195] . --- compose_rl/algorithms/offline/model.py | 4 ---- compose_rl/algorithms/offline/model_methods.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 3bb6468e..44a51fff 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,7 +34,6 @@ def __init__( label_smoothing: float = 0, sft_alpha: float = 0.0, average_log_prob: bool = False, - bce: bool = False, **kwargs: Any, ): self.loss_type = PairwiseOfflineEnum(loss_type) @@ -42,8 +41,6 @@ def __init__( self.label_smoothing = label_smoothing self.sft_alpha = sft_alpha self.average_log_prob = average_log_prob - - self.bce = bce super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward @@ -76,7 +73,6 @@ def loss(self, outputs: CausalLMOutputWithPast, self.beta, self.label_smoothing, self.sft_alpha, - self.bce ) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 267b3903..cd8011c5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -178,7 +178,7 @@ def pairwise_offline_loss( beta: float, label_smoothing: float, sft_alpha: float = 0.0, - bce: bool = False, + bce: bool = True, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. From e7be40509d23fc39703e498db7b6bd0c2762d8fb Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 17:15:31 -0400 Subject: [PATCH 046/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index cd8011c5..b622b1d7 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -249,7 +249,7 @@ def pairwise_offline_loss( prob_chosen = torch.sigmoid(beta*(policy_chosen_logp - ref_chosen_logp)) # turn prediction into prob prob_reject = torch.sigmoid(beta*(policy_rejected_logp - ref_rejected_logp)) # turn prediction into prob loss_1 = torch.log(prob_chosen) * normalized_adv_chosen + (1. - normalized_adv_chosen) * torch.log(1 - prob_chosen) - loss_1 = torch.log(prob_reject) * normalized_adv_reject + (1. - normalized_adv_reject) * torch.log(1 - prob_reject) + loss_2 = torch.log(prob_reject) * normalized_adv_reject + (1. - normalized_adv_reject) * torch.log(1 - prob_reject) losses = -(loss_1 + loss_2) / 2. # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) From 9a2cc92726dde5d193a821eff3a8cb70f3b99144 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 21:49:15 -0400 Subject: [PATCH 047/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index b622b1d7..d8056974 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -178,7 +178,7 @@ def pairwise_offline_loss( beta: float, label_smoothing: float, sft_alpha: float = 0.0, - bce: bool = True, + bce: bool = False, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. From ece883aada8875f2c8a70735dfb4afcb3dcb9fff Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 4 Jul 2025 21:55:30 -0400 Subject: [PATCH 048/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index d8056974..850cc4a6 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -44,7 +44,7 @@ def pairwise_offline_forward( batch: MutableMapping, average_log_prob: bool = False, policy_model_config: Optional[PretrainedConfig] = None, - use_attention_sequence_id: bool = False, + use_attention_sequence_id: bool = True, ) -> dict[str, torch.Tensor]: """Forwards the model for dpo and get the chosen and rejected log probs. From 4e469e218eab7eb59d72fdeaa96acc9b6e7373c9 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 13:32:01 -0400 Subject: [PATCH 049/195] added single offline data track --- compose_rl/algorithms/offline/__init__.py | 7 +- compose_rl/algorithms/offline/callback.py | 18 ++ compose_rl/algorithms/offline/model.py | 90 ++++++++- .../algorithms/offline/model_methods.py | 137 ++++++++++++++ compose_rl/data/__init__.py | 9 + compose_rl/data/dataloader.py | 10 + compose_rl/data/offline_data.py | 178 ++++++++++++++++++ pyproject.toml | 5 + 8 files changed, 452 insertions(+), 2 deletions(-) create mode 100644 compose_rl/data/offline_data.py diff --git a/compose_rl/algorithms/offline/__init__.py b/compose_rl/algorithms/offline/__init__.py index 08f8132b..5b9f4a14 100644 --- a/compose_rl/algorithms/offline/__init__.py +++ b/compose_rl/algorithms/offline/__init__.py @@ -1,14 +1,19 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -from compose_rl.algorithms.offline.callback import ReferencePolicyCallback +from compose_rl.algorithms.offline.callback import ReferencePolicyCallback, PairwiseReferencePolicyCallback from compose_rl.algorithms.offline.model import ( + ComposerHFOfflinePolicyLM, + ComposerMPTOfflinePolicyLM, ComposerHFPairwiseOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) __all__ = [ + 'ComposerHFOfflinePolicyLM', + 'ComposerMPTOfflinePolicyLM', 'ComposerMPTPairwiseOfflinePolicyLM', 'ComposerHFPairwiseOfflinePolicyLM', + 'PairwiseReferencePolicyCallback', 'ReferencePolicyCallback', ] diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 266c9951..f10a86ab 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -75,6 +75,24 @@ def after_load(self, state: State, logger: Logger) -> None: callbacks=load_checkpoint_callbacks, ) + def before_forward(self, state: State, logger: Logger) -> Optional[int]: + # Before every batch we need to do a forwards pass over the reference model + with get_precision_context(state.precision): + with torch.no_grad(): + assert self.reference_model is not None + reference_outputs = self.reference_model(state.batch) + state.batch.update({ + 'ref_logp': reference_outputs['policy_logp'], + }) + + +class PairwiseReferencePolicyCallback(ReferencePolicyCallback): + """Callback to run reference policy in pairwise offline RL. + + Args: + train_config (dict): Training config passed to callback via foundry train.py as + callback is registered under callbacks_with_config registry. + """ def before_forward(self, state: State, logger: Logger) -> Optional[int]: # Before every batch we need to do a forwards pass over the reference model with get_precision_context(state.precision): diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 44a51fff..f692331d 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -"""Pairwise Offline RL Composer Implementation.""" +"""Offline RL Composer Implementation.""" from __future__ import annotations @@ -14,6 +14,9 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from compose_rl.algorithms.offline.model_methods import ( + OfflineEnum, + offline_forward, + offline_loss, PairwiseOfflineEnum, pairwise_offline_forward, pairwise_offline_loss, @@ -23,6 +26,91 @@ log = logging.getLogger(__name__) +class ComposerMPTOfflinePolicyLM(ComposerMPTCausalLM): + """MPT model wrapper for offline rl model.""" + + def __init__( + self, + loss_type: str = 'apo', + beta: float = 0.1, + average_log_prob: bool = False, + **kwargs: Any, + ): + self.loss_type = OfflineEnum(loss_type) + self.beta = beta + self.average_log_prob = average_log_prob + + super().__init__(**kwargs) + self.train_metrics = None # DPOLM does not support eval_forward + + def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: + assert self.tokenizer is not None + return offline_forward( + model=self.model, + tokenizer=self.tokenizer, + batch=batch, + average_log_prob=self.average_log_prob, + policy_model_config=self.config, + ) + + def eval_forward( + self, + batch: MutableMapping, + outputs: CausalLMOutputWithPast, + ) -> None: + raise ValueError('Eval forward is not implemented for ComposerDPOLM.') + + def loss(self, outputs: CausalLMOutputWithPast, + batch: Mapping) -> dict[str, torch.Tensor]: + return offline_loss( + outputs, + batch, + self.loss_type, + self.beta, + ) + + +class ComposerHFOfflinePolicyLM(ComposerHFCausalLM): + """HF class wrapper for offline rl model.""" + + def __init__( + self, + loss_type: str = 'apo', + beta: float = 0.1, + average_log_prob: bool = False, + **kwargs: Any, + ): + self.loss_type = OfflineEnum(loss_type) + self.beta = beta + self.average_log_prob = average_log_prob + + super().__init__(**kwargs) + self.train_metrics = None # DPOLM does not support eval_forward + + def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: + assert self.tokenizer is not None + return offline_forward( + model=self.model, + tokenizer=self.tokenizer, + batch=batch, + average_log_prob=self.average_log_prob, + ) + + def eval_forward( + self, + batch: MutableMapping, + outputs: CausalLMOutputWithPast, + ) -> None: + raise ValueError('Eval forward is not implemented for ComposerHFDPOLM.') + + def loss(self, outputs: CausalLMOutputWithPast, + batch: Mapping) -> dict[str, torch.Tensor]: + return offline_loss( + outputs, + batch, + self.loss_type, + self.beta, + ) class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM): """MPT model wrapper for DPO model.""" diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 850cc4a6..22dfc2c9 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -35,8 +35,145 @@ class PairwiseOfflineEnum(Enum): KTO = 'kto' APO = 'apo' # Not a pair-wise preference algorithm +class OfflineEnum(Enum): + APO = 'apo' +def offline_forward( + model: nn.Module, + tokenizer: Tokenizer, + batch: MutableMapping, + average_log_prob: bool = False, + policy_model_config: Optional[PretrainedConfig] = None, +) -> dict[str, torch.Tensor]: + """Forwards the model for dpo and get the chosen and rejected log probs. + + Args: + model (nn.Module): Model we are forwarding. + tokenizer (Tokenizer): Tokenizer for the model. + batch (Dict[str, torch.LongTensor]): Batch over which we should forward the model. + Note: this batch has chosen and rejected concated along the sequence dimension. + average_log_prob (bool): Whether should we average the log probabilities. + policy_model_config: Policy model config. + """ + if policy_model_config is not None and hasattr(model, 'transformer'): + clear_mb_load_balancing_loss( + policy_model_config, + model.transformer, # type: ignore + ) + + seq_len = batch['input_ids'].size(1) + pad_token_id = tokenizer.pad_token_id + if pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + # If we can't use attn_seq_id then we need to unpack each batch and + # Pack along the batch dimension instead. + output_logits = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + + labels = extract_packed_chosen_rejected( + batch['input_ids'], + batch['response_len'], + 0, + seq_len, + pad_token_id=0, + ) + + logps = get_batch_logp( + labels, + output_logits, + batch['prompt_len'], + batch['response_len'], + average_log_prob, + ) + + outputs: dict[str, torch.Tensor] = { + 'policy_logp': logps, + 'response_len': batch['response_len'], + } + + if 'reward' in batch: + outputs['reward'] = batch['reward'] + + if 'vstar' in batch: + outputs['vstar'] = batch['vstar'] + + if policy_model_config is not None and hasattr(model, 'transformer'): + lbl = get_mb_load_balancing_loss( + policy_model_config, + model.transformer, # type: ignore + ) + if lbl is not None: + outputs['lbl'] = lbl + + return outputs + +def offline_loss( + outputs: CausalLMOutputWithPast, + batch: Mapping, + loss_type: OfflineEnum, + beta: float, + bce: bool = False, +): + policy_logp = outputs['policy_logp'] # (batch_size, ) + ref_logp = batch.get( + 'ref_logp', + torch.zeros_like(policy_logp), + ) + + if loss_type == OfflineEnum.APO: + # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 + # APO is not a pair-wise loss function. + # We assume the dataset contains two responses per prompt. + # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. + # The chosen and reject do not mean anything in APO + # Similar to REBEL, we assume each response has a reward in the batch. + # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch + vstars = outputs['vstar'] # (batch_size, ) + + if bce == False: + losses = ( + beta * (policy_logp - ref_logp) - + (outputs['reward'] - vstars) + )**2 + else: + normalized_adv_chosen = torch.sigmoid(outputs['reward'] - vstars) # put it into [0,1] + reward_prob = torch.sigmoid(beta*(policy_logp - ref_logp)) # turn prediction into prob + losses = torch.log(reward_prob) * normalized_adv_chosen + (1. - normalized_adv_chosen) * torch.log(1 - reward_prob) + losses = -1 * losses + + # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) + # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref + # use clip to ensure the output from exp is valid + with torch.no_grad(): + estimated_rewards = torch.exp(torch.clip(policy_logp - ref_logp, max = 5.)) * outputs['reward'] + estimated_reward = torch.mean(estimated_rewards) + + losses = losses.mean() + + implicit_rewards = beta * (policy_logp - ref_logp).detach() + + # Logging KL margins for comparing different methods + reverse_kl = (policy_logp - ref_logp).detach() + forward_kl = (ref_logp - policy_logp).detach() + loss_dict = { + 'implicit_rewards': implicit_rewards, + 'reverse_kl': reverse_kl, + 'forward_kl': forward_kl, + } + if loss_type == OfflineEnum.APO: + loss_dict['estimated_reward'] = estimated_reward + + if 'lbl' in outputs: + losses += outputs['lbl'] + loss_dict['lbl'] = outputs['lbl'] + + loss_dict['total'] = losses + + return loss_dict def pairwise_offline_forward( model: nn.Module, diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 5031dd10..92933703 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -10,6 +10,7 @@ build_messages_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, + build_offline_dataloader, ) from compose_rl.data.messages_data import messages_dataset_collate_fn from compose_rl.data.preference_data import ( @@ -18,14 +19,22 @@ ) from compose_rl.data.prompt_data import prompt_dataset_collate_fn +from compose_rl.data.offline_data import ( + offline_dataset_collate_fn, + OfflineStreamingDataset, +) + __all__ = [ 'build_pairwise_preference_dataloader', 'build_finegrained_preference_dataloader', 'build_messages_dataloader', + 'build_offline_dataloader', 'build_prompt_dataloader', 'DummyDataset', 'finegrained_preference_dataset_collate_fn', 'MinibatchRolloutBuffer', + 'offline_dataset_collate_fn', + 'OfflineStreamingDataset', 'pairwise_preference_dataset_collate_fn', 'prompt_dataset_collate_fn', 'messages_dataset_collate_fn', diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3085c26a..0a2180b3 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -20,6 +20,10 @@ finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, ) +from compose_rl.data.offline_data import ( + offline_dataset_collate_fn, + OfflineStreamingDataset, +) from compose_rl.data.prompt_data import ( PromptStreamingDataset, prompt_dataset_collate_fn, @@ -30,6 +34,7 @@ 'build_pairwise_preference_dataloader', 'build_prompt_dataloader', 'build_messages_dataloader', + 'build_offline_dataloader', ] @@ -126,3 +131,8 @@ def build_preference_dataloader( MessagesStreamingDataset, messages_dataset_collate_fn, ) + +build_offline_dataloader = generate_dataloader_builder( + OfflineStreamingDataset, + offline_dataset_collate_fn, +) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py new file mode 100644 index 00000000..ef60bf82 --- /dev/null +++ b/compose_rl/data/offline_data.py @@ -0,0 +1,178 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build a reward dataset and dataloader for training.""" + +import logging +from typing import Any + +import numpy as np +import torch +from streaming import StreamingDataset +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer + +log = logging.getLogger(__name__) + + +def offline_dataset_collate_fn( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, torch.Tensor]], +) -> dict[str, Any]: + """Collator for offline data. + + Args: + tokenizer (Tokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, torch.Tensor]]): The preference data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + batch_input_ids = [] + attention_masks = [] + sequence_lens = [] + prompt_lens = [] + rewards = [] + vstars = [] + + for sample in data: + input_ids = sample['input_ids'] + prompt_len = sample['prompt_len'] + sequence_len = sample['sequence_len'] + + # Note: if we do any truncation, we force the last token to be EOS + # https://github.com/mosaicml/RLHF/issues/101 + + # Add the eos token if it's not in the chosen sample + if input_ids[-1] != tokenizer.eos_token_id: + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + pad_len = max_seq_len - sequence_len + + if pad_len < 0: + # We should truncate with an additional token left for eos + truncate_len = pad_len + 1 + + log.warning(( + f'Sequence length: {sequence_len}' + f' are too long for max_seq_len: {max_seq_len}' + f' truncating by {truncate_len[0]} tokens.' + )) + + # Truncate each value by truncate length, and make the last token EOS + input_ids = input_ids[:-truncate_len] + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + sequence_len = torch.tensor([len(sequence_len)]) + + pad_len = max_seq_len - sequence_len + + if pad_len > 0: + input_ids = torch.cat( + [ + input_ids, + torch.ones(int(pad_len.item()), dtype=input_ids.dtype) * + tokenizer.pad_token_id, # type: ignore + ], + dim=-1, # type: ignore + ) + + attention_mask = torch.logical_not( + torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore + ) + + batch_input_ids.append(input_ids) + attention_masks.append(attention_mask) + sequence_lens.append(sequence_len) + prompt_lens.append(prompt_len) + if 'reward' in sample: + rewards.append(sample['reward']) + if 'vstar' in sample: + vstars.append(sample['vstar']) + + batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] + attention_masks = torch.stack(attention_masks) + + sequence_lens = torch.cat(sequence_lens) + prompt_lens = torch.cat(prompt_lens) + return_dict = { + 'sequence_lens': sequence_lens, + 'prompt_len': prompt_lens, + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + } + if len(rewards) > 0: + rewards = torch.stack(rewards) + return_dict['reward'] = rewards + if len(vstars) > 0: + vstars = torch.stack(vstars) + return_dict['vstar'] = vstars + + return return_dict + + +class OfflineStreamingDataset(StreamingDataset): + """Dataloader for streaming in preference data.""" + + def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): + self.max_seq_len = max_seq_len + super().__init__(**kwargs) + self.num_truncated = 0 + self.num_read = 0 + + def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): + self.num_read += 1 + temp_sample = torch.from_numpy(np.frombuffer(sample[key])) + if len(temp_sample) > self.max_seq_len: + log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') + self.num_truncated += 1 + truncated = torch.from_numpy( + np.frombuffer(sample[key][self.max_seq_len:], dtype=np.int64), + ) + log.info(f'Truncating: {truncated}') + decoded_arr = torch.from_numpy( + np.frombuffer(sample[key], + dtype=np.int64)[:self.max_seq_len].copy(), + ) + return decoded_arr + + # How to process a sample + def __getitem__(self, idx: int) -> dict[str, Any]: + """Get an item from StreamingDataset at a given index. + + Args: + idx (int): the index where we fetch the data in the StreamingDataset. + """ + sample = super().__getitem__(idx) + + # Read Samples + sample['input_ids'] = sample['prompt'] + sample['response'] + input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + + # Get Lenghts + prompt_len = len(prompt) + sequence_len = len(input_ids) + + return_dict = { + 'input_ids': input_ids, + 'sequence_len': torch.Tensor([sequence_len]).to(torch.int64), + 'prompt_len': torch.Tensor([prompt_len]).to(torch.int64), + } + # If rewards are given, add them to the return dict + if 'reward' in sample: + return_dict['reward'] = torch.Tensor([sample['reward']]) + + if 'vstar' in sample: + return_dict['vstar'] = torch.Tensor([sample['vstar']]) + + return return_dict diff --git a/pyproject.toml b/pyproject.toml index f13f1efe..ba0e9fb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ hf_classifier_rm = "compose_rl.algorithms.reward_modeling:ComposerHFClassifierRe hf_causal_classifier_rm = "compose_rl.algorithms.reward_modeling:ComposerHFCausalClassifierRewardModel" mpt_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerMPTPairwiseOfflinePolicyLM" hf_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerHFPairwiseOfflinePolicyLM" +mpt_offline_lm = "compose_rl.algorithms.offline:ComposerMPTOfflinePolicyLM" +hf_offline_lm = "compose_rl.algorithms.offline:ComposerHFOfflinePolicyLM" mpt_actor_critic_lm = "compose_rl.algorithms.online:ComposerMPTPolicyLM" hf_actor_critic_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM" hf_critic_free_lm = "compose_rl.algorithms.online:ComposerHFCriticFreePolicyLM" @@ -60,10 +62,13 @@ pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader" finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader" prompt = "compose_rl.data:build_prompt_dataloader" messages = "compose_rl.data:build_messages_dataloader" +offline = "compose_rl.data:build_offline_dataloader" [project.entry-points."llmfoundry_callbacks_with_config"] offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback" +pairwise_offline_rl = "compose_rl.algorithms.offline:PairwiseReferencePolicyCallback" on_policy_rl = "compose_rl.algorithms.online:OnPolicyCallback" + # Backwards Compatibility dpo = "compose_rl.algorithms.offline:ReferencePolicyCallback" ppo = "compose_rl.algorithms.online:OnPolicyCallback" From 3059bc4b462cd1e72863b3e934fbc709a1994cc0 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 13:36:44 -0400 Subject: [PATCH 050/195] cleanup --- compose_rl/algorithms/offline/model.py | 2 -- .../algorithms/offline/model_methods.py | 19 ++----------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index f692331d..39701dfb 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -47,7 +47,6 @@ def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: assert self.tokenizer is not None return offline_forward( model=self.model, - tokenizer=self.tokenizer, batch=batch, average_log_prob=self.average_log_prob, policy_model_config=self.config, @@ -91,7 +90,6 @@ def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: assert self.tokenizer is not None return offline_forward( model=self.model, - tokenizer=self.tokenizer, batch=batch, average_log_prob=self.average_log_prob, ) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 22dfc2c9..0dacff20 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -41,7 +41,6 @@ class OfflineEnum(Enum): def offline_forward( model: nn.Module, - tokenizer: Tokenizer, batch: MutableMapping, average_log_prob: bool = False, policy_model_config: Optional[PretrainedConfig] = None, @@ -62,11 +61,6 @@ def offline_forward( model.transformer, # type: ignore ) - seq_len = batch['input_ids'].size(1) - pad_token_id = tokenizer.pad_token_id - if pad_token_id is None: - raise ValueError('Tokenizer must have a PAD token.') - # If we can't use attn_seq_id then we need to unpack each batch and # Pack along the batch dimension instead. output_logits = model( @@ -74,25 +68,16 @@ def offline_forward( attention_mask=batch["attention_mask"], ).logits - labels = extract_packed_chosen_rejected( - batch['input_ids'], - batch['response_len'], - 0, - seq_len, - pad_token_id=0, - ) - logps = get_batch_logp( - labels, + batch['input_ids'], output_logits, batch['prompt_len'], - batch['response_len'], + batch['sequence_len'], average_log_prob, ) outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, - 'response_len': batch['response_len'], } if 'reward' in batch: From a3a2e7171ae8078cf6e8564404be08fee0eed4c7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 13:59:11 -0400 Subject: [PATCH 051/195] bce fix --- compose_rl/algorithms/offline/model_methods.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 0dacff20..1f01014c 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -101,7 +101,6 @@ def offline_loss( batch: Mapping, loss_type: OfflineEnum, beta: float, - bce: bool = False, ): policy_logp = outputs['policy_logp'] # (batch_size, ) ref_logp = batch.get( @@ -118,17 +117,10 @@ def offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch vstars = outputs['vstar'] # (batch_size, ) - - if bce == False: - losses = ( - beta * (policy_logp - ref_logp) - - (outputs['reward'] - vstars) - )**2 - else: - normalized_adv_chosen = torch.sigmoid(outputs['reward'] - vstars) # put it into [0,1] - reward_prob = torch.sigmoid(beta*(policy_logp - ref_logp)) # turn prediction into prob - losses = torch.log(reward_prob) * normalized_adv_chosen + (1. - normalized_adv_chosen) * torch.log(1 - reward_prob) - losses = -1 * losses + losses = ( + beta * (policy_logp - ref_logp) - + (outputs['reward'] - vstars) + )**2 # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref From 9ecb7d0e3e0207b572a24f1c634acdcb9c7cb30e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 14:59:29 -0400 Subject: [PATCH 052/195] linkt --- compose_rl/algorithms/offline/__init__.py | 7 +- compose_rl/algorithms/offline/callback.py | 1 + compose_rl/algorithms/offline/model.py | 4 +- .../algorithms/offline/model_methods.py | 71 ++++++++++++------- compose_rl/data/__init__.py | 11 ++- compose_rl/data/dataloader.py | 8 +-- yamls/local_dpo.yaml | 12 ++-- 7 files changed, 70 insertions(+), 44 deletions(-) diff --git a/compose_rl/algorithms/offline/__init__.py b/compose_rl/algorithms/offline/__init__.py index 5b9f4a14..685e5a31 100644 --- a/compose_rl/algorithms/offline/__init__.py +++ b/compose_rl/algorithms/offline/__init__.py @@ -1,11 +1,14 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -from compose_rl.algorithms.offline.callback import ReferencePolicyCallback, PairwiseReferencePolicyCallback +from compose_rl.algorithms.offline.callback import ( + PairwiseReferencePolicyCallback, + ReferencePolicyCallback, +) from compose_rl.algorithms.offline.model import ( ComposerHFOfflinePolicyLM, - ComposerMPTOfflinePolicyLM, ComposerHFPairwiseOfflinePolicyLM, + ComposerMPTOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index f10a86ab..1a94dd21 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -93,6 +93,7 @@ class PairwiseReferencePolicyCallback(ReferencePolicyCallback): train_config (dict): Training config passed to callback via foundry train.py as callback is registered under callbacks_with_config registry. """ + def before_forward(self, state: State, logger: Logger) -> Optional[int]: # Before every batch we need to do a forwards pass over the reference model with get_precision_context(state.precision): diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 7a6a1c5b..3007d344 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -15,9 +15,9 @@ from compose_rl.algorithms.offline.model_methods import ( OfflineEnum, + PairwiseOfflineEnum, offline_forward, offline_loss, - PairwiseOfflineEnum, pairwise_offline_forward, pairwise_offline_loss, ) @@ -26,6 +26,7 @@ log = logging.getLogger(__name__) + class ComposerMPTOfflinePolicyLM(ComposerMPTCausalLM): """MPT model wrapper for offline rl model.""" @@ -110,6 +111,7 @@ def loss(self, outputs: CausalLMOutputWithPast, self.beta, ) + class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM): """MPT model wrapper for DPO model.""" diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index ac97e6e6..a5a25db5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -35,6 +35,7 @@ class PairwiseOfflineEnum(Enum): KTO = 'kto' APO = 'apo' # Not a pair-wise preference algorithm + class OfflineEnum(Enum): APO = 'apo' @@ -64,8 +65,8 @@ def offline_forward( # If we can't use attn_seq_id then we need to unpack each batch and # Pack along the batch dimension instead. output_logits = model( - batch["input_ids"], - attention_mask=batch["attention_mask"], + batch['input_ids'], + attention_mask=batch['attention_mask'], ).logits logps = get_batch_logp( @@ -82,7 +83,7 @@ def offline_forward( if 'reward' in batch: outputs['reward'] = batch['reward'] - + if 'vstar' in batch: outputs['vstar'] = batch['vstar'] @@ -96,6 +97,7 @@ def offline_forward( return outputs + def offline_loss( outputs: CausalLMOutputWithPast, batch: Mapping, @@ -118,15 +120,16 @@ def offline_loss( # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch vstars = outputs['vstar'] # (batch_size, ) losses = ( - beta * (policy_logp - ref_logp) - - (outputs['reward'] - vstars) + beta * (policy_logp - ref_logp) - (outputs['reward'] - vstars) )**2 # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref # use clip to ensure the output from exp is valid with torch.no_grad(): - estimated_rewards = torch.exp(torch.clip(policy_logp - ref_logp, max = 5.)) * outputs['reward'] + estimated_rewards = torch.exp( + torch.clip(policy_logp - ref_logp, max=5.), + ) * outputs['reward'] estimated_reward = torch.mean(estimated_rewards) losses = losses.mean() @@ -152,6 +155,7 @@ def offline_loss( return loss_dict + def pairwise_offline_forward( model: nn.Module, tokenizer: Tokenizer, @@ -274,7 +278,7 @@ def pairwise_offline_forward( if 'chosen_reward' in batch: outputs['chosen_reward'] = batch['chosen_reward'] outputs['rejected_reward'] = batch['rejected_reward'] - + if 'vstar' in batch: outputs['vstar'] = batch['vstar'] @@ -295,8 +299,8 @@ def pairwise_offline_loss( loss_type: PairwiseOfflineEnum, beta: float, label_smoothing: float, - sft_alpha: float = 0.0, - bce: bool = False, + sft_alpha: float = 0.0, + bce: bool = False, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -312,7 +316,7 @@ def pairwise_offline_loss( preferences as noisy (preferences are flipped with probability label_smoothing). sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to be added to DPO type loss. - bce (bool): loss type that is alternative to the squared loss. It is in APO, potentially can be + bce (bool): loss type that is alternative to the squared loss. It is in APO, potentially can be used for REBEL and IPO. """ policy_chosen_logp = outputs['policy_chosen_logp'] # (batch_size, ) @@ -347,7 +351,7 @@ def pairwise_offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch vstars = outputs['vstar'] # (batch_size, ) - + if bce == False: loss_1 = ( beta * (policy_chosen_logp - ref_chosen_logp) - @@ -359,24 +363,40 @@ def pairwise_offline_loss( )**2 losses = (loss_1 + loss_2) / 2. else: - print("#####################") - print("using BCE loss") - print("######################") - normalized_adv_chosen = torch.sigmoid(outputs['chosen_reward'] - vstars) # put it into [0,1] - normalized_adv_reject = torch.sigmoid(outputs['rejected_reward'] - vstars) # put it into [0,1] - prob_chosen = torch.sigmoid(beta*(policy_chosen_logp - ref_chosen_logp)) # turn prediction into prob - prob_reject = torch.sigmoid(beta*(policy_rejected_logp - ref_rejected_logp)) # turn prediction into prob - loss_1 = torch.log(prob_chosen) * normalized_adv_chosen + (1. - normalized_adv_chosen) * torch.log(1 - prob_chosen) - loss_2 = torch.log(prob_reject) * normalized_adv_reject + (1. - normalized_adv_reject) * torch.log(1 - prob_reject) + print('#####################') + print('using BCE loss') + print('######################') + normalized_adv_chosen = torch.sigmoid( + outputs['chosen_reward'] - vstars, + ) # put it into [0,1] + normalized_adv_reject = torch.sigmoid( + outputs['rejected_reward'] - vstars, + ) # put it into [0,1] + prob_chosen = torch.sigmoid( + beta * (policy_chosen_logp - ref_chosen_logp), + ) # turn prediction into prob + prob_reject = torch.sigmoid( + beta * (policy_rejected_logp - ref_rejected_logp), + ) # turn prediction into prob + loss_1 = torch.log(prob_chosen) * normalized_adv_chosen + ( + 1. - normalized_adv_chosen + ) * torch.log(1 - prob_chosen) + loss_2 = torch.log(prob_reject) * normalized_adv_reject + ( + 1. - normalized_adv_reject + ) * torch.log(1 - prob_reject) losses = -(loss_1 + loss_2) / 2. # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref # use clip to ensure the output from exp is valid with torch.no_grad(): - estimated_rewards = torch.exp(torch.clip(policy_chosen_logp - ref_chosen_logp, max = 5.)) * outputs['chosen_reward'] - estimated_rewards += torch.exp(torch.clip(policy_rejected_logp - ref_rejected_logp, max = 5.)) * outputs['rejected_reward'] - estimated_reward = torch.mean(estimated_rewards)/2. + estimated_rewards = torch.exp( + torch.clip(policy_chosen_logp - ref_chosen_logp, max=5.), + ) * outputs['chosen_reward'] + estimated_rewards += torch.exp( + torch.clip(policy_rejected_logp - ref_rejected_logp, max=5.), + ) * outputs['rejected_reward'] + estimated_reward = torch.mean(estimated_rewards) / 2. elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob @@ -471,8 +491,9 @@ def pairwise_offline_loss( # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore if loss_type == PairwiseOfflineEnum.APO: - forward_kl = ((ref_chosen_logp - policy_chosen_logp) + (ref_rejected_logp - policy_rejected_logp)).detach() - loss_dict['forward_kl'] = forward_kl/2. + forward_kl = ((ref_chosen_logp - policy_chosen_logp) + + (ref_rejected_logp - policy_rejected_logp)).detach() + loss_dict['forward_kl'] = forward_kl / 2. loss_dict['estimated_reward'] = estimated_reward if sft_alpha > 0: diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 92933703..5d2ca370 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -8,22 +8,21 @@ from compose_rl.data.dataloader import ( build_finegrained_preference_dataloader, build_messages_dataloader, + build_offline_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, - build_offline_dataloader, ) from compose_rl.data.messages_data import messages_dataset_collate_fn +from compose_rl.data.offline_data import ( + OfflineStreamingDataset, + offline_dataset_collate_fn, +) from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, ) from compose_rl.data.prompt_data import prompt_dataset_collate_fn -from compose_rl.data.offline_data import ( - offline_dataset_collate_fn, - OfflineStreamingDataset, -) - __all__ = [ 'build_pairwise_preference_dataloader', 'build_finegrained_preference_dataloader', diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 0a2180b3..e72d1576 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -14,16 +14,16 @@ MessagesStreamingDataset, messages_dataset_collate_fn, ) +from compose_rl.data.offline_data import ( + OfflineStreamingDataset, + offline_dataset_collate_fn, +) from compose_rl.data.preference_data import ( FinegrainedPreferenceStreamingDataset, PairwisePreferenceStreamingDataset, finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, ) -from compose_rl.data.offline_data import ( - offline_dataset_collate_fn, - OfflineStreamingDataset, -) from compose_rl.data.prompt_data import ( PromptStreamingDataset, prompt_dataset_collate_fn, diff --git a/yamls/local_dpo.yaml b/yamls/local_dpo.yaml index 3c6aef27..129388de 100644 --- a/yamls/local_dpo.yaml +++ b/yamls/local_dpo.yaml @@ -7,7 +7,7 @@ model: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: Qwen/Qwen2.5-7B #meta-llama/Llama-3.1-8B-Instruct + pretrained_model_name_or_path: Qwen/Qwen2.5-7B loggers: mlflow: @@ -20,7 +20,7 @@ callbacks: window_size: 10 memory_monitor: {} # hf_checkpointer: - # save_folder: # TODO: insert save path for huggingface checkpoints + # save_folder: # TODO: insert save path for huggingface checkpoints # save_interval: 1ep optimizer: @@ -39,7 +39,7 @@ scheduler: t_warmup: 0.1dur tokenizer: - name: Qwen/Qwen2.5-7B #meta-llama/Llama-3.1-8B-Instruct + name: Qwen/Qwen2.5-7B kwargs: model_max_length: ${max_seq_len} trust_remote_code: true @@ -59,7 +59,7 @@ fsdp_config: sharding_strategy: FULL_SHARD activation_cpu_offload: false -max_seq_len: 4096 #2048 +max_seq_len: 4096 save_folder: /tmp/dpo_model # TODO: update for a proper save path dist_timeout: 600 max_duration: 1ep @@ -68,8 +68,8 @@ progress_bar: false train_loader: name: pairwise_preference dataset: - # local: # TODO: insert local dataset path - # remote: # TODO: insert remote dataset path if applicable + # local: # TODO: insert local dataset path + # remote: # TODO: insert remote dataset path if applicable split: train shuffle: true From 829b7328d8538473d484cb6a3a6723d797c55ed6 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 20:32:06 -0400 Subject: [PATCH 053/195] update --- tests/test_offline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_offline.py b/tests/test_offline.py index 2838700b..33f7b303 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -21,7 +21,7 @@ ComposerHFPairwiseOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) -from compose_rl.algorithms.offline.callback import ReferencePolicyCallback +from compose_rl.algorithms.offline.callback import PairwiseReferencePolicyCallback from compose_rl.data import pairwise_preference_dataset_collate_fn from tests.common import PairwisePreference, world_size @@ -64,7 +64,7 @@ def test_load_checkpoint_with_offline_callback( train_config = { 'model': model_config, } - reference_policy_callback = ReferencePolicyCallback( + reference_policy_callback = PairwiseReferencePolicyCallback( train_config=train_config, ) @@ -125,7 +125,7 @@ def test_reference_policy_callback_forward( 'fsdp_config': {}, 'seed': 17, } - callback = ReferencePolicyCallback(train_config=train_config) + callback = PairwiseReferencePolicyCallback(train_config=train_config) Trainer( model=model, callbacks=callback, @@ -209,7 +209,7 @@ def test_train( trainer = Trainer( model=model, train_dataloader=dataloader, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='1ep', ) @@ -295,7 +295,7 @@ def test_checkpoint_reloading( model=model, train_dataloader=dataloader, loggers=in_memory_logger, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='8ba', autoresume=True, @@ -317,7 +317,7 @@ def test_checkpoint_reloading( model=model, train_dataloader=dataloader, loggers=in_memory_logger, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='8ba', save_overwrite=True, From d687b0277a36311a9d14fddd560cdee8b18590d8 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 20:32:41 -0400 Subject: [PATCH 054/195] update --- yamls/local_dpo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/local_dpo.yaml b/yamls/local_dpo.yaml index 129388de..94de3507 100644 --- a/yamls/local_dpo.yaml +++ b/yamls/local_dpo.yaml @@ -14,7 +14,7 @@ loggers: experiment_name: wensun_dpo_test callbacks: - offline_rl: {} + pairwise_offline_rl: {} lr_monitor: {} speed_monitor: window_size: 10 From 22ae6554941f16c321a02b14e4d1bee5f6b715b1 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 20:47:35 -0400 Subject: [PATCH 055/195] bug fix for attention mask --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index ef60bf82..c1edc796 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -87,7 +87,7 @@ def offline_dataset_collate_fn( ) attention_mask = torch.logical_not( - torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore + torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore ) batch_input_ids.append(input_ids) From a112dc38c762d84285d4754719d232e177c822e2 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 20:50:33 -0400 Subject: [PATCH 056/195] sequence len typo fix --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index c1edc796..7c17465a 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -105,7 +105,7 @@ def offline_dataset_collate_fn( sequence_lens = torch.cat(sequence_lens) prompt_lens = torch.cat(prompt_lens) return_dict = { - 'sequence_lens': sequence_lens, + 'sequence_len': sequence_lens, 'prompt_len': prompt_lens, 'input_ids': batch_input_ids, 'attention_mask': attention_masks, From d6bd97561d866e70e306e2925cc90d59543b5f2e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 20:58:13 -0400 Subject: [PATCH 057/195] update --- .../algorithms/offline/model_methods.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index a5a25db5..833d7921 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -81,12 +81,6 @@ def offline_forward( 'policy_logp': logps, } - if 'reward' in batch: - outputs['reward'] = batch['reward'] - - if 'vstar' in batch: - outputs['vstar'] = batch['vstar'] - if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( policy_model_config, @@ -118,9 +112,19 @@ def offline_loss( # The chosen and reject do not mean anything in APO # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - vstars = outputs['vstar'] # (batch_size, ) + # + # + print("INSIDE APO LOSS") + print("Batch") + for k, v in batch.items(): + print(f"{k}: {v.shape}") + print("Outputs") + for k, v in outputs.items(): + print(f"{k}: {v.shape}") + + vstars = batch['vstar'] # (batch_size, ) losses = ( - beta * (policy_logp - ref_logp) - (outputs['reward'] - vstars) + beta * (policy_logp - ref_logp) - (batch['reward'] - vstars) )**2 # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) From b38736a2be7ec58a1956754ec85bb73e9e5e61e7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 21:10:03 -0400 Subject: [PATCH 058/195] update --- compose_rl/algorithms/offline/model_methods.py | 7 ++++++- compose_rl/data/offline_data.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 833d7921..2bc3c001 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -69,14 +69,19 @@ def offline_forward( attention_mask=batch['attention_mask'], ).logits + print("LOGITS") + print(output_logits.shape) + logps = get_batch_logp( - batch['input_ids'], + batch['input_ids'].clone(), output_logits, batch['prompt_len'], batch['sequence_len'], average_log_prob, ) + print(logps.shape) + outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, } diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 7c17465a..5f9d816b 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -170,9 +170,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: } # If rewards are given, add them to the return dict if 'reward' in sample: - return_dict['reward'] = torch.Tensor([sample['reward']]) + return_dict['reward'] = torch.Tensor(sample['reward']) if 'vstar' in sample: - return_dict['vstar'] = torch.Tensor([sample['vstar']]) + return_dict['vstar'] = torch.Tensor(sample['vstar']) return return_dict From 15aabfdbe430952302984a9e3606f409f200dab7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 21:14:50 -0400 Subject: [PATCH 059/195] fix --- compose_rl/data/offline_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 5f9d816b..904b6498 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -170,9 +170,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: } # If rewards are given, add them to the return dict if 'reward' in sample: - return_dict['reward'] = torch.Tensor(sample['reward']) + return_dict['reward'] = torch.from_numpy(sample['reward']).float() if 'vstar' in sample: - return_dict['vstar'] = torch.Tensor(sample['vstar']) + return_dict['vstar'] = torch.from_numpy(sample['vstar']).float() return return_dict From 6672e5d94dba433ffb65541d020b38c6eadffb08 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 21:21:35 -0400 Subject: [PATCH 060/195] shape fix --- compose_rl/algorithms/offline/model_methods.py | 15 +-------------- compose_rl/data/offline_data.py | 8 ++++---- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 2bc3c001..7979808c 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -69,19 +69,14 @@ def offline_forward( attention_mask=batch['attention_mask'], ).logits - print("LOGITS") - print(output_logits.shape) - logps = get_batch_logp( - batch['input_ids'].clone(), + batch['input_ids'], output_logits, batch['prompt_len'], batch['sequence_len'], average_log_prob, ) - print(logps.shape) - outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, } @@ -119,14 +114,6 @@ def offline_loss( # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch # # - print("INSIDE APO LOSS") - print("Batch") - for k, v in batch.items(): - print(f"{k}: {v.shape}") - print("Outputs") - for k, v in outputs.items(): - print(f"{k}: {v.shape}") - vstars = batch['vstar'] # (batch_size, ) losses = ( beta * (policy_logp - ref_logp) - (batch['reward'] - vstars) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 904b6498..98496759 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -111,10 +111,10 @@ def offline_dataset_collate_fn( 'attention_mask': attention_masks, } if len(rewards) > 0: - rewards = torch.stack(rewards) + rewards = torch.cat(rewards) return_dict['reward'] = rewards if len(vstars) > 0: - vstars = torch.stack(vstars) + vstars = torch.cat(vstars) return_dict['vstar'] = vstars return return_dict @@ -170,9 +170,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: } # If rewards are given, add them to the return dict if 'reward' in sample: - return_dict['reward'] = torch.from_numpy(sample['reward']).float() + return_dict['reward'] = torch.Tensor([sample['reward']]) if 'vstar' in sample: - return_dict['vstar'] = torch.from_numpy(sample['vstar']).float() + return_dict['vstar'] = torch.Tensor([sample['vstar']]) return return_dict From 3d57b0a42f2a49f824d11e1f1051c83082c78dd4 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 21:28:24 -0400 Subject: [PATCH 061/195] fix --- compose_rl/algorithms/offline/model_methods.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 7979808c..7e99beac 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -114,7 +114,23 @@ def offline_loss( # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch # # + # Check data types vstars = batch['vstar'] # (batch_size, ) + print(f"policy_logp dtype: {policy_logp.dtype}") + print(f"ref_logp dtype: {ref_logp.dtype}") + print(f"batch['reward'] dtype: {batch['reward'].dtype}") + print(f"vstars dtype: {vstars.dtype}") + print(f"beta dtype: {type(beta)}") + +# Check for any boolean tensors + print(f"policy_logp is bool: {policy_logp.dtype == torch.bool}") + print(f"ref_logp is bool: {ref_logp.dtype == torch.bool}") + +# Check shapes + print(f"policy_logp shape: {policy_logp.shape}") + print(f"ref_logp shape: {ref_logp.shape}") + print(f"batch['reward'] shape: {batch['reward'].shape}") + print(f"vstars shape: {vstars.shape}") losses = ( beta * (policy_logp - ref_logp) - (batch['reward'] - vstars) )**2 From 7b042542f026213d85081f58b9766c74b52b1bbc Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 7 Jul 2025 21:37:12 -0400 Subject: [PATCH 062/195] fix --- .../algorithms/offline/model_methods.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 7e99beac..f80c3e4d 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -107,32 +107,11 @@ def offline_loss( if loss_type == OfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 # APO is not a pair-wise loss function. - # We assume the dataset contains two responses per prompt. - # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. - # The chosen and reject do not mean anything in APO # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - # - # - # Check data types - vstars = batch['vstar'] # (batch_size, ) - print(f"policy_logp dtype: {policy_logp.dtype}") - print(f"ref_logp dtype: {ref_logp.dtype}") - print(f"batch['reward'] dtype: {batch['reward'].dtype}") - print(f"vstars dtype: {vstars.dtype}") - print(f"beta dtype: {type(beta)}") - -# Check for any boolean tensors - print(f"policy_logp is bool: {policy_logp.dtype == torch.bool}") - print(f"ref_logp is bool: {ref_logp.dtype == torch.bool}") - -# Check shapes - print(f"policy_logp shape: {policy_logp.shape}") - print(f"ref_logp shape: {ref_logp.shape}") - print(f"batch['reward'] shape: {batch['reward'].shape}") - print(f"vstars shape: {vstars.shape}") + losses = ( - beta * (policy_logp - ref_logp) - (batch['reward'] - vstars) + beta * (policy_logp - ref_logp) - (batch['reward'] - batch['vstar']) )**2 # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) @@ -141,7 +120,7 @@ def offline_loss( with torch.no_grad(): estimated_rewards = torch.exp( torch.clip(policy_logp - ref_logp, max=5.), - ) * outputs['reward'] + ) * batch['reward'] estimated_reward = torch.mean(estimated_rewards) losses = losses.mean() From 11df0ad77afca079c6efea6bb08acf53c44fe51c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 07:49:43 -0400 Subject: [PATCH 063/195] fix --- .../algorithms/offline/model_methods.py | 73 +------------------ 1 file changed, 2 insertions(+), 71 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f80c3e4d..57975218 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -25,6 +25,8 @@ get_mb_load_balancing_loss, ) +class OfflineEnum(Enum): + APO = 'apo' class PairwiseOfflineEnum(Enum): DPO = 'dpo' @@ -33,11 +35,6 @@ class PairwiseOfflineEnum(Enum): REBEL = 'rebel' IPO = 'ipo' KTO = 'kto' - APO = 'apo' # Not a pair-wise preference algorithm - - -class OfflineEnum(Enum): - APO = 'apo' def offline_forward( @@ -291,7 +288,6 @@ def pairwise_offline_loss( beta: float, label_smoothing: float, sft_alpha: float = 0.0, - bce: bool = False, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -307,8 +303,6 @@ def pairwise_offline_loss( preferences as noisy (preferences are flipped with probability label_smoothing). sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to be added to DPO type loss. - bce (bool): loss type that is alternative to the squared loss. It is in APO, potentially can be - used for REBEL and IPO. """ policy_chosen_logp = outputs['policy_chosen_logp'] # (batch_size, ) policy_rejected_logp = outputs['policy_rejected_logp'] # (batch_size, ) @@ -333,62 +327,6 @@ def pairwise_offline_loss( -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing ) - elif loss_type == PairwiseOfflineEnum.APO: - # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 - # APO is not a pair-wise loss function. - # We assume the dataset contains two responses per prompt. - # The name chosen and reject just refers response 1 and response 2. This is for design simplicity. - # The chosen and reject do not mean anything in APO - # Similar to REBEL, we assume each response has a reward in the batch. - # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - vstars = outputs['vstar'] # (batch_size, ) - - if bce == False: - loss_1 = ( - beta * (policy_chosen_logp - ref_chosen_logp) - - (outputs['chosen_reward'] - vstars) - )**2 - loss_2 = ( - beta * (policy_rejected_logp - ref_rejected_logp) - - (outputs['rejected_reward'] - vstars) - )**2 - losses = (loss_1 + loss_2) / 2. - else: - print('#####################') - print('using BCE loss') - print('######################') - normalized_adv_chosen = torch.sigmoid( - outputs['chosen_reward'] - vstars, - ) # put it into [0,1] - normalized_adv_reject = torch.sigmoid( - outputs['rejected_reward'] - vstars, - ) # put it into [0,1] - prob_chosen = torch.sigmoid( - beta * (policy_chosen_logp - ref_chosen_logp), - ) # turn prediction into prob - prob_reject = torch.sigmoid( - beta * (policy_rejected_logp - ref_rejected_logp), - ) # turn prediction into prob - loss_1 = torch.log(prob_chosen) * normalized_adv_chosen + ( - 1. - normalized_adv_chosen - ) * torch.log(1 - prob_chosen) - loss_2 = torch.log(prob_reject) * normalized_adv_reject + ( - 1. - normalized_adv_reject - ) * torch.log(1 - prob_reject) - losses = -(loss_1 + loss_2) / 2. - - # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) - # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref - # use clip to ensure the output from exp is valid - with torch.no_grad(): - estimated_rewards = torch.exp( - torch.clip(policy_chosen_logp - ref_chosen_logp, max=5.), - ) * outputs['chosen_reward'] - estimated_rewards += torch.exp( - torch.clip(policy_rejected_logp - ref_rejected_logp, max=5.), - ) * outputs['rejected_reward'] - estimated_reward = torch.mean(estimated_rewards) / 2. - elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob chosen_reward = outputs['chosen_reward'] @@ -446,8 +384,6 @@ def pairwise_offline_loss( ), 0, ) - else: - raise ValueError(f'Loss type: {loss_type} is not supported.') if sft_alpha > 0: sft_losses = -1 * sft_alpha * policy_chosen_logp @@ -481,11 +417,6 @@ def pairwise_offline_loss( ]: # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore - if loss_type == PairwiseOfflineEnum.APO: - forward_kl = ((ref_chosen_logp - policy_chosen_logp) + - (ref_rejected_logp - policy_rejected_logp)).detach() - loss_dict['forward_kl'] = forward_kl / 2. - loss_dict['estimated_reward'] = estimated_reward if sft_alpha > 0: # sft_losses_normalized is always defined if sft_alpha>0 From 36d01205b6ed0b476d909637aa1cbfff583295e3 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 07:50:18 -0400 Subject: [PATCH 064/195] pre-commit --- compose_rl/algorithms/offline/model_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 57975218..12844c2d 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -25,9 +25,11 @@ get_mb_load_balancing_loss, ) + class OfflineEnum(Enum): APO = 'apo' + class PairwiseOfflineEnum(Enum): DPO = 'dpo' RPO = 'rpo' @@ -108,7 +110,8 @@ def offline_loss( # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch losses = ( - beta * (policy_logp - ref_logp) - (batch['reward'] - batch['vstar']) + beta * (policy_logp - ref_logp) - + (batch['reward'] - batch['vstar']) )**2 # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) From 3a33026c90349f8192718525b53fe6ebc3c1b442 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 07:56:47 -0400 Subject: [PATCH 065/195] precommit isort --- tests/test_offline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_offline.py b/tests/test_offline.py index 33f7b303..fe8132a6 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -21,7 +21,8 @@ ComposerHFPairwiseOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) -from compose_rl.algorithms.offline.callback import PairwiseReferencePolicyCallback +from compose_rl.algorithms.offline.callback import \ + PairwiseReferencePolicyCallback from compose_rl.data import pairwise_preference_dataset_collate_fn from tests.common import PairwisePreference, world_size From d44610e4d6673fff8da799c24273aac997cbc63e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 12:03:42 -0400 Subject: [PATCH 066/195] support ndarray typing --- compose_rl/data/offline_data.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 98496759..91cb94cd 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -155,9 +155,20 @@ def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) # Read Samples - sample['input_ids'] = sample['prompt'] + sample['response'] - input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') - prompt = self._read_binary_tokenized_sample(sample, 'prompt') + input_ids, prompt = [], [] + if isinstance(sample['prompt'], bytes): + sample['input_ids'] = sample['prompt'] + sample['response'] + input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + elif isinstance(sample['prompt'], np.ndarray): + input_ids = np.concatenate([sample['prompt'], sample['response']]) + sample['input_ids'] = input_ids[:self.max_seq_len].tolist().copy() + prompt = sample['prompt'].tolist().copy() + else: + token_type = type(sample['input_ids']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) # Get Lenghts prompt_len = len(prompt) From e0e015ae260134f968988445ac981f2f9b9fe6bd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:13:21 -0400 Subject: [PATCH 067/195] support ndarray typing --- compose_rl/data/preference_data.py | 73 ++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index d08c53dc..6d698ebc 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -293,21 +293,44 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) + # Handle prompt if available - if 'prompt' in sample: + if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses - sample['chosen'] = sample['prompt'] + sample['chosen'] - sample['rejected'] = sample['prompt'] + sample['rejected'] - chosen = self._read_binary_tokenized_sample(sample, 'chosen') - rejected = self._read_binary_tokenized_sample(sample, 'rejected') - - if 'prompt' in sample: - prompt = self._read_binary_tokenized_sample(sample, 'prompt') - prompt_len = len(prompt) + if 'prompt' in sample: + sample['chosen'] = sample['prompt'] + sample['chosen'] + sample['rejected'] = sample['prompt'] + sample['rejected'] + chosen = self._read_binary_tokenized_sample(sample, 'chosen') + rejected = self._read_binary_tokenized_sample(sample, 'rejected') + + if 'prompt' in sample: + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + prompt_len = len(prompt) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) + + elif isinstance(sample['chosen'], np.ndarray): + if 'prompt' in sample: + sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) + sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) + + chosen = sample['chosen'][:self.max_seq_len].tolist().copy() + rejected = sample['rejected'][:self.max_seq_len].tolist().copy() + + if 'prompt' in sample: + prompt_len = len(sample['prompt']) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) else: - # Only use prefix matching version of prompt_len when - # 'prompt' is not directly given in the sample - prompt_len = self.find_prompt_length(chosen, rejected) + token_type = type(sample['chosen']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + chosen_len, rejected_len = len(chosen), len(rejected) return_dict = { 'chosen': chosen, @@ -335,14 +358,24 @@ def __getitem__(self, idx: int) -> dict[str, Any]: f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', ) - chosen_token_type_ids = self._read_binary_tokenized_sample( - sample, - 'chosen_token_type_ids', - ) - rejected_token_type_ids = self._read_binary_tokenized_sample( - sample, - 'rejected_token_type_ids', - ) + if isinstance(sample['chosen_token_type_ids'], bytes): + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'rejected_token_type_ids', + ) + elif isinstance(sample['chosen_token_type_ids'], np.ndarray): + chosen_token_type_ids = sample['chosen_token_type_ids'][:self.max_seq_len].tolist().copy() + rejected_token_type_ids = sample['rejected_token_type_ids'][:self.max_seq_len].tolist().copy() + else: + token_type = type(sample['chosen_token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids From 2c1c0d4c421e10ae51fc11a81638449548febba1 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:17:06 -0400 Subject: [PATCH 068/195] PIL image support --- compose_rl/data/preference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 6d698ebc..155ddd51 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -349,7 +349,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): pixel_values = torch.Tensor(sample['pixel_values']) - elif isinstance(sample['pixel_values'], Image): + elif isinstance(sample['pixel_values'], Image.Image): pil_to_tensor_transform = transforms.PILToTensor() pixel_values = pil_to_tensor_transform(sample['pixel_values']) else: From 15ca2c6577ffc03d6ccc19bf19938e9f2309975c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:20:52 -0400 Subject: [PATCH 069/195] numpy support bug fix --- compose_rl/data/preference_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 155ddd51..90d014b9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -316,8 +316,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) - chosen = sample['chosen'][:self.max_seq_len].tolist().copy() - rejected = sample['rejected'][:self.max_seq_len].tolist().copy() + chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) + rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) if 'prompt' in sample: prompt_len = len(sample['prompt']) @@ -368,8 +368,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'rejected_token_type_ids', ) elif isinstance(sample['chosen_token_type_ids'], np.ndarray): - chosen_token_type_ids = sample['chosen_token_type_ids'][:self.max_seq_len].tolist().copy() - rejected_token_type_ids = sample['rejected_token_type_ids'][:self.max_seq_len].tolist().copy() + chosen_token_type_ids = torch.from_numpy(sample['chosen_token_type_ids'][:self.max_seq_len]) + rejected_token_type_ids = torch.from_numpy(sample['rejected_token_type_ids'][:self.max_seq_len]) else: token_type = type(sample['chosen_token_type_ids']) raise ValueError( From 375d2576b1d65f40ab65f002a8d416f176cb52ef Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:25:50 -0400 Subject: [PATCH 070/195] pixel values into lists --- compose_rl/algorithms/offline/model_methods.py | 4 +--- compose_rl/data/preference_data.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 00b74f73..f927617f 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -135,9 +135,7 @@ def pairwise_offline_forward( 'token_type_ids': torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - 'pixel_values': - torch.cat([batch['pixel_values'], batch['pixel_values']], - dim=0), + 'pixel_values': batch['pixel_values'] * 2, # double the list } print("MULTIMODAL INPUTS") diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 90d014b9..08dca8bb 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -202,7 +202,6 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) - pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From ebf43faaa3e9a072a62c8e3a7839a04ecee18862 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:29:32 -0400 Subject: [PATCH 071/195] logging fix --- compose_rl/algorithms/offline/model_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f927617f..f05b7bde 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,7 +140,10 @@ def pairwise_offline_forward( print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - print(f"{k}: {v.shape}") + if isinstance(v, torch.tensor): + print(f"{k}: {v.shape}") + else: + print(f"{k}: {len(v)}") inputs.update(multimodal_inputs) From daae699af73e053c4944822ff8ff3362f6a80200 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:32:00 -0400 Subject: [PATCH 072/195] fix --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f05b7bde..272fba8f 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,7 +140,7 @@ def pairwise_offline_forward( print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - if isinstance(v, torch.tensor): + if isinstance(v, torch.Tensor): print(f"{k}: {v.shape}") else: print(f"{k}: {len(v)}") From 82e2f1d0256fd279c35b9975cbe0ca70ec15216f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:40:45 -0400 Subject: [PATCH 073/195] change back to tensor --- compose_rl/algorithms/offline/model_methods.py | 9 ++++----- compose_rl/data/preference_data.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 272fba8f..00b74f73 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -135,15 +135,14 @@ def pairwise_offline_forward( 'token_type_ids': torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - 'pixel_values': batch['pixel_values'] * 2, # double the list + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), } print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - if isinstance(v, torch.Tensor): - print(f"{k}: {v.shape}") - else: - print(f"{k}: {len(v)}") + print(f"{k}: {v.shape}") inputs.update(multimodal_inputs) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 08dca8bb..90d014b9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -202,6 +202,7 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From 5a1891d4c7e910896ca8c505bf7f86b23891569a Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Tue, 8 Jul 2025 14:47:15 -0400 Subject: [PATCH 074/195] Update offline_data.py --- compose_rl/data/offline_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 91cb94cd..0835dc23 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -162,8 +162,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt = self._read_binary_tokenized_sample(sample, 'prompt') elif isinstance(sample['prompt'], np.ndarray): input_ids = np.concatenate([sample['prompt'], sample['response']]) - sample['input_ids'] = input_ids[:self.max_seq_len].tolist().copy() - prompt = sample['prompt'].tolist().copy() + sample['input_ids'] = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt = torch.from_numpy(sample['prompt']) else: token_type = type(sample['input_ids']) raise ValueError( From 6af3c99695aabf6842fd713c86972c67103c2ef4 Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Tue, 8 Jul 2025 15:03:09 -0400 Subject: [PATCH 075/195] Update offline_data.py --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 0835dc23..ba20c01b 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -162,7 +162,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt = self._read_binary_tokenized_sample(sample, 'prompt') elif isinstance(sample['prompt'], np.ndarray): input_ids = np.concatenate([sample['prompt'], sample['response']]) - sample['input_ids'] = torch.from_numpy(input_ids[:self.max_seq_len]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) prompt = torch.from_numpy(sample['prompt']) else: token_type = type(sample['input_ids']) From dbeec36aaa0f298142ed247b5a40e1587e92fc40 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:26:38 -0400 Subject: [PATCH 076/195] nd array for pixel_values --- compose_rl/data/preference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 90d014b9..f1cab307 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -348,7 +348,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): - pixel_values = torch.Tensor(sample['pixel_values']) + pixel_values = torch.from_numpy(sample['pixel_values']) elif isinstance(sample['pixel_values'], Image.Image): pil_to_tensor_transform = transforms.PILToTensor() pixel_values = pil_to_tensor_transform(sample['pixel_values']) From e521e728a0e3a55da711257c1b679a05cbe3457e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:41:18 -0400 Subject: [PATCH 077/195] fix --- .../algorithms/offline/model_methods.py | 4 +- compose_rl/data/preference_data.py | 56 +++++++++++-------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 00b74f73..abcc7fe9 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,9 +140,9 @@ def pairwise_offline_forward( dim=0), } - print("MULTIMODAL INPUTS") + print('MULTIMODAL INPUTS') for k, v in multimodal_inputs.items(): - print(f"{k}: {v.shape}") + print(f'{k}: {v.shape}') inputs.update(multimodal_inputs) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index f1cab307..71fc2bd9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -8,11 +8,10 @@ import numpy as np import torch -from streaming import StreamingDataset -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer - from PIL import Image +from streaming import StreamingDataset from torchvision import transforms +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer log = logging.getLogger(__name__) @@ -118,9 +117,11 @@ def pairwise_preference_dataset_collate_fn( rejected[-1] = tokenizer.eos_token_id # type: ignore if is_multimodal: - chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] - rejected_token_type_ids = rejected_token_type_ids[:-truncate_len - ] + chosen_token_type_ids = chosen_token_type_ids[: + -truncate_len # type: ignore + ] + rejected_token_type_ids = rejected_token_type_ids[: # type: ignore + -truncate_len] # NOTE: GEMMA specific: 0 == text token chosen_token_type_ids[-1] = 0 @@ -148,14 +149,16 @@ def pairwise_preference_dataset_collate_fn( dim=-1, # type: ignore ) if is_multimodal: - cat_token_type_ids = torch.cat([ - cat_token_type_ids, - torch.zeros( - int(pad_len.item()), - dtype=cat_token_type_ids.dtype, - ), - ], - dim=-1) + cat_token_type_ids = torch.cat( + [ + cat_token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -176,7 +179,7 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards.append(sample['rejected_reward']) if is_multimodal: - token_type_ids.append(cat_token_type_ids) + token_type_ids.append(cat_token_type_ids) # type: ignore pixel_values.append(pixel_vals) input_ids = ref_collate_fn(input_ids)['input_ids'] @@ -200,7 +203,7 @@ def pairwise_preference_dataset_collate_fn( return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards - if is_multimodal: + if is_multimodal: # type: ignore token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids @@ -293,7 +296,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) - + # Handle prompt if available if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses @@ -313,8 +316,14 @@ def __getitem__(self, idx: int) -> dict[str, Any]: elif isinstance(sample['chosen'], np.ndarray): if 'prompt' in sample: - sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) - sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) + sample['chosen'] = np.concatenate([ + sample['prompt'], + sample['chosen'], + ]) + sample['rejected'] = np.concatenate([ + sample['prompt'], + sample['rejected'], + ]) chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) @@ -368,15 +377,18 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'rejected_token_type_ids', ) elif isinstance(sample['chosen_token_type_ids'], np.ndarray): - chosen_token_type_ids = torch.from_numpy(sample['chosen_token_type_ids'][:self.max_seq_len]) - rejected_token_type_ids = torch.from_numpy(sample['rejected_token_type_ids'][:self.max_seq_len]) + chosen_token_type_ids = torch.from_numpy( + sample['chosen_token_type_ids'][:self.max_seq_len], + ) + rejected_token_type_ids = torch.from_numpy( + sample['rejected_token_type_ids'][:self.max_seq_len], + ) else: token_type = type(sample['chosen_token_type_ids']) raise ValueError( f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', ) - return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids return_dict['rejected_token_type_ids'] = rejected_token_type_ids From 83d1e3bef5b42d2bb909ffd5b452705ad0a3fefd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:52:05 -0400 Subject: [PATCH 078/195] fix --- compose_rl/algorithms/offline/model_methods.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index abcc7fe9..95bc1567 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,10 +140,6 @@ def pairwise_offline_forward( dim=0), } - print('MULTIMODAL INPUTS') - for k, v in multimodal_inputs.items(): - print(f'{k}: {v.shape}') - inputs.update(multimodal_inputs) output_logits = model( From b85e4f477c1987b05056bcf95a12ac176ff7612e Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Tue, 8 Jul 2025 17:33:12 -0400 Subject: [PATCH 079/195] Update compose_rl/algorithms/offline/model_methods.py Co-authored-by: bcui-db <141345999+bcui-db@users.noreply.github.com> --- compose_rl/algorithms/offline/model_methods.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 12844c2d..1d12ea4d 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -61,8 +61,6 @@ def offline_forward( model.transformer, # type: ignore ) - # If we can't use attn_seq_id then we need to unpack each batch and - # Pack along the batch dimension instead. output_logits = model( batch['input_ids'], attention_mask=batch['attention_mask'], From 26c0ab8617fc84691471cfa7b611793415f6ded6 Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Tue, 8 Jul 2025 17:34:01 -0400 Subject: [PATCH 080/195] Update compose_rl/data/offline_data.py Co-authored-by: bcui-db <141345999+bcui-db@users.noreply.github.com> --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index ba20c01b..c63b2325 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -60,7 +60,7 @@ def offline_dataset_collate_fn( if pad_len < 0: # We should truncate with an additional token left for eos - truncate_len = pad_len + 1 + truncate_len = abs(pad_len) + 1 log.warning(( f'Sequence length: {sequence_len}' From 99165b0e78e0d6f4d92aea5de4f29b56a184248d Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 9 Jul 2025 10:45:25 -0400 Subject: [PATCH 081/195] remove vstar from preference dat --- compose_rl/data/preference_data.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 1af2f983..93a3fb6c 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -54,7 +54,6 @@ def pairwise_preference_dataset_collate_fn( prompt_lens = [] sequence_id = [] chosen_rewards = [] - vstars = [] rejected_rewards = [] for sample in data: @@ -127,8 +126,6 @@ def pairwise_preference_dataset_collate_fn( if 'chosen_reward' in sample: chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) - if 'vstar' in sample: - vstars.append(sample['vstar']) input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) @@ -150,9 +147,6 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards - if len(vstars) > 0: - vstars = torch.stack(vstars) - return_dict['vstar'] = vstars return return_dict @@ -271,9 +265,6 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward - if 'vstar' in sample: - return_dict['vstar'] = torch.Tensor([sample['vstar']]) - return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): From b68ca4b49be432b7bc476aa882138ca100b059c1 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 9 Jul 2025 10:53:22 -0400 Subject: [PATCH 082/195] add pixel values to forward pass --- compose_rl/algorithms/offline/model.py | 10 +++++-- .../algorithms/offline/model_methods.py | 29 ++++++++++++++----- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 3007d344..108118e0 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -14,7 +14,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from compose_rl.algorithms.offline.model_methods import ( - OfflineEnum, + RegressionOfflineEnum, PairwiseOfflineEnum, offline_forward, offline_loss, @@ -35,11 +35,13 @@ def __init__( loss_type: str = 'apo', beta: float = 0.1, average_log_prob: bool = False, + temperature: float = 1.0, **kwargs: Any, ): - self.loss_type = OfflineEnum(loss_type) + self.loss_type = RegressionOfflineEnum(loss_type) self.beta = beta self.average_log_prob = average_log_prob + self.temperature = temperature super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward @@ -78,11 +80,13 @@ def __init__( loss_type: str = 'apo', beta: float = 0.1, average_log_prob: bool = False, + temperature: float = 1.0, **kwargs: Any, ): - self.loss_type = OfflineEnum(loss_type) + self.loss_type = RegressionOfflineEnum(loss_type) self.beta = beta self.average_log_prob = average_log_prob + self.temperature = temperature super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 18dac341..cb7c1d94 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -26,7 +26,7 @@ ) -class OfflineEnum(Enum): +class RegressionOfflineEnum(Enum): APO = 'apo' @@ -44,6 +44,7 @@ def offline_forward( batch: MutableMapping, average_log_prob: bool = False, policy_model_config: Optional[PretrainedConfig] = None, + temperature: float = 1.0, ) -> dict[str, torch.Tensor]: """Forwards the model for dpo and get the chosen and rejected log probs. @@ -55,16 +56,27 @@ def offline_forward( average_log_prob (bool): Whether should we average the log probabilities. policy_model_config: Policy model config. """ + is_multimodal = 'pixel_values' in batch.keys() + if policy_model_config is not None and hasattr(model, 'transformer'): clear_mb_load_balancing_loss( policy_model_config, model.transformer, # type: ignore ) - output_logits = model( - batch['input_ids'], - attention_mask=batch['attention_mask'], - ).logits + inputs = { + "input_ids": batch['input_ids'], + "attention_mask": batch['attention_mask'], + } + + if is_multimodal: + multimodal_inputs = { + 'token_type_ids': batch['token_type_ids'], + 'pixel_values': batch['pixel_values'], + } + inputs.update(multimodal_inputs) + + output_logits = model(**inputs).logits logps = get_batch_logp( batch['input_ids'], @@ -72,6 +84,7 @@ def offline_forward( batch['prompt_len'], batch['sequence_len'], average_log_prob, + temperature=temperature, ) outputs: dict[str, torch.Tensor] = { @@ -92,7 +105,7 @@ def offline_forward( def offline_loss( outputs: CausalLMOutputWithPast, batch: Mapping, - loss_type: OfflineEnum, + loss_type: RegressionOfflineEnum, beta: float, ): policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -101,7 +114,7 @@ def offline_loss( torch.zeros_like(policy_logp), ) - if loss_type == OfflineEnum.APO: + if loss_type == RegressionOfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 # APO is not a pair-wise loss function. # Similar to REBEL, we assume each response has a reward in the batch. @@ -133,7 +146,7 @@ def offline_loss( 'reverse_kl': reverse_kl, 'forward_kl': forward_kl, } - if loss_type == OfflineEnum.APO: + if loss_type == RegressionOfflineEnum.APO: loss_dict['estimated_reward'] = estimated_reward if 'lbl' in outputs: From ba9e453719dae946fe020fb620840feb19b826de Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 9 Jul 2025 11:21:29 -0400 Subject: [PATCH 083/195] offline single stream multimodal support --- compose_rl/data/offline_data.py | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index c63b2325..fb46d050 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -8,7 +8,9 @@ import numpy as np import torch +from PIL import Image from streaming import StreamingDataset +from torchvision import transforms from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer log = logging.getLogger(__name__) @@ -44,11 +46,23 @@ def offline_dataset_collate_fn( rewards = [] vstars = [] + # For VLMs + batch_token_type_ids = [] + pixel_values = [] + for sample in data: input_ids = sample['input_ids'] prompt_len = sample['prompt_len'] sequence_len = sample['sequence_len'] + is_multimodal = 'pixel_values' in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + token_type_ids = sample['token_type_ids'] + else: + pixel_vals = None + token_type_ids = None + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -72,6 +86,11 @@ def offline_dataset_collate_fn( input_ids = input_ids[:-truncate_len] input_ids[-1] = tokenizer.eos_token_id # type: ignore + if is_multimodal: + token_type_ids = token_type_ids[:-truncate_len] + # NOTE: GEMMA specific: 0 == text token + token_type_ids[-1] = 0 + sequence_len = torch.tensor([len(sequence_len)]) pad_len = max_seq_len - sequence_len @@ -85,6 +104,17 @@ def offline_dataset_collate_fn( ], dim=-1, # type: ignore ) + if is_multimodal: + token_type_ids = torch.cat( + [ + token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) attention_mask = torch.logical_not( torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore @@ -99,6 +129,10 @@ def offline_dataset_collate_fn( if 'vstar' in sample: vstars.append(sample['vstar']) + if is_multimodal: + batch_token_type_ids.append(token_type_ids) # type: ignore + pixel_values.append(pixel_vals) + batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) @@ -117,6 +151,12 @@ def offline_dataset_collate_fn( vstars = torch.cat(vstars) return_dict['vstar'] = vstars + if is_multimodal: # type: ignore + token_type_ids = torch.stack(batch_token_type_ids) + pixel_values = torch.stack(pixel_values) + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + return return_dict @@ -186,4 +226,34 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'vstar' in sample: return_dict['vstar'] = torch.Tensor([sample['vstar']]) + if 'pixel_values' in sample: + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.from_numpy(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image.Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + + if isinstance(sample['token_type_ids'], bytes): + token_type_ids = self._read_binary_tokenized_sample( + sample, + 'token_type_ids', + ) + elif isinstance(sample['token_type_ids'], np.ndarray): + token_type_ids = torch.from_numpy( + sample['token_type_ids'][:self.max_seq_len], + ) + else: + token_type = type(sample['token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + + return_dict['pixel_values'] = pixel_values + return_dict['token_type_ids'] = token_type_ids + return return_dict From 06d6a2f6d366a4975df5dec5f69f83035ba09fab Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 9 Jul 2025 11:31:31 -0400 Subject: [PATCH 084/195] temperature scaling --- compose_rl/algorithms/offline/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 108118e0..58783e2e 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -97,6 +97,7 @@ def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: model=self.model, batch=batch, average_log_prob=self.average_log_prob, + temperature=self.temperature, ) def eval_forward( From c2da2f65ada1e3d31f3f4835663f58fded2999e9 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 9 Jul 2025 11:39:25 -0400 Subject: [PATCH 085/195] add processor to Dataset to ensure proper HF checkpointing --- compose_rl/data/offline_data.py | 11 ++++++++--- compose_rl/data/preference_data.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index fb46d050..33fdc17c 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -4,14 +4,14 @@ """Build a reward dataset and dataloader for training.""" import logging -from typing import Any +from typing import Any, Optional import numpy as np import torch from PIL import Image from streaming import StreamingDataset from torchvision import transforms -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, AutoProcessor log = logging.getLogger(__name__) @@ -163,12 +163,17 @@ def offline_dataset_collate_fn( class OfflineStreamingDataset(StreamingDataset): """Dataloader for streaming in preference data.""" - def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): + def __init__(self, max_seq_len: int, processor_name: Optional[str] = None, **kwargs: dict[str, Any]): self.max_seq_len = max_seq_len super().__init__(**kwargs) self.num_truncated = 0 self.num_read = 0 + # For proper multimodal HF checkpointing + self.processor = None + if processor_name is not None: + self.processor = AutoProcessor.from_pretrained(processor_name) + def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 71fc2bd9..f591996f 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -4,14 +4,14 @@ """Build a reward dataset and dataloader for training.""" import logging -from typing import Any +from typing import Any, Optional import numpy as np import torch from PIL import Image from streaming import StreamingDataset from torchvision import transforms -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, AutoProcessor log = logging.getLogger(__name__) @@ -266,12 +266,17 @@ def finegrained_preference_dataset_collate_fn( class PairwisePreferenceStreamingDataset(StreamingDataset): """Dataloader for streaming in preference data.""" - def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): + def __init__(self, max_seq_len: int, processor_name: Optional[str] = None, **kwargs: dict[str, Any]): self.max_seq_len = max_seq_len super().__init__(**kwargs) self.num_truncated = 0 self.num_read = 0 + # For proper multimodal HF checkpointing + self.processor = None + if processor_name is not None: + self.processor = AutoProcessor.from_pretrained(processor_name) + def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) From 952a976a92c2a87e0d50a30d9082fb165af7b245 Mon Sep 17 00:00:00 2001 From: wensun Date: Thu, 10 Jul 2025 19:43:29 -0400 Subject: [PATCH 086/195] quick test for shape --- compose_rl/algorithms/offline/model_methods.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index cb7c1d94..072ef5c0 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -109,6 +109,12 @@ def offline_loss( beta: float, ): policy_logp = outputs['policy_logp'] # (batch_size, ) + + #test shape + print("################################################") + print(policy_logp.shape) + print("################################################") + ref_logp = batch.get( 'ref_logp', torch.zeros_like(policy_logp), From 4f5fb8b891f81f7a333e14e63a40d7330d39d09c Mon Sep 17 00:00:00 2001 From: wensun Date: Thu, 10 Jul 2025 19:51:22 -0400 Subject: [PATCH 087/195] convert it back --- compose_rl/algorithms/offline/model_methods.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 072ef5c0..3987e194 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -110,11 +110,6 @@ def offline_loss( ): policy_logp = outputs['policy_logp'] # (batch_size, ) - #test shape - print("################################################") - print(policy_logp.shape) - print("################################################") - ref_logp = batch.get( 'ref_logp', torch.zeros_like(policy_logp), From ddd0ad7e762af70321de218cd825c5b4f44aae0a Mon Sep 17 00:00:00 2001 From: wensun Date: Thu, 10 Jul 2025 20:00:56 -0400 Subject: [PATCH 088/195] add another metric to track batch advantage --- compose_rl/algorithms/offline/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 3987e194..c458d796 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -149,6 +149,7 @@ def offline_loss( } if loss_type == RegressionOfflineEnum.APO: loss_dict['estimated_reward'] = estimated_reward + loss_dict['batch_advantage'] = torch.mean(batch['reward'] - batch['vstar']) if 'lbl' in outputs: losses += outputs['lbl'] From 7eb0951f1ecba22e8eab6809f88a9ecfe56e6dd1 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 11 Jul 2025 10:40:31 -0400 Subject: [PATCH 089/195] . --- compose_rl/algorithms/offline/model_methods.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c458d796..0d906c2e 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -65,8 +65,8 @@ def offline_forward( ) inputs = { - "input_ids": batch['input_ids'], - "attention_mask": batch['attention_mask'], + 'input_ids': batch['input_ids'], + 'attention_mask': batch['attention_mask'], } if is_multimodal: @@ -109,7 +109,7 @@ def offline_loss( beta: float, ): policy_logp = outputs['policy_logp'] # (batch_size, ) - + ref_logp = batch.get( 'ref_logp', torch.zeros_like(policy_logp), @@ -149,7 +149,9 @@ def offline_loss( } if loss_type == RegressionOfflineEnum.APO: loss_dict['estimated_reward'] = estimated_reward - loss_dict['batch_advantage'] = torch.mean(batch['reward'] - batch['vstar']) + loss_dict['batch_advantage'] = torch.mean( + batch['reward'] - batch['vstar'] + ) if 'lbl' in outputs: losses += outputs['lbl'] From 0a3b8cd8967216951cf3624e8341d3372217d0ab Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 11 Jul 2025 10:40:56 -0400 Subject: [PATCH 090/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 0d906c2e..c7f77be5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -150,7 +150,7 @@ def offline_loss( if loss_type == RegressionOfflineEnum.APO: loss_dict['estimated_reward'] = estimated_reward loss_dict['batch_advantage'] = torch.mean( - batch['reward'] - batch['vstar'] + batch['reward'] - batch['vstar'], ) if 'lbl' in outputs: From 5d0c53fab19097e090b733ed9bc7d9412d283f0d Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Fri, 11 Jul 2025 15:53:31 -0400 Subject: [PATCH 091/195] quick fix --- compose_rl/data/offline_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 33fdc17c..e32ec220 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -231,6 +231,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'vstar' in sample: return_dict['vstar'] = torch.Tensor([sample['vstar']]) + if 'v-star' in sample: + return_dict['vstar'] = torch.Tensor([sample['v-star']]) + if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): pixel_values = torch.from_numpy(sample['pixel_values']) From 144d80ecfff977f6be6252f05ac79bbb07fcca36 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 12 Jul 2025 16:49:11 -0400 Subject: [PATCH 092/195] added bce --- compose_rl/algorithms/offline/model.py | 5 +++++ compose_rl/algorithms/offline/model_methods.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 58783e2e..9b15d7fe 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,12 +34,14 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, + bce: bool = False, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) self.beta = beta + self.bce = bce self.average_log_prob = average_log_prob self.temperature = temperature @@ -69,6 +71,7 @@ def loss(self, outputs: CausalLMOutputWithPast, batch, self.loss_type, self.beta, + self.bce ) @@ -79,6 +82,7 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, + bce: bool = False, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, @@ -114,6 +118,7 @@ def loss(self, outputs: CausalLMOutputWithPast, batch, self.loss_type, self.beta, + self.bce, ) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c7f77be5..9ac00372 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -107,6 +107,7 @@ def offline_loss( batch: Mapping, loss_type: RegressionOfflineEnum, beta: float, + bce: bool = False, ): policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -121,10 +122,17 @@ def offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch - losses = ( - beta * (policy_logp - ref_logp) - - (batch['reward'] - batch['vstar']) - )**2 + if bce == False: + losses = ( + beta * (policy_logp - ref_logp) - + (batch['reward'] - batch['vstar']) + )**2 + elif bce == True: + predicted_prob = F.sigmoid(beta * (policy_logp - ref_logp)) + actual_prob = F.sigmoid(batch['reward'] - batch['vstar']) + losses = -(actual_prob * torch.log(predicted_prob) + + (1.-actual_prob)*torch.log(1.-predicted_prob) + ) # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref From ba2d9e89ca4c9d3da47218c223928ea7207c2e72 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 12 Jul 2025 21:00:27 -0400 Subject: [PATCH 093/195] temporally just set bce to be true --- compose_rl/algorithms/offline/model.py | 4 ++-- compose_rl/algorithms/offline/model_methods.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 9b15d7fe..01569056 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,7 +34,7 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, - bce: bool = False, + bce: bool = True, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, @@ -82,7 +82,7 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, - bce: bool = False, + bce: bool = True, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 9ac00372..8d2b4a05 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -107,7 +107,7 @@ def offline_loss( batch: Mapping, loss_type: RegressionOfflineEnum, beta: float, - bce: bool = False, + bce: bool = True, ): policy_logp = outputs['policy_logp'] # (batch_size, ) From b94b59460387c619da32a1441944578a6654a387 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 12 Jul 2025 21:05:00 -0400 Subject: [PATCH 094/195] . --- compose_rl/algorithms/offline/model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 01569056..58783e2e 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -34,14 +34,12 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, - bce: bool = True, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) self.beta = beta - self.bce = bce self.average_log_prob = average_log_prob self.temperature = temperature @@ -71,7 +69,6 @@ def loss(self, outputs: CausalLMOutputWithPast, batch, self.loss_type, self.beta, - self.bce ) @@ -82,7 +79,6 @@ def __init__( self, loss_type: str = 'apo', beta: float = 0.1, - bce: bool = True, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, @@ -118,7 +114,6 @@ def loss(self, outputs: CausalLMOutputWithPast, batch, self.loss_type, self.beta, - self.bce, ) From 94359e6726d0bdb68fdd7d650ba82f5c37f29b0e Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 12 Jul 2025 21:14:19 -0400 Subject: [PATCH 095/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 8d2b4a05..9ac00372 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -107,7 +107,7 @@ def offline_loss( batch: Mapping, loss_type: RegressionOfflineEnum, beta: float, - bce: bool = True, + bce: bool = False, ): policy_logp = outputs['policy_logp'] # (batch_size, ) From 844470db015f1e210a30281a781798eef2ec0280 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 13 Jul 2025 16:36:47 -0400 Subject: [PATCH 096/195] computation --- .../algorithms/offline/model_methods.py | 22 ++++++++++++++----- compose_rl/data/offline_data.py | 20 +++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 9ac00372..9856c974 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -106,7 +106,8 @@ def offline_loss( outputs: CausalLMOutputWithPast, batch: Mapping, loss_type: RegressionOfflineEnum, - beta: float, + beta1: float, + beta2: float, bce: bool = False, ): policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -121,15 +122,24 @@ def offline_loss( # APO is not a pair-wise loss function. # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch + # + vstar = batch.get('vstar', None) + if vstar is None: + vstar_rewards = batch.get('vstar_rewards', None) + assert vstar_rewards is not None + exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) + vstar = beta1 * torch.log(exponentiated_mean) + + assert vstar.shape == batch['reward'].shape if bce == False: losses = ( - beta * (policy_logp - ref_logp) - - (batch['reward'] - batch['vstar']) + beta2 * (policy_logp - ref_logp) - + (batch['reward'] - vstar) )**2 elif bce == True: - predicted_prob = F.sigmoid(beta * (policy_logp - ref_logp)) - actual_prob = F.sigmoid(batch['reward'] - batch['vstar']) + predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) + actual_prob = F.sigmoid(batch['reward'] - vstar) losses = -(actual_prob * torch.log(predicted_prob) + (1.-actual_prob)*torch.log(1.-predicted_prob) ) @@ -145,7 +155,7 @@ def offline_loss( losses = losses.mean() - implicit_rewards = beta * (policy_logp - ref_logp).detach() + implicit_rewards = beta2 * (policy_logp - ref_logp).detach() # Logging KL margins for comparing different methods reverse_kl = (policy_logp - ref_logp).detach() diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index e32ec220..a5e7886a 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -45,6 +45,7 @@ def offline_dataset_collate_fn( prompt_lens = [] rewards = [] vstars = [] + vstar_rewards = [] # For VLMs batch_token_type_ids = [] @@ -128,6 +129,8 @@ def offline_dataset_collate_fn( rewards.append(sample['reward']) if 'vstar' in sample: vstars.append(sample['vstar']) + if 'vstar_rewards' in sample: + vstar_rewards.append(sample['vstar_rewards']) if is_multimodal: batch_token_type_ids.append(token_type_ids) # type: ignore @@ -150,6 +153,9 @@ def offline_dataset_collate_fn( if len(vstars) > 0: vstars = torch.cat(vstars) return_dict['vstar'] = vstars + if len(vstar_rewards) > 0: + vstar_rewards = torch.stack(vstar_rewards) + return_dict['vstar_rewards'] = vstar_rewards if is_multimodal: # type: ignore token_type_ids = torch.stack(batch_token_type_ids) @@ -229,11 +235,25 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['reward'] = torch.Tensor([sample['reward']]) if 'vstar' in sample: + assert 'vstar-rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['vstar']]) if 'v-star' in sample: + assert 'vstar-rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['v-star']]) + if 'vstar-rewards' in sample: + assert 'vstar' not in sample + assert 'v-star' not in sample + if isinstance(sample['vstar_rewards'], np.ndarray): + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar-rewards']) + else: + rewards_type = type(sample['vstar_rewards']) + raise ValueError( + f'Expect vstar_rewards to be numpy.ndarray type, but got {rewards_type}', + ) + + if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): pixel_values = torch.from_numpy(sample['pixel_values']) From 8e25f8106e9c1edd6ff9185f740a1987f96473a7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 13 Jul 2025 16:39:56 -0400 Subject: [PATCH 097/195] model compatibility --- compose_rl/algorithms/offline/model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 58783e2e..7747ded5 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -33,13 +33,15 @@ class ComposerMPTOfflinePolicyLM(ComposerMPTCausalLM): def __init__( self, loss_type: str = 'apo', - beta: float = 0.1, + beta1: float = 0.5, + beta2: float = 0.1, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) - self.beta = beta + self.beta1 = beta1 + self.beta2 = beta2 self.average_log_prob = average_log_prob self.temperature = temperature @@ -68,7 +70,8 @@ def loss(self, outputs: CausalLMOutputWithPast, outputs, batch, self.loss_type, - self.beta, + self.beta1, + self.beta2, ) @@ -78,13 +81,15 @@ class ComposerHFOfflinePolicyLM(ComposerHFCausalLM): def __init__( self, loss_type: str = 'apo', - beta: float = 0.1, + beta1: float = 0.5, + beta2: float = 0.1, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) - self.beta = beta + self.beta1 = beta1 + self.beta2 = beta2 self.average_log_prob = average_log_prob self.temperature = temperature @@ -113,7 +118,8 @@ def loss(self, outputs: CausalLMOutputWithPast, outputs, batch, self.loss_type, - self.beta, + self.beta1, + self.beta2, ) From 171b0f14f8499e379f5b1d43001809808aa4fe84 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 13 Jul 2025 17:03:29 -0400 Subject: [PATCH 098/195] multistep computation --- compose_rl/algorithms/offline/model.py | 6 ++++++ compose_rl/algorithms/offline/model_methods.py | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 7747ded5..33a96637 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -35,6 +35,7 @@ def __init__( loss_type: str = 'apo', beta1: float = 0.5, beta2: float = 0.1, + multistep: bool = False, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, @@ -42,6 +43,7 @@ def __init__( self.loss_type = RegressionOfflineEnum(loss_type) self.beta1 = beta1 self.beta2 = beta2 + self.multistep = multistep self.average_log_prob = average_log_prob self.temperature = temperature @@ -72,6 +74,7 @@ def loss(self, outputs: CausalLMOutputWithPast, self.loss_type, self.beta1, self.beta2, + self.multistep, ) @@ -83,6 +86,7 @@ def __init__( loss_type: str = 'apo', beta1: float = 0.5, beta2: float = 0.1, + multistep: bool = False, average_log_prob: bool = False, temperature: float = 1.0, **kwargs: Any, @@ -90,6 +94,7 @@ def __init__( self.loss_type = RegressionOfflineEnum(loss_type) self.beta1 = beta1 self.beta2 = beta2 + self.multistep = multistep self.average_log_prob = average_log_prob self.temperature = temperature @@ -120,6 +125,7 @@ def loss(self, outputs: CausalLMOutputWithPast, self.loss_type, self.beta1, self.beta2, + self.multistep, ) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 9856c974..e78be26b 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -108,6 +108,7 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, + multistep: bool = False, bce: bool = False, ): policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -127,7 +128,13 @@ def offline_loss( if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None - exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) + if not multistep: + exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) + else: + exponentiated_mean = torch.mean( + vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), + dim=-1, + ) vstar = beta1 * torch.log(exponentiated_mean) assert vstar.shape == batch['reward'].shape From 26ec9a421e255eb8b573f7823b662996182e4c3f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 14 Jul 2025 10:47:50 -0400 Subject: [PATCH 099/195] multimodal fix --- compose_rl/algorithms/offline/model_methods.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index e78be26b..0a49f0de 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -139,6 +139,9 @@ def offline_loss( assert vstar.shape == batch['reward'].shape + # temporary + vstar = beta1 * torch.log(vstar) + if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - From 331ae42b50ae3a182c7741b872b583a24233e272 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 14 Jul 2025 11:40:17 -0400 Subject: [PATCH 100/195] batch advantage computation fix --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 0a49f0de..8f0a2a99 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -178,7 +178,7 @@ def offline_loss( if loss_type == RegressionOfflineEnum.APO: loss_dict['estimated_reward'] = estimated_reward loss_dict['batch_advantage'] = torch.mean( - batch['reward'] - batch['vstar'], + batch['reward'] - vstar, ) if 'lbl' in outputs: From 29975331ef313e3c172e6b0daaaf4be4ceeb884d Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 14 Jul 2025 14:10:52 -0400 Subject: [PATCH 101/195] qrpo --- .../algorithms/offline/model_methods.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 8f0a2a99..55db989e 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -28,6 +28,7 @@ class RegressionOfflineEnum(Enum): APO = 'apo' + QRPO = 'qrpo' class PairwiseOfflineEnum(Enum): @@ -139,9 +140,6 @@ def offline_loss( assert vstar.shape == batch['reward'].shape - # temporary - vstar = beta1 * torch.log(vstar) - if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - @@ -153,15 +151,24 @@ def offline_loss( losses = -(actual_prob * torch.log(predicted_prob) + (1.-actual_prob)*torch.log(1.-predicted_prob) ) - - # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) - # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref - # use clip to ensure the output from exp is valid - with torch.no_grad(): - estimated_rewards = torch.exp( - torch.clip(policy_logp - ref_logp, max=5.), - ) * batch['reward'] - estimated_reward = torch.mean(estimated_rewards) + elif loss_type == RegressionOfflineEnum.QRPO: + vstar_rewards = batch.get('vstar_rewards', None) + assert vstar_rewards is not None + if not multistep: + reward_q = torch.mean((batch['reward'].view(-1, 1) < vstar_rewards).float(), dim=-1) + else: + raise NotImplementedError("Multistep for QRPO not implemented") + + losses = (reward_q - beta2 * torch.log(beta2) - 1 - beta2 * (policy_logp - ref_logp)) ** 2 + + # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) + # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref + # use clip to ensure the output from exp is valid + with torch.no_grad(): + estimated_rewards = torch.exp( + torch.clip(policy_logp - ref_logp, max=5.), + ) * batch['reward'] + estimated_reward = torch.mean(estimated_rewards) losses = losses.mean() @@ -174,9 +181,9 @@ def offline_loss( 'implicit_rewards': implicit_rewards, 'reverse_kl': reverse_kl, 'forward_kl': forward_kl, + 'estimated_reward': estimated_reward, } if loss_type == RegressionOfflineEnum.APO: - loss_dict['estimated_reward'] = estimated_reward loss_dict['batch_advantage'] = torch.mean( batch['reward'] - vstar, ) From aaef5f84f683811f854f8d40ea7389e7d52d0cd6 Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Mon, 14 Jul 2025 15:24:14 -0400 Subject: [PATCH 102/195] Update model_methods.py --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 55db989e..c09ca665 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -155,7 +155,7 @@ def offline_loss( vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None if not multistep: - reward_q = torch.mean((batch['reward'].view(-1, 1) < vstar_rewards).float(), dim=-1) + reward_q = torch.mean((batch['reward'].view(-1, 1) >= vstar_rewards).float(), dim=-1) else: raise NotImplementedError("Multistep for QRPO not implemented") From 204df50a2ed558879fd2751dbdb38eef49c0e145 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 14 Jul 2025 16:49:02 -0400 Subject: [PATCH 103/195] fix --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index a5e7886a..262b6a58 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -246,7 +246,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: assert 'vstar' not in sample assert 'v-star' not in sample if isinstance(sample['vstar_rewards'], np.ndarray): - return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar-rewards']) + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) else: rewards_type = type(sample['vstar_rewards']) raise ValueError( From 2e50aff32807cd2c1edf3f3161405d7e2e8ac837 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 14 Jul 2025 16:50:17 -0400 Subject: [PATCH 104/195] fix --- compose_rl/data/offline_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 262b6a58..6d31c6ed 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -235,14 +235,14 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['reward'] = torch.Tensor([sample['reward']]) if 'vstar' in sample: - assert 'vstar-rewards' not in sample + assert 'vstar_rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['vstar']]) if 'v-star' in sample: - assert 'vstar-rewards' not in sample + assert 'vstar_rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['v-star']]) - if 'vstar-rewards' in sample: + if 'vstar_rewards' in sample: assert 'vstar' not in sample assert 'v-star' not in sample if isinstance(sample['vstar_rewards'], np.ndarray): From fdf8ec27e96315feb21a72a44c9b1818209c3383 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 17 Jul 2025 15:22:25 -0400 Subject: [PATCH 105/195] remove the need for preprocessed image inputs --- compose_rl/data/offline_data.py | 42 ++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 6d31c6ed..b59a03e1 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -13,6 +13,28 @@ from torchvision import transforms from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, AutoProcessor +import base64 +from io import BytesIO + +def base64_to_pil(base64_string: str): + """ + Converts a base64 string to a PIL Image object. + + Args: + base64_string: The base64 encoded string of the image. + + Returns: + A PIL Image object, or None if an error occurs. + """ + try: + image_data = base64.b64decode(base64_string) + image_stream = BytesIO(image_data) + image = Image.open(image_stream) + return image + except Exception as e: + print(f"Error decoding base64 string: {e}") + return None + log = logging.getLogger(__name__) @@ -253,8 +275,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: f'Expect vstar_rewards to be numpy.ndarray type, but got {rewards_type}', ) - + # Gemma 3 Specific if 'pixel_values' in sample: + assert 'image' not in sample if isinstance(sample['pixel_values'], np.ndarray): pixel_values = torch.from_numpy(sample['pixel_values']) elif isinstance(sample['pixel_values'], Image.Image): @@ -284,4 +307,21 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['pixel_values'] = pixel_values return_dict['token_type_ids'] = token_type_ids + if 'image' in sample: + assert 'pixel_values' not in sample + assert 'token_type_ids' not in sample + assert self.processor is not None + + text = self.processor.decode(return_dict['input_ids']) + input_tensors = self.processor( + text=text, + images=base64_to_pil(sample['image']), + return_tensors="pt", + padding=False, + truncation=False, + ) + return_dict['pixel_values'] = input_tensors['pixel_values'][0] + return_dict['token_type_ids'] = input_tensors['token_type_ids'][0] + assert return_dict['token_type_ids'].size(-1) == return_dict['input_ids'].size(-1) + return return_dict From 43b4324105c9f8277e0d74f61b298032c48e9efa Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 17 Jul 2025 15:25:21 -0400 Subject: [PATCH 106/195] add token type ids back in --- compose_rl/data/offline_data.py | 35 +++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index b59a03e1..b341a26d 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -288,7 +288,23 @@ def __getitem__(self, idx: int) -> dict[str, Any]: raise ValueError( f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', ) + return_dict['pixel_values'] = pixel_values + + if 'image' in sample: + assert 'pixel_values' not in sample + assert self.processor is not None + + text = self.processor.decode(return_dict['input_ids']) + input_tensors = self.processor( + text=text, + images=base64_to_pil(sample['image']), + return_tensors="pt", + padding=False, + truncation=False, + ) + return_dict['pixel_values'] = input_tensors['pixel_values'][0] + if 'token_type_ids' in sample: if isinstance(sample['token_type_ids'], bytes): token_type_ids = self._read_binary_tokenized_sample( sample, @@ -303,25 +319,6 @@ def __getitem__(self, idx: int) -> dict[str, Any]: raise ValueError( f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', ) - - return_dict['pixel_values'] = pixel_values return_dict['token_type_ids'] = token_type_ids - if 'image' in sample: - assert 'pixel_values' not in sample - assert 'token_type_ids' not in sample - assert self.processor is not None - - text = self.processor.decode(return_dict['input_ids']) - input_tensors = self.processor( - text=text, - images=base64_to_pil(sample['image']), - return_tensors="pt", - padding=False, - truncation=False, - ) - return_dict['pixel_values'] = input_tensors['pixel_values'][0] - return_dict['token_type_ids'] = input_tensors['token_type_ids'][0] - assert return_dict['token_type_ids'].size(-1) == return_dict['input_ids'].size(-1) - return return_dict From da38deee2e3caff10e31c333cb73d11a72a2e5c8 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 15:37:49 -0400 Subject: [PATCH 107/195] initial set up for dealing with multi turn dataformat --- compose_rl/data/offline_data.py | 76 +++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index b341a26d..0c391586 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -78,6 +78,13 @@ def offline_dataset_collate_fn( prompt_len = sample['prompt_len'] sequence_len = sample['sequence_len'] + # for multi-turn tool-use, which contains masks that mask out non-assistant turns + # it contains 1 at the token positions from the assistant turns, and 0 otherwise + has_mask = 'mask' in sample.keys() + if has_mask: + mask = sample['mask'] # torch tensor array + assert mask.shape == input_ids.shape + is_multimodal = 'pixel_values' in sample.keys() if is_multimodal: pixel_vals = sample['pixel_values'] @@ -119,6 +126,7 @@ def offline_dataset_collate_fn( pad_len = max_seq_len - sequence_len if pad_len > 0: + # right padding with padding token input_ids = torch.cat( [ input_ids, @@ -142,10 +150,15 @@ def offline_dataset_collate_fn( attention_mask = torch.logical_not( torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore ) + if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm + if len(mask) <= len(attention_mask): # this happens when we padded input_ids + attention_mask[0:len(mask)] *= mask # zero out the token positions that do not belong to the assistant turns + else: # this happens when we truncate input_id + attention_mask *= mask[0:len(attention_mask)] batch_input_ids.append(input_ids) attention_masks.append(attention_mask) - sequence_lens.append(sequence_len) + sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? prompt_lens.append(prompt_len) if 'reward' in sample: rewards.append(sample['reward']) @@ -157,6 +170,7 @@ def offline_dataset_collate_fn( if is_multimodal: batch_token_type_ids.append(token_type_ids) # type: ignore pixel_values.append(pixel_vals) + batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) @@ -228,30 +242,48 @@ def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) # Read Samples - input_ids, prompt = [], [] - if isinstance(sample['prompt'], bytes): - sample['input_ids'] = sample['prompt'] + sample['response'] - input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') - prompt = self._read_binary_tokenized_sample(sample, 'prompt') - elif isinstance(sample['prompt'], np.ndarray): - input_ids = np.concatenate([sample['prompt'], sample['response']]) - input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) - prompt = torch.from_numpy(sample['prompt']) - else: - token_type = type(sample['input_ids']) - raise ValueError( - f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', - ) - - # Get Lenghts - prompt_len = len(prompt) - sequence_len = len(input_ids) - + if 'prompt' in sample: + input_ids, prompt = [], [] + if isinstance(sample['prompt'], bytes): + sample['input_ids'] = sample['prompt'] + sample['response'] + input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + elif isinstance(sample['prompt'], np.ndarray): + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt = torch.from_numpy(sample['prompt']) + else: + token_type = type(sample['input_ids']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + # Get Lenghts + prompt_len = len(prompt) + sequence_len = len(input_ids) + + elif 'input' in sample and 'mask' in sample: # input already combines prompt and reponse (e.g., used in the tool call setup) + input_ids, mask = [],[] + if isinstance(sample['input'], bytes): + input_ids = self._read_binary_tokenized_sample(sample, 'input') + mask = self._read_binary_tokenized_sample(sample, 'mask') + elif isinstance(sample['input'], np.ndarray): + input_ids = torch.from_numpy(sample['input'], dtype = torch.int64) + mask = torch.from_numpy(sample['mask'], dtype=torch.int64) + else: + token_type = type(sample['input']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + return_dict = { 'input_ids': input_ids, - 'sequence_len': torch.Tensor([sequence_len]).to(torch.int64), - 'prompt_len': torch.Tensor([prompt_len]).to(torch.int64), + 'sequence_len': torch.Tensor([len(input_ids)]).to(torch.int64), + 'prompt_len': torch.Tensor([0]).to(torch.int64) } + + if 'mask' in sample and 'input' in sample: + return_dict['mask'] = mask + # If rewards are given, add them to the return dict if 'reward' in sample: return_dict['reward'] = torch.Tensor([sample['reward']]) From 51912d3cf02e5593711c7ba4c96c5ccb76e70866 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 17:43:07 -0400 Subject: [PATCH 108/195] . --- compose_rl/data/offline_data.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 0c391586..d697bd58 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -69,6 +69,9 @@ def offline_dataset_collate_fn( vstars = [] vstar_rewards = [] + # for multi-turn + masks = [] + # For VLMs batch_token_type_ids = [] pixel_values = [] @@ -151,15 +154,19 @@ def offline_dataset_collate_fn( torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore ) if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm - if len(mask) <= len(attention_mask): # this happens when we padded input_ids - attention_mask[0:len(mask)] *= mask # zero out the token positions that do not belong to the assistant turns - else: # this happens when we truncate input_id - attention_mask *= mask[0:len(attention_mask)] + if len(mask) <= len(attention_mask): # this happens when we padded input_ids, so we should pad mask + mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=token_type_ids.dtype)],dim =-1) + else: # this happens when we truncate input_id, so we truncate mask + mask = mask[0: len(attention_mask)] + assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape batch_input_ids.append(input_ids) attention_masks.append(attention_mask) sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? prompt_lens.append(prompt_len) + + if has_mask: + masks.append(mask) if 'reward' in sample: rewards.append(sample['reward']) if 'vstar' in sample: @@ -183,6 +190,10 @@ def offline_dataset_collate_fn( 'input_ids': batch_input_ids, 'attention_mask': attention_masks, } + if len(masks) > 0: + masks = torch.stack(masks) + assert masks.shape == attention_masks.shape + return_dict['mask'] = masks if len(rewards) > 0: rewards = torch.cat(rewards) return_dict['reward'] = rewards @@ -274,11 +285,13 @@ def __getitem__(self, idx: int) -> dict[str, Any]: raise ValueError( f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', ) + prompt_len = 0 + sequence_len = len(input_ids) return_dict = { 'input_ids': input_ids, - 'sequence_len': torch.Tensor([len(input_ids)]).to(torch.int64), - 'prompt_len': torch.Tensor([0]).to(torch.int64) + 'sequence_len': torch.Tensor([len(sequence_len)]).to(torch.int64), + 'prompt_len': torch.Tensor([prompt_len]).to(torch.int64) } if 'mask' in sample and 'input' in sample: From a48305bf162127d9588c8aeff98091c36c0a8c2e Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 19:28:53 -0400 Subject: [PATCH 109/195] . --- compose_rl/algorithms/offline/callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 1a94dd21..734eeabe 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -81,8 +81,10 @@ def before_forward(self, state: State, logger: Logger) -> Optional[int]: with torch.no_grad(): assert self.reference_model is not None reference_outputs = self.reference_model(state.batch) + print(reference_outputs.keys()) state.batch.update({ 'ref_logp': reference_outputs['policy_logp'], + #'raw_ref_logits': reference_outputs.logits, }) From 1e9b8554b62e2ef9254e6f9da1158038068db4db Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 19:33:12 -0400 Subject: [PATCH 110/195] . --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index d697bd58..492c57f5 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -290,7 +290,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict = { 'input_ids': input_ids, - 'sequence_len': torch.Tensor([len(sequence_len)]).to(torch.int64), + 'sequence_len': torch.Tensor([sequence_len]).to(torch.int64), 'prompt_len': torch.Tensor([prompt_len]).to(torch.int64) } From 0ac2d29a64f90a152c6458470d62ef9681fc38ff Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 19:43:12 -0400 Subject: [PATCH 111/195] . --- .../algorithms/offline/model_methods.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c09ca665..75461933 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -23,6 +23,7 @@ extract_packed_chosen_rejected, get_batch_logp, get_mb_load_balancing_loss, + get_log_probs_from_logits, ) @@ -92,6 +93,8 @@ def offline_forward( 'policy_logp': logps, } + outputs["raw_logits"] = output_logits # raw logits (bs, seq_len, vocab_size) + if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( policy_model_config, @@ -119,6 +122,29 @@ def offline_loss( torch.zeros_like(policy_logp), ) + has_mask = 'mask' in batch + if has_mask: + raw_policy_logits = outputs['raw_logits'][:,:-1] + assert 'raw_ref_logits' in batch + raw_ref_logits = batch['raw_ref_logits'][:,:-1] + assert raw_policy_logits.shape == raw_ref_logits + policy_logps = get_log_probs_from_logits(raw_policy_logits, batch['input_ids'][:,1:]) + ref_logps = get_log_probs_from_logits(raw_ref_logits, batch['input_ids'][:,1:]) + assert policy_logps.size(1) == batch['input_ids'].size(1) - 1 + + #apply masks + #1 apply attention mask + policy_logps *= batch["attention_mask"][:,1:] # shift right by 1. + ref_logps *= batch["attention_mask"][:,1:] + # apply position mask + policy_logps *= batch['mask'][:,1:] + ref_logps *= batch['mask'][:,1:] + + policy_logp = torch.sum(policy_logps, dim = -1) + ref_logp = torch.sum(ref_logps, dim = -1) + + + if loss_type == RegressionOfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 # APO is not a pair-wise loss function. From 00304a0793470635f692457f703bbbcd552ae054 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 20:20:05 -0400 Subject: [PATCH 112/195] . --- compose_rl/algorithms/offline/callback.py | 3 +-- .../algorithms/offline/model_methods.py | 23 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 734eeabe..99295ed0 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -81,10 +81,9 @@ def before_forward(self, state: State, logger: Logger) -> Optional[int]: with torch.no_grad(): assert self.reference_model is not None reference_outputs = self.reference_model(state.batch) - print(reference_outputs.keys()) state.batch.update({ 'ref_logp': reference_outputs['policy_logp'], - #'raw_ref_logits': reference_outputs.logits, + 'raw_ref_logits': reference_outputs['raw_logits'], }) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 75461933..56793f67 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -115,25 +115,18 @@ def offline_loss( multistep: bool = False, bce: bool = False, ): - policy_logp = outputs['policy_logp'] # (batch_size, ) - - ref_logp = batch.get( - 'ref_logp', - torch.zeros_like(policy_logp), - ) - - has_mask = 'mask' in batch + + has_mask = 'mask' in batch # handling mask explicitly if has_mask: raw_policy_logits = outputs['raw_logits'][:,:-1] assert 'raw_ref_logits' in batch raw_ref_logits = batch['raw_ref_logits'][:,:-1] - assert raw_policy_logits.shape == raw_ref_logits + assert raw_policy_logits.shape == raw_ref_logits.shape policy_logps = get_log_probs_from_logits(raw_policy_logits, batch['input_ids'][:,1:]) ref_logps = get_log_probs_from_logits(raw_ref_logits, batch['input_ids'][:,1:]) - assert policy_logps.size(1) == batch['input_ids'].size(1) - 1 - #apply masks - #1 apply attention mask + # apply masks + # first apply attention mask policy_logps *= batch["attention_mask"][:,1:] # shift right by 1. ref_logps *= batch["attention_mask"][:,1:] # apply position mask @@ -142,7 +135,13 @@ def offline_loss( policy_logp = torch.sum(policy_logps, dim = -1) ref_logp = torch.sum(ref_logps, dim = -1) + else: + policy_logp = outputs['policy_logp'] # (batch_size, ) + ref_logp = batch.get( + 'ref_logp', + torch.zeros_like(policy_logp), + ) if loss_type == RegressionOfflineEnum.APO: From 207842776e9f6372520d24df4f7161e3e37b9e66 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 20:25:16 -0400 Subject: [PATCH 113/195] . --- compose_rl/data/offline_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 492c57f5..09dbc23a 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -278,8 +278,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: input_ids = self._read_binary_tokenized_sample(sample, 'input') mask = self._read_binary_tokenized_sample(sample, 'mask') elif isinstance(sample['input'], np.ndarray): - input_ids = torch.from_numpy(sample['input'], dtype = torch.int64) - mask = torch.from_numpy(sample['mask'], dtype=torch.int64) + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) else: token_type = type(sample['input']) raise ValueError( From 1019e74893af533f7450d6b0ac837b3728c0dcfc Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 20:29:44 -0400 Subject: [PATCH 114/195] . --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 09dbc23a..776a9456 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -154,7 +154,7 @@ def offline_dataset_collate_fn( torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore ) if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm - if len(mask) <= len(attention_mask): # this happens when we padded input_ids, so we should pad mask + if len(mask) < len(attention_mask): # this happens when we padded input_ids, so we should pad mask mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=token_type_ids.dtype)],dim =-1) else: # this happens when we truncate input_id, so we truncate mask mask = mask[0: len(attention_mask)] From 254ca009bf2c66cfe87a92ae72332bae893e71df Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 20:43:29 -0400 Subject: [PATCH 115/195] . --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 776a9456..c0ae1255 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -155,7 +155,7 @@ def offline_dataset_collate_fn( ) if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm if len(mask) < len(attention_mask): # this happens when we padded input_ids, so we should pad mask - mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=token_type_ids.dtype)],dim =-1) + mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=mask.dtype)],dim =-1) else: # this happens when we truncate input_id, so we truncate mask mask = mask[0: len(attention_mask)] assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape From a4492861800581a5bec51d379802c2b7e36847af Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 21:14:43 -0400 Subject: [PATCH 116/195] . --- compose_rl/algorithms/offline/callback.py | 1 - .../algorithms/offline/model_methods.py | 60 ++++++++----------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 99295ed0..1a94dd21 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -83,7 +83,6 @@ def before_forward(self, state: State, logger: Logger) -> Optional[int]: reference_outputs = self.reference_model(state.batch) state.batch.update({ 'ref_logp': reference_outputs['policy_logp'], - 'raw_ref_logits': reference_outputs['raw_logits'], }) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 56793f67..c92dd301 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -59,6 +59,7 @@ def offline_forward( policy_model_config: Policy model config. """ is_multimodal = 'pixel_values' in batch.keys() + has_mask = 'mask' in batch.keys() if policy_model_config is not None and hasattr(model, 'transformer'): clear_mb_load_balancing_loss( @@ -80,21 +81,29 @@ def offline_forward( output_logits = model(**inputs).logits - logps = get_batch_logp( - batch['input_ids'], - output_logits, - batch['prompt_len'], - batch['sequence_len'], - average_log_prob, - temperature=temperature, - ) + if has_mask is False: + logps = get_batch_logp( + batch['input_ids'], + output_logits, + batch['prompt_len'], + batch['sequence_len'], + average_log_prob, + temperature=temperature, + ) + else: + token_policy_logps = get_log_probs_from_logits( + output_logits[:,:-1], + batch['input_ids'][:,1:] + ) + # apply attention_mask and mask explicitly + token_policy_logps *= batch['attention_mask'][:,1:] + token_policy_logps *= batch['mask'][:,1:] + logps = torch.sum(token_policy_logps, dim = -1) # (bs, ) outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, } - outputs["raw_logits"] = output_logits # raw logits (bs, seq_len, vocab_size) - if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( policy_model_config, @@ -116,33 +125,12 @@ def offline_loss( bce: bool = False, ): - has_mask = 'mask' in batch # handling mask explicitly - if has_mask: - raw_policy_logits = outputs['raw_logits'][:,:-1] - assert 'raw_ref_logits' in batch - raw_ref_logits = batch['raw_ref_logits'][:,:-1] - assert raw_policy_logits.shape == raw_ref_logits.shape - policy_logps = get_log_probs_from_logits(raw_policy_logits, batch['input_ids'][:,1:]) - ref_logps = get_log_probs_from_logits(raw_ref_logits, batch['input_ids'][:,1:]) - - # apply masks - # first apply attention mask - policy_logps *= batch["attention_mask"][:,1:] # shift right by 1. - ref_logps *= batch["attention_mask"][:,1:] - # apply position mask - policy_logps *= batch['mask'][:,1:] - ref_logps *= batch['mask'][:,1:] - - policy_logp = torch.sum(policy_logps, dim = -1) - ref_logp = torch.sum(ref_logps, dim = -1) - else: - policy_logp = outputs['policy_logp'] # (batch_size, ) - - ref_logp = batch.get( - 'ref_logp', - torch.zeros_like(policy_logp), - ) + policy_logp = outputs['policy_logp'] # (batch_size, ) + ref_logp = batch.get( + 'ref_logp', + torch.zeros_like(policy_logp), + ) if loss_type == RegressionOfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 From 6637844777a346d9322d6df2f4980099e5bcc12e Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 21:23:20 -0400 Subject: [PATCH 117/195] . --- compose_rl/algorithms/offline/model_methods.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c92dd301..e985c3d5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -91,10 +91,14 @@ def offline_forward( temperature=temperature, ) else: + print("has mask...") token_policy_logps = get_log_probs_from_logits( output_logits[:,:-1], batch['input_ids'][:,1:] ) + print(torch.sum(batch['mask'])) + print(torch.sum(batch['attention_mask'])) + # apply attention_mask and mask explicitly token_policy_logps *= batch['attention_mask'][:,1:] token_policy_logps *= batch['mask'][:,1:] From dddc4b385e625b4491ceba9349b855f3fee407ac Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 20 Jul 2025 21:28:38 -0400 Subject: [PATCH 118/195] working version of optimizing tool call in traj level --- compose_rl/algorithms/offline/model_methods.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index e985c3d5..c92dd301 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -91,14 +91,10 @@ def offline_forward( temperature=temperature, ) else: - print("has mask...") token_policy_logps = get_log_probs_from_logits( output_logits[:,:-1], batch['input_ids'][:,1:] ) - print(torch.sum(batch['mask'])) - print(torch.sum(batch['attention_mask'])) - # apply attention_mask and mask explicitly token_policy_logps *= batch['attention_mask'][:,1:] token_policy_logps *= batch['mask'][:,1:] From 9e4395d2a86cd67b0f613597b3a8e05934d15f2e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 21 Jul 2025 11:20:48 -0400 Subject: [PATCH 119/195] exclude gold --- compose_rl/algorithms/offline/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c92dd301..69c11083 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -142,6 +142,7 @@ def offline_loss( if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None + vstar_rewards = vstar_rewards[:, 1:] # exclude gold if not multistep: exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) else: From c6d9c64c3cbf696ba6f53b0eb93f69f2cfd017de Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 20:22:41 -0400 Subject: [PATCH 120/195] fixed a seq len bug --- compose_rl/data/offline_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index c0ae1255..dfa4d2d3 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -124,9 +124,9 @@ def offline_dataset_collate_fn( # NOTE: GEMMA specific: 0 == text token token_type_ids[-1] = 0 - sequence_len = torch.tensor([len(sequence_len)]) + #sequence_len = torch.tensor([len(sequence_len)]) # TODO: check this line of code, len(sequence_len) should be one?? - pad_len = max_seq_len - sequence_len + #pad_len = max_seq_len - sequence_len if pad_len > 0: # right padding with padding token From 3f08ed11f5ccab623c497c73cb85e6df88819bb6 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 21:35:11 -0400 Subject: [PATCH 121/195] . --- compose_rl/data/offline_data.py | 179 ++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index dfa4d2d3..930553a6 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -38,6 +38,185 @@ def base64_to_pil(base64_string: str): log = logging.getLogger(__name__) +def offline_dataset_collate_fn_test( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, torch.Tensor]], +) -> dict[str, Any]: + """Collator for offline data. + + Args: + tokenizer (Tokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, torch.Tensor]]): The preference data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + batch_input_ids = [item['input_ids'] for item in data] + + + batch_input_ids = [] + attention_masks = [] + sequence_lens = [] + prompt_lens = [] + rewards = [] + vstars = [] + vstar_rewards = [] + + # for multi-turn + masks = [] + + # For VLMs + batch_token_type_ids = [] + pixel_values = [] + + for sample in data: + input_ids = sample['input_ids'] + prompt_len = sample['prompt_len'] + sequence_len = sample['sequence_len'] + + # for multi-turn tool-use, which contains masks that mask out non-assistant turns + # it contains 1 at the token positions from the assistant turns, and 0 otherwise + has_mask = 'mask' in sample.keys() + if has_mask: + mask = sample['mask'] # torch tensor array + assert mask.shape == input_ids.shape + + is_multimodal = 'pixel_values' in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + token_type_ids = sample['token_type_ids'] + else: + pixel_vals = None + token_type_ids = None + + # Note: if we do any truncation, we force the last token to be EOS + # https://github.com/mosaicml/RLHF/issues/101 + + # Add the eos token if it's not in the chosen sample + if input_ids[-1] != tokenizer.eos_token_id: + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + pad_len = max_seq_len - sequence_len + + if pad_len < 0: + # We should truncate with an additional token left for eos + truncate_len = abs(pad_len) + 1 + + log.warning(( + f'Sequence length: {sequence_len}' + f' are too long for max_seq_len: {max_seq_len}' + f' truncating by {truncate_len[0]} tokens.' + )) + + # Truncate each value by truncate length, and make the last token EOS + input_ids = input_ids[:-truncate_len] + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + if is_multimodal: + token_type_ids = token_type_ids[:-truncate_len] + # NOTE: GEMMA specific: 0 == text token + token_type_ids[-1] = 0 + + sequence_len = torch.tensor([len(input_ids)]) # TODO: check this line of code, len(sequence_len) should be one?? + pad_len = max_seq_len - sequence_len # TODO: so if it truncted, then pad_len = 1 in this case? + + if pad_len > 0: + # right padding with padding token + input_ids = torch.cat( + [ + input_ids, + torch.ones(int(pad_len.item()), dtype=input_ids.dtype) * + tokenizer.pad_token_id, # type: ignore + ], + dim=-1, # type: ignore + ) + if is_multimodal: + token_type_ids = torch.cat( + [ + token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) + + attention_mask = torch.logical_not( + torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore + ) + if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm + if len(mask) < len(attention_mask): # this happens when we padded input_ids, so we should pad mask + mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=mask.dtype)],dim =-1) + else: # this happens when we truncate input_id, so we truncate mask + mask = mask[0: len(attention_mask)] + assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape + + batch_input_ids.append(input_ids) + attention_masks.append(attention_mask) + sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? + prompt_lens.append(prompt_len) + + if has_mask: + masks.append(mask) + if 'reward' in sample: + rewards.append(sample['reward']) + if 'vstar' in sample: + vstars.append(sample['vstar']) + if 'vstar_rewards' in sample: + vstar_rewards.append(sample['vstar_rewards']) + + if is_multimodal: + batch_token_type_ids.append(token_type_ids) # type: ignore + pixel_values.append(pixel_vals) + + + batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] + attention_masks = torch.stack(attention_masks) + + sequence_lens = torch.cat(sequence_lens) + prompt_lens = torch.cat(prompt_lens) + return_dict = { + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + } + if len(masks) > 0: + masks = torch.stack(masks) + assert masks.shape == attention_masks.shape + return_dict['mask'] = masks + if len(rewards) > 0: + rewards = torch.cat(rewards) + return_dict['reward'] = rewards + if len(vstars) > 0: + vstars = torch.cat(vstars) + return_dict['vstar'] = vstars + if len(vstar_rewards) > 0: + vstar_rewards = torch.stack(vstar_rewards) + return_dict['vstar_rewards'] = vstar_rewards + + if is_multimodal: # type: ignore + token_type_ids = torch.stack(batch_token_type_ids) + pixel_values = torch.stack(pixel_values) + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + + return return_dict + + + + def offline_dataset_collate_fn( tokenizer: PreTrainedTokenizer, max_seq_len: int, From 7292533355b22eb1e7f5e8d46edda52cce618067 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 21:39:29 -0400 Subject: [PATCH 122/195] . --- compose_rl/data/offline_data.py | 184 +------------------------------- 1 file changed, 2 insertions(+), 182 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 930553a6..99a77597 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -38,185 +38,6 @@ def base64_to_pil(base64_string: str): log = logging.getLogger(__name__) -def offline_dataset_collate_fn_test( - tokenizer: PreTrainedTokenizer, - max_seq_len: int, - data: list[dict[str, torch.Tensor]], -) -> dict[str, Any]: - """Collator for offline data. - - Args: - tokenizer (Tokenizer): The model's tokenizer. - max_seq_len (int): The maximum sequence length of the model. - data (list[dict[str, torch.Tensor]]): The preference data to collate. - """ - if tokenizer.eos_token_id is None: - raise ValueError('Tokenizer must have an EOS token.') - if tokenizer.pad_token_id is None: - raise ValueError('Tokenizer must have a PAD token.') - - ref_collate_fn = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm=False, - mlm_probability=0.0, - ) - - batch_input_ids = [item['input_ids'] for item in data] - - - batch_input_ids = [] - attention_masks = [] - sequence_lens = [] - prompt_lens = [] - rewards = [] - vstars = [] - vstar_rewards = [] - - # for multi-turn - masks = [] - - # For VLMs - batch_token_type_ids = [] - pixel_values = [] - - for sample in data: - input_ids = sample['input_ids'] - prompt_len = sample['prompt_len'] - sequence_len = sample['sequence_len'] - - # for multi-turn tool-use, which contains masks that mask out non-assistant turns - # it contains 1 at the token positions from the assistant turns, and 0 otherwise - has_mask = 'mask' in sample.keys() - if has_mask: - mask = sample['mask'] # torch tensor array - assert mask.shape == input_ids.shape - - is_multimodal = 'pixel_values' in sample.keys() - if is_multimodal: - pixel_vals = sample['pixel_values'] - token_type_ids = sample['token_type_ids'] - else: - pixel_vals = None - token_type_ids = None - - # Note: if we do any truncation, we force the last token to be EOS - # https://github.com/mosaicml/RLHF/issues/101 - - # Add the eos token if it's not in the chosen sample - if input_ids[-1] != tokenizer.eos_token_id: - input_ids[-1] = tokenizer.eos_token_id # type: ignore - - pad_len = max_seq_len - sequence_len - - if pad_len < 0: - # We should truncate with an additional token left for eos - truncate_len = abs(pad_len) + 1 - - log.warning(( - f'Sequence length: {sequence_len}' - f' are too long for max_seq_len: {max_seq_len}' - f' truncating by {truncate_len[0]} tokens.' - )) - - # Truncate each value by truncate length, and make the last token EOS - input_ids = input_ids[:-truncate_len] - input_ids[-1] = tokenizer.eos_token_id # type: ignore - - if is_multimodal: - token_type_ids = token_type_ids[:-truncate_len] - # NOTE: GEMMA specific: 0 == text token - token_type_ids[-1] = 0 - - sequence_len = torch.tensor([len(input_ids)]) # TODO: check this line of code, len(sequence_len) should be one?? - pad_len = max_seq_len - sequence_len # TODO: so if it truncted, then pad_len = 1 in this case? - - if pad_len > 0: - # right padding with padding token - input_ids = torch.cat( - [ - input_ids, - torch.ones(int(pad_len.item()), dtype=input_ids.dtype) * - tokenizer.pad_token_id, # type: ignore - ], - dim=-1, # type: ignore - ) - if is_multimodal: - token_type_ids = torch.cat( - [ - token_type_ids, # type: ignore - torch.zeros( - int(pad_len.item()), - dtype=token_type_ids.dtype, # type: ignore - ), - ], - dim=-1, - ) - - attention_mask = torch.logical_not( - torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore - ) - if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm - if len(mask) < len(attention_mask): # this happens when we padded input_ids, so we should pad mask - mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=mask.dtype)],dim =-1) - else: # this happens when we truncate input_id, so we truncate mask - mask = mask[0: len(attention_mask)] - assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape - - batch_input_ids.append(input_ids) - attention_masks.append(attention_mask) - sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? - prompt_lens.append(prompt_len) - - if has_mask: - masks.append(mask) - if 'reward' in sample: - rewards.append(sample['reward']) - if 'vstar' in sample: - vstars.append(sample['vstar']) - if 'vstar_rewards' in sample: - vstar_rewards.append(sample['vstar_rewards']) - - if is_multimodal: - batch_token_type_ids.append(token_type_ids) # type: ignore - pixel_values.append(pixel_vals) - - - batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] - attention_masks = torch.stack(attention_masks) - - sequence_lens = torch.cat(sequence_lens) - prompt_lens = torch.cat(prompt_lens) - return_dict = { - 'sequence_len': sequence_lens, - 'prompt_len': prompt_lens, - 'input_ids': batch_input_ids, - 'attention_mask': attention_masks, - } - if len(masks) > 0: - masks = torch.stack(masks) - assert masks.shape == attention_masks.shape - return_dict['mask'] = masks - if len(rewards) > 0: - rewards = torch.cat(rewards) - return_dict['reward'] = rewards - if len(vstars) > 0: - vstars = torch.cat(vstars) - return_dict['vstar'] = vstars - if len(vstar_rewards) > 0: - vstar_rewards = torch.stack(vstar_rewards) - return_dict['vstar_rewards'] = vstar_rewards - - if is_multimodal: # type: ignore - token_type_ids = torch.stack(batch_token_type_ids) - pixel_values = torch.stack(pixel_values) - return_dict['token_type_ids'] = token_type_ids - return_dict['pixel_values'] = pixel_values - - return return_dict - - - - def offline_dataset_collate_fn( tokenizer: PreTrainedTokenizer, max_seq_len: int, @@ -303,9 +124,8 @@ def offline_dataset_collate_fn( # NOTE: GEMMA specific: 0 == text token token_type_ids[-1] = 0 - #sequence_len = torch.tensor([len(sequence_len)]) # TODO: check this line of code, len(sequence_len) should be one?? - - #pad_len = max_seq_len - sequence_len + sequence_len = torch.tensor([len(input_ids)]) # TODO: check this line of code, len(sequence_len) should be one?? + pad_len = max_seq_len - sequence_len # TODO: seems that in this case, pad_len = 1? so it enters the next if statement? if pad_len > 0: # right padding with padding token From 49bcae88fbaf2ef82664678a6ab9378d09098900 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 22:30:25 -0400 Subject: [PATCH 123/195] testing new collator --- compose_rl/data/__init__.py | 2 + compose_rl/data/dataloader.py | 3 +- compose_rl/data/offline_data.py | 84 +++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 5d2ca370..b5e14cba 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -16,6 +16,7 @@ from compose_rl.data.offline_data import ( OfflineStreamingDataset, offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, @@ -33,6 +34,7 @@ 'finegrained_preference_dataset_collate_fn', 'MinibatchRolloutBuffer', 'offline_dataset_collate_fn', + 'offline_dataset_collate_fn_test', 'OfflineStreamingDataset', 'pairwise_preference_dataset_collate_fn', 'prompt_dataset_collate_fn', diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index e72d1576..7c4ec6f5 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -17,6 +17,7 @@ from compose_rl.data.offline_data import ( OfflineStreamingDataset, offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) from compose_rl.data.preference_data import ( FinegrainedPreferenceStreamingDataset, @@ -134,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 99a77597..ec14e407 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -211,6 +211,90 @@ def offline_dataset_collate_fn( return return_dict +def offline_dataset_collate_fn_test( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, torch.Tensor]], +) -> dict[str, Any]: + """Collator for offline data. + + Args: + tokenizer (Tokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, torch.Tensor]]): The preference data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + tokenizer.padding_side = 'right' # right + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + list_input_ids = [item['input_ids'] for item in data] + ret = ref_collate_fn(list_input_ids) + batch_input_ids = ret['input_ids'] + attention_masks = torch.logical_not( + torch.eq(batch_input_ids, tokenizer.eos_token_id) + ).to(torch.int64) + + batch_max_seq_len = batch_input_ids.shape[1] + if batch_max_seq_len > max_seq_len: # truncate both input_ids and attenion_mask + batch_input_ids = batch_input_ids[:,:max_seq_len] + attention_masks = attention_masks[:,:max_seq_len] + + for i in range(batch_input_ids.shape[0]): + if batch_input_ids[i,-1] != tokenizer.eos_token_id and batch_input_ids[i,-1] != tokenizer.pad_token_id: + batch_input_ids[i,-1] = tokenizer.eos_token_id + + sequence_lens = torch.sum(attention_masks, dim = -1) # sum of all 1 in attention mask, row-wise + prompt_lens = torch.cat([item['prompt_len'] for item in data]) + + masks, rewards, vstars, vstar_rewards = [], [],[],[] + + if 'reward' in data[0].keys(): + rewards = torch.cat([item['reward'] for item in data]) + if 'vstar' in data[0].keys(): + vstars = torch.cat([item['vstar'] for item in data]) + if 'vstar_rewards' in data[0].keys(): + vstar_rewards = torch.stack([item['vstar_rewards'] for item in data]) + + has_mask = 'mask' in data[0].keys() + if has_mask: + for i in range(len(batch_input_ids.shape[0])): + mask_i = data[i]['mask'] + if len(mask_i) < len(batch_input_ids[i]): # right padded + all_zeros = torch.zeros(len(batch_input_ids[i])) + all_zeros[0:len(mask_i)] = mask_i + mask_i = all_zeros + else: # truncated + mask_i = mask_i[0:len(batch_input_ids[i])] + masks.append(mask_i) + masks = torch.stack(masks) + + return_dict = { + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + } + if len(masks) > 0: + assert masks.shape == attention_masks.shape + return_dict['mask'] = masks + if len(rewards) > 0: + return_dict['reward'] = rewards + if len(vstars) > 0: + return_dict['vstar'] = vstars + if len(vstar_rewards) > 0: + return_dict['vstar_rewards'] = vstar_rewards + + return return_dict + class OfflineStreamingDataset(StreamingDataset): """Dataloader for streaming in preference data.""" From 3942181dd2022fd9c3940195157b642e77c97e31 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 22:33:59 -0400 Subject: [PATCH 124/195] testing new collator --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index ec14e407..7849fb94 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -266,7 +266,7 @@ def offline_dataset_collate_fn_test( has_mask = 'mask' in data[0].keys() if has_mask: - for i in range(len(batch_input_ids.shape[0])): + for i in range(batch_input_ids.shape[0]): mask_i = data[i]['mask'] if len(mask_i) < len(batch_input_ids[i]): # right padded all_zeros = torch.zeros(len(batch_input_ids[i])) From 0842b916389c24b874ac5fdf9d9f646d8f5bbc57 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 22:58:39 -0400 Subject: [PATCH 125/195] changed it back to the original collator, test for the new one looks good --- compose_rl/data/dataloader.py | 2 +- compose_rl/data/offline_data.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 7c4ec6f5..3fb37ec7 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn_test, + offline_dataset_collate_fn, ) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 7849fb94..5727b776 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -237,7 +237,7 @@ def offline_dataset_collate_fn_test( ) list_input_ids = [item['input_ids'] for item in data] - ret = ref_collate_fn(list_input_ids) + ret = ref_collate_fn(list_input_ids) # right padded based on the longest sequence in the batch batch_input_ids = ret['input_ids'] attention_masks = torch.logical_not( torch.eq(batch_input_ids, tokenizer.eos_token_id) @@ -247,8 +247,10 @@ def offline_dataset_collate_fn_test( if batch_max_seq_len > max_seq_len: # truncate both input_ids and attenion_mask batch_input_ids = batch_input_ids[:,:max_seq_len] attention_masks = attention_masks[:,:max_seq_len] - + + # pad eos token on the sequence that is truncated for i in range(batch_input_ids.shape[0]): + # check if this sequence is truncated if batch_input_ids[i,-1] != tokenizer.eos_token_id and batch_input_ids[i,-1] != tokenizer.pad_token_id: batch_input_ids[i,-1] = tokenizer.eos_token_id From 74a88429c22c7f3f3fb0b3b176f79fd3d5fc6af8 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 23:46:59 -0400 Subject: [PATCH 126/195] . --- compose_rl/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3fb37ec7..7c4ec6f5 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) From 76df519e5e903a3a02f1435a24d8357dd2bba72c Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 21 Jul 2025 23:48:40 -0400 Subject: [PATCH 127/195] . --- compose_rl/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 7c4ec6f5..3fb37ec7 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn_test, + offline_dataset_collate_fn, ) From 076994f7cfcc4ed4df7a2e38c1dae36447573984 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 22 Jul 2025 22:49:25 -0400 Subject: [PATCH 128/195] . --- compose_rl/algorithms/offline/model_methods.py | 7 ++++++- compose_rl/data/dataloader.py | 2 +- compose_rl/data/offline_data.py | 9 ++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c92dd301..dd6476ba 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -121,9 +121,11 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, + gamma: float = 0.1, multistep: bool = False, bce: bool = False, ): + # gamma: r + gamma * bonus (bonus can be used to model things like tool use) policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -153,10 +155,13 @@ def offline_loss( assert vstar.shape == batch['reward'].shape + bonuses = batch.get('bonus', None) + if bonuses is None: + bonuses = torch.zeros_like(batch['reward']) if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - - (batch['reward'] - vstar) + (batch['reward'] + gamma * bonuses - vstar) )**2 elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3fb37ec7..7c4ec6f5 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 5727b776..fc9563ed 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -257,10 +257,12 @@ def offline_dataset_collate_fn_test( sequence_lens = torch.sum(attention_masks, dim = -1) # sum of all 1 in attention mask, row-wise prompt_lens = torch.cat([item['prompt_len'] for item in data]) - masks, rewards, vstars, vstar_rewards = [], [],[],[] + masks, rewards, vstars, vstar_rewards, bonuses = [], [],[],[], [] if 'reward' in data[0].keys(): rewards = torch.cat([item['reward'] for item in data]) + if 'bonus' in data[0].keys(): + bonuses = torch.cat([item['bonus'] for item in data]) if 'vstar' in data[0].keys(): vstars = torch.cat([item['vstar'] for item in data]) if 'vstar_rewards' in data[0].keys(): @@ -290,6 +292,8 @@ def offline_dataset_collate_fn_test( return_dict['mask'] = masks if len(rewards) > 0: return_dict['reward'] = rewards + if len(bonuses) > 0: + return_dict['bonus'] = bonuses if len(vstars) > 0: return_dict['vstar'] = vstars if len(vstar_rewards) > 0: @@ -385,6 +389,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # If rewards are given, add them to the return dict if 'reward' in sample: return_dict['reward'] = torch.Tensor([sample['reward']]) + + if 'bonus' in sample: + return_dict['bonus'] = torch.Tensor([sample['bonus']]) if 'vstar' in sample: assert 'vstar_rewards' not in sample From 3e9c9b022036fab3d609e4a38cfc35c8f288ef3f Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 22 Jul 2025 23:02:35 -0400 Subject: [PATCH 129/195] . --- compose_rl/algorithms/offline/model_methods.py | 7 +++++-- compose_rl/data/offline_data.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index dd6476ba..29e5de82 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -144,8 +144,10 @@ def offline_loss( if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None - if not multistep: - exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) + vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) + print('vstar_bonus shape {}'.format(vstar_bonus.shape)) + if not multistep: + exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*vstar_bonus) / beta1), dim=-1) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), @@ -156,6 +158,7 @@ def offline_loss( assert vstar.shape == batch['reward'].shape bonuses = batch.get('bonus', None) + print('bonuses shape {}'.format(bonuses.shape)) if bonuses is None: bonuses = torch.zeros_like(batch['reward']) if bce == False: diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index fc9563ed..9876cf66 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -257,7 +257,7 @@ def offline_dataset_collate_fn_test( sequence_lens = torch.sum(attention_masks, dim = -1) # sum of all 1 in attention mask, row-wise prompt_lens = torch.cat([item['prompt_len'] for item in data]) - masks, rewards, vstars, vstar_rewards, bonuses = [], [],[],[], [] + masks, rewards, vstars, vstar_rewards, bonuses, vstar_bonus = [], [],[],[], [],[] if 'reward' in data[0].keys(): rewards = torch.cat([item['reward'] for item in data]) @@ -267,6 +267,8 @@ def offline_dataset_collate_fn_test( vstars = torch.cat([item['vstar'] for item in data]) if 'vstar_rewards' in data[0].keys(): vstar_rewards = torch.stack([item['vstar_rewards'] for item in data]) + if 'vstar_bonus' in data[0].keys(): + vstar_bonus = torch.stack([item['vstar_bonus'] for item in data]) has_mask = 'mask' in data[0].keys() if has_mask: @@ -298,6 +300,9 @@ def offline_dataset_collate_fn_test( return_dict['vstar'] = vstars if len(vstar_rewards) > 0: return_dict['vstar_rewards'] = vstar_rewards + if len(vstar_bonus) > 0: + return_dict['vstar_bonus'] = vstar_bonus + return return_dict From 96563543565fc8455436e3b9a3b8533112120190 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 22 Jul 2025 23:13:11 -0400 Subject: [PATCH 130/195] . --- compose_rl/algorithms/offline/model_methods.py | 6 +----- compose_rl/data/dataloader.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 29e5de82..d4684727 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -145,7 +145,6 @@ def offline_loss( vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) - print('vstar_bonus shape {}'.format(vstar_bonus.shape)) if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*vstar_bonus) / beta1), dim=-1) else: @@ -157,10 +156,7 @@ def offline_loss( assert vstar.shape == batch['reward'].shape - bonuses = batch.get('bonus', None) - print('bonuses shape {}'.format(bonuses.shape)) - if bonuses is None: - bonuses = torch.zeros_like(batch['reward']) + bonuses = batch.get('bonus', torch.zeros_like(batch['reward'])) if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 7c4ec6f5..3fb37ec7 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn_test, + offline_dataset_collate_fn, ) From 5cce1206be56bb3123dda6aad72be0bfdce47f08 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 22 Jul 2025 23:14:00 -0400 Subject: [PATCH 131/195] . --- compose_rl/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3fb37ec7..7c4ec6f5 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) From 0812e792c2f20913dcb91628b4d414aa3dedc767 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 22 Jul 2025 23:41:19 -0400 Subject: [PATCH 132/195] tested adding bonus --- compose_rl/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 7c4ec6f5..3fb37ec7 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn_test, + offline_dataset_collate_fn, ) From 5b803e4ccc9b1edee7bb99e425852fcfc34f53ce Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 09:32:52 -0400 Subject: [PATCH 133/195] . --- compose_rl/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3fb37ec7..7c4ec6f5 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -135,5 +135,5 @@ def build_preference_dataloader( build_offline_dataloader = generate_dataloader_builder( OfflineStreamingDataset, - offline_dataset_collate_fn, + offline_dataset_collate_fn_test, ) From 9739fc2e78901652f35095f70a00b4be875d4aab Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 09:57:39 -0400 Subject: [PATCH 134/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index d4684727..2baea00d 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -121,7 +121,7 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, - gamma: float = 0.1, + gamma: float = 1., multistep: bool = False, bce: bool = False, ): From a099f472f6f31eade990012778681aa64041c0a2 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 14:41:00 -0400 Subject: [PATCH 135/195] . --- compose_rl/algorithms/offline/model_methods.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 2baea00d..29ca281e 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -121,7 +121,7 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, - gamma: float = 1., + gamma: float = 0.5, multistep: bool = False, bce: bool = False, ): @@ -145,8 +145,9 @@ def offline_loss( vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) + added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 if not multistep: - exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*vstar_bonus) / beta1), dim=-1) + exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), @@ -157,10 +158,11 @@ def offline_loss( assert vstar.shape == batch['reward'].shape bonuses = batch.get('bonus', torch.zeros_like(batch['reward'])) + added_bonuses = bonuses * batch['reward'] # true added bonus = 1 if both bonus = 1 and reward = 1 if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - - (batch['reward'] + gamma * bonuses - vstar) + (batch['reward'] + gamma * added_bonuses - vstar) )**2 elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) From 6e7befee6f5e7b8b4119f22aaf58390f5e57079f Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 15:43:07 -0400 Subject: [PATCH 136/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 29ca281e..49b56ca5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -148,6 +148,7 @@ def offline_loss( added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) + print(torch.max(vstar_rewards+gamma*added_vstar_bonus)) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), @@ -164,6 +165,7 @@ def offline_loss( beta2 * (policy_logp - ref_logp) - (batch['reward'] + gamma * added_bonuses - vstar) )**2 + print(torch.max(batch['reward']+gamma*added_bonuses)) elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) actual_prob = F.sigmoid(batch['reward'] - vstar) From 096bb76c83f347f607169d30ac67cf4d1aed8852 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 15:58:13 -0400 Subject: [PATCH 137/195] . --- compose_rl/algorithms/offline/model_methods.py | 4 ++-- compose_rl/data/offline_data.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 49b56ca5..37518db5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -148,7 +148,7 @@ def offline_loss( added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) - print(torch.max(vstar_rewards+gamma*added_vstar_bonus)) + print("1. {}".format(torch.max(vstar_rewards+gamma*added_vstar_bonus))) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), @@ -165,7 +165,7 @@ def offline_loss( beta2 * (policy_logp - ref_logp) - (batch['reward'] + gamma * added_bonuses - vstar) )**2 - print(torch.max(batch['reward']+gamma*added_bonuses)) + print("2. {}".format(torch.max(batch['reward']+gamma*added_bonuses))) elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) actual_prob = F.sigmoid(batch['reward'] - vstar) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 9876cf66..94f36a09 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -263,12 +263,14 @@ def offline_dataset_collate_fn_test( rewards = torch.cat([item['reward'] for item in data]) if 'bonus' in data[0].keys(): bonuses = torch.cat([item['bonus'] for item in data]) + print(bonuses) if 'vstar' in data[0].keys(): vstars = torch.cat([item['vstar'] for item in data]) if 'vstar_rewards' in data[0].keys(): vstar_rewards = torch.stack([item['vstar_rewards'] for item in data]) if 'vstar_bonus' in data[0].keys(): vstar_bonus = torch.stack([item['vstar_bonus'] for item in data]) + print(vstar_bonus) has_mask = 'mask' in data[0].keys() if has_mask: @@ -402,6 +404,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]: assert 'vstar_rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['vstar']]) + if 'vstar_bonus' in sample: + return_dict['vstar_bonus'] = torch.Tensor([sample['vstar_bonus']]) + + if 'v-star' in sample: assert 'vstar_rewards' not in sample return_dict['vstar'] = torch.Tensor([sample['v-star']]) From 32931f5f8c7018777ebd3369e1a6e41ecfe9c76b Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:04:06 -0400 Subject: [PATCH 138/195] . --- compose_rl/data/offline_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 94f36a09..96261875 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -405,8 +405,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['vstar'] = torch.Tensor([sample['vstar']]) if 'vstar_bonus' in sample: - return_dict['vstar_bonus'] = torch.Tensor([sample['vstar_bonus']]) - + return_dict['vstar_bonus'] = torch.from_numpy(sample["vstar_bonus"]) if 'v-star' in sample: assert 'vstar_rewards' not in sample From f28bea7c6a2ff5bafc8a3e9f61e3839c14c49d5e Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:08:48 -0400 Subject: [PATCH 139/195] . --- compose_rl/algorithms/offline/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 37518db5..1b8310a6 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -145,6 +145,7 @@ def offline_loss( vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) + print("vstar_bonus: {}".format(vstar_bonus)) added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) From 7561a2541ee42d77085e6996469737dcecccb819 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:14:34 -0400 Subject: [PATCH 140/195] . --- compose_rl/algorithms/offline/model_methods.py | 3 +++ compose_rl/data/offline_data.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 1b8310a6..d20262ee 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -147,8 +147,11 @@ def offline_loss( vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) print("vstar_bonus: {}".format(vstar_bonus)) added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 + print("added_vstar_bonus: {}".format(added_vstar_bonus)) if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) + print("vstar reward: {}".format(vstar_rewards)) + print("combined: {}".format(vstar_rewards+gamma*added_vstar_bonus)) print("1. {}".format(torch.max(vstar_rewards+gamma*added_vstar_bonus))) else: exponentiated_mean = torch.mean( diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 96261875..673319af 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -263,14 +263,12 @@ def offline_dataset_collate_fn_test( rewards = torch.cat([item['reward'] for item in data]) if 'bonus' in data[0].keys(): bonuses = torch.cat([item['bonus'] for item in data]) - print(bonuses) if 'vstar' in data[0].keys(): vstars = torch.cat([item['vstar'] for item in data]) if 'vstar_rewards' in data[0].keys(): vstar_rewards = torch.stack([item['vstar_rewards'] for item in data]) if 'vstar_bonus' in data[0].keys(): vstar_bonus = torch.stack([item['vstar_bonus'] for item in data]) - print(vstar_bonus) has_mask = 'mask' in data[0].keys() if has_mask: @@ -298,12 +296,14 @@ def offline_dataset_collate_fn_test( return_dict['reward'] = rewards if len(bonuses) > 0: return_dict['bonus'] = bonuses + #print("data collator {}".format(return_dict['bonus'])) if len(vstars) > 0: return_dict['vstar'] = vstars if len(vstar_rewards) > 0: return_dict['vstar_rewards'] = vstar_rewards if len(vstar_bonus) > 0: return_dict['vstar_bonus'] = vstar_bonus + #print("data collator {}".format(return_dict['vstar_bonus'])) return return_dict From f593881ada3dfbd20275762b0e5b81c156bde4a9 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:24:26 -0400 Subject: [PATCH 141/195] . --- compose_rl/algorithms/offline/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index d20262ee..1a2043db 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -151,6 +151,7 @@ def offline_loss( if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) print("vstar reward: {}".format(vstar_rewards)) + print(gamma) print("combined: {}".format(vstar_rewards+gamma*added_vstar_bonus)) print("1. {}".format(torch.max(vstar_rewards+gamma*added_vstar_bonus))) else: From 19531c65f60ddaf42fc45cdfc410fc026d0e9ee7 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:25:21 -0400 Subject: [PATCH 142/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 1a2043db..85194741 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -152,7 +152,7 @@ def offline_loss( exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) print("vstar reward: {}".format(vstar_rewards)) print(gamma) - print("combined: {}".format(vstar_rewards+gamma*added_vstar_bonus)) + print("combined: {}".format(vstar_rewards.to(torch.float)+gamma*added_vstar_bonus.to(torch.float))) print("1. {}".format(torch.max(vstar_rewards+gamma*added_vstar_bonus))) else: exponentiated_mean = torch.mean( From 091a3ce532c7b306eeb188cf1b33da430c29e4e5 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:32:31 -0400 Subject: [PATCH 143/195] . --- compose_rl/algorithms/offline/model_methods.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 85194741..f5fb4480 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -121,7 +121,7 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, - gamma: float = 0.5, + eta: float = 0.5, multistep: bool = False, bce: bool = False, ): @@ -149,11 +149,11 @@ def offline_loss( added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 print("added_vstar_bonus: {}".format(added_vstar_bonus)) if not multistep: - exponentiated_mean = torch.mean(torch.exp((vstar_rewards+gamma*added_vstar_bonus) / beta1), dim=-1) + exponentiated_mean = torch.mean(torch.exp((vstar_rewards+eta*added_vstar_bonus) / beta1), dim=-1) print("vstar reward: {}".format(vstar_rewards)) - print(gamma) - print("combined: {}".format(vstar_rewards.to(torch.float)+gamma*added_vstar_bonus.to(torch.float))) - print("1. {}".format(torch.max(vstar_rewards+gamma*added_vstar_bonus))) + print(eta) + print("combined: {}".format(vstar_rewards.to(torch.float)+eta*added_vstar_bonus.to(torch.float))) + print("1. {}".format(torch.max(vstar_rewards+eta*added_vstar_bonus))) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), @@ -168,9 +168,9 @@ def offline_loss( if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - - (batch['reward'] + gamma * added_bonuses - vstar) + (batch['reward'] + eta * added_bonuses - vstar) )**2 - print("2. {}".format(torch.max(batch['reward']+gamma*added_bonuses))) + print("2. {}".format(torch.max(batch['reward']+eta*added_bonuses))) elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) actual_prob = F.sigmoid(batch['reward'] - vstar) From 0dfccdd9cf57c56951d82ce4c48d82f8687a842b Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:36:41 -0400 Subject: [PATCH 144/195] . --- compose_rl/algorithms/offline/model_methods.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f5fb4480..a6c0fb72 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -125,7 +125,7 @@ def offline_loss( multistep: bool = False, bce: bool = False, ): - # gamma: r + gamma * bonus (bonus can be used to model things like tool use) + # eta: r + eta * bonus (bonus can be used to model things like tool use) policy_logp = outputs['policy_logp'] # (batch_size, ) @@ -140,6 +140,7 @@ def offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch # + print("eta{}".format(eta)) vstar = batch.get('vstar', None) if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) @@ -152,7 +153,7 @@ def offline_loss( exponentiated_mean = torch.mean(torch.exp((vstar_rewards+eta*added_vstar_bonus) / beta1), dim=-1) print("vstar reward: {}".format(vstar_rewards)) print(eta) - print("combined: {}".format(vstar_rewards.to(torch.float)+eta*added_vstar_bonus.to(torch.float))) + print("combined: {}".format(vstar_rewards+eta*added_vstar_bonus)) print("1. {}".format(torch.max(vstar_rewards+eta*added_vstar_bonus))) else: exponentiated_mean = torch.mean( From e8d04b4cd7a8220604381621819c7de982b4232e Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:42:19 -0400 Subject: [PATCH 145/195] . --- compose_rl/algorithms/offline/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 33a96637..c0cd865f 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -74,7 +74,7 @@ def loss(self, outputs: CausalLMOutputWithPast, self.loss_type, self.beta1, self.beta2, - self.multistep, + multistep = self.multistep, ) From 8f5063cbfe9f341dd3a0c4badd01c0b7b9bd9ada Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 16:48:22 -0400 Subject: [PATCH 146/195] . --- compose_rl/algorithms/offline/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index c0cd865f..6329c9a0 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -69,11 +69,11 @@ def eval_forward( def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]: return offline_loss( - outputs, - batch, - self.loss_type, - self.beta1, - self.beta2, + outputs = outputs, + batch = batch, + loss_type = self.loss_type, + beta1 = self.beta1, + beta2 = self.beta2, multistep = self.multistep, ) From 917723c2bd3499a25977bcc294caf72bfacbb15c Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 18:41:37 -0400 Subject: [PATCH 147/195] . --- compose_rl/algorithms/offline/model.py | 18 ++++++++++++------ compose_rl/algorithms/offline/model_methods.py | 4 ++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 6329c9a0..1131d944 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -35,6 +35,7 @@ def __init__( loss_type: str = 'apo', beta1: float = 0.5, beta2: float = 0.1, + eta: float = 0.5, multistep: bool = False, average_log_prob: bool = False, temperature: float = 1.0, @@ -43,6 +44,7 @@ def __init__( self.loss_type = RegressionOfflineEnum(loss_type) self.beta1 = beta1 self.beta2 = beta2 + self.eta = eta self.multistep = multistep self.average_log_prob = average_log_prob self.temperature = temperature @@ -74,6 +76,7 @@ def loss(self, outputs: CausalLMOutputWithPast, loss_type = self.loss_type, beta1 = self.beta1, beta2 = self.beta2, + eta = self.eta, multistep = self.multistep, ) @@ -86,6 +89,7 @@ def __init__( loss_type: str = 'apo', beta1: float = 0.5, beta2: float = 0.1, + eta: float = 0.5, multistep: bool = False, average_log_prob: bool = False, temperature: float = 1.0, @@ -94,6 +98,7 @@ def __init__( self.loss_type = RegressionOfflineEnum(loss_type) self.beta1 = beta1 self.beta2 = beta2 + self.eta = eta self.multistep = multistep self.average_log_prob = average_log_prob self.temperature = temperature @@ -120,12 +125,13 @@ def eval_forward( def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]: return offline_loss( - outputs, - batch, - self.loss_type, - self.beta1, - self.beta2, - self.multistep, + outputs = outputs, + batch = batch, + loss_type = self.loss_type, + beta1 = self.beta1, + beta2 = self.beta2, + eta = self.eta, + multistep = self.multistep, ) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index a6c0fb72..874b931b 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -121,7 +121,7 @@ def offline_loss( loss_type: RegressionOfflineEnum, beta1: float, beta2: float, - eta: float = 0.5, + eta: float, multistep: bool = False, bce: bool = False, ): @@ -157,7 +157,7 @@ def offline_loss( print("1. {}".format(torch.max(vstar_rewards+eta*added_vstar_bonus))) else: exponentiated_mean = torch.mean( - vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), + vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), # TODO: something is wrong here. dim=-1, ) vstar = beta1 * torch.log(exponentiated_mean) From bf4031885ecf4021b9855699cb6c4885c6b55a15 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 23 Jul 2025 18:49:17 -0400 Subject: [PATCH 148/195] . --- compose_rl/algorithms/offline/model_methods.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 874b931b..25a78579 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,21 +140,14 @@ def offline_loss( # Similar to REBEL, we assume each response has a reward in the batch. # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch # - print("eta{}".format(eta)) vstar = batch.get('vstar', None) if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) - print("vstar_bonus: {}".format(vstar_bonus)) added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 - print("added_vstar_bonus: {}".format(added_vstar_bonus)) if not multistep: exponentiated_mean = torch.mean(torch.exp((vstar_rewards+eta*added_vstar_bonus) / beta1), dim=-1) - print("vstar reward: {}".format(vstar_rewards)) - print(eta) - print("combined: {}".format(vstar_rewards+eta*added_vstar_bonus)) - print("1. {}".format(torch.max(vstar_rewards+eta*added_vstar_bonus))) else: exponentiated_mean = torch.mean( vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), # TODO: something is wrong here. @@ -171,7 +164,6 @@ def offline_loss( beta2 * (policy_logp - ref_logp) - (batch['reward'] + eta * added_bonuses - vstar) )**2 - print("2. {}".format(torch.max(batch['reward']+eta*added_bonuses))) elif bce == True: predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) actual_prob = F.sigmoid(batch['reward'] - vstar) From f40974b0b26707de8483e49837f5fbdbd789a461 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 24 Jul 2025 10:19:43 -0400 Subject: [PATCH 149/195] rgb --- compose_rl/data/offline_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index c0ae1255..b6b673db 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -30,6 +30,7 @@ def base64_to_pil(base64_string: str): image_data = base64.b64decode(base64_string) image_stream = BytesIO(image_data) image = Image.open(image_stream) + image = image.convert("RGB") return image except Exception as e: print(f"Error decoding base64 string: {e}") From 64ca7a760873c2f005d3f96b4900447e8c472827 Mon Sep 17 00:00:00 2001 From: abaheti95 Date: Sun, 27 Jul 2025 18:27:52 +0000 Subject: [PATCH 150/195] tracking sequence entropies in offline rl forward --- .../algorithms/offline/model_methods.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 69c11083..b49ef35e 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -24,6 +24,9 @@ get_batch_logp, get_mb_load_balancing_loss, get_log_probs_from_logits, + make_action_mask, + get_token_entropies, + get_sequence_entropies, ) @@ -80,6 +83,9 @@ def offline_forward( inputs.update(multimodal_inputs) output_logits = model(**inputs).logits + # Calculate token entropies from the logits + token_entropies = get_token_entropies(logits=output_logits) + token_entropies = token_entropies.detach() if has_mask is False: logps = get_batch_logp( @@ -90,6 +96,17 @@ def offline_forward( average_log_prob, temperature=temperature, ) + # Calculate sequence entropies + action_mask = make_action_mask( + batch['prompt_len'], + batch['sequence_len'], + batch['attention_mask'].shape, + device=output_logits.device, + ) + sequence_entropies = get_sequence_entropies( + token_entropies=token_entropies, + action_mask=action_mask + ) else: token_policy_logps = get_log_probs_from_logits( output_logits[:,:-1], @@ -99,9 +116,17 @@ def offline_forward( token_policy_logps *= batch['attention_mask'][:,1:] token_policy_logps *= batch['mask'][:,1:] logps = torch.sum(token_policy_logps, dim = -1) # (bs, ) + # Calculate sequence entropies + # TODO: confirm with JC and Adyasha if this is correct + combined_mask = batch['attention_mask'] * batch['mask'] + sequence_entropies = get_sequence_entropies( + token_entropies=token_entropies, + action_mask=combined_mask, + ) outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, + 'sequence_entropies': sequence_entropies, } if policy_model_config is not None and hasattr(model, 'transformer'): @@ -196,6 +221,7 @@ def offline_loss( 'reverse_kl': reverse_kl, 'forward_kl': forward_kl, 'estimated_reward': estimated_reward, + 'sequence_entropies': outputs['sequence_entropies'], # Track detached sequence entropies in the loss dict } if loss_type == RegressionOfflineEnum.APO: loss_dict['batch_advantage'] = torch.mean( From 5e9c33bf0e31a2ca0b777282347f7dca31cd5d8f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 28 Jul 2025 16:56:04 -0400 Subject: [PATCH 151/195] excluding gold --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index b49ef35e..0ea92a6d 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -167,7 +167,7 @@ def offline_loss( if vstar is None: vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None - vstar_rewards = vstar_rewards[:, 1:] # exclude gold + # vstar_rewards = vstar_rewards[:, 1:] # exclude gold if not multistep: exponentiated_mean = torch.mean(torch.exp(vstar_rewards / beta1), dim=-1) else: From 21695ddec38392adb8c37c095f5ed0c5060ee994 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 20 Aug 2025 14:09:05 -0400 Subject: [PATCH 152/195] . --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 1c4d4227..3fce0ecf 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -282,7 +282,7 @@ def offline_dataset_collate_fn_test( else: # truncated mask_i = mask_i[0:len(batch_input_ids[i])] masks.append(mask_i) - masks = torch.stack(masks) + masks = torch.stack(masks) return_dict = { 'sequence_len': sequence_lens, From b67fbe6af5f1384f406c78f5c9ce7d0c8d2eda8d Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 26 Aug 2025 14:22:45 -0400 Subject: [PATCH 153/195] add sequence id --- compose_rl/data/offline_data.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 3fce0ecf..bbde9684 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -64,6 +64,7 @@ def offline_dataset_collate_fn( batch_input_ids = [] attention_masks = [] + sequence_id = [] sequence_lens = [] prompt_lens = [] rewards = [] @@ -160,6 +161,11 @@ def offline_dataset_collate_fn( mask = mask[0: len(attention_mask)] assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape + + cur_sequence_id = torch.tensor(([0] * sequence_len) + + ([-1] * max(0, int(pad_len.item()))),) + + sequence_id.append(cur_sequence_id) batch_input_ids.append(input_ids) attention_masks.append(attention_mask) sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? @@ -181,6 +187,7 @@ def offline_dataset_collate_fn( batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) + sequence_id = torch.stack(sequence_id) sequence_lens = torch.cat(sequence_lens) prompt_lens = torch.cat(prompt_lens) @@ -189,6 +196,7 @@ def offline_dataset_collate_fn( 'prompt_len': prompt_lens, 'input_ids': batch_input_ids, 'attention_mask': attention_masks, + 'sequence_id': sequence_id, } if len(masks) > 0: masks = torch.stack(masks) From de597f4b9b2fc454e3819f65e729573d16c61dde Mon Sep 17 00:00:00 2001 From: wensun Date: Thu, 28 Aug 2025 11:42:02 -0400 Subject: [PATCH 154/195] pad token id fix --- compose_rl/data/offline_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py index 3fce0ecf..10a26e2b 100644 --- a/compose_rl/data/offline_data.py +++ b/compose_rl/data/offline_data.py @@ -241,7 +241,7 @@ def offline_dataset_collate_fn_test( ret = ref_collate_fn(list_input_ids) # right padded based on the longest sequence in the batch batch_input_ids = ret['input_ids'] attention_masks = torch.logical_not( - torch.eq(batch_input_ids, tokenizer.eos_token_id) + torch.eq(batch_input_ids, tokenizer.pad_token_id) ).to(torch.int64) batch_max_seq_len = batch_input_ids.shape[1] From 517d2b934e0eb6c020f306b82fc52739ec69aad8 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 21:01:20 -0400 Subject: [PATCH 155/195] add reference model loading --- compose_rl/algorithms/offline/callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 1a94dd21..b0bfaded 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -33,7 +33,8 @@ def __init__( self.reference_model = None def after_load(self, state: State, logger: Logger) -> None: - model_config = self.train_config['model'] + #model_config = self.train_config['model'] + model_config = self.train_config['reference_model'] init_context = process_init_device( model_config, self.train_config.get('fsdp_config'), From 4093fbb2d1c4a8b6570994594ee886d9202ccf19 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 21:05:49 -0400 Subject: [PATCH 156/195] . --- compose_rl/algorithms/offline/callback.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index b0bfaded..e9fb96d8 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -40,6 +40,9 @@ def after_load(self, state: State, logger: Logger) -> None: self.train_config.get('fsdp_config'), ) name = model_config.pop('name') + print("################################################") + print(f"reference model name: {name}") + print("################################################") self.reference_model = build_composer_model( name=name, cfg=model_config, From f246ea9e6c7c7adf4a5f5729280deec6231bab2b Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 21:16:13 -0400 Subject: [PATCH 157/195] . --- compose_rl/algorithms/offline/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index e9fb96d8..d8eed13d 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -34,7 +34,7 @@ def __init__( def after_load(self, state: State, logger: Logger) -> None: #model_config = self.train_config['model'] - model_config = self.train_config['reference_model'] + model_config = self.train_config['variables']['reference_model'] init_context = process_init_device( model_config, self.train_config.get('fsdp_config'), From 5d14693a822f7ef81f5c1864c083242a6ba7e1c7 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 21:34:51 -0400 Subject: [PATCH 158/195] . --- compose_rl/algorithms/offline/callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index d8eed13d..3a1bffe5 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -41,7 +41,8 @@ def after_load(self, state: State, logger: Logger) -> None: ) name = model_config.pop('name') print("################################################") - print(f"reference model name: {name}") + print("reference model config:") + print(model_config) print("################################################") self.reference_model = build_composer_model( name=name, From 306dbd9f4f0225a4e6e2a573f1c03128d90885d8 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 12:18:39 -0400 Subject: [PATCH 159/195] first version of a unified dataloader --- compose_rl/data/rl_data.py | 279 +++++++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 compose_rl/data/rl_data.py diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py new file mode 100644 index 00000000..b49c14cc --- /dev/null +++ b/compose_rl/data/rl_data.py @@ -0,0 +1,279 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build dataloader for RL training.""" + +import logging +from typing import Any, Optional +import compose_rl.utils as utils + +import numpy as np +import torch +from streaming import StreamingDataset +from transformers import PreTrainedTokenizer,DataCollatorForLanguageModeling + +log = logging.getLogger(__name__) + + +def dataset_collate_fn( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, Any]], +) -> dict[str, Any]: + """Collator for RL data. + + Args: + tokenizer (PreTrainedTokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, Any]]): The RL data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + tokenizer.padding_side = 'right' # right + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + list_of_input_ids = [] + list_of_prompt_len = [] + list_of_prompts = [] + list_of_prompt_ids = [] + list_of_num_turns = [] + return_dict: dict[str, Any] = {} + + # case 1: input_ids key is present + if 'input_ids' in data[0]: + list_of_input_ids = [item['input_ids'] for item in data] + list_of_prompt_len = [item['prompt_len'] for item in data] + + # case 2: turn_data key is present + elif "turn_data" in data[0]: + for data_point in data: + list_of_input_ids.extend([turn['input_ids'] for turn in data_point['turn_data']]) + list_of_prompt_len.extend([turn['prompt_len'] for turn in data_point['turn_data']]) + list_of_num_turns.append(torch.tensor([len(data_point['turn_data'])], dtype=torch.int64)) + + elif 'prompt' in data[0]: + list_of_prompts = [item['prompt'] for item in data] + list_of_prompt_len = [item['prompt_len'] for item in data] + list_of_prompt_ids = [item['prompt_id'] for item in data] + + + if len(list_of_input_ids) > 0: # dealing with input_ids if it not empty. batch, padd, and truncate based on max_seq_len + batch_input_ids = ref_collate_fn(list_of_input_ids)['input_ids'] + attention_masks = torch.logical_not(torch.eq(batch_input_ids, tokenizer.pad_token_id)).to(torch.int64) + # truncate if length of the batch exceeds max_seq_len + batch_max_seq_len = batch_input_ids.shape[1] + if batch_max_seq_len > max_seq_len: + batch_input_ids = batch_input_ids[:,:max_seq_len] + attention_masks = attention_masks[:,:max_seq_len] + # pad eos token on the sequence that is truncated + for i in range(batch_input_ids.shape[0]): + if batch_input_ids[i,-1] != tokenizer.eos_token_id and batch_input_ids[i,-1] != tokenizer.pad_token_id: + batch_input_ids[i,-1] = tokenizer.eos_token_id + + sequence_lens = torch.sum(attention_masks, dim = -1) + prompt_lens = torch.cat(list_of_prompt_len) + + # Add sequence_id tracking (like offline_dataset_collate_fn) + sequence_id = [] + for i in range(batch_input_ids.shape[0]): + cur_seq_len = int(sequence_lens[i].item()) + pad_len = int(batch_input_ids.shape[1] - cur_seq_len) + cur_sequence_id = torch.tensor([0] * cur_seq_len + [-1] * pad_len) + sequence_id.append(cur_sequence_id) + + return_dict = { + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'sequence_id': torch.stack(sequence_id), + } + + if 'mask' in data[0]: # check if additional mask is provided, if so process it and add it to the return dict + masks = [] + for i in range(batch_input_ids.shape[0]): + mask_i = data[i]['mask'] + if len(mask_i) < len(batch_input_ids[i]): # right padded + all_zeros = torch.zeros(len(batch_input_ids[i])) + all_zeros[0:len(mask_i)] = mask_i + mask_i = all_zeros + else: # truncated + mask_i = mask_i[0:len(batch_input_ids[i])] + masks.append(mask_i) + masks = torch.stack(masks) + return_dict['mask'] = masks + + if len(list_of_prompts) > 0: # dealing with prompts if present + tokenizer.padding_side = 'left' # switch to left padding for prompts + return_dict['prompt'] = ref_collate_fn(list_of_prompts)['input_ids'] + prompt_attention_mask = torch.logical_not(torch.eq(return_dict['prompt'], tokenizer.pad_token_id)).to(torch.int64) + return_dict['prompt_attention_mask'] = prompt_attention_mask + return_dict['prompt_id'] = torch.cat(list_of_prompt_ids) + return_dict['prompt_len'] = torch.cat(list_of_prompt_len) + + + if len(list_of_num_turns) > 0: # this is the case where we have turn level data + assert 'turn_data' in data[0], "turn_data must be present if num_turns is present" + return_dict['num_turns'] = torch.cat(list_of_num_turns) + + + if 'reward' in data[0]: + return_dict['reward'] = torch.cat([item['reward'] for item in data]) + if 'bonus' in data[0]: + return_dict['bonus'] = torch.cat([item['bonus'] for item in data]) + if 'vstar_rewards' in data[0]: + return_dict['vstar_rewards'] = torch.stack([item['vstar_rewards'] for item in data]) + if 'vstar_bonus' in data[0]: + return_dict['vstar_bonus'] = torch.stack([item['vstar_bonus'] for item in data]) + if "verified_answer" in data[0]: + return_dict['verified_answer'] = list(utils.flatten([item['verified_answer'] for item in data])) + + return return_dict + + + +class RLStreamingDataset(StreamingDataset): + """Dataloader for streaming in RL data.""" + + def __init__(self, + max_seq_len: int, + max_gen_len: int, + tokenizer: PreTrainedTokenizer, + chat_template: Optional[str] = None, + chat_template_path: Optional[str] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.max_seq_len = max_seq_len + self.max_gen_len = max_gen_len + self.tokenizer = tokenizer + + # Handle chat template (priority: file path > direct template > default) + if chat_template_path is not None: + # Load template from file + import os + + # Convert to absolute path for clarity + abs_template_path = os.path.abspath(chat_template_path) + + if not os.path.exists(abs_template_path): + raise FileNotFoundError(f"Chat template file not found: {chat_template_path} (resolved to: {abs_template_path})") + + with open(abs_template_path, 'r', encoding='utf-8') as f: + self.chat_template = f.read().strip() + log.info(f"Loaded chat template from: {abs_template_path}") + # Apply it to the tokenizer + self.tokenizer.chat_template = self.chat_template + + elif chat_template is not None: + # Use direct template string + self.chat_template = chat_template + # Apply it to the tokenizer + self.tokenizer.chat_template = chat_template + + else: + # Use tokenizer's default chat template + self.chat_template = getattr(tokenizer, 'chat_template', None) + + + def __getitem__(self, idx: int) -> dict[str, Any]: + sample = super().__getitem__(idx) + + return_dict: dict[str, Any] = {} + + prompt_id = None + prompt = None + prompt_len = None + input_ids = None + sequence_len = None + mask = None + turn_data: list[dict[str, Any]] = [] + + # case 0: just contains prompt. This is for online RL setting + if 'prompt' in sample and 'response' not in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + prompt = torch.from_numpy(sample['prompt']) + prompt_id = idx + prompt_len = len(prompt) + + # case 1: prompt + response, we assume both are tokenized ndarray; this is for standard single turn offline rl + elif 'prompt' in sample and 'response' in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + assert isinstance(sample['response'], np.ndarray), f"Response must be a numpy array, but got {type(sample['response'])}" + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt_len = len(torch.from_numpy(sample['prompt'])) + sequence_len = len(input_ids) + + # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns + elif 'input' in sample and 'mask' in sample: + assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" + assert isinstance(sample['mask'], np.ndarray), f"Mask must be a numpy array, but got {type(sample['mask'])}" + + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) + + prompt_len = 0 + sequence_len = len(input_ids) + + # case 3: for multi-turn data, and sample['messages] contains a list of messages in text + elif 'messages' in sample: + messages = sample['messages'] + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == 'assistant': + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] + + assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] + input_ids = history_assistant + prompt_len = len(history) + sequence_len = len(input_ids) + + turn_data.append({ + 'input_ids': input_ids, + 'prompt_len': prompt_len, + 'sequence_len': sequence_len, + }) + + else: + raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") + + if len(turn_data) > 0: + return_dict['turn_data'] = turn_data + if prompt_id is not None: + return_dict['prompt_id'] = prompt_id + if prompt is not None: + return_dict['prompt'] = prompt + if prompt_len is not None: + return_dict['prompt_len'] = torch.tensor([prompt_len], dtype=torch.int64) + if input_ids is not None: + return_dict['input_ids'] = input_ids + if sequence_len is not None: + return_dict['sequence_len'] = torch.tensor([sequence_len], dtype=torch.int64) + if mask is not None: + return_dict['mask'] = mask + + if 'reward' in sample: + return_dict['reward'] = torch.tensor([sample['reward']]) + if 'bonus' in sample: + return_dict['bonus'] = torch.tensor([sample['bonus']]) + if 'vstar_rewards' in sample: + assert isinstance(sample['vstar_rewards'], np.ndarray), f"Vstar rewards must be a numpy array, but got {type(sample['vstar_rewards'])}" + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) + if 'vstar_bonus' in sample: + assert isinstance(sample['vstar_bonus'], np.ndarray), f"Vstar bonus must be a numpy array, but got {type(sample['vstar_bonus'])}" + return_dict['vstar_bonus'] = torch.from_numpy(sample['vstar_bonus']) + if 'verified_answer' in sample: + assert isinstance(sample['verified_answer'], str), f"Verified answer must be a string, but got {type(sample['verified_answer'])}" + return_dict['verified_answer'] = sample['verified_answer'] + + return return_dict + \ No newline at end of file From d0d7c951069cdb861571581ff4eff9e89bbd304f Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 14:32:28 -0400 Subject: [PATCH 160/195] dataloader --- compose_rl/data/__init__.py | 8 ++ compose_rl/data/dataloader.py | 10 ++ compose_rl/data/rl_data.py | 5 +- pyproject.toml | 1 + yamls/offline_apo.yaml | 167 ++++++++++++++++++++++++++++++++++ 5 files changed, 187 insertions(+), 4 deletions(-) create mode 100644 yamls/offline_apo.yaml diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index b5e14cba..87ec50a4 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -11,6 +11,7 @@ build_offline_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, + build_rl_dataloader, ) from compose_rl.data.messages_data import messages_dataset_collate_fn from compose_rl.data.offline_data import ( @@ -23,6 +24,10 @@ pairwise_preference_dataset_collate_fn, ) from compose_rl.data.prompt_data import prompt_dataset_collate_fn +from compose_rl.data.rl_data import ( + RLStreamingDataset, + dataset_collate_fn, +) __all__ = [ 'build_pairwise_preference_dataloader', @@ -30,6 +35,7 @@ 'build_messages_dataloader', 'build_offline_dataloader', 'build_prompt_dataloader', + 'build_rl_dataloader', 'DummyDataset', 'finegrained_preference_dataset_collate_fn', 'MinibatchRolloutBuffer', @@ -39,4 +45,6 @@ 'pairwise_preference_dataset_collate_fn', 'prompt_dataset_collate_fn', 'messages_dataset_collate_fn', + 'RLStreamingDataset', + 'dataset_collate_fn', ] diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 7c4ec6f5..ed600780 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -29,6 +29,10 @@ PromptStreamingDataset, prompt_dataset_collate_fn, ) +from compose_rl.data.rl_data import ( + RLStreamingDataset, + dataset_collate_fn, +) __all__ = [ 'build_finegrained_preference_dataloader', @@ -36,6 +40,7 @@ 'build_prompt_dataloader', 'build_messages_dataloader', 'build_offline_dataloader', + 'build_rl_dataloader', ] @@ -137,3 +142,8 @@ def build_preference_dataloader( OfflineStreamingDataset, offline_dataset_collate_fn_test, ) + +build_rl_dataloader = generate_dataloader_builder( + RLStreamingDataset, + dataset_collate_fn, +) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index b49c14cc..a12d4367 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -143,14 +143,12 @@ class RLStreamingDataset(StreamingDataset): def __init__(self, max_seq_len: int, - max_gen_len: int, tokenizer: PreTrainedTokenizer, chat_template: Optional[str] = None, chat_template_path: Optional[str] = None, **kwargs: Any): super().__init__(**kwargs) self.max_seq_len = max_seq_len - self.max_gen_len = max_gen_len self.tokenizer = tokenizer # Handle chat template (priority: file path > direct template > default) @@ -275,5 +273,4 @@ def __getitem__(self, idx: int) -> dict[str, Any]: assert isinstance(sample['verified_answer'], str), f"Verified answer must be a string, but got {type(sample['verified_answer'])}" return_dict['verified_answer'] = sample['verified_answer'] - return return_dict - \ No newline at end of file + return return_dict \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ba0e9fb6..fbc0434b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloade prompt = "compose_rl.data:build_prompt_dataloader" messages = "compose_rl.data:build_messages_dataloader" offline = "compose_rl.data:build_offline_dataloader" +unified_rl = "compose_rl.data:build_rl_dataloader" [project.entry-points."llmfoundry_callbacks_with_config"] offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback" diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml new file mode 100644 index 00000000..c5306bf8 --- /dev/null +++ b/yamls/offline_apo.yaml @@ -0,0 +1,167 @@ +name: apo-single-stream-openr1 +#name: apo-single-stream-math + +image: mosaicml/dle:nightly-latest +scheduling: + priority: high + max_retries: 0 + preemptible: true + retry_on_system_failure: false + +compute: + gpus: 16 #32 #16 + cluster: r5z2p1 + +#run_name: offline_apo_math_qwen +run_name: offline_apo_openr1_traj_wise_stream_deepseek + +parameters: + seed: 7338 + model: + name: hf_offline_lm + beta1: 1 + beta2: 0.01 + eta: 0.5 + loss_type: apo + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B #deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + #Qwen/Qwen3-4B-Instruct-2507 + #/tmp/local_qwen + #Qwen/Qwen3-4B + #Qwen/Qwen2.5-Coder-32B-Instruct + #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + #Qwen/Qwen2.5-7B + + loggers: + mlflow: + tracking_uri: databricks + experiment_name: offline_apo_openr1_single_resposne + #/Users/j.chang@databricks.com/mlflow_experiments/offline_apo_single_response_openr1 + + callbacks: + offline_rl: {} + lr_monitor: {} + scheduled_gc: + batch_interval: 2000 + speed_monitor: + window_size: 10 + memory_monitor: {} + hf_checkpointer: + overwrite: true + precision: bfloat16 + save_folder: s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/{run_name} + #s3://data-force-one-datasets/mosaicml-internal-checkpoints/jchang/models/hf/{run_name} + save_interval: 1ep + + optimizer: + lr: 1.0e-06 #7.0e-07 + eps: 1.0e-10 + name: decoupled_adamw + betas: + - 0.9 + - 0.95 + weight_decay: 1.0e-8 + precision: amp_bf16 + scheduler: + name: cosine_with_warmup + alpha_f: 0.05 + t_warmup: 0.02dur + #name: constant_with_warmup + #t_warmup: 0.01dur + + tokenizer: + name: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B #deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + #Qwen/Qwen3-4B-Instruct-2507 + #Qwen/Qwen2.5-7B + #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + kwargs: + trust_remote_code: true + padding_side: right + + algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1 + eval_first: false + + fsdp_config: + verbose: false + cpu_offload: false + mixed_precision: PURE + state_dict_type: sharded + forward_prefetch: true + backward_prefetch: BACKWARD_PRE + limit_all_gathers: true + sharding_strategy: FULL_SHARD + activation_cpu_offload: false + activation_checkpointing: true + activation_checkpointing_reentrant: false + + max_seq_len: 16384 #65536 #35000 #16384 #65K does not work due to oom. + dist_timeout: 3600 + max_duration: 1ep + autoresume: true + progress_bar: false + train_loader: + name: unified_rl + dataset: + local: /tmp/dataset/ + split: train + remote: dbfs:/Volumes/datasets/wensun/data/offline_apo/traj_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v3_medium_0.5/ + #dbfs:/Volumes/datasets/wensun/data/offline_apo/step_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ + # dbfs:/Volumes/datasets/wensun/data/offline_apo/traj_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v3_medium_0.5/ + #dbfs:/Volumes/datasets/jchang/data/offline_apo/apo_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ + shuffle: true + max_seq_len: ${max_seq_len} + shuffle_seed: 7338 + download_retry: 4 + download_timeout: 600 + drop_last: false + num_workers: 8 + save_folder: s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/mpt/{run_name} + #s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/{run_name} + save_overwrite: True + eval_interval: 1ep + save_interval: 1ep + log_to_console: true + python_log_level: debug + load_weights_only: true + console_log_interval: 5ba + device_eval_batch_size: 1 + eval_subset_num_batches: -1 + global_train_batch_size: 256 + device_train_microbatch_size: 1 + + + variables: + reference_model: + name: hf_offline_lm + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + +integrations: +- integration_type: git_repo + path: custom-llm-foundry + git_repo: mosaicml/llm-foundry + git_branch: main + pip_install: . --no-deps +- integration_type: git_repo + git_repo: databricks/Compose-RL + git_branch: wensun/offline_apo + pip_install: .[gpu] --no-deps +env_variables: + AWS_PROFILE: data-force-one +command: |- + # Run llm foundry train + + #mkdir -p /tmp/local_qwen + #s5cmd --numworkers 32 cp s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/offline_apo_tool_call-1-1.0-0.01-1e-06-4-256-4b_f_t/huggingface/ba689/* /tmp/local_qwen + + cd llm-foundry/scripts/train/ + composer train.py /mnt/config/parameters.yaml From d522ea644b6d0021bd2b26bb96595cd989c2b27d Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 15:22:56 -0400 Subject: [PATCH 161/195] tokenizer --- compose_rl/data/dataloader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index ed600780..caa39c78 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -92,6 +92,11 @@ def build_preference_dataloader( MessagesStreamingDataset, ) and 'tokenizer' not in dataset_cfg: dataset_cfg['tokenizer'] = tokenizer + if issubclass( + dataset_cls, + RLStreamingDataset, + ) and 'tokenizer' not in dataset_cfg: + dataset_cfg['tokenizer'] = tokenizer streaming_dataset = dataset_cls( streams=streams, # type: ignore From fca744d02833f7cfbd5a2acc261b50fae06ba491 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 16:27:59 -0400 Subject: [PATCH 162/195] . --- compose_rl/data/rl_data.py | 44 ++++++++++++++++++++++++++++++++++++++ example_tools.jsonl | 1 + 2 files changed, 45 insertions(+) create mode 100644 example_tools.jsonl diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index a12d4367..74c46148 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -146,6 +146,8 @@ def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: Optional[str] = None, chat_template_path: Optional[str] = None, + tools: Optional[list[dict[str, Any]]] = None, + tools_path: Optional[str] = None, **kwargs: Any): super().__init__(**kwargs) self.max_seq_len = max_seq_len @@ -178,6 +180,48 @@ def __init__(self, # Use tokenizer's default chat template self.chat_template = getattr(tokenizer, 'chat_template', None) + # Handle tools (priority: file path > direct tools > None) + if tools_path is not None: + # Load tools from JSONL file (one JSON object per line) + import json + import os + + abs_tools_path = os.path.abspath(tools_path) + if not os.path.exists(abs_tools_path): + raise FileNotFoundError(f"Tools file not found: {tools_path} (resolved to: {abs_tools_path})") + + self.tools = [] + with open(abs_tools_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: # Skip empty lines + continue + try: + tool = json.loads(line) + if not isinstance(tool, dict): + raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") + self.tools.append(tool) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") + + log.info(f"Loaded {len(self.tools)} tools from JSONL file: {abs_tools_path}") + + elif tools is not None: + # Use direct tools list + if not isinstance(tools, list): + raise ValueError(f"Tools must be a list, but got {type(tools)}") + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") + + self.tools = tools + log.info(f"Using {len(self.tools)} tools provided directly") + + else: + # No tools provided + self.tools = None + def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) diff --git a/example_tools.jsonl b/example_tools.jsonl new file mode 100644 index 00000000..a8122875 --- /dev/null +++ b/example_tools.jsonl @@ -0,0 +1 @@ +{"type": "function", "function": {"name": "vector_search", "description": "Search for similar documents in a vector index using text queries.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Text query to search for similar documents"}}, "required": ["query"]}}} \ No newline at end of file From e9470b4d10a2fc002c10dd706c214205f30b0734 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 21:45:50 -0400 Subject: [PATCH 163/195] test tool loading and message loading --- compose_rl/data/rl_data.py | 13 +++++++------ yamls/offline_apo.yaml | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 74c46148..90600c7d 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -181,6 +181,7 @@ def __init__(self, self.chat_template = getattr(tokenizer, 'chat_template', None) # Handle tools (priority: file path > direct tools > None) + self.tools = [] if tools_path is not None: # Load tools from JSONL file (one JSON object per line) import json @@ -190,7 +191,6 @@ def __init__(self, if not os.path.exists(abs_tools_path): raise FileNotFoundError(f"Tools file not found: {tools_path} (resolved to: {abs_tools_path})") - self.tools = [] with open(abs_tools_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() @@ -218,9 +218,9 @@ def __init__(self, self.tools = tools log.info(f"Using {len(self.tools)} tools provided directly") - else: - # No tools provided - self.tools = None + print("############# Debug: tools #############") + print(self.tools) + print("############# Debug: tools #############") def __getitem__(self, idx: int) -> dict[str, Any]: @@ -265,14 +265,15 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # case 3: for multi-turn data, and sample['messages] contains a list of messages in text elif 'messages' in sample: + print("############# Debug: messages in sample #############") messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" for i in range(len(messages)): message = messages[i] assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools = self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools = self.tools, add_generation_prompt=False, return_tensors='pt')[0] assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] input_ids = history_assistant diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml index c5306bf8..e9f36ab3 100644 --- a/yamls/offline_apo.yaml +++ b/yamls/offline_apo.yaml @@ -116,6 +116,7 @@ parameters: #dbfs:/Volumes/datasets/jchang/data/offline_apo/apo_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ shuffle: true max_seq_len: ${max_seq_len} + tools_path: ../example_tools.jsonl shuffle_seed: 7338 download_retry: 4 download_timeout: 600 From 0511355d7b0550735199f9deb957b3098a914c2c Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 22:20:53 -0400 Subject: [PATCH 164/195] . --- compose_rl/data/rl_data.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 90600c7d..727847ab 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -211,16 +211,32 @@ def __init__(self, if not isinstance(tools, list): raise ValueError(f"Tools must be a list, but got {type(tools)}") + # Clean and validate tools to remove any Undefined objects + self.tools = [] for i, tool in enumerate(tools): if not isinstance(tool, dict): raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") + + try: + cleaned_tool = self._clean_tool_structure(tool) + self.tools.append(cleaned_tool) + except Exception as e: + log.error(f"Failed to clean tool {i}: {e}") + print(f"############# Debug: Problematic tool {i} #############") + print(f"Tool type: {type(tool)}") + print(f"Tool content: {tool}") + print("############# End Debug #############") + raise ValueError(f"Tool {i} contains invalid data: {e}") - self.tools = tools - log.info(f"Using {len(self.tools)} tools provided directly") + print("############# Debug: Cleaned tools #############") + print(f"Tools type: {type(self.tools)}") + print(f"Number of tools: {len(self.tools)}") + if self.tools: + print(f"First tool: {self.tools[0]}") + print("############# End Debug #############") + log.info(f"Using {len(self.tools)} cleaned tools provided directly") - print("############# Debug: tools #############") - print(self.tools) - print("############# Debug: tools #############") + def __getitem__(self, idx: int) -> dict[str, Any]: From 23576c4da065642f37e54b302d348ff49886357f Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 22:44:17 -0400 Subject: [PATCH 165/195] . --- compose_rl/data/rl_data.py | 25 +++++-------------------- yamls/offline_apo.yaml | 5 +++-- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 727847ab..88a62129 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -201,6 +201,9 @@ def __init__(self, if not isinstance(tool, dict): raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") self.tools.append(tool) + print("############# Debug: tools #############") + print(self.tools) + print("############# Debug: tools #############") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") @@ -211,30 +214,12 @@ def __init__(self, if not isinstance(tools, list): raise ValueError(f"Tools must be a list, but got {type(tools)}") - # Clean and validate tools to remove any Undefined objects - self.tools = [] for i, tool in enumerate(tools): if not isinstance(tool, dict): raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") - - try: - cleaned_tool = self._clean_tool_structure(tool) - self.tools.append(cleaned_tool) - except Exception as e: - log.error(f"Failed to clean tool {i}: {e}") - print(f"############# Debug: Problematic tool {i} #############") - print(f"Tool type: {type(tool)}") - print(f"Tool content: {tool}") - print("############# End Debug #############") - raise ValueError(f"Tool {i} contains invalid data: {e}") - print("############# Debug: Cleaned tools #############") - print(f"Tools type: {type(self.tools)}") - print(f"Number of tools: {len(self.tools)}") - if self.tools: - print(f"First tool: {self.tools[0]}") - print("############# End Debug #############") - log.info(f"Using {len(self.tools)} cleaned tools provided directly") + self.tools = tools + log.info(f"Using {len(self.tools)} tools provided directly") diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml index e9f36ab3..a82304b3 100644 --- a/yamls/offline_apo.yaml +++ b/yamls/offline_apo.yaml @@ -116,7 +116,7 @@ parameters: #dbfs:/Volumes/datasets/jchang/data/offline_apo/apo_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ shuffle: true max_seq_len: ${max_seq_len} - tools_path: ../example_tools.jsonl + tools_path: ../../../compose-rl/example_tools.jsonl shuffle_seed: 7338 download_retry: 4 download_timeout: 600 @@ -153,6 +153,7 @@ integrations: git_branch: main pip_install: . --no-deps - integration_type: git_repo + path: compose-rl git_repo: databricks/Compose-RL git_branch: wensun/offline_apo pip_install: .[gpu] --no-deps @@ -164,5 +165,5 @@ command: |- #mkdir -p /tmp/local_qwen #s5cmd --numworkers 32 cp s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/offline_apo_tool_call-1-1.0-0.01-1e-06-4-256-4b_f_t/huggingface/ba689/* /tmp/local_qwen - cd llm-foundry/scripts/train/ + cd custom-llm-foundry/scripts/train/ composer train.py /mnt/config/parameters.yaml From acda258f76457c05fee916910ddceae8e79b4ef8 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 22:48:18 -0400 Subject: [PATCH 166/195] . --- compose_rl/data/rl_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 88a62129..5d50e4ec 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -273,8 +273,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: message = messages[i] assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools = self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools = self.tools, add_generation_prompt=False, return_tensors='pt')[0] + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] input_ids = history_assistant From ffce2f2658962ea73d54fadafcad117b74c653ea Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 22:50:16 -0400 Subject: [PATCH 167/195] . --- compose_rl/data/rl_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 5d50e4ec..b14713e4 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -270,6 +270,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" for i in range(len(messages)): + print("############# Debug: message in messages #############") + print(messages[i]) + print("############# Debug: message in messages #############") message = messages[i] assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': From 2bfe9c41e1bbe99a5063c17ff1d833ad3d01e34f Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 22:56:03 -0400 Subject: [PATCH 168/195] . --- compose_rl/data/rl_data.py | 51 +++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index b14713e4..e87c346f 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -269,15 +269,58 @@ def __getitem__(self, idx: int) -> dict[str, Any]: print("############# Debug: messages in sample #############") messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + + # Clean messages to convert any Undefined objects to proper values + cleaned_messages = [] + for msg in messages: + cleaned_msg = {} + for key, value in msg.items(): + # Convert Undefined objects to None or proper values + if hasattr(value, '__class__') and 'Undefined' in str(value.__class__): + print(f"Found Undefined value in message[{key}]: {value}") + cleaned_msg[key] = None # Convert Undefined to None + elif str(value) == 'None' and not isinstance(value, type(None)): + # Handle cases where Undefined prints as 'None' but isn't actually None + cleaned_msg[key] = None + else: + cleaned_msg[key] = value + cleaned_messages.append(cleaned_msg) + + messages = cleaned_messages + print(f"Using {len(messages)} cleaned messages") + + # Test that cleaned messages are JSON serializable + import json + try: + json.dumps(messages) + print("āœ… Cleaned messages are JSON serializable") + except (TypeError, ValueError) as e: + print(f"āŒ Cleaned messages still not JSON serializable: {e}") + for i in range(len(messages)): - print("############# Debug: message in messages #############") + print("############# Debug: cleaned message #############") print(messages[i]) - print("############# Debug: message in messages #############") + print("############# Debug: cleaned message #############") message = messages[i] assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] + try: + print(f"šŸ”„ Applying chat template for history (messages 0 to {i-1})") + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + print("āœ… History template applied successfully") + except Exception as e: + print(f"āŒ Error in history template: {e}") + print(f"Problematic messages slice: {messages[:i]}") + raise e + + try: + print(f"šŸ”„ Applying chat template for history_assistant (messages 0 to {i})") + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] + print("āœ… History_assistant template applied successfully") + except Exception as e: + print(f"āŒ Error in history_assistant template: {e}") + print(f"Problematic messages slice: {messages[:i+1]}") + raise e assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] input_ids = history_assistant From 09c3ae4c046826269542a3d8ab6e732fda2af6ab Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 23:01:26 -0400 Subject: [PATCH 169/195] . --- compose_rl/data/rl_data.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index e87c346f..d537bf1e 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -271,6 +271,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" # Clean messages to convert any Undefined objects to proper values + import json cleaned_messages = [] for msg in messages: cleaned_msg = {} @@ -282,6 +283,16 @@ def __getitem__(self, idx: int) -> dict[str, Any]: elif str(value) == 'None' and not isinstance(value, type(None)): # Handle cases where Undefined prints as 'None' but isn't actually None cleaned_msg[key] = None + elif key == 'tool_calls' and isinstance(value, str) and value and value != 'None': + # Parse tool_calls JSON string into proper Python objects + try: + print(f"šŸ”§ Parsing tool_calls JSON: {value}") + parsed_tool_calls = json.loads(value) + cleaned_msg[key] = parsed_tool_calls + print(f"āœ… Successfully parsed tool_calls: {parsed_tool_calls}") + except json.JSONDecodeError as e: + print(f"āŒ Failed to parse tool_calls JSON: {e}") + cleaned_msg[key] = None # Fallback to None if parsing fails else: cleaned_msg[key] = value cleaned_messages.append(cleaned_msg) @@ -306,7 +317,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if message['role'] == 'assistant': try: print(f"šŸ”„ Applying chat template for history (messages 0 to {i-1})") - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=None, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token print("āœ… History template applied successfully") except Exception as e: print(f"āŒ Error in history template: {e}") @@ -315,7 +326,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: try: print(f"šŸ”„ Applying chat template for history_assistant (messages 0 to {i})") - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, add_generation_prompt=False, return_tensors='pt')[0] + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=None, add_generation_prompt=False, return_tensors='pt')[0] print("āœ… History_assistant template applied successfully") except Exception as e: print(f"āŒ Error in history_assistant template: {e}") From 230f5f4946cec47f88b5edb1aca70a2b7a49c2f5 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 23:15:53 -0400 Subject: [PATCH 170/195] . --- compose_rl/data/rl_data.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index d537bf1e..688f6a6f 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -220,9 +220,6 @@ def __init__(self, self.tools = tools log.info(f"Using {len(self.tools)} tools provided directly") - - - def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) @@ -284,15 +281,11 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # Handle cases where Undefined prints as 'None' but isn't actually None cleaned_msg[key] = None elif key == 'tool_calls' and isinstance(value, str) and value and value != 'None': - # Parse tool_calls JSON string into proper Python objects - try: - print(f"šŸ”§ Parsing tool_calls JSON: {value}") - parsed_tool_calls = json.loads(value) - cleaned_msg[key] = parsed_tool_calls - print(f"āœ… Successfully parsed tool_calls: {parsed_tool_calls}") - except json.JSONDecodeError as e: - print(f"āŒ Failed to parse tool_calls JSON: {e}") - cleaned_msg[key] = None # Fallback to None if parsing fails + # Keep tool_calls as string to avoid parsing issues + # When we add tools back later, we can parse this properly + print(f"šŸ”§ Keeping tool_calls as string: {value[:100]}...") + cleaned_msg[key] = str(value) # Ensure it's a plain string + print(f"āœ… Preserved tool_calls as string (length: {len(value)})") else: cleaned_msg[key] = value cleaned_messages.append(cleaned_msg) @@ -317,7 +310,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if message['role'] == 'assistant': try: print(f"šŸ”„ Applying chat template for history (messages 0 to {i-1})") - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=None, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token print("āœ… History template applied successfully") except Exception as e: print(f"āŒ Error in history template: {e}") @@ -326,7 +319,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: try: print(f"šŸ”„ Applying chat template for history_assistant (messages 0 to {i})") - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=None, add_generation_prompt=False, return_tensors='pt')[0] + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] print("āœ… History_assistant template applied successfully") except Exception as e: print(f"āŒ Error in history_assistant template: {e}") From dddbdd4c00290eb8cd2ce1f9ffb99853fbe66ed2 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 23:24:13 -0400 Subject: [PATCH 171/195] . --- compose_rl/data/rl_data.py | 41 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 688f6a6f..000526f3 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -267,28 +267,29 @@ def __getitem__(self, idx: int) -> dict[str, Any]: messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" - # Clean messages to convert any Undefined objects to proper values + # Clean messages by doing JSON round-trip - forces all values to be native Python types import json cleaned_messages = [] - for msg in messages: - cleaned_msg = {} - for key, value in msg.items(): - # Convert Undefined objects to None or proper values - if hasattr(value, '__class__') and 'Undefined' in str(value.__class__): - print(f"Found Undefined value in message[{key}]: {value}") - cleaned_msg[key] = None # Convert Undefined to None - elif str(value) == 'None' and not isinstance(value, type(None)): - # Handle cases where Undefined prints as 'None' but isn't actually None - cleaned_msg[key] = None - elif key == 'tool_calls' and isinstance(value, str) and value and value != 'None': - # Keep tool_calls as string to avoid parsing issues - # When we add tools back later, we can parse this properly - print(f"šŸ”§ Keeping tool_calls as string: {value[:100]}...") - cleaned_msg[key] = str(value) # Ensure it's a plain string - print(f"āœ… Preserved tool_calls as string (length: {len(value)})") - else: - cleaned_msg[key] = value - cleaned_messages.append(cleaned_msg) + for i, msg in enumerate(messages): + try: + # Serialize and deserialize to clean all Undefined objects + json_str = json.dumps(msg) + cleaned_msg = json.loads(json_str) + print(f"āœ… Message {i} cleaned via JSON round-trip") + cleaned_messages.append(cleaned_msg) + except (TypeError, ValueError) as e: + print(f"āŒ Message {i} failed JSON round-trip: {e}") + print(f" Problematic message: {msg}") + # Fallback: create a minimal safe message + safe_msg = { + 'role': msg.get('role', 'unknown'), + 'content': str(msg.get('content', '')) if msg.get('content') else None, + 'tool_calls': None, + 'tool_call_id': None, + 'name': None + } + print(f" Using fallback safe message: {safe_msg}") + cleaned_messages.append(safe_msg) messages = cleaned_messages print(f"Using {len(messages)} cleaned messages") From 969d18220671caafaed7a6179fa0af49fbad5e03 Mon Sep 17 00:00:00 2001 From: wensun Date: Fri, 5 Sep 2025 23:30:21 -0400 Subject: [PATCH 172/195] . --- compose_rl/data/rl_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 000526f3..6c8a57c8 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -311,7 +311,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if message['role'] == 'assistant': try: print(f"šŸ”„ Applying chat template for history (messages 0 to {i-1})") - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=None, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token print("āœ… History template applied successfully") except Exception as e: print(f"āŒ Error in history template: {e}") @@ -320,7 +320,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: try: print(f"šŸ”„ Applying chat template for history_assistant (messages 0 to {i})") - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=None, add_generation_prompt=False, return_tensors='pt')[0] print("āœ… History_assistant template applied successfully") except Exception as e: print(f"āŒ Error in history_assistant template: {e}") From 2f9631017dd4a4ee44c712436fcd6e5fbcc09ac0 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 00:37:07 -0400 Subject: [PATCH 173/195] new data formt --- compose_rl/data/rl_data.py | 40 +------------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 6c8a57c8..a4fa3806 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -266,46 +266,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: print("############# Debug: messages in sample #############") messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" - - # Clean messages by doing JSON round-trip - forces all values to be native Python types - import json - cleaned_messages = [] - for i, msg in enumerate(messages): - try: - # Serialize and deserialize to clean all Undefined objects - json_str = json.dumps(msg) - cleaned_msg = json.loads(json_str) - print(f"āœ… Message {i} cleaned via JSON round-trip") - cleaned_messages.append(cleaned_msg) - except (TypeError, ValueError) as e: - print(f"āŒ Message {i} failed JSON round-trip: {e}") - print(f" Problematic message: {msg}") - # Fallback: create a minimal safe message - safe_msg = { - 'role': msg.get('role', 'unknown'), - 'content': str(msg.get('content', '')) if msg.get('content') else None, - 'tool_calls': None, - 'tool_call_id': None, - 'name': None - } - print(f" Using fallback safe message: {safe_msg}") - cleaned_messages.append(safe_msg) - - messages = cleaned_messages - print(f"Using {len(messages)} cleaned messages") - - # Test that cleaned messages are JSON serializable - import json - try: - json.dumps(messages) - print("āœ… Cleaned messages are JSON serializable") - except (TypeError, ValueError) as e: - print(f"āŒ Cleaned messages still not JSON serializable: {e}") - + for i in range(len(messages)): - print("############# Debug: cleaned message #############") - print(messages[i]) - print("############# Debug: cleaned message #############") message = messages[i] assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': From d6f31a74268f5fa9ecc77d2bb59253c9d19c5448 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 00:40:14 -0400 Subject: [PATCH 174/195] test tool --- compose_rl/data/rl_data.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index a4fa3806..38d9f8ff 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -272,20 +272,16 @@ def __getitem__(self, idx: int) -> dict[str, Any]: assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" if message['role'] == 'assistant': try: - print(f"šŸ”„ Applying chat template for history (messages 0 to {i-1})") - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=None, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token - print("āœ… History template applied successfully") + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token except Exception as e: - print(f"āŒ Error in history template: {e}") + print(f"Error in history template: {e}") print(f"Problematic messages slice: {messages[:i]}") raise e try: - print(f"šŸ”„ Applying chat template for history_assistant (messages 0 to {i})") - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=None, add_generation_prompt=False, return_tensors='pt')[0] - print("āœ… History_assistant template applied successfully") + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] except Exception as e: - print(f"āŒ Error in history_assistant template: {e}") + print(f"Error in history_assistant template: {e}") print(f"Problematic messages slice: {messages[:i+1]}") raise e @@ -293,6 +289,11 @@ def __getitem__(self, idx: int) -> dict[str, Any]: input_ids = history_assistant prompt_len = len(history) sequence_len = len(input_ids) + print(f"History: {history}") + print(f"History assistant: {history_assistant}") + print(f"Input ids: {input_ids}") + print(f"Prompt len: {prompt_len}") + print(f"Sequence len: {sequence_len}") turn_data.append({ 'input_ids': input_ids, From 6cdb6197a715f6845f24cc14e59282af9739c3fa Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 00:43:09 -0400 Subject: [PATCH 175/195] test tool --- compose_rl/data/rl_data.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 38d9f8ff..31e106f7 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -263,7 +263,6 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # case 3: for multi-turn data, and sample['messages] contains a list of messages in text elif 'messages' in sample: - print("############# Debug: messages in sample #############") messages = sample['messages'] assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" @@ -280,6 +279,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: try: history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + history_assistan_text = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=False, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] except Exception as e: print(f"Error in history_assistant template: {e}") print(f"Problematic messages slice: {messages[:i+1]}") @@ -289,11 +289,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: input_ids = history_assistant prompt_len = len(history) sequence_len = len(input_ids) - print(f"History: {history}") - print(f"History assistant: {history_assistant}") - print(f"Input ids: {input_ids}") - print(f"Prompt len: {prompt_len}") - print(f"Sequence len: {sequence_len}") + print(f"History assistant: {history_assistan_text}") turn_data.append({ 'input_ids': input_ids, From 2382630afd4a52f52e0ecc0efdfd7168fa07b1db Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 00:46:00 -0400 Subject: [PATCH 176/195] test tool --- compose_rl/data/rl_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 31e106f7..2f71bf9f 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -279,7 +279,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: try: history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] - history_assistan_text = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=False, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + history_assistan_text = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=False, tools=self.tools, add_generation_prompt=False, return_tensors='pt') except Exception as e: print(f"Error in history_assistant template: {e}") print(f"Problematic messages slice: {messages[:i+1]}") From 42cf662eb8dbeef793950d62e85770f26d8a529d Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 09:39:34 -0400 Subject: [PATCH 177/195] message working --- compose_rl/data/rl_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 2f71bf9f..7f60b921 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -289,7 +289,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: input_ids = history_assistant prompt_len = len(history) sequence_len = len(input_ids) - print(f"History assistant: {history_assistan_text}") + turn_data.append({ 'input_ids': input_ids, From 2642698ca42520077200a2c3c0c25201b937de07 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 15:05:14 -0400 Subject: [PATCH 178/195] implemented flatten messages --- compose_rl/data/rl_data.py | 155 +++++--- test_rl_data_comprehensive.py | 651 ++++++++++++++++++++++++++++++++++ yamls/offline_apo.yaml | 1 + 3 files changed, 755 insertions(+), 52 deletions(-) create mode 100644 test_rl_data_comprehensive.py diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 7f60b921..4eb7d302 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -148,11 +148,13 @@ def __init__(self, chat_template_path: Optional[str] = None, tools: Optional[list[dict[str, Any]]] = None, tools_path: Optional[str] = None, + flatten_messages: bool = False, **kwargs: Any): super().__init__(**kwargs) self.max_seq_len = max_seq_len self.tokenizer = tokenizer - + self.flatten_messages = flatten_messages + # Handle chat template (priority: file path > direct template > default) if chat_template_path is not None: # Load template from file @@ -180,6 +182,8 @@ def __init__(self, # Use tokenizer's default chat template self.chat_template = getattr(tokenizer, 'chat_template', None) + print(f"Using chat template: {self.chat_template}") + # Handle tools (priority: file path > direct tools > None) self.tools = [] if tools_path is not None: @@ -201,9 +205,6 @@ def __init__(self, if not isinstance(tool, dict): raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") self.tools.append(tool) - print("############# Debug: tools #############") - print(self.tools) - print("############# Debug: tools #############") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") @@ -220,6 +221,57 @@ def __init__(self, self.tools = tools log.info(f"Using {len(self.tools)} tools provided directly") + + print(f"Using {len(self.tools)} tools, and tools are {self.tools}") + + def _convert_messages_to_turn_wise_data(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + #convert mesage to turn wise data + turn_data: list[dict[str, Any]] = [] + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == 'assistant': + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] + + input_ids = history_assistant + prompt_len = len(history) + sequence_len = len(input_ids) + + turn_data.append({ + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + }) + + return turn_data # list of dict + + def _convert_messages_to_traj_wise_data(self, messages: list[dict[str, Any]]) -> dict[str, Any]: + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + mask = [] + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == "assistant": + history = self.tokenizer.apply_chat_template(messages[0:i], return_tensors='pt', add_generation_prompt=True, tokenize=True, tools=self.tools)[0] + history_assistant = self.tokenizer.apply_chat_template(messages[0:i+1], return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + generation_len = len(history_assistant) - len(history) + current_mask = [0]*len(history) + [1]*generation_len + current_mask[0:len(mask)] = mask + mask = current_mask + + input_ids = self.tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + assert len(input_ids) == len(mask), f"Input ids and mask must have the same length" + return_dict = { + 'input_ids': input_ids, + 'mask': torch.tensor(mask, dtype=torch.int64), + } + + return return_dict # dict + def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) @@ -232,7 +284,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: input_ids = None sequence_len = None mask = None - turn_data: list[dict[str, Any]] = [] + turn_data = None # list[dict[str, Any]] = [] # case 0: just contains prompt. This is for online RL setting if 'prompt' in sample and 'response' not in sample: @@ -241,6 +293,13 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt_id = idx prompt_len = len(prompt) + # return dict for case 0: + return_dict = { + 'prompt': prompt, + 'prompt_id': prompt_id, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + } + # case 1: prompt + response, we assume both are tokenized ndarray; this is for standard single turn offline rl elif 'prompt' in sample and 'response' in sample: assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" @@ -250,6 +309,13 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt_len = len(torch.from_numpy(sample['prompt'])) sequence_len = len(input_ids) + # return dict for case 1: + return_dict = { + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns elif 'input' in sample and 'mask' in sample: assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" @@ -260,61 +326,46 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt_len = 0 sequence_len = len(input_ids) + + # return dict for case 2: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } # case 3: for multi-turn data, and sample['messages] contains a list of messages in text elif 'messages' in sample: messages = sample['messages'] - assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" - - for i in range(len(messages)): - message = messages[i] - assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" - if message['role'] == 'assistant': - try: - history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token - except Exception as e: - print(f"Error in history template: {e}") - print(f"Problematic messages slice: {messages[:i]}") - raise e - - try: - history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] - history_assistan_text = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=False, tools=self.tools, add_generation_prompt=False, return_tensors='pt') - except Exception as e: - print(f"Error in history_assistant template: {e}") - print(f"Problematic messages slice: {messages[:i+1]}") - raise e - - assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] - input_ids = history_assistant - prompt_len = len(history) - sequence_len = len(input_ids) - + if self.flatten_messages is False: + turn_data = self._convert_messages_to_turn_wise_data(messages) # list of dict, one dict per assistant turn + # return dict for case 3.a: + return_dict = { + 'turn_data': turn_data, + } + else: + traj_data = self._convert_messages_to_traj_wise_data(messages) # dict, flatten the message into a single trajectory + input_ids = traj_data['input_ids'] + mask = traj_data['mask'] + prompt_len = 0 + sequence_len = len(input_ids) + # return dict for case 3.b: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, # Already converted to tensor in _convert_messages_to_traj_wise_data + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + print("#### test: print assistant tokens ####") + print(self.tokenizer.decode(input_ids[mask.bool()])) + print("#### test: done printing assistant tokens ####") - turn_data.append({ - 'input_ids': input_ids, - 'prompt_len': prompt_len, - 'sequence_len': sequence_len, - }) - else: raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") - if len(turn_data) > 0: - return_dict['turn_data'] = turn_data - if prompt_id is not None: - return_dict['prompt_id'] = prompt_id - if prompt is not None: - return_dict['prompt'] = prompt - if prompt_len is not None: - return_dict['prompt_len'] = torch.tensor([prompt_len], dtype=torch.int64) - if input_ids is not None: - return_dict['input_ids'] = input_ids - if sequence_len is not None: - return_dict['sequence_len'] = torch.tensor([sequence_len], dtype=torch.int64) - if mask is not None: - return_dict['mask'] = mask - + # now add additional keys if 'reward' in sample: return_dict['reward'] = torch.tensor([sample['reward']]) if 'bonus' in sample: diff --git a/test_rl_data_comprehensive.py b/test_rl_data_comprehensive.py new file mode 100644 index 00000000..8a81d853 --- /dev/null +++ b/test_rl_data_comprehensive.py @@ -0,0 +1,651 @@ +#!/usr/bin/env python3 +""" +Comprehensive test cases for RLStreamingDataset +Self-contained with mocked dependencies to avoid import issues. +""" + +import json +import os +import tempfile +import torch +import numpy as np +from typing import Any, Dict, List, Optional, Union +from unittest.mock import MagicMock + + +# Mock StreamingDataset base class +class MockStreamingDataset: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __getitem__(self, idx): + # Override in actual usage + pass + + +# Mock PreTrainedTokenizer +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + self.eos_token_id = 2 + self.padding_side = 'right' + self.chat_template = None + + def apply_chat_template(self, messages, tokenize=False, tools=None, add_generation_prompt=False, return_tensors=None): + """Mock chat template application""" + # Simulate tokenization by converting to simple token IDs + if not tokenize: + return "mocked_template_string" + + # Create mock token sequences based on message content + total_tokens = [] + + for msg in messages: + content = msg.get('content', '') or '' + role = msg.get('role', 'unknown') + + # Add role-specific prefix tokens + if role == 'system': + total_tokens.extend([1001, 1002]) # system prefix tokens + elif role == 'user': + total_tokens.extend([1003, 1004]) # user prefix tokens + elif role == 'assistant': + total_tokens.extend([1005, 1006]) # assistant prefix tokens + + # Simple tokenization: each word becomes a token ID + if content: + words = content.split() + token_ids = [hash(word) % 800 + 100 for word in words] # Mock content token IDs + total_tokens.extend(token_ids) + + # Add role-specific suffix tokens + total_tokens.append(1010) # end of message token + + # Add special tokens based on generation prompt + if add_generation_prompt: + total_tokens.extend([1005, 1006]) # Add assistant prefix for generation + else: + # Only add EOS if this is the final completion (no generation prompt) + if messages and messages[-1].get('role') == 'assistant': + total_tokens.append(self.eos_token_id) + + result_tensor = torch.tensor(total_tokens) + + if return_tensors == 'pt': + return result_tensor.unsqueeze(0) # Batch dimension + return result_tensor + + +# Copy the core RLStreamingDataset logic (simplified) +class RLStreamingDataset(MockStreamingDataset): + """Dataloader for streaming in RL data.""" + + def __init__(self, + max_seq_len: int, + tokenizer, + chat_template: Optional[str] = None, + chat_template_path: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tools_path: Optional[str] = None, + flatten_messages: bool = False, + **kwargs: Any): + super().__init__(**kwargs) + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + self.flatten_messages = flatten_messages + + # Handle chat template (priority: file path > direct template > default) + if chat_template_path is not None: + # Load template from file + abs_template_path = os.path.abspath(chat_template_path) + if not os.path.exists(abs_template_path): + raise FileNotFoundError(f"Chat template file not found: {chat_template_path} (resolved to: {abs_template_path})") + + with open(abs_template_path, 'r', encoding='utf-8') as f: + self.chat_template = f.read().strip() + # Apply it to the tokenizer + self.tokenizer.chat_template = self.chat_template + elif chat_template is not None: + # Use direct template string + self.chat_template = chat_template + self.tokenizer.chat_template = chat_template + else: + # Use tokenizer's default chat template + self.chat_template = getattr(tokenizer, 'chat_template', None) + + # Handle tools (priority: file path > direct tools > None) + self.tools = [] + if tools_path is not None: + # Load tools from JSONL file (one JSON object per line) + abs_tools_path = os.path.abspath(tools_path) + if not os.path.exists(abs_tools_path): + raise FileNotFoundError(f"Tools file not found: {tools_path} (resolved to: {abs_tools_path})") + + with open(abs_tools_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: # Skip empty lines + continue + try: + tool = json.loads(line) + if not isinstance(tool, dict): + raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") + self.tools.append(tool) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") + + elif tools is not None: + # Use direct tools list + if not isinstance(tools, list): + raise ValueError(f"Tools must be a list, but got {type(tools)}") + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") + + self.tools = tools + + def _convert_messages_to_turn_wise_data(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert messages to turn wise data""" + turn_data: List[Dict[str, Any]] = [] + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == 'assistant': + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" + + input_ids = history_assistant + prompt_len = len(history) + sequence_len = len(input_ids) + + turn_data.append({ + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + }) + + return turn_data + + def _convert_messages_to_traj_wise_data(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: + """Convert messages to trajectory wise data""" + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + mask = [] + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == "assistant": + history = self.tokenizer.apply_chat_template(messages[0:i], return_tensors='pt', add_generation_prompt=True, tokenize=True, tools=self.tools)[0] + history_assistant = self.tokenizer.apply_chat_template(messages[0:i+1], return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + generation_len = len(history_assistant) - len(history) + current_mask = [0]*len(history) + [1]*generation_len + current_mask[0:len(mask)] = mask + mask = current_mask + + input_ids = self.tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + assert len(input_ids) == len(mask), f"Input ids and mask must have the same length, got {len(input_ids)} vs {len(mask)}" + return_dict = { + 'input_ids': input_ids, + 'mask': torch.tensor(mask, dtype=torch.int64), + } + return return_dict + + def __getitem__(self, idx: int) -> Dict[str, Any]: + # Mock the parent's __getitem__ to return test data + if hasattr(self, '_test_samples') and idx < len(self._test_samples): + sample = self._test_samples[idx] + else: + raise IndexError(f"Index {idx} out of range") + + return_dict: Dict[str, Any] = {} + + prompt_id = None + prompt = None + prompt_len = None + input_ids = None + sequence_len = None + mask = None + turn_data = None + + # case 0: just contains prompt. This is for online RL setting + if 'prompt' in sample and 'response' not in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + prompt = torch.from_numpy(sample['prompt']) + prompt_id = idx + prompt_len = len(prompt) + + # return dict for case 0: + return_dict = { + 'prompt': prompt, + 'prompt_id': prompt_id, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + } + + # case 1: prompt + response, we assume both are tokenized ndarray; this is for standard single turn offline rl + elif 'prompt' in sample and 'response' in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + assert isinstance(sample['response'], np.ndarray), f"Response must be a numpy array, but got {type(sample['response'])}" + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt_len = len(torch.from_numpy(sample['prompt'])) + sequence_len = len(input_ids) + + # return dict for case 1: + return_dict = { + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns + elif 'input' in sample and 'mask' in sample: + assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" + assert isinstance(sample['mask'], np.ndarray), f"Mask must be a numpy array, but got {type(sample['mask'])}" + + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) + + prompt_len = 0 + sequence_len = len(input_ids) + + # return dict for case 2: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 3: for multi-turn data, and sample['messages'] contains a list of messages in text + elif 'messages' in sample: + messages = sample['messages'] + if self.flatten_messages is False: + turn_data = self._convert_messages_to_turn_wise_data(messages) # list of dict, one dict per assistant turn + # return dict for case 3.a: + return_dict = { + 'turn_data': turn_data, + } + else: + traj_data = self._convert_messages_to_traj_wise_data(messages) # dict, flatten the message into a single trajectory + input_ids = traj_data['input_ids'] + mask = traj_data['mask'] + prompt_len = 0 + sequence_len = len(input_ids) + # return dict for case 3.b: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, # Already converted to tensor in _convert_messages_to_traj_wise_data + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + else: + raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") + + # now add additional keys + if 'reward' in sample: + return_dict['reward'] = torch.tensor([sample['reward']]) + if 'bonus' in sample: + return_dict['bonus'] = torch.tensor([sample['bonus']]) + if 'vstar_rewards' in sample: + assert isinstance(sample['vstar_rewards'], np.ndarray), f"Vstar rewards must be a numpy array, but got {type(sample['vstar_rewards'])}" + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) + if 'vstar_bonus' in sample: + assert isinstance(sample['vstar_bonus'], np.ndarray), f"Vstar bonus must be a numpy array, but got {type(sample['vstar_bonus'])}" + return_dict['vstar_bonus'] = torch.from_numpy(sample['vstar_bonus']) + if 'verified_answer' in sample: + assert isinstance(sample['verified_answer'], str), f"Verified answer must be a string, but got {type(sample['verified_answer'])}" + return_dict['verified_answer'] = sample['verified_answer'] + + return return_dict + + +# Test cases +def test_case_0_prompt_only(): + """Test case 0: prompt only (online RL)""" + print("🧪 Testing Case 0: Prompt only...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'prompt': np.array([10, 11, 12, 13, 14]), + 'reward': 0.5 + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'prompt' in result + assert 'prompt_id' in result + assert 'prompt_len' in result + assert 'reward' in result + assert isinstance(result['prompt'], torch.Tensor) + assert result['prompt_id'] == 0 + assert result['prompt_len'].item() == 5 + assert result['reward'].item() == 0.5 + + print("āœ… Case 0 passed!") + + +def test_case_1_prompt_response(): + """Test case 1: prompt + response (single turn offline RL)""" + print("🧪 Testing Case 1: Prompt + Response...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'prompt': np.array([10, 11, 12]), + 'response': np.array([20, 21, 22, 23]), + 'reward': 1.0 + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert 'reward' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert result['prompt_len'].item() == 3 + assert result['sequence_len'].item() == 7 # 3 + 4 = 7 + assert result['reward'].item() == 1.0 + + # Check concatenation + expected_input_ids = torch.tensor([10, 11, 12, 20, 21, 22, 23]) + assert torch.equal(result['input_ids'], expected_input_ids) + + print("āœ… Case 1 passed!") + + +def test_case_2_input_mask(): + """Test case 2: input + mask (multi-turn with mask)""" + print("🧪 Testing Case 2: Input + Mask...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'input': np.array([10, 11, 12, 20, 21, 22]), + 'mask': np.array([0, 0, 0, 1, 1, 1]), # Only last 3 tokens are assistant + 'vstar_rewards': np.array([0.1, 0.2, 0.3]) + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'mask' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert 'vstar_rewards' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert isinstance(result['mask'], torch.Tensor) + assert result['prompt_len'].item() == 0 + assert result['sequence_len'].item() == 6 + + # Check mask values + expected_mask = torch.tensor([0, 0, 0, 1, 1, 1]) + assert torch.equal(result['mask'], expected_mask) + + print("āœ… Case 2 passed!") + + +def test_case_3a_messages_turn_wise(): + """Test case 3a: messages with turn-wise processing""" + print("🧪 Testing Case 3a: Messages (Turn-wise)...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, flatten_messages=False) + + # Mock test data + test_sample = { + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi there! How can I help you?'}, + {'role': 'user', 'content': 'Tell me a joke.'}, + {'role': 'assistant', 'content': 'Why did the chicken cross the road?'} + ] + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'turn_data' in result + assert isinstance(result['turn_data'], list) + assert len(result['turn_data']) == 2 # Two assistant turns + + # Check first turn + turn1 = result['turn_data'][0] + assert 'input_ids' in turn1 + assert 'prompt_len' in turn1 + assert 'sequence_len' in turn1 + assert isinstance(turn1['input_ids'], torch.Tensor) + assert turn1['prompt_len'].item() > 0 + assert turn1['sequence_len'].item() > 0 + + # Check second turn + turn2 = result['turn_data'][1] + assert 'input_ids' in turn2 + assert 'prompt_len' in turn2 + assert 'sequence_len' in turn2 + assert isinstance(turn2['input_ids'], torch.Tensor) + + print("āœ… Case 3a passed!") + + +def test_case_3b_messages_trajectory_wise(): + """Test case 3b: messages with trajectory-wise processing""" + print("🧪 Testing Case 3b: Messages (Trajectory-wise)...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, flatten_messages=True) + + # Mock test data + test_sample = { + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi there!'}, + {'role': 'user', 'content': 'Thanks.'}, + {'role': 'assistant', 'content': 'Welcome!'} + ] + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'mask' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert isinstance(result['mask'], torch.Tensor) + assert result['prompt_len'].item() == 0 + assert result['sequence_len'].item() > 0 + + # Check mask - should have 1s for assistant tokens, 0s for others + assert len(result['mask']) == len(result['input_ids']) + assert torch.sum(result['mask']).item() > 0 # Some tokens should be marked as assistant + + print("āœ… Case 3b passed!") + + +def test_tools_from_list(): + """Test tools loading from direct list""" + print("🧪 Testing Tools from List...") + + tokenizer = MockTokenizer() + test_tools = [ + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {}} + } + } + ] + + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools=test_tools) + + # Assertions + assert len(dataset.tools) == 1 + assert dataset.tools[0]["type"] == "function" + assert dataset.tools[0]["function"]["name"] == "test_tool" + + print("āœ… Tools from list passed!") + + +def test_tools_from_jsonl_file(): + """Test tools loading from JSONL file""" + print("🧪 Testing Tools from JSONL file...") + + # Create temporary JSONL file + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: + tool1 = {"type": "function", "function": {"name": "tool1", "description": "First tool"}} + tool2 = {"type": "function", "function": {"name": "tool2", "description": "Second tool"}} + f.write(json.dumps(tool1) + '\n') + f.write(json.dumps(tool2) + '\n') + temp_file = f.name + + try: + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools_path=temp_file) + + # Assertions + assert len(dataset.tools) == 2 + assert dataset.tools[0]["function"]["name"] == "tool1" + assert dataset.tools[1]["function"]["name"] == "tool2" + + print("āœ… Tools from JSONL file passed!") + + finally: + os.unlink(temp_file) + + +def test_chat_template_from_string(): + """Test chat template from direct string""" + print("🧪 Testing Chat Template from String...") + + tokenizer = MockTokenizer() + test_template = "Custom template: {{ content }}" + + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, chat_template=test_template) + + # Assertions + assert dataset.chat_template == test_template + assert dataset.tokenizer.chat_template == test_template + + print("āœ… Chat template from string passed!") + + +def test_error_cases(): + """Test error handling""" + print("🧪 Testing Error Cases...") + + tokenizer = MockTokenizer() + + # Test invalid tools + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools="invalid") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Tools must be a list" in str(e) + + # Test invalid tool in list + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools=["invalid"]) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Tool 0 must be a dictionary" in str(e) + + # Test non-existent tools file + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools_path="/non/existent/file.jsonl") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError as e: + assert "Tools file not found" in str(e) + + # Test invalid sample format + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + dataset._test_samples = [{'invalid': 'data'}] + + try: + dataset[0] + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Sample must contain" in str(e) + + print("āœ… Error cases passed!") + + +def test_additional_fields(): + """Test additional fields like bonus, verified_answer""" + print("🧪 Testing Additional Fields...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data with additional fields + test_sample = { + 'prompt': np.array([10, 11, 12]), + 'response': np.array([20, 21]), + 'bonus': 2.5, + 'vstar_bonus': np.array([0.1, 0.2]), + 'verified_answer': 'This is verified' + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'bonus' in result + assert 'vstar_bonus' in result + assert 'verified_answer' in result + assert result['bonus'].item() == 2.5 + assert isinstance(result['vstar_bonus'], torch.Tensor) + assert result['verified_answer'] == 'This is verified' + + print("āœ… Additional fields passed!") + + +def run_all_tests(): + """Run all test cases""" + print("šŸš€ Running RLStreamingDataset Tests...\n") + + try: + test_case_0_prompt_only() + test_case_1_prompt_response() + test_case_2_input_mask() + test_case_3a_messages_turn_wise() + test_case_3b_messages_trajectory_wise() + test_tools_from_list() + test_tools_from_jsonl_file() + test_chat_template_from_string() + test_additional_fields() + test_error_cases() + + print("\nšŸŽ‰ All tests passed! RLStreamingDataset is working correctly.") + + except Exception as e: + print(f"\nāŒ Test failed: {e}") + raise + + +if __name__ == "__main__": + run_all_tests() diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml index a82304b3..c2970ed3 100644 --- a/yamls/offline_apo.yaml +++ b/yamls/offline_apo.yaml @@ -116,6 +116,7 @@ parameters: #dbfs:/Volumes/datasets/jchang/data/offline_apo/apo_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ shuffle: true max_seq_len: ${max_seq_len} + flatten_messages: true # Enable trajectory-wise flattened processing for messages tools_path: ../../../compose-rl/example_tools.jsonl shuffle_seed: 7338 download_retry: 4 From efe896568fef1f7b5179498fcf56dadbdfa86b73 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 15:42:22 -0400 Subject: [PATCH 179/195] test non flatten message --- compose_rl/data/rl_data.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 4eb7d302..3da32f2a 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -45,25 +45,26 @@ def dataset_collate_fn( list_of_num_turns = [] return_dict: dict[str, Any] = {} - # case 1: input_ids key is present + # case 1: input_ids key is present # for single turn and flattten messages case if 'input_ids' in data[0]: list_of_input_ids = [item['input_ids'] for item in data] list_of_prompt_len = [item['prompt_len'] for item in data] - # case 2: turn_data key is present + # case 2: turn_data key is present: for non-flatterned messages elif "turn_data" in data[0]: for data_point in data: list_of_input_ids.extend([turn['input_ids'] for turn in data_point['turn_data']]) list_of_prompt_len.extend([turn['prompt_len'] for turn in data_point['turn_data']]) list_of_num_turns.append(torch.tensor([len(data_point['turn_data'])], dtype=torch.int64)) + # case 3: prompt key is presented; no response key elif 'prompt' in data[0]: list_of_prompts = [item['prompt'] for item in data] list_of_prompt_len = [item['prompt_len'] for item in data] list_of_prompt_ids = [item['prompt_id'] for item in data] - - if len(list_of_input_ids) > 0: # dealing with input_ids if it not empty. batch, padd, and truncate based on max_seq_len + # dealing with input_ids if it not empty. batch, padd, and truncate based on max_seq_len + if len(list_of_input_ids) > 0: batch_input_ids = ref_collate_fn(list_of_input_ids)['input_ids'] attention_masks = torch.logical_not(torch.eq(batch_input_ids, tokenizer.pad_token_id)).to(torch.int64) # truncate if length of the batch exceeds max_seq_len @@ -99,11 +100,11 @@ def dataset_collate_fn( masks = [] for i in range(batch_input_ids.shape[0]): mask_i = data[i]['mask'] - if len(mask_i) < len(batch_input_ids[i]): # right padded + if len(mask_i) < len(batch_input_ids[i]): # input_id got right padded all_zeros = torch.zeros(len(batch_input_ids[i])) all_zeros[0:len(mask_i)] = mask_i mask_i = all_zeros - else: # truncated + else: # input_ids got truncated mask_i = mask_i[0:len(batch_input_ids[i])] masks.append(mask_i) masks = torch.stack(masks) @@ -117,12 +118,13 @@ def dataset_collate_fn( return_dict['prompt_id'] = torch.cat(list_of_prompt_ids) return_dict['prompt_len'] = torch.cat(list_of_prompt_len) - - if len(list_of_num_turns) > 0: # this is the case where we have turn level data + # this is the case where we have turn level data and messages are not flatterned + if len(list_of_num_turns) > 0: assert 'turn_data' in data[0], "turn_data must be present if num_turns is present" return_dict['num_turns'] = torch.cat(list_of_num_turns) + assert return_dict['input_ids'].shape[0] == torch.sum(return_dict['num_turns']), "input_ids and num_turns must have the same length" - + if 'reward' in data[0]: return_dict['reward'] = torch.cat([item['reward'] for item in data]) if 'bonus' in data[0]: @@ -358,10 +360,6 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), } - print("#### test: print assistant tokens ####") - print(self.tokenizer.decode(input_ids[mask.bool()])) - print("#### test: done printing assistant tokens ####") - else: raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") From 46a74ceda1a01254ca23ce7f540dc0e6e36f1495 Mon Sep 17 00:00:00 2001 From: wensun Date: Sat, 6 Sep 2025 15:52:39 -0400 Subject: [PATCH 180/195] test non flatten message --- compose_rl/data/rl_data.py | 3 +++ test_rl_data_comprehensive.py | 1 + 2 files changed, 4 insertions(+) diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 3da32f2a..43593f93 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -122,6 +122,9 @@ def dataset_collate_fn( if len(list_of_num_turns) > 0: assert 'turn_data' in data[0], "turn_data must be present if num_turns is present" return_dict['num_turns'] = torch.cat(list_of_num_turns) + print("#### test: print num_turns ####") + print(return_dict['num_turns']) + print("#### test: done printing num_turns ####") assert return_dict['input_ids'].shape[0] == torch.sum(return_dict['num_turns']), "input_ids and num_turns must have the same length" diff --git a/test_rl_data_comprehensive.py b/test_rl_data_comprehensive.py index 8a81d853..d55d6468 100644 --- a/test_rl_data_comprehensive.py +++ b/test_rl_data_comprehensive.py @@ -638,6 +638,7 @@ def run_all_tests(): test_tools_from_jsonl_file() test_chat_template_from_string() test_additional_fields() + test_variable_length_vstar_fields() test_error_cases() print("\nšŸŽ‰ All tests passed! RLStreamingDataset is working correctly.") From 0819de1cbf5d751ade8ad06b3b6da9b3a8e5c70d Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 8 Sep 2025 23:31:41 -0400 Subject: [PATCH 181/195] first version of value function integration using flatten_message --- compose_rl/algorithms/offline/callback.py | 40 +++++++ compose_rl/algorithms/offline/model.py | 3 + .../algorithms/offline/model_methods.py | 100 ++++++++++++++++-- yamls/offline_apo.yaml | 11 +- 4 files changed, 145 insertions(+), 9 deletions(-) diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 3a1bffe5..5215e903 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -31,6 +31,7 @@ def __init__( ): self.train_config = copy.deepcopy(train_config) self.reference_model = None + self.auxiliary_model = None # Add auxiliary model def after_load(self, state: State, logger: Logger) -> None: #model_config = self.train_config['model'] @@ -80,6 +81,37 @@ def after_load(self, state: State, logger: Logger) -> None: callbacks=load_checkpoint_callbacks, ) + # Load auxiliary model following the same pattern + if 'auxiliary_model' in self.train_config.get('variables', {}): + aux_model_config = self.train_config['variables']['auxiliary_model'] + aux_init_context = process_init_device( + aux_model_config, + self.train_config.get('fsdp_config'), + ) + aux_name = aux_model_config.pop('name') + print("################################################") + print("auxiliary model config:") + print(aux_model_config) + print("################################################") + self.auxiliary_model = build_composer_model( + name=aux_name, + cfg=aux_model_config, + tokenizer=state.model.tokenizer, # type: ignore + init_context=aux_init_context, + master_weights_dtype=aux_model_config.get('master_weights_dtype', None), + ) + + # Load auxiliary model with same checkpoint loading procedure + _ = Trainer( + model=self.auxiliary_model, + parallelism_config={'fsdp': state.fsdp_config}, + precision=state.precision, + load_weights_only=True, + load_strict_model_weights=False, + load_path=original_load_path, + callbacks=load_checkpoint_callbacks, + ) + def before_forward(self, state: State, logger: Logger) -> Optional[int]: # Before every batch we need to do a forwards pass over the reference model with get_precision_context(state.precision): @@ -88,7 +120,15 @@ def before_forward(self, state: State, logger: Logger) -> Optional[int]: reference_outputs = self.reference_model(state.batch) state.batch.update({ 'ref_logp': reference_outputs['policy_logp'], + 'ref_token_policy_logps': reference_outputs['token_policy_logps'], }) + + # Add auxiliary model forward pass if available + if self.auxiliary_model is not None: + auxiliary_outputs = self.auxiliary_model(state.batch) + state.batch.update({ + 'aux_first_num_bins_logits': auxiliary_outputs['first_num_bins_logits'], + }) class PairwiseReferencePolicyCallback(ReferencePolicyCallback): diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 1131d944..ca89c94c 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -93,6 +93,7 @@ def __init__( multistep: bool = False, average_log_prob: bool = False, temperature: float = 1.0, + num_bins: int = 1, # Add num_bins parameter **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) @@ -102,6 +103,7 @@ def __init__( self.multistep = multistep self.average_log_prob = average_log_prob self.temperature = temperature + self.num_bins = num_bins # Store num_bins super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward @@ -113,6 +115,7 @@ def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: batch=batch, average_log_prob=self.average_log_prob, temperature=self.temperature, + num_bins=self.num_bins, # Pass num_bins to offline_forward ) def eval_forward( diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 6f7d549c..f922c733 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -33,6 +33,7 @@ class RegressionOfflineEnum(Enum): APO = 'apo' QRPO = 'qrpo' + APO_CRITIC = 'apo_critic' class PairwiseOfflineEnum(Enum): @@ -50,6 +51,7 @@ def offline_forward( average_log_prob: bool = False, policy_model_config: Optional[PretrainedConfig] = None, temperature: float = 1.0, + num_bins: int = 1, ) -> dict[str, torch.Tensor]: """Forwards the model for dpo and get the chosen and rejected log probs. @@ -60,6 +62,7 @@ def offline_forward( Note: this batch has chosen and rejected concated along the sequence dimension. average_log_prob (bool): Whether should we average the log probabilities. policy_model_config: Policy model config. + num_bins: Number of bins for the histogram. used for modeling distributional value function; will return the first num_bins logits """ is_multimodal = 'pixel_values' in batch.keys() has_mask = 'mask' in batch.keys() @@ -82,7 +85,11 @@ def offline_forward( } inputs.update(multimodal_inputs) - output_logits = model(**inputs).logits + output_logits = model(**inputs).logits # (batch_size, seq_len, vocab_size) + token_policy_logps = get_log_probs_from_logits( + output_logits[:,:-1], + batch['input_ids'][:,1:] + ) # tokenize logps (batch_size, seq_len-1) here seq_len-1 because we shifted # Calculate token entropies from the logits token_entropies = get_token_entropies(logits=output_logits) token_entropies = token_entropies.detach() @@ -108,10 +115,6 @@ def offline_forward( action_mask=action_mask ) else: - token_policy_logps = get_log_probs_from_logits( - output_logits[:,:-1], - batch['input_ids'][:,1:] - ) # apply attention_mask and mask explicitly token_policy_logps *= batch['attention_mask'][:,1:] token_policy_logps *= batch['mask'][:,1:] @@ -127,7 +130,11 @@ def offline_forward( outputs: dict[str, torch.Tensor] = { 'policy_logp': logps, 'sequence_entropies': sequence_entropies, + 'token_policy_logps': token_policy_logps, } + if num_bins >= 1: + first_num_bins_logits = output_logits[:,:,:num_bins] # take the first num_bins logits (batch_size, seq_len, num_bins) + outputs['first_num_bins_logits'] = first_num_bins_logits if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( @@ -140,6 +147,32 @@ def offline_forward( return outputs +def _extract_segments(mask: torch.Tensor) -> list[tuple[int, int]]: + """Extract contiguous segments where mask == 1. + + Args: + mask: 1D tensor of 0s and 1s + + Returns: + List of (start, end) tuples for each contiguous segment of 1s + """ + # Convert to CPU and ensure integer type for reliable comparison + mask_cpu = mask.cpu().int() + + # Find transitions: 0->1 (start) and 1->0 (end) + # Pad with 0 to handle edge cases + padded_mask = torch.cat([torch.tensor([0]), mask_cpu, torch.tensor([0])]) + diff = torch.diff(padded_mask) + + # Find starts (0->1 transitions) and ends (1->0 transitions) + starts = torch.where(diff == 1)[0].tolist() # Convert to Python list + ends = torch.where(diff == -1)[0].tolist() # Convert to Python list + + # Adjust indices (remove padding offset) + segments = [(start, end - 1) for start, end in zip(starts, ends)] + + return segments + def offline_loss( outputs: CausalLMOutputWithPast, batch: Mapping, @@ -158,6 +191,9 @@ def offline_loss( 'ref_logp', torch.zeros_like(policy_logp), ) + + # Initialize vstar to avoid "possibly unbound" warning + vstar = None if loss_type == RegressionOfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 @@ -203,8 +239,56 @@ def offline_loss( else: raise NotImplementedError("Multistep for QRPO not implemented") - losses = (reward_q - beta2 * torch.log(beta2) - 1 - beta2 * (policy_logp - ref_logp)) ** 2 - + losses = (reward_q - beta2 * torch.log(torch.tensor(beta2)) - 1 - beta2 * (policy_logp - ref_logp)) ** 2 + + elif loss_type == RegressionOfflineEnum.APO_CRITIC: + # grab necessaryinformation for this actor-critic style APO loss: + first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model + assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' + num_bins = first_num_bins_logits.shape[2] + mask = batch.get('mask', None) + assert mask is not None, 'must have a mask when using APO_CRITIC' + attention_mask = batch.get('attention_mask', None) + assert attention_mask is not None, 'must have an attention mask' + token_policy_logps = outputs.get('token_policy_logps', None) # (batch_size, seq_len-1) + assert token_policy_logps is not None, 'must have a token policy logps' + ref_token_policy_logps = batch.get('ref_token_policy_logps', None) + assert ref_token_policy_logps is not None, 'must have a reference token policy logps' + assert ref_token_policy_logps.shape == token_policy_logps.shape, 'must have the same shape for token policy logps and reference token policy logps' + + bs = first_num_bins_logits.shape[0] + device = first_num_bins_logits.device + losses = torch.zeros(bs, device=device) + + # Create arange tensor on correct device + arange_tensor = torch.arange(num_bins, device=device, dtype=torch.float32) + + for i in range(bs): + combined_mask = mask[i][1:] * attention_mask[i][1:] # mask starts from the second token + # scan through combined_mask, and compute loss per at each turn + segments = _extract_segments(combined_mask) + segment_losses = [] + + for segment in segments: + seg_logp = torch.sum(token_policy_logps[i][segment[0]:segment[1]+1]) + seg_ref_logp = torch.sum(ref_token_policy_logps[i][segment[0]:segment[1]+1]) + logits_start = first_num_bins_logits[i, segment[0], :] + logits_end = first_num_bins_logits[i, segment[1], :] + + # Fix device issues and use pre-computed arange tensor + vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp((arange_tensor/num_bins)*1.0/beta1))) + vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp((arange_tensor/num_bins)*1.0/beta1))) + + segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 + segment_losses.append(segment_loss) + + # Accumulate losses across segments for this batch item + if segment_losses: + losses[i] = torch.stack(segment_losses).mean() # Average loss across segments + else: + losses[i] = torch.tensor(0.0, device=device) # No valid segments + + # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref # use clip to ensure the output from exp is valid @@ -228,7 +312,7 @@ def offline_loss( 'estimated_reward': estimated_reward, 'sequence_entropies': outputs['sequence_entropies'], # Track detached sequence entropies in the loss dict } - if loss_type == RegressionOfflineEnum.APO: + if loss_type == RegressionOfflineEnum.APO and vstar is not None: loss_dict['batch_advantage'] = torch.mean( batch['reward'] - vstar, ) diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml index c2970ed3..00ecf871 100644 --- a/yamls/offline_apo.yaml +++ b/yamls/offline_apo.yaml @@ -22,7 +22,7 @@ parameters: beta1: 1 beta2: 0.01 eta: 0.5 - loss_type: apo + loss_type: apo_critic #apo pretrained: true init_device: mixed use_auth_token: true @@ -146,6 +146,15 @@ parameters: use_auth_token: true use_flash_attention_2: true pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + + auxiliary_model: # Add your second model here + name: hf_offline_lm + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B # Different model or same + num_bins: 32 # Number of bins for distributional value function integrations: - integration_type: git_repo From 8cbbc409d7cf0003c4bdba02aa29abe2fba64183 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 09:10:38 -0400 Subject: [PATCH 182/195] testing value function integration --- .../algorithms/offline/model_methods.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f922c733..01244302 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -194,6 +194,7 @@ def offline_loss( # Initialize vstar to avoid "possibly unbound" warning vstar = None + advantages = None if loss_type == RegressionOfflineEnum.APO: # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 @@ -220,6 +221,7 @@ def offline_loss( bonuses = batch.get('bonus', torch.zeros_like(batch['reward'])) added_bonuses = bonuses * batch['reward'] # true added bonus = 1 if both bonus = 1 and reward = 1 + advantages = batch['reward'] + eta * added_bonuses - vstar if bce == False: losses = ( beta2 * (policy_logp - ref_logp) - @@ -243,15 +245,18 @@ def offline_loss( elif loss_type == RegressionOfflineEnum.APO_CRITIC: # grab necessaryinformation for this actor-critic style APO loss: + print('------using APO_CRITIC loss: grabbing necessary information from the batch------') first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' num_bins = first_num_bins_logits.shape[2] + mask = batch.get('mask', None) - assert mask is not None, 'must have a mask when using APO_CRITIC' + assert mask is not None, 'must have a mask when using APO_CRITIC -- we use musk to grab the value function information at the right token positions' attention_mask = batch.get('attention_mask', None) assert attention_mask is not None, 'must have an attention mask' + token_policy_logps = outputs.get('token_policy_logps', None) # (batch_size, seq_len-1) - assert token_policy_logps is not None, 'must have a token policy logps' + assert token_policy_logps is not None, 'must have a token policy logps -- we need to calculate sum of logps explicitly here based on mask' ref_token_policy_logps = batch.get('ref_token_policy_logps', None) assert ref_token_policy_logps is not None, 'must have a reference token policy logps' assert ref_token_policy_logps.shape == token_policy_logps.shape, 'must have the same shape for token policy logps and reference token policy logps' @@ -259,10 +264,10 @@ def offline_loss( bs = first_num_bins_logits.shape[0] device = first_num_bins_logits.device losses = torch.zeros(bs, device=device) + advantages = torch.zeros(bs, device=device) - # Create arange tensor on correct device - arange_tensor = torch.arange(num_bins, device=device, dtype=torch.float32) - + # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins + bin_values = torch.arange(num_bins, device=device, dtype=torch.float32)*1.0 / num_bins for i in range(bs): combined_mask = mask[i][1:] * attention_mask[i][1:] # mask starts from the second token # scan through combined_mask, and compute loss per at each turn @@ -275,17 +280,19 @@ def offline_loss( logits_start = first_num_bins_logits[i, segment[0], :] logits_end = first_num_bins_logits[i, segment[1], :] - # Fix device issues and use pre-computed arange tensor - vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp((arange_tensor/num_bins)*1.0/beta1))) - vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp((arange_tensor/num_bins)*1.0/beta1))) + # use pre-computed arange tensor + vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) + vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 segment_losses.append(segment_loss) + advantages[i] += (vstar_end - vstar_start).detach() # Accumulate losses across segments for this batch item if segment_losses: losses[i] = torch.stack(segment_losses).mean() # Average loss across segments else: + print('------no valid segments------') losses[i] = torch.tensor(0.0, device=device) # No valid segments @@ -300,21 +307,22 @@ def offline_loss( losses = losses.mean() - implicit_rewards = beta2 * (policy_logp - ref_logp).detach() + #implicit_rewards = beta2 * (policy_logp - ref_logp).detach() # Logging KL margins for comparing different methods - reverse_kl = (policy_logp - ref_logp).detach() forward_kl = (ref_logp - policy_logp).detach() loss_dict = { - 'implicit_rewards': implicit_rewards, - 'reverse_kl': reverse_kl, + #'implicit_rewards': implicit_rewards, 'forward_kl': forward_kl, 'estimated_reward': estimated_reward, 'sequence_entropies': outputs['sequence_entropies'], # Track detached sequence entropies in the loss dict } - if loss_type == RegressionOfflineEnum.APO and vstar is not None: + + #if loss_type == RegressionOfflineEnum.APO and vstar is not None: + if advantages is not None: loss_dict['batch_advantage'] = torch.mean( - batch['reward'] - vstar, + #batch['reward'] - vstar, + advantages, ) if 'lbl' in outputs: From 11efde5f0850a4c2f459343dde0933fd9cdcb61e Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 09:19:20 -0400 Subject: [PATCH 183/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 01244302..da9aaff2 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -278,7 +278,7 @@ def offline_loss( seg_logp = torch.sum(token_policy_logps[i][segment[0]:segment[1]+1]) seg_ref_logp = torch.sum(ref_token_policy_logps[i][segment[0]:segment[1]+1]) logits_start = first_num_bins_logits[i, segment[0], :] - logits_end = first_num_bins_logits[i, segment[1], :] + logits_end = first_num_bins_logits[i, segment[1]+1, :] # TODO: double check if segment[1] or segment[1]+1 # use pre-computed arange tensor vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) From 2bc485cd62325d3df1af766f6367383ab9fd3a4f Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 09:39:04 -0400 Subject: [PATCH 184/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index da9aaff2..c1c707bb 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -191,6 +191,8 @@ def offline_loss( 'ref_logp', torch.zeros_like(policy_logp), ) + + print('------using loss type------: ', loss_type) # Initialize vstar to avoid "possibly unbound" warning vstar = None From 003ec73193f6166bc521a0ebff28ea308178ea95 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 09:53:41 -0400 Subject: [PATCH 185/195] . --- compose_rl/algorithms/offline/model.py | 3 +++ compose_rl/algorithms/offline/model_methods.py | 5 ++++- yamls/offline_apo.yaml | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index ca89c94c..bed4c589 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -104,6 +104,9 @@ def __init__( self.average_log_prob = average_log_prob self.temperature = temperature self.num_bins = num_bins # Store num_bins + + # Debug print to verify loss_type + print(f"🚨 DEBUG: Model initialized with loss_type = {self.loss_type} (value: {self.loss_type.value})") super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index c1c707bb..f5f02153 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -199,6 +199,8 @@ def offline_loss( advantages = None if loss_type == RegressionOfflineEnum.APO: + print('šŸ“Š ------USING REGULAR APO LOSS (not apo_critic)------') + print(f'šŸ“Š loss_type = {loss_type}, loss_type.value = {loss_type.value}') # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 # APO is not a pair-wise loss function. # Similar to REBEL, we assume each response has a reward in the batch. @@ -247,7 +249,8 @@ def offline_loss( elif loss_type == RegressionOfflineEnum.APO_CRITIC: # grab necessaryinformation for this actor-critic style APO loss: - print('------using APO_CRITIC loss: grabbing necessary information from the batch------') + print('šŸŽÆ ------USING APO_CRITIC LOSS: grabbing necessary information from the batch------') + print(f'šŸŽÆ loss_type = {loss_type}, loss_type.value = {loss_type.value}') first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' num_bins = first_num_bins_logits.shape[2] diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml index 00ecf871..8ba7cde3 100644 --- a/yamls/offline_apo.yaml +++ b/yamls/offline_apo.yaml @@ -22,7 +22,7 @@ parameters: beta1: 1 beta2: 0.01 eta: 0.5 - loss_type: apo_critic #apo + loss_type: apo_critic pretrained: true init_device: mixed use_auth_token: true From 91914db5ed9be52b6da5190a54efaa924ca0a182 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 10:10:27 -0400 Subject: [PATCH 186/195] . --- compose_rl/algorithms/offline/model.py | 3 --- compose_rl/algorithms/offline/model_methods.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index bed4c589..ca89c94c 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -104,9 +104,6 @@ def __init__( self.average_log_prob = average_log_prob self.temperature = temperature self.num_bins = num_bins # Store num_bins - - # Debug print to verify loss_type - print(f"🚨 DEBUG: Model initialized with loss_type = {self.loss_type} (value: {self.loss_type.value})") super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f5f02153..9aa04486 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -191,16 +191,12 @@ def offline_loss( 'ref_logp', torch.zeros_like(policy_logp), ) - - print('------using loss type------: ', loss_type) # Initialize vstar to avoid "possibly unbound" warning vstar = None advantages = None if loss_type == RegressionOfflineEnum.APO: - print('šŸ“Š ------USING REGULAR APO LOSS (not apo_critic)------') - print(f'šŸ“Š loss_type = {loss_type}, loss_type.value = {loss_type.value}') # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 # APO is not a pair-wise loss function. # Similar to REBEL, we assume each response has a reward in the batch. @@ -250,7 +246,6 @@ def offline_loss( elif loss_type == RegressionOfflineEnum.APO_CRITIC: # grab necessaryinformation for this actor-critic style APO loss: print('šŸŽÆ ------USING APO_CRITIC LOSS: grabbing necessary information from the batch------') - print(f'šŸŽÆ loss_type = {loss_type}, loss_type.value = {loss_type.value}') first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' num_bins = first_num_bins_logits.shape[2] From f5e99c39bbff024117a9bb36bca3aef5aba24224 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 14:59:14 -0400 Subject: [PATCH 187/195] testing apo critic with reward and vstar --- compose_rl/algorithms/offline/model_methods.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 9aa04486..0aeeccd8 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -265,6 +265,12 @@ def offline_loss( device = first_num_bins_logits.device losses = torch.zeros(bs, device=device) advantages = torch.zeros(bs, device=device) + + # for debugging purpose, let's grab reward and vstar_rewards here + rewards = batch.get('reward', None) + assert rewards is not None, 'reward must be present in batch for APO_CRITIC' + vstar_rewards = batch.get('vstar_rewards', None) + assert vstar_rewards is not None, 'vstar_rewards must be present in batch for APO_CRITIC' # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins bin_values = torch.arange(num_bins, device=device, dtype=torch.float32)*1.0 / num_bins @@ -280,9 +286,15 @@ def offline_loss( logits_start = first_num_bins_logits[i, segment[0], :] logits_end = first_num_bins_logits[i, segment[1]+1, :] # TODO: double check if segment[1] or segment[1]+1 + # for debugging purpose, let's use reward and vstar_rewards here + vstar_end = rewards[i] + vstar_start = beta1*torch.log(torch.mean(torch.exp(vstar_rewards[i]/beta1))) + # use pre-computed arange tensor - vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) - vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) + #vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) + #vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) + #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) + #vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 segment_losses.append(segment_loss) From f176ec5e27b74720dcff54a6973b6807fe8486db Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 17:24:55 -0400 Subject: [PATCH 188/195] repupose the apo critic code for the standard apo code just for testing --- .../algorithms/offline/model_methods.py | 20 ++++++++++++++----- compose_rl/data/rl_data.py | 2 ++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 0aeeccd8..1c2bc261 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -272,6 +272,9 @@ def offline_loss( vstar_rewards = batch.get('vstar_rewards', None) assert vstar_rewards is not None, 'vstar_rewards must be present in batch for APO_CRITIC' + # for debugging purpose, let's compute vstar here + advantages = rewards - beta1*torch.log(torch.mean(torch.exp(vstar_rewards/beta1), dim = -1)) + # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins bin_values = torch.arange(num_bins, device=device, dtype=torch.float32)*1.0 / num_bins for i in range(bs): @@ -290,19 +293,26 @@ def offline_loss( vstar_end = rewards[i] vstar_start = beta1*torch.log(torch.mean(torch.exp(vstar_rewards[i]/beta1))) + # for debugging purpose, let's do traj-wise APO actually here: + segment_logp_diff = seg_logp - seg_ref_logp + segment_losses.append(segment_logp_diff) + # use pre-computed arange tensor #vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) #vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) - - segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 - segment_losses.append(segment_loss) - advantages[i] += (vstar_end - vstar_start).detach() + #segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 + #segment_losses.append(segment_loss) + #advantages[i] += (vstar_end - vstar_start).detach() + + # Accumulate losses across segments for this batch item if segment_losses: - losses[i] = torch.stack(segment_losses).mean() # Average loss across segments + #losses[i] = torch.stack(segment_losses).mean() # Average loss across segments + # debugging purpose: + losses[i] = ((beta2 *torch.stack(segment_losses).sum()) - advantages[i])**2 else: print('------no valid segments------') losses[i] = torch.tensor(0.0, device=device) # No valid segments diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py index 43593f93..7fb51fb1 100644 --- a/compose_rl/data/rl_data.py +++ b/compose_rl/data/rl_data.py @@ -323,6 +323,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns elif 'input' in sample and 'mask' in sample: + print('------case 2: input + mask------') + assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" assert isinstance(sample['mask'], np.ndarray), f"Mask must be a numpy array, but got {type(sample['mask'])}" From 485abf843dcbeaf84ae913d7550527371d3d64ec Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 17:35:41 -0400 Subject: [PATCH 189/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 1c2bc261..9f38a028 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -298,6 +298,7 @@ def offline_loss( segment_losses.append(segment_logp_diff) # use pre-computed arange tensor + # below is the implementation we wanted: #vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) @@ -312,6 +313,7 @@ def offline_loss( if segment_losses: #losses[i] = torch.stack(segment_losses).mean() # Average loss across segments # debugging purpose: + print(segment_losses) losses[i] = ((beta2 *torch.stack(segment_losses).sum()) - advantages[i])**2 else: print('------no valid segments------') From 983fc88d396a7d6cfb7e3f4c3adae39f0d3fd88c Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 17:36:25 -0400 Subject: [PATCH 190/195] . --- compose_rl/algorithms/offline/model_methods.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 9f38a028..b5871159 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -313,7 +313,8 @@ def offline_loss( if segment_losses: #losses[i] = torch.stack(segment_losses).mean() # Average loss across segments # debugging purpose: - print(segment_losses) + print(segment_losses[0]) + print(len(segment_losses)) losses[i] = ((beta2 *torch.stack(segment_losses).sum()) - advantages[i])**2 else: print('------no valid segments------') From bf1e617fc2ad9b39de7d90e89a1e405e1a44dc6e Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 19:58:49 -0400 Subject: [PATCH 191/195] . --- .../algorithms/offline/model_methods.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index b5871159..1d84d1c3 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -245,7 +245,6 @@ def offline_loss( elif loss_type == RegressionOfflineEnum.APO_CRITIC: # grab necessaryinformation for this actor-critic style APO loss: - print('šŸŽÆ ------USING APO_CRITIC LOSS: grabbing necessary information from the batch------') first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' num_bins = first_num_bins_logits.shape[2] @@ -265,17 +264,8 @@ def offline_loss( device = first_num_bins_logits.device losses = torch.zeros(bs, device=device) advantages = torch.zeros(bs, device=device) - - # for debugging purpose, let's grab reward and vstar_rewards here - rewards = batch.get('reward', None) - assert rewards is not None, 'reward must be present in batch for APO_CRITIC' - vstar_rewards = batch.get('vstar_rewards', None) - assert vstar_rewards is not None, 'vstar_rewards must be present in batch for APO_CRITIC' - # for debugging purpose, let's compute vstar here - advantages = rewards - beta1*torch.log(torch.mean(torch.exp(vstar_rewards/beta1), dim = -1)) - - # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins + # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins -- using left end points of bins bin_values = torch.arange(num_bins, device=device, dtype=torch.float32)*1.0 / num_bins for i in range(bs): combined_mask = mask[i][1:] * attention_mask[i][1:] # mask starts from the second token @@ -289,33 +279,20 @@ def offline_loss( logits_start = first_num_bins_logits[i, segment[0], :] logits_end = first_num_bins_logits[i, segment[1]+1, :] # TODO: double check if segment[1] or segment[1]+1 - # for debugging purpose, let's use reward and vstar_rewards here - vstar_end = rewards[i] - vstar_start = beta1*torch.log(torch.mean(torch.exp(vstar_rewards[i]/beta1))) - - # for debugging purpose, let's do traj-wise APO actually here: - segment_logp_diff = seg_logp - seg_ref_logp - segment_losses.append(segment_logp_diff) - # use pre-computed arange tensor # below is the implementation we wanted: - #vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) - #vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) + vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) + vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) #vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) - #segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 - #segment_losses.append(segment_loss) - #advantages[i] += (vstar_end - vstar_start).detach() - - + segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 + segment_losses.append(segment_loss) + advantages[i] += (vstar_end - vstar_start).detach() - # Accumulate losses across segments for this batch item + # Accumulate losses across segments for this batch item and average advantage if segment_losses: - #losses[i] = torch.stack(segment_losses).mean() # Average loss across segments - # debugging purpose: - print(segment_losses[0]) - print(len(segment_losses)) - losses[i] = ((beta2 *torch.stack(segment_losses).sum()) - advantages[i])**2 + losses[i] = torch.stack(segment_losses).mean() # Average loss across segments + advantages[i] = advantages[i]/len(segment_losses) # average advantage across segments else: print('------no valid segments------') losses[i] = torch.tensor(0.0, device=device) # No valid segments From 8172e57396d8ed499413d1fdd8d84881b050dde4 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 20:07:01 -0400 Subject: [PATCH 192/195] . --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 1d84d1c3..a57fec8c 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -154,7 +154,7 @@ def _extract_segments(mask: torch.Tensor) -> list[tuple[int, int]]: mask: 1D tensor of 0s and 1s Returns: - List of (start, end) tuples for each contiguous segment of 1s + List of (start, end) tuples for each contiguous segment of 1s; both start and end are inclusive """ # Convert to CPU and ensure integer type for reliable comparison mask_cpu = mask.cpu().int() From 6566e0ef76a9b3b4a55b3c9539c7aafd69978a75 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 9 Sep 2025 22:59:58 -0400 Subject: [PATCH 193/195] use reward at the last step --- compose_rl/algorithms/offline/model_methods.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index a57fec8c..19a034d6 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -273,7 +273,7 @@ def offline_loss( segments = _extract_segments(combined_mask) segment_losses = [] - for segment in segments: + for k, segment in enumerate(segments): seg_logp = torch.sum(token_policy_logps[i][segment[0]:segment[1]+1]) seg_ref_logp = torch.sum(ref_token_policy_logps[i][segment[0]:segment[1]+1]) logits_start = first_num_bins_logits[i, segment[0], :] @@ -282,7 +282,10 @@ def offline_loss( # use pre-computed arange tensor # below is the implementation we wanted: vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) - vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) + if k == len(segments) - 1: + vstar_end = batch['reward'][i] + else: + vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) #vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 From 03af246a1d631e12ed449529c14e1ef5227213f4 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 10 Sep 2025 22:09:58 -0400 Subject: [PATCH 194/195] use vstar at the beginning --- compose_rl/algorithms/offline/model_methods.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 19a034d6..85e2b7a5 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -281,8 +281,12 @@ def offline_loss( # use pre-computed arange tensor # below is the implementation we wanted: - vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) - if k == len(segments) - 1: + if k == 0: # for first segment, we can just use vstar_reward r(y) as V*(y). + vstar_start = beta1*torch.log(torch.mean(torch.exp(batch['vstar_rewards'][i]/beta1))) + else: + vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) + + if k == len(segments) - 1: # for last segment, we can just use reward r(y) as V*(y). vstar_end = batch['reward'][i] else: vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) From 03797863addbe17a6992838150891b1494475981 Mon Sep 17 00:00:00 2001 From: Owen Oertell <53378167+Owen-Oertell@users.noreply.github.com> Date: Thu, 11 Sep 2025 12:01:12 -0400 Subject: [PATCH 195/195] Add value learning into offline RL (#151) Adding value learning to offline apo. Support two more args in model. `distributional_value_learning` (to use top n logits) and `top_n_logits`. Also add a new loss type `value_learning`. --------- Co-authored-by: Owen Oertell --- compose_rl/algorithms/offline/model.py | 26 ++- .../algorithms/offline/model_methods.py | 32 ++++ .../metrics/offline_learning_metrics.py | 156 ++++++++++++++++++ 3 files changed, 208 insertions(+), 6 deletions(-) create mode 100644 compose_rl/metrics/offline_learning_metrics.py diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index ca89c94c..871e064d 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -12,6 +12,13 @@ from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.modeling_outputs import CausalLMOutputWithPast +from compose_rl.metrics.offline_learning_metrics import ( + TestEstimatedRewardLossMetric, + TestImplicitRewardsLossMetric, + TestKLDivergenceLossMetric, + TestSequenceEntropiesLossMetric, + TestTotalLossMetric +) from compose_rl.algorithms.offline.model_methods import ( RegressionOfflineEnum, @@ -94,6 +101,7 @@ def __init__( average_log_prob: bool = False, temperature: float = 1.0, num_bins: int = 1, # Add num_bins parameter + distributional_value_learning: bool = True, **kwargs: Any, ): self.loss_type = RegressionOfflineEnum(loss_type) @@ -104,9 +112,11 @@ def __init__( self.average_log_prob = average_log_prob self.temperature = temperature self.num_bins = num_bins # Store num_bins - + self.distributional_value_learning = distributional_value_learning + super().__init__(**kwargs) self.train_metrics = None # DPOLM does not support eval_forward + self.val_metrics = {metric.__class__.__name__ : metric for metric in [TestEstimatedRewardLossMetric(), TestImplicitRewardsLossMetric(), TestKLDivergenceLossMetric(), TestSequenceEntropiesLossMetric(), TestTotalLossMetric()]} def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: assert self.tokenizer is not None @@ -121,9 +131,13 @@ def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: def eval_forward( self, batch: MutableMapping, - outputs: CausalLMOutputWithPast, - ) -> None: - raise ValueError('Eval forward is not implemented for ComposerHFDPOLM.') + outputs: CausalLMOutputWithPast | None = None, + ) -> dict[str, torch.Tensor]: + with torch.no_grad(): + fwd = self.forward(batch) + loss = self.loss(fwd, batch) + + return loss def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]: @@ -133,11 +147,11 @@ def loss(self, outputs: CausalLMOutputWithPast, loss_type = self.loss_type, beta1 = self.beta1, beta2 = self.beta2, - eta = self.eta, + eta = self.eta, multistep = self.multistep, + distributional_value_learning = self.distributional_value_learning, ) - class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM): """MPT model wrapper for DPO model.""" diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 85e2b7a5..28628792 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -34,6 +34,7 @@ class RegressionOfflineEnum(Enum): APO = 'apo' QRPO = 'qrpo' APO_CRITIC = 'apo_critic' + VALUE_LEARNING = 'value_learning' class PairwiseOfflineEnum(Enum): @@ -182,6 +183,7 @@ def offline_loss( eta: float, multistep: bool = False, bce: bool = False, + distributional_value_learning: bool = True, ): # eta: r + eta * bonus (bonus can be used to model things like tool use) @@ -243,6 +245,36 @@ def offline_loss( losses = (reward_q - beta2 * torch.log(torch.tensor(beta2)) - 1 - beta2 * (policy_logp - ref_logp)) ** 2 + elif loss_type == RegressionOfflineEnum.VALUE_LEARNING: + + # loss for VALUE_LEARNING is just regressing the first logit in the batch to the value of batch['reward'] + # shape of policy_logits: (batch_size, gen_len, 0) + # shape of batch['reward']: (batch_size, ) + # and then the subtraction will broadcast. + + assert batch['reward'] is not None, "reward must be in the batch. called from offline_loss fn" + # option 1: single value learning. regress directly to the value. + if distributional_value_learning == False: + assert outputs['first_num_bins_logits'] is not None, "first_num_bins_logits must be in the batch. called from offline_loss fn" + losses = (outputs['first_num_bins_logits'][:,:,0] - batch['reward']) ** 2 + losses *= batch['attention_mask'] + + # option 2: distributional value learning. given n logits, we predict and the do softmax to get a distribution. + else: # (distributional_value_learning == True): + first_n_logits = outputs['first_num_bins_logits'] + top_n_logits = first_n_logits.shape[2] + bucketized_reward = torch.bucketize(batch['reward'], torch.linspace(0, 1, top_n_logits).to(batch['reward'].device)).to(batch['reward'].device) + + input = first_n_logits.reshape(-1, first_n_logits.size(-1)) + target = bucketized_reward.repeat_interleave(first_n_logits.size(1)) + + losses = F.cross_entropy(input, target, reduction='none') + + # reshape masks to match flattened losses + losses *= batch['attention_mask'].view(-1) + + # note in this case, you don't need to mask based on the next one, just the true tokens should get a value. + elif loss_type == RegressionOfflineEnum.APO_CRITIC: # grab necessaryinformation for this actor-critic style APO loss: first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model diff --git a/compose_rl/metrics/offline_learning_metrics.py b/compose_rl/metrics/offline_learning_metrics.py new file mode 100644 index 00000000..2eac04ef --- /dev/null +++ b/compose_rl/metrics/offline_learning_metrics.py @@ -0,0 +1,156 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from torchmetrics import Metric + + +class TestTotalLossMetric(Metric): + """Metric for tracking total training loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'total' in batch: + self.loss_sum += batch['total'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestImplicitRewardsLossMetric(Metric): + """Metric for tracking implicit rewards loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'implicit_rewards' in batch: + self.loss_sum += batch['implicit_rewards'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestKLDivergenceLossMetric(Metric): + """Metric for tracking KL divergence losses.""" + + full_state_update = False + + def __init__(self, kl_type: str = 'reverse', dist_sync_on_step: bool = False, **kwargs: Any): + """Initialize KL divergence loss metric. + + Args: + kl_type: Type of KL divergence to track ('reverse' or 'forward') + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.kl_type = kl_type + self.loss_key = f'{kl_type}_kl' + + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if self.loss_key in batch: + self.loss_sum += batch[self.loss_key].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestEstimatedRewardLossMetric(Metric): + """Metric for tracking estimated reward loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + # print("keys in batch: ", batch.keys()) + # print("updating in estimated reward loss metric") + if 'estimated_reward' in batch: + self.loss_sum += batch['estimated_reward'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestSequenceEntropiesLossMetric(Metric): + """Metric for tracking sequence entropies loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'sequence_entropies' in batch: + self.loss_sum += batch['sequence_entropies'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) \ No newline at end of file