diff --git a/compose_rl/algorithms/offline/__init__.py b/compose_rl/algorithms/offline/__init__.py index 08f8132b..685e5a31 100644 --- a/compose_rl/algorithms/offline/__init__.py +++ b/compose_rl/algorithms/offline/__init__.py @@ -1,14 +1,22 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -from compose_rl.algorithms.offline.callback import ReferencePolicyCallback +from compose_rl.algorithms.offline.callback import ( + PairwiseReferencePolicyCallback, + ReferencePolicyCallback, +) from compose_rl.algorithms.offline.model import ( + ComposerHFOfflinePolicyLM, ComposerHFPairwiseOfflinePolicyLM, + ComposerMPTOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) __all__ = [ + 'ComposerHFOfflinePolicyLM', + 'ComposerMPTOfflinePolicyLM', 'ComposerMPTPairwiseOfflinePolicyLM', 'ComposerHFPairwiseOfflinePolicyLM', + 'PairwiseReferencePolicyCallback', 'ReferencePolicyCallback', ] diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 266c9951..5215e903 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -31,14 +31,20 @@ def __init__( ): self.train_config = copy.deepcopy(train_config) self.reference_model = None + self.auxiliary_model = None # Add auxiliary model def after_load(self, state: State, logger: Logger) -> None: - model_config = self.train_config['model'] + #model_config = self.train_config['model'] + model_config = self.train_config['variables']['reference_model'] init_context = process_init_device( model_config, self.train_config.get('fsdp_config'), ) name = model_config.pop('name') + print("################################################") + print("reference model config:") + print(model_config) + print("################################################") self.reference_model = build_composer_model( name=name, cfg=model_config, @@ -75,6 +81,64 @@ def after_load(self, state: State, logger: Logger) -> None: callbacks=load_checkpoint_callbacks, ) + # Load auxiliary model following the same pattern + if 'auxiliary_model' in self.train_config.get('variables', {}): + aux_model_config = self.train_config['variables']['auxiliary_model'] + aux_init_context = process_init_device( + aux_model_config, + self.train_config.get('fsdp_config'), + ) + aux_name = aux_model_config.pop('name') + print("################################################") + print("auxiliary model config:") + print(aux_model_config) + print("################################################") + self.auxiliary_model = build_composer_model( + name=aux_name, + cfg=aux_model_config, + tokenizer=state.model.tokenizer, # type: ignore + init_context=aux_init_context, + master_weights_dtype=aux_model_config.get('master_weights_dtype', None), + ) + + # Load auxiliary model with same checkpoint loading procedure + _ = Trainer( + model=self.auxiliary_model, + parallelism_config={'fsdp': state.fsdp_config}, + precision=state.precision, + load_weights_only=True, + load_strict_model_weights=False, + load_path=original_load_path, + callbacks=load_checkpoint_callbacks, + ) + + def before_forward(self, state: State, logger: Logger) -> Optional[int]: + # Before every batch we need to do a forwards pass over the reference model + with get_precision_context(state.precision): + with torch.no_grad(): + assert self.reference_model is not None + reference_outputs = self.reference_model(state.batch) + state.batch.update({ + 'ref_logp': reference_outputs['policy_logp'], + 'ref_token_policy_logps': reference_outputs['token_policy_logps'], + }) + + # Add auxiliary model forward pass if available + if self.auxiliary_model is not None: + auxiliary_outputs = self.auxiliary_model(state.batch) + state.batch.update({ + 'aux_first_num_bins_logits': auxiliary_outputs['first_num_bins_logits'], + }) + + +class PairwiseReferencePolicyCallback(ReferencePolicyCallback): + """Callback to run reference policy in pairwise offline RL. + + Args: + train_config (dict): Training config passed to callback via foundry train.py as + callback is registered under callbacks_with_config registry. + """ + def before_forward(self, state: State, logger: Logger) -> Optional[int]: # Before every batch we need to do a forwards pass over the reference model with get_precision_context(state.precision): diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 806842c1..871e064d 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -"""Pairwise Offline RL Composer Implementation.""" +"""Offline RL Composer Implementation.""" from __future__ import annotations @@ -12,9 +12,19 @@ from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.modeling_outputs import CausalLMOutputWithPast +from compose_rl.metrics.offline_learning_metrics import ( + TestEstimatedRewardLossMetric, + TestImplicitRewardsLossMetric, + TestKLDivergenceLossMetric, + TestSequenceEntropiesLossMetric, + TestTotalLossMetric +) from compose_rl.algorithms.offline.model_methods import ( + RegressionOfflineEnum, PairwiseOfflineEnum, + offline_forward, + offline_loss, pairwise_offline_forward, pairwise_offline_loss, ) @@ -24,6 +34,124 @@ log = logging.getLogger(__name__) +class ComposerMPTOfflinePolicyLM(ComposerMPTCausalLM): + """MPT model wrapper for offline rl model.""" + + def __init__( + self, + loss_type: str = 'apo', + beta1: float = 0.5, + beta2: float = 0.1, + eta: float = 0.5, + multistep: bool = False, + average_log_prob: bool = False, + temperature: float = 1.0, + **kwargs: Any, + ): + self.loss_type = RegressionOfflineEnum(loss_type) + self.beta1 = beta1 + self.beta2 = beta2 + self.eta = eta + self.multistep = multistep + self.average_log_prob = average_log_prob + self.temperature = temperature + + super().__init__(**kwargs) + self.train_metrics = None # DPOLM does not support eval_forward + + def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: + assert self.tokenizer is not None + return offline_forward( + model=self.model, + batch=batch, + average_log_prob=self.average_log_prob, + policy_model_config=self.config, + ) + + def eval_forward( + self, + batch: MutableMapping, + outputs: CausalLMOutputWithPast, + ) -> None: + raise ValueError('Eval forward is not implemented for ComposerDPOLM.') + + def loss(self, outputs: CausalLMOutputWithPast, + batch: Mapping) -> dict[str, torch.Tensor]: + return offline_loss( + outputs = outputs, + batch = batch, + loss_type = self.loss_type, + beta1 = self.beta1, + beta2 = self.beta2, + eta = self.eta, + multistep = self.multistep, + ) + + +class ComposerHFOfflinePolicyLM(ComposerHFCausalLM): + """HF class wrapper for offline rl model.""" + + def __init__( + self, + loss_type: str = 'apo', + beta1: float = 0.5, + beta2: float = 0.1, + eta: float = 0.5, + multistep: bool = False, + average_log_prob: bool = False, + temperature: float = 1.0, + num_bins: int = 1, # Add num_bins parameter + distributional_value_learning: bool = True, + **kwargs: Any, + ): + self.loss_type = RegressionOfflineEnum(loss_type) + self.beta1 = beta1 + self.beta2 = beta2 + self.eta = eta + self.multistep = multistep + self.average_log_prob = average_log_prob + self.temperature = temperature + self.num_bins = num_bins # Store num_bins + self.distributional_value_learning = distributional_value_learning + + super().__init__(**kwargs) + self.train_metrics = None # DPOLM does not support eval_forward + self.val_metrics = {metric.__class__.__name__ : metric for metric in [TestEstimatedRewardLossMetric(), TestImplicitRewardsLossMetric(), TestKLDivergenceLossMetric(), TestSequenceEntropiesLossMetric(), TestTotalLossMetric()]} + + def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]: + assert self.tokenizer is not None + return offline_forward( + model=self.model, + batch=batch, + average_log_prob=self.average_log_prob, + temperature=self.temperature, + num_bins=self.num_bins, # Pass num_bins to offline_forward + ) + + def eval_forward( + self, + batch: MutableMapping, + outputs: CausalLMOutputWithPast | None = None, + ) -> dict[str, torch.Tensor]: + with torch.no_grad(): + fwd = self.forward(batch) + loss = self.loss(fwd, batch) + + return loss + + def loss(self, outputs: CausalLMOutputWithPast, + batch: Mapping) -> dict[str, torch.Tensor]: + return offline_loss( + outputs = outputs, + batch = batch, + loss_type = self.loss_type, + beta1 = self.beta1, + beta2 = self.beta2, + eta = self.eta, + multistep = self.multistep, + distributional_value_learning = self.distributional_value_learning, + ) + class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM): """MPT model wrapper for DPO model.""" diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4da46019..28628792 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -23,9 +23,20 @@ extract_packed_chosen_rejected, get_batch_logp, get_mb_load_balancing_loss, + get_log_probs_from_logits, + make_action_mask, + get_token_entropies, + get_sequence_entropies, ) +class RegressionOfflineEnum(Enum): + APO = 'apo' + QRPO = 'qrpo' + APO_CRITIC = 'apo_critic' + VALUE_LEARNING = 'value_learning' + + class PairwiseOfflineEnum(Enum): DPO = 'dpo' RPO = 'rpo' @@ -35,6 +46,335 @@ class PairwiseOfflineEnum(Enum): KTO = 'kto' +def offline_forward( + model: nn.Module, + batch: MutableMapping, + average_log_prob: bool = False, + policy_model_config: Optional[PretrainedConfig] = None, + temperature: float = 1.0, + num_bins: int = 1, +) -> dict[str, torch.Tensor]: + """Forwards the model for dpo and get the chosen and rejected log probs. + + Args: + model (nn.Module): Model we are forwarding. + tokenizer (Tokenizer): Tokenizer for the model. + batch (Dict[str, torch.LongTensor]): Batch over which we should forward the model. + Note: this batch has chosen and rejected concated along the sequence dimension. + average_log_prob (bool): Whether should we average the log probabilities. + policy_model_config: Policy model config. + num_bins: Number of bins for the histogram. used for modeling distributional value function; will return the first num_bins logits + """ + is_multimodal = 'pixel_values' in batch.keys() + has_mask = 'mask' in batch.keys() + + if policy_model_config is not None and hasattr(model, 'transformer'): + clear_mb_load_balancing_loss( + policy_model_config, + model.transformer, # type: ignore + ) + + inputs = { + 'input_ids': batch['input_ids'], + 'attention_mask': batch['attention_mask'], + } + + if is_multimodal: + multimodal_inputs = { + 'token_type_ids': batch['token_type_ids'], + 'pixel_values': batch['pixel_values'], + } + inputs.update(multimodal_inputs) + + output_logits = model(**inputs).logits # (batch_size, seq_len, vocab_size) + token_policy_logps = get_log_probs_from_logits( + output_logits[:,:-1], + batch['input_ids'][:,1:] + ) # tokenize logps (batch_size, seq_len-1) here seq_len-1 because we shifted + # Calculate token entropies from the logits + token_entropies = get_token_entropies(logits=output_logits) + token_entropies = token_entropies.detach() + + if has_mask is False: + logps = get_batch_logp( + batch['input_ids'], + output_logits, + batch['prompt_len'], + batch['sequence_len'], + average_log_prob, + temperature=temperature, + ) + # Calculate sequence entropies + action_mask = make_action_mask( + batch['prompt_len'], + batch['sequence_len'], + batch['attention_mask'].shape, + device=output_logits.device, + ) + sequence_entropies = get_sequence_entropies( + token_entropies=token_entropies, + action_mask=action_mask + ) + else: + # apply attention_mask and mask explicitly + token_policy_logps *= batch['attention_mask'][:,1:] + token_policy_logps *= batch['mask'][:,1:] + logps = torch.sum(token_policy_logps, dim = -1) # (bs, ) + # Calculate sequence entropies + # TODO: confirm with JC and Adyasha if this is correct + combined_mask = batch['attention_mask'] * batch['mask'] + sequence_entropies = get_sequence_entropies( + token_entropies=token_entropies, + action_mask=combined_mask, + ) + + outputs: dict[str, torch.Tensor] = { + 'policy_logp': logps, + 'sequence_entropies': sequence_entropies, + 'token_policy_logps': token_policy_logps, + } + if num_bins >= 1: + first_num_bins_logits = output_logits[:,:,:num_bins] # take the first num_bins logits (batch_size, seq_len, num_bins) + outputs['first_num_bins_logits'] = first_num_bins_logits + + if policy_model_config is not None and hasattr(model, 'transformer'): + lbl = get_mb_load_balancing_loss( + policy_model_config, + model.transformer, # type: ignore + ) + if lbl is not None: + outputs['lbl'] = lbl + + return outputs + + +def _extract_segments(mask: torch.Tensor) -> list[tuple[int, int]]: + """Extract contiguous segments where mask == 1. + + Args: + mask: 1D tensor of 0s and 1s + + Returns: + List of (start, end) tuples for each contiguous segment of 1s; both start and end are inclusive + """ + # Convert to CPU and ensure integer type for reliable comparison + mask_cpu = mask.cpu().int() + + # Find transitions: 0->1 (start) and 1->0 (end) + # Pad with 0 to handle edge cases + padded_mask = torch.cat([torch.tensor([0]), mask_cpu, torch.tensor([0])]) + diff = torch.diff(padded_mask) + + # Find starts (0->1 transitions) and ends (1->0 transitions) + starts = torch.where(diff == 1)[0].tolist() # Convert to Python list + ends = torch.where(diff == -1)[0].tolist() # Convert to Python list + + # Adjust indices (remove padding offset) + segments = [(start, end - 1) for start, end in zip(starts, ends)] + + return segments + +def offline_loss( + outputs: CausalLMOutputWithPast, + batch: Mapping, + loss_type: RegressionOfflineEnum, + beta1: float, + beta2: float, + eta: float, + multistep: bool = False, + bce: bool = False, + distributional_value_learning: bool = True, +): + # eta: r + eta * bonus (bonus can be used to model things like tool use) + + policy_logp = outputs['policy_logp'] # (batch_size, ) + + ref_logp = batch.get( + 'ref_logp', + torch.zeros_like(policy_logp), + ) + + # Initialize vstar to avoid "possibly unbound" warning + vstar = None + advantages = None + + if loss_type == RegressionOfflineEnum.APO: + # Reproducing the APO loss from APO paper: https://arxiv.org/pdf/2505.20686 on page 3 + # APO is not a pair-wise loss function. + # Similar to REBEL, we assume each response has a reward in the batch. + # We assume that the dataset contains vstar values, i.e., V^star(x) for each prompt x in the batch + # + vstar = batch.get('vstar', None) + if vstar is None: + vstar_rewards = batch.get('vstar_rewards', None) + assert vstar_rewards is not None + vstar_bonus = batch.get('vstar_bonus', torch.zeros_like(vstar_rewards)) + added_vstar_bonus = vstar_bonus * vstar_rewards # true added bonus is 1 iff both bonus = 1 and reward = 1 + if not multistep: + exponentiated_mean = torch.mean(torch.exp((vstar_rewards+eta*added_vstar_bonus) / beta1), dim=-1) + else: + exponentiated_mean = torch.mean( + vstar_rewards * torch.exp(batch['reward'] / beta1).view(-1, 1) + (1 - vstar_rewards), # TODO: something is wrong here. + dim=-1, + ) + vstar = beta1 * torch.log(exponentiated_mean) + + assert vstar.shape == batch['reward'].shape + + bonuses = batch.get('bonus', torch.zeros_like(batch['reward'])) + added_bonuses = bonuses * batch['reward'] # true added bonus = 1 if both bonus = 1 and reward = 1 + advantages = batch['reward'] + eta * added_bonuses - vstar + if bce == False: + losses = ( + beta2 * (policy_logp - ref_logp) - + (batch['reward'] + eta * added_bonuses - vstar) + )**2 + elif bce == True: + predicted_prob = F.sigmoid(beta2 * (policy_logp - ref_logp)) + actual_prob = F.sigmoid(batch['reward'] - vstar) + losses = -(actual_prob * torch.log(predicted_prob) + + (1.-actual_prob)*torch.log(1.-predicted_prob) + ) + elif loss_type == RegressionOfflineEnum.QRPO: + vstar_rewards = batch.get('vstar_rewards', None) + assert vstar_rewards is not None + if not multistep: + reward_q = torch.mean((batch['reward'].view(-1, 1) >= vstar_rewards).float(), dim=-1) + else: + raise NotImplementedError("Multistep for QRPO not implemented") + + losses = (reward_q - beta2 * torch.log(torch.tensor(beta2)) - 1 - beta2 * (policy_logp - ref_logp)) ** 2 + + elif loss_type == RegressionOfflineEnum.VALUE_LEARNING: + + # loss for VALUE_LEARNING is just regressing the first logit in the batch to the value of batch['reward'] + # shape of policy_logits: (batch_size, gen_len, 0) + # shape of batch['reward']: (batch_size, ) + # and then the subtraction will broadcast. + + assert batch['reward'] is not None, "reward must be in the batch. called from offline_loss fn" + # option 1: single value learning. regress directly to the value. + if distributional_value_learning == False: + assert outputs['first_num_bins_logits'] is not None, "first_num_bins_logits must be in the batch. called from offline_loss fn" + losses = (outputs['first_num_bins_logits'][:,:,0] - batch['reward']) ** 2 + losses *= batch['attention_mask'] + + # option 2: distributional value learning. given n logits, we predict and the do softmax to get a distribution. + else: # (distributional_value_learning == True): + first_n_logits = outputs['first_num_bins_logits'] + top_n_logits = first_n_logits.shape[2] + bucketized_reward = torch.bucketize(batch['reward'], torch.linspace(0, 1, top_n_logits).to(batch['reward'].device)).to(batch['reward'].device) + + input = first_n_logits.reshape(-1, first_n_logits.size(-1)) + target = bucketized_reward.repeat_interleave(first_n_logits.size(1)) + + losses = F.cross_entropy(input, target, reduction='none') + + # reshape masks to match flattened losses + losses *= batch['attention_mask'].view(-1) + + # note in this case, you don't need to mask based on the next one, just the true tokens should get a value. + + elif loss_type == RegressionOfflineEnum.APO_CRITIC: + # grab necessaryinformation for this actor-critic style APO loss: + first_num_bins_logits = batch.get('aux_first_num_bins_logits', None) # from the auxiliary distributional value function model + assert first_num_bins_logits is not None, 'must have a value model that returns the first num_bins logits' + num_bins = first_num_bins_logits.shape[2] + + mask = batch.get('mask', None) + assert mask is not None, 'must have a mask when using APO_CRITIC -- we use musk to grab the value function information at the right token positions' + attention_mask = batch.get('attention_mask', None) + assert attention_mask is not None, 'must have an attention mask' + + token_policy_logps = outputs.get('token_policy_logps', None) # (batch_size, seq_len-1) + assert token_policy_logps is not None, 'must have a token policy logps -- we need to calculate sum of logps explicitly here based on mask' + ref_token_policy_logps = batch.get('ref_token_policy_logps', None) + assert ref_token_policy_logps is not None, 'must have a reference token policy logps' + assert ref_token_policy_logps.shape == token_policy_logps.shape, 'must have the same shape for token policy logps and reference token policy logps' + + bs = first_num_bins_logits.shape[0] + device = first_num_bins_logits.device + losses = torch.zeros(bs, device=device) + advantages = torch.zeros(bs, device=device) + + # define value bin values: 0, 1/num_bins, 2/num_bins, ..., (num_bins-1)/num_bins -- using left end points of bins + bin_values = torch.arange(num_bins, device=device, dtype=torch.float32)*1.0 / num_bins + for i in range(bs): + combined_mask = mask[i][1:] * attention_mask[i][1:] # mask starts from the second token + # scan through combined_mask, and compute loss per at each turn + segments = _extract_segments(combined_mask) + segment_losses = [] + + for k, segment in enumerate(segments): + seg_logp = torch.sum(token_policy_logps[i][segment[0]:segment[1]+1]) + seg_ref_logp = torch.sum(ref_token_policy_logps[i][segment[0]:segment[1]+1]) + logits_start = first_num_bins_logits[i, segment[0], :] + logits_end = first_num_bins_logits[i, segment[1]+1, :] # TODO: double check if segment[1] or segment[1]+1 + + # use pre-computed arange tensor + # below is the implementation we wanted: + if k == 0: # for first segment, we can just use vstar_reward r(y) as V*(y). + vstar_start = beta1*torch.log(torch.mean(torch.exp(batch['vstar_rewards'][i]/beta1))) + else: + vstar_start = beta1*torch.log(torch.softmax(logits_start,dim=0).dot(torch.exp(bin_values/beta1))) + + if k == len(segments) - 1: # for last segment, we can just use reward r(y) as V*(y). + vstar_end = batch['reward'][i] + else: + vstar_end = beta1*torch.log(torch.softmax(logits_end,dim=0).dot(torch.exp(bin_values/beta1))) + #vstar_start = beta1*torch.log(torch.sum(torch.softmax(logits_start,dim=0)*torch.exp(bin_values/beta1))) + #vstar_end = beta1*torch.log(torch.sum(torch.softmax(logits_end,dim=0)*torch.exp(bin_values/beta1))) + segment_loss = (beta2 * (seg_logp - seg_ref_logp) - (vstar_end - vstar_start))**2 + segment_losses.append(segment_loss) + advantages[i] += (vstar_end - vstar_start).detach() + + # Accumulate losses across segments for this batch item and average advantage + if segment_losses: + losses[i] = torch.stack(segment_losses).mean() # Average loss across segments + advantages[i] = advantages[i]/len(segment_losses) # average advantage across segments + else: + print('------no valid segments------') + losses[i] = torch.tensor(0.0, device=device) # No valid segments + + + # Estimate policy's reward via offine method, i.e., importance weighting here (can be high variance) + # formula: sum_y exp( log pi(y) - log pi_ref(y) ) r(y) where y ~ pi_ref + # use clip to ensure the output from exp is valid + with torch.no_grad(): + estimated_rewards = torch.exp( + torch.clip(policy_logp - ref_logp, max=5.), + ) * batch['reward'] + estimated_reward = torch.mean(estimated_rewards) + + losses = losses.mean() + + #implicit_rewards = beta2 * (policy_logp - ref_logp).detach() + + # Logging KL margins for comparing different methods + forward_kl = (ref_logp - policy_logp).detach() + loss_dict = { + #'implicit_rewards': implicit_rewards, + 'forward_kl': forward_kl, + 'estimated_reward': estimated_reward, + 'sequence_entropies': outputs['sequence_entropies'], # Track detached sequence entropies in the loss dict + } + + #if loss_type == RegressionOfflineEnum.APO and vstar is not None: + if advantages is not None: + loss_dict['batch_advantage'] = torch.mean( + #batch['reward'] - vstar, + advantages, + ) + + if 'lbl' in outputs: + losses += outputs['lbl'] + loss_dict['lbl'] = outputs['lbl'] + + loss_dict['total'] = losses + + return loss_dict + + def pairwise_offline_forward( model: nn.Module, tokenizer: Tokenizer, @@ -67,6 +407,12 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') + is_multimodal = 'pixel_values' in batch.keys() + if is_multimodal and use_attention_sequence_id: + raise NotImplementedError( + 'Using Sequence ID is not implemented for VLMs', + ) + # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` if use_attention_sequence_id: @@ -104,18 +450,42 @@ def pairwise_offline_forward( pad_token_id=0, ) - batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0) - batch_attn_mask = torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ) + inputs = { + 'input_ids': + torch.cat([chosen_inputs, rejected_inputs], dim=0), + 'attention_mask': + torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), + } + + if is_multimodal: + chosen_token_type_ids, rejected_token_type_ids = extract_packed_chosen_rejected( + batch['token_type_ids'], + batch['chosen_len'], + batch['rejected_len'], + concat_seq_len, + pad_token_id=0, + ) + + # TODO: Ask if assuming same pixel inputs is ok? + multimodal_inputs = { + 'token_type_ids': + torch.cat([chosen_token_type_ids, rejected_token_type_ids], + dim=0), + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), + } + + inputs.update(multimodal_inputs) output_logits = model( - batch_cat_inputs, - attention_mask=batch_attn_mask, + **inputs, ).logits # Extract out the chosen and rejected logits along the batch dimension @@ -158,6 +528,9 @@ def pairwise_offline_forward( outputs['chosen_reward'] = batch['chosen_reward'] outputs['rejected_reward'] = batch['rejected_reward'] + if 'vstar' in batch: + outputs['vstar'] = batch['vstar'] + if policy_model_config is not None and hasattr(model, 'transformer'): lbl = get_mb_load_balancing_loss( policy_model_config, @@ -175,7 +548,7 @@ def pairwise_offline_loss( loss_type: PairwiseOfflineEnum, beta: float, label_smoothing: float, - sft_alpha: float, + sft_alpha: float = 0.0, ) -> dict[str, torch.Tensor]: """Computes pairwise offline RL losses. @@ -192,8 +565,8 @@ def pairwise_offline_loss( sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to be added to DPO type loss. """ - policy_chosen_logp = outputs['policy_chosen_logp'] - policy_rejected_logp = outputs['policy_rejected_logp'] + policy_chosen_logp = outputs['policy_chosen_logp'] # (batch_size, ) + policy_rejected_logp = outputs['policy_rejected_logp'] # (batch_size, ) ref_chosen_logp = batch.get( 'ref_chosen', torch.zeros_like(policy_chosen_logp), @@ -209,6 +582,7 @@ def pairwise_offline_loss( logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l} losses = torch.zeros_like(logits) + if loss_type == PairwiseOfflineEnum.DPO: losses = ( -F.logsigmoid(beta * logits) * (1 - label_smoothing) - @@ -271,8 +645,6 @@ def pairwise_offline_loss( ), 0, ) - else: - raise ValueError(f'Loss type: {loss_type} is not supported.') if sft_alpha > 0: sft_losses = -1 * sft_alpha * policy_chosen_logp @@ -306,6 +678,7 @@ def pairwise_offline_loss( ]: # reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL loss_dict['reward_diff'] = reward_diff.detach() # type: ignore + if sft_alpha > 0: # sft_losses_normalized is always defined if sft_alpha>0 snl = sft_losses_normalized.detach() # type: ignore diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index acd00f4e..c583e5f0 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -124,6 +124,10 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] + if 'pixel_values' in batch.keys(): + model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] + model_forward_kwargs['pixel_values'] = batch['pixel_values'] + actor_output = model(batch['obs'], **model_forward_kwargs) logits = actor_output.logits @@ -160,7 +164,9 @@ def critic_loss( ) -> MutableMapping: if loss_type == OnPolicyEnum.PPO: advantages = batch['advantages'] - v_preds = outputs['values'][:, :-1] * batch['action_mask'] + v_preds = outputs['values'][:, :-1] * batch[ + 'action_mask' + ] #TODO: why shift -1? eos token? why no shift in log_probs? v_preds = v_preds.to(advantages.dtype) values = batch['values'][:, :-1] * batch['action_mask'] diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 5031dd10..87ec50a4 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -8,25 +8,43 @@ from compose_rl.data.dataloader import ( build_finegrained_preference_dataloader, build_messages_dataloader, + build_offline_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, + build_rl_dataloader, ) from compose_rl.data.messages_data import messages_dataset_collate_fn +from compose_rl.data.offline_data import ( + OfflineStreamingDataset, + offline_dataset_collate_fn, + offline_dataset_collate_fn_test, +) from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, ) from compose_rl.data.prompt_data import prompt_dataset_collate_fn +from compose_rl.data.rl_data import ( + RLStreamingDataset, + dataset_collate_fn, +) __all__ = [ 'build_pairwise_preference_dataloader', 'build_finegrained_preference_dataloader', 'build_messages_dataloader', + 'build_offline_dataloader', 'build_prompt_dataloader', + 'build_rl_dataloader', 'DummyDataset', 'finegrained_preference_dataset_collate_fn', 'MinibatchRolloutBuffer', + 'offline_dataset_collate_fn', + 'offline_dataset_collate_fn_test', + 'OfflineStreamingDataset', 'pairwise_preference_dataset_collate_fn', 'prompt_dataset_collate_fn', 'messages_dataset_collate_fn', + 'RLStreamingDataset', + 'dataset_collate_fn', ] diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 3085c26a..caa39c78 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -14,6 +14,11 @@ MessagesStreamingDataset, messages_dataset_collate_fn, ) +from compose_rl.data.offline_data import ( + OfflineStreamingDataset, + offline_dataset_collate_fn, + offline_dataset_collate_fn_test, +) from compose_rl.data.preference_data import ( FinegrainedPreferenceStreamingDataset, PairwisePreferenceStreamingDataset, @@ -24,12 +29,18 @@ PromptStreamingDataset, prompt_dataset_collate_fn, ) +from compose_rl.data.rl_data import ( + RLStreamingDataset, + dataset_collate_fn, +) __all__ = [ 'build_finegrained_preference_dataloader', 'build_pairwise_preference_dataloader', 'build_prompt_dataloader', 'build_messages_dataloader', + 'build_offline_dataloader', + 'build_rl_dataloader', ] @@ -81,6 +92,11 @@ def build_preference_dataloader( MessagesStreamingDataset, ) and 'tokenizer' not in dataset_cfg: dataset_cfg['tokenizer'] = tokenizer + if issubclass( + dataset_cls, + RLStreamingDataset, + ) and 'tokenizer' not in dataset_cfg: + dataset_cfg['tokenizer'] = tokenizer streaming_dataset = dataset_cls( streams=streams, # type: ignore @@ -126,3 +142,13 @@ def build_preference_dataloader( MessagesStreamingDataset, messages_dataset_collate_fn, ) + +build_offline_dataloader = generate_dataloader_builder( + OfflineStreamingDataset, + offline_dataset_collate_fn_test, +) + +build_rl_dataloader = generate_dataloader_builder( + RLStreamingDataset, + dataset_collate_fn, +) diff --git a/compose_rl/data/offline_data.py b/compose_rl/data/offline_data.py new file mode 100644 index 00000000..278adc48 --- /dev/null +++ b/compose_rl/data/offline_data.py @@ -0,0 +1,480 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build a reward dataset and dataloader for training.""" + +import logging +from typing import Any, Optional + +import numpy as np +import torch +from PIL import Image +from streaming import StreamingDataset +from torchvision import transforms +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, AutoProcessor + +import base64 +from io import BytesIO + +def base64_to_pil(base64_string: str): + """ + Converts a base64 string to a PIL Image object. + + Args: + base64_string: The base64 encoded string of the image. + + Returns: + A PIL Image object, or None if an error occurs. + """ + try: + image_data = base64.b64decode(base64_string) + image_stream = BytesIO(image_data) + image = Image.open(image_stream) + image = image.convert("RGB") + return image + except Exception as e: + print(f"Error decoding base64 string: {e}") + return None + +log = logging.getLogger(__name__) + + +def offline_dataset_collate_fn( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, torch.Tensor]], +) -> dict[str, Any]: + """Collator for offline data. + + Args: + tokenizer (Tokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, torch.Tensor]]): The preference data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + batch_input_ids = [] + attention_masks = [] + sequence_id = [] + sequence_lens = [] + prompt_lens = [] + rewards = [] + vstars = [] + vstar_rewards = [] + + # for multi-turn + masks = [] + + # For VLMs + batch_token_type_ids = [] + pixel_values = [] + + for sample in data: + input_ids = sample['input_ids'] + prompt_len = sample['prompt_len'] + sequence_len = sample['sequence_len'] + + # for multi-turn tool-use, which contains masks that mask out non-assistant turns + # it contains 1 at the token positions from the assistant turns, and 0 otherwise + has_mask = 'mask' in sample.keys() + if has_mask: + mask = sample['mask'] # torch tensor array + assert mask.shape == input_ids.shape + + is_multimodal = 'pixel_values' in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + token_type_ids = sample['token_type_ids'] + else: + pixel_vals = None + token_type_ids = None + + # Note: if we do any truncation, we force the last token to be EOS + # https://github.com/mosaicml/RLHF/issues/101 + + # Add the eos token if it's not in the chosen sample + if input_ids[-1] != tokenizer.eos_token_id: + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + pad_len = max_seq_len - sequence_len + + if pad_len < 0: + # We should truncate with an additional token left for eos + truncate_len = abs(pad_len) + 1 + + log.warning(( + f'Sequence length: {sequence_len}' + f' are too long for max_seq_len: {max_seq_len}' + f' truncating by {truncate_len[0]} tokens.' + )) + + # Truncate each value by truncate length, and make the last token EOS + input_ids = input_ids[:-truncate_len] + input_ids[-1] = tokenizer.eos_token_id # type: ignore + + if is_multimodal: + token_type_ids = token_type_ids[:-truncate_len] + # NOTE: GEMMA specific: 0 == text token + token_type_ids[-1] = 0 + + sequence_len = torch.tensor([len(input_ids)]) # TODO: check this line of code, len(sequence_len) should be one?? + pad_len = max_seq_len - sequence_len # TODO: seems that in this case, pad_len = 1? so it enters the next if statement? + + if pad_len > 0: + # right padding with padding token + input_ids = torch.cat( + [ + input_ids, + torch.ones(int(pad_len.item()), dtype=input_ids.dtype) * + tokenizer.pad_token_id, # type: ignore + ], + dim=-1, # type: ignore + ) + if is_multimodal: + token_type_ids = torch.cat( + [ + token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) + + attention_mask = torch.logical_not( + torch.eq(input_ids, tokenizer.pad_token_id), # type: ignore + ) + if has_mask: # combine the two masks together so that we can forget about mask and only use attention_mask inside algorithm + if len(mask) < len(attention_mask): # this happens when we padded input_ids, so we should pad mask + mask = torch.cat([mask, torch.zeros(len(attention_mask)-len(mask), dtype=mask.dtype)],dim =-1) + else: # this happens when we truncate input_id, so we truncate mask + mask = mask[0: len(attention_mask)] + assert mask.shape == attention_mask.shape and mask.shape == input_ids.shape + + + cur_sequence_id = torch.tensor(([0] * sequence_len) + + ([-1] * max(0, int(pad_len.item()))),) + + sequence_id.append(cur_sequence_id) + batch_input_ids.append(input_ids) + attention_masks.append(attention_mask) + sequence_lens.append(sequence_len) # TODO: this sequence_len is out of dated? + prompt_lens.append(prompt_len) + + if has_mask: + masks.append(mask) + if 'reward' in sample: + rewards.append(sample['reward']) + if 'vstar' in sample: + vstars.append(sample['vstar']) + if 'vstar_rewards' in sample: + vstar_rewards.append(sample['vstar_rewards']) + + if is_multimodal: + batch_token_type_ids.append(token_type_ids) # type: ignore + pixel_values.append(pixel_vals) + + + batch_input_ids = ref_collate_fn(batch_input_ids)['input_ids'] + attention_masks = torch.stack(attention_masks) + sequence_id = torch.stack(sequence_id) + + sequence_lens = torch.cat(sequence_lens) + prompt_lens = torch.cat(prompt_lens) + return_dict = { + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + 'sequence_id': sequence_id, + } + if len(masks) > 0: + masks = torch.stack(masks) + assert masks.shape == attention_masks.shape + return_dict['mask'] = masks + if len(rewards) > 0: + rewards = torch.cat(rewards) + return_dict['reward'] = rewards + if len(vstars) > 0: + vstars = torch.cat(vstars) + return_dict['vstar'] = vstars + if len(vstar_rewards) > 0: + vstar_rewards = torch.stack(vstar_rewards) + return_dict['vstar_rewards'] = vstar_rewards + + if is_multimodal: # type: ignore + token_type_ids = torch.stack(batch_token_type_ids) + pixel_values = torch.stack(pixel_values) + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + + return return_dict + +def offline_dataset_collate_fn_test( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, torch.Tensor]], +) -> dict[str, Any]: + """Collator for offline data. + + Args: + tokenizer (Tokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, torch.Tensor]]): The preference data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + tokenizer.padding_side = 'right' # right + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + list_input_ids = [item['input_ids'] for item in data] + ret = ref_collate_fn(list_input_ids) # right padded based on the longest sequence in the batch + batch_input_ids = ret['input_ids'] + attention_masks = torch.logical_not( + torch.eq(batch_input_ids, tokenizer.pad_token_id) + ).to(torch.int64) + + batch_max_seq_len = batch_input_ids.shape[1] + if batch_max_seq_len > max_seq_len: # truncate both input_ids and attenion_mask + batch_input_ids = batch_input_ids[:,:max_seq_len] + attention_masks = attention_masks[:,:max_seq_len] + + # pad eos token on the sequence that is truncated + for i in range(batch_input_ids.shape[0]): + # check if this sequence is truncated + if batch_input_ids[i,-1] != tokenizer.eos_token_id and batch_input_ids[i,-1] != tokenizer.pad_token_id: + batch_input_ids[i,-1] = tokenizer.eos_token_id + + sequence_lens = torch.sum(attention_masks, dim = -1) # sum of all 1 in attention mask, row-wise + prompt_lens = torch.cat([item['prompt_len'] for item in data]) + + masks, rewards, vstars, vstar_rewards, bonuses, vstar_bonus = [], [],[],[], [],[] + + if 'reward' in data[0].keys(): + rewards = torch.cat([item['reward'] for item in data]) + if 'bonus' in data[0].keys(): + bonuses = torch.cat([item['bonus'] for item in data]) + if 'vstar' in data[0].keys(): + vstars = torch.cat([item['vstar'] for item in data]) + if 'vstar_rewards' in data[0].keys(): + vstar_rewards = torch.stack([item['vstar_rewards'] for item in data]) + if 'vstar_bonus' in data[0].keys(): + vstar_bonus = torch.stack([item['vstar_bonus'] for item in data]) + + has_mask = 'mask' in data[0].keys() + if has_mask: + for i in range(batch_input_ids.shape[0]): + mask_i = data[i]['mask'] + if len(mask_i) < len(batch_input_ids[i]): # right padded + all_zeros = torch.zeros(len(batch_input_ids[i])) + all_zeros[0:len(mask_i)] = mask_i + mask_i = all_zeros + else: # truncated + mask_i = mask_i[0:len(batch_input_ids[i])] + masks.append(mask_i) + masks = torch.stack(masks) + + return_dict = { + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + } + if len(masks) > 0: + assert masks.shape == attention_masks.shape + return_dict['mask'] = masks + if len(rewards) > 0: + return_dict['reward'] = rewards + if len(bonuses) > 0: + return_dict['bonus'] = bonuses + #print("data collator {}".format(return_dict['bonus'])) + if len(vstars) > 0: + return_dict['vstar'] = vstars + if len(vstar_rewards) > 0: + return_dict['vstar_rewards'] = vstar_rewards + if len(vstar_bonus) > 0: + return_dict['vstar_bonus'] = vstar_bonus + #print("data collator {}".format(return_dict['vstar_bonus'])) + + + return return_dict + + +class OfflineStreamingDataset(StreamingDataset): + """Dataloader for streaming in preference data.""" + + def __init__(self, max_seq_len: int, processor_name: Optional[str] = None, **kwargs: dict[str, Any]): + self.max_seq_len = max_seq_len + super().__init__(**kwargs) + self.num_truncated = 0 + self.num_read = 0 + + # For proper multimodal HF checkpointing + self.processor = None + if processor_name is not None: + self.processor = AutoProcessor.from_pretrained(processor_name) + + def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): + self.num_read += 1 + temp_sample = torch.from_numpy(np.frombuffer(sample[key])) + if len(temp_sample) > self.max_seq_len: + log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') + self.num_truncated += 1 + truncated = torch.from_numpy( + np.frombuffer(sample[key][self.max_seq_len:], dtype=np.int64), + ) + log.info(f'Truncating: {truncated}') + decoded_arr = torch.from_numpy( + np.frombuffer(sample[key], + dtype=np.int64)[:self.max_seq_len].copy(), + ) + return decoded_arr + + # How to process a sample + def __getitem__(self, idx: int) -> dict[str, Any]: + """Get an item from StreamingDataset at a given index. + + Args: + idx (int): the index where we fetch the data in the StreamingDataset. + """ + sample = super().__getitem__(idx) + + # Read Samples + if 'prompt' in sample: + input_ids, prompt = [], [] + if isinstance(sample['prompt'], bytes): + sample['input_ids'] = sample['prompt'] + sample['response'] + input_ids = self._read_binary_tokenized_sample(sample, 'input_ids') + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + elif isinstance(sample['prompt'], np.ndarray): + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt = torch.from_numpy(sample['prompt']) + else: + token_type = type(sample['input_ids']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + # Get Lenghts + prompt_len = len(prompt) + sequence_len = len(input_ids) + + elif 'input' in sample and 'mask' in sample: # input already combines prompt and reponse (e.g., used in the tool call setup) + input_ids, mask = [],[] + if isinstance(sample['input'], bytes): + input_ids = self._read_binary_tokenized_sample(sample, 'input') + mask = self._read_binary_tokenized_sample(sample, 'mask') + elif isinstance(sample['input'], np.ndarray): + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) + else: + token_type = type(sample['input']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + prompt_len = 0 + sequence_len = len(input_ids) + + return_dict = { + 'input_ids': input_ids, + 'sequence_len': torch.Tensor([sequence_len]).to(torch.int64), + 'prompt_len': torch.Tensor([prompt_len]).to(torch.int64) + } + + if 'mask' in sample and 'input' in sample: + return_dict['mask'] = mask + + # If rewards are given, add them to the return dict + if 'reward' in sample: + return_dict['reward'] = torch.Tensor([sample['reward']]) + + if 'bonus' in sample: + return_dict['bonus'] = torch.Tensor([sample['bonus']]) + + if 'vstar' in sample: + assert 'vstar_rewards' not in sample + return_dict['vstar'] = torch.Tensor([sample['vstar']]) + + if 'vstar_bonus' in sample: + return_dict['vstar_bonus'] = torch.from_numpy(sample["vstar_bonus"]) + + if 'v-star' in sample: + assert 'vstar_rewards' not in sample + return_dict['vstar'] = torch.Tensor([sample['v-star']]) + + if 'vstar_rewards' in sample: + assert 'vstar' not in sample + assert 'v-star' not in sample + if isinstance(sample['vstar_rewards'], np.ndarray): + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) + else: + rewards_type = type(sample['vstar_rewards']) + raise ValueError( + f'Expect vstar_rewards to be numpy.ndarray type, but got {rewards_type}', + ) + + # Gemma 3 Specific + if 'pixel_values' in sample: + assert 'image' not in sample + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.from_numpy(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image.Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + return_dict['pixel_values'] = pixel_values + + if 'image' in sample: + assert 'pixel_values' not in sample + assert self.processor is not None + + text = self.processor.decode(return_dict['input_ids']) + input_tensors = self.processor( + text=text, + images=base64_to_pil(sample['image']), + return_tensors="pt", + padding=False, + truncation=False, + ) + return_dict['pixel_values'] = input_tensors['pixel_values'][0] + + if 'token_type_ids' in sample: + if isinstance(sample['token_type_ids'], bytes): + token_type_ids = self._read_binary_tokenized_sample( + sample, + 'token_type_ids', + ) + elif isinstance(sample['token_type_ids'], np.ndarray): + token_type_ids = torch.from_numpy( + sample['token_type_ids'][:self.max_seq_len], + ) + else: + token_type = type(sample['token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + return_dict['token_type_ids'] = token_type_ids + + return return_dict diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..f591996f 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -4,12 +4,14 @@ """Build a reward dataset and dataloader for training.""" import logging -from typing import Any +from typing import Any, Optional import numpy as np import torch +from PIL import Image from streaming import StreamingDataset -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer +from torchvision import transforms +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, AutoProcessor log = logging.getLogger(__name__) @@ -56,6 +58,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards = [] rejected_rewards = [] + # For VLMs + token_type_ids = [] + pixel_values = [] + for sample in data: chosen = sample['chosen'] rejected = sample['rejected'] @@ -63,6 +69,17 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] + is_multimodal = 'pixel_values' in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + chosen_token_type_ids = sample['chosen_token_type_ids'] + rejected_token_type_ids = sample['rejected_token_type_ids'] + else: + pixel_vals = None + chosen_token_type_ids = None + rejected_token_type_ids = None + cat_token_type_ids = None + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -75,6 +92,13 @@ def pairwise_preference_dataset_collate_fn( pad_len = max_seq_len - chosen_len - rejected_len cat_batch = torch.cat([chosen, rejected], dim=-1) + if is_multimodal: + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) + if pad_len < 0: # We should truncate chosen and rejected by the same amount truncate_len = abs(pad_len // 2) + 1 @@ -92,6 +116,22 @@ def pairwise_preference_dataset_collate_fn( rejected = rejected[:-truncate_len] rejected[-1] = tokenizer.eos_token_id # type: ignore + if is_multimodal: + chosen_token_type_ids = chosen_token_type_ids[: + -truncate_len # type: ignore + ] + rejected_token_type_ids = rejected_token_type_ids[: # type: ignore + -truncate_len] + + # NOTE: GEMMA specific: 0 == text token + chosen_token_type_ids[-1] = 0 + rejected_token_type_ids[-1] = 0 + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) + cat_batch = torch.cat([chosen, rejected], dim=-1) chosen_len = torch.tensor([len(chosen)]) @@ -108,6 +148,17 @@ def pairwise_preference_dataset_collate_fn( ], dim=-1, # type: ignore ) + if is_multimodal: + cat_token_type_ids = torch.cat( + [ + cat_token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -127,6 +178,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) + if is_multimodal: + token_type_ids.append(cat_token_type_ids) # type: ignore + pixel_values.append(pixel_vals) + input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) sequence_id = torch.stack(sequence_id) @@ -147,6 +202,13 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards + + if is_multimodal: # type: ignore + token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + return return_dict @@ -204,12 +266,17 @@ def finegrained_preference_dataset_collate_fn( class PairwisePreferenceStreamingDataset(StreamingDataset): """Dataloader for streaming in preference data.""" - def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): + def __init__(self, max_seq_len: int, processor_name: Optional[str] = None, **kwargs: dict[str, Any]): self.max_seq_len = max_seq_len super().__init__(**kwargs) self.num_truncated = 0 self.num_read = 0 + # For proper multimodal HF checkpointing + self.processor = None + if processor_name is not None: + self.processor = AutoProcessor.from_pretrained(processor_name) + def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) @@ -234,21 +301,50 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) + # Handle prompt if available - if 'prompt' in sample: + if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses - sample['chosen'] = sample['prompt'] + sample['chosen'] - sample['rejected'] = sample['prompt'] + sample['rejected'] - chosen = self._read_binary_tokenized_sample(sample, 'chosen') - rejected = self._read_binary_tokenized_sample(sample, 'rejected') - - if 'prompt' in sample: - prompt = self._read_binary_tokenized_sample(sample, 'prompt') - prompt_len = len(prompt) + if 'prompt' in sample: + sample['chosen'] = sample['prompt'] + sample['chosen'] + sample['rejected'] = sample['prompt'] + sample['rejected'] + chosen = self._read_binary_tokenized_sample(sample, 'chosen') + rejected = self._read_binary_tokenized_sample(sample, 'rejected') + + if 'prompt' in sample: + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + prompt_len = len(prompt) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) + + elif isinstance(sample['chosen'], np.ndarray): + if 'prompt' in sample: + sample['chosen'] = np.concatenate([ + sample['prompt'], + sample['chosen'], + ]) + sample['rejected'] = np.concatenate([ + sample['prompt'], + sample['rejected'], + ]) + + chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) + rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) + + if 'prompt' in sample: + prompt_len = len(sample['prompt']) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) else: - # Only use prefix matching version of prompt_len when - # 'prompt' is not directly given in the sample - prompt_len = self.find_prompt_length(chosen, rejected) + token_type = type(sample['chosen']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + chosen_len, rejected_len = len(chosen), len(rejected) return_dict = { 'chosen': chosen, @@ -263,6 +359,45 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward + + if 'pixel_values' in sample: + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.from_numpy(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image.Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + + if isinstance(sample['chosen_token_type_ids'], bytes): + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'rejected_token_type_ids', + ) + elif isinstance(sample['chosen_token_type_ids'], np.ndarray): + chosen_token_type_ids = torch.from_numpy( + sample['chosen_token_type_ids'][:self.max_seq_len], + ) + rejected_token_type_ids = torch.from_numpy( + sample['rejected_token_type_ids'][:self.max_seq_len], + ) + else: + token_type = type(sample['chosen_token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + + return_dict['pixel_values'] = pixel_values + return_dict['chosen_token_type_ids'] = chosen_token_type_ids + return_dict['rejected_token_type_ids'] = rejected_token_type_ids + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): diff --git a/compose_rl/data/prompt_data.py b/compose_rl/data/prompt_data.py index 82135bd0..20d1f2df 100644 --- a/compose_rl/data/prompt_data.py +++ b/compose_rl/data/prompt_data.py @@ -50,6 +50,9 @@ def prompt_dataset_collate_fn( if key in ['prompt_id', 'vstar']: collated_batch[key] = torch.tensor(cur_values) continue + if key == 'vstar': + collated_batch[key] = torch.tensor(cur_values) + continue if key in ['verified_answer']: collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] utils.flatten(cur_values), diff --git a/compose_rl/data/rl_data.py b/compose_rl/data/rl_data.py new file mode 100644 index 00000000..7fb51fb1 --- /dev/null +++ b/compose_rl/data/rl_data.py @@ -0,0 +1,386 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build dataloader for RL training.""" + +import logging +from typing import Any, Optional +import compose_rl.utils as utils + +import numpy as np +import torch +from streaming import StreamingDataset +from transformers import PreTrainedTokenizer,DataCollatorForLanguageModeling + +log = logging.getLogger(__name__) + + +def dataset_collate_fn( + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + data: list[dict[str, Any]], +) -> dict[str, Any]: + """Collator for RL data. + + Args: + tokenizer (PreTrainedTokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + data (list[dict[str, Any]]): The RL data to collate. + """ + if tokenizer.eos_token_id is None: + raise ValueError('Tokenizer must have an EOS token.') + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + tokenizer.padding_side = 'right' # right + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + list_of_input_ids = [] + list_of_prompt_len = [] + list_of_prompts = [] + list_of_prompt_ids = [] + list_of_num_turns = [] + return_dict: dict[str, Any] = {} + + # case 1: input_ids key is present # for single turn and flattten messages case + if 'input_ids' in data[0]: + list_of_input_ids = [item['input_ids'] for item in data] + list_of_prompt_len = [item['prompt_len'] for item in data] + + # case 2: turn_data key is present: for non-flatterned messages + elif "turn_data" in data[0]: + for data_point in data: + list_of_input_ids.extend([turn['input_ids'] for turn in data_point['turn_data']]) + list_of_prompt_len.extend([turn['prompt_len'] for turn in data_point['turn_data']]) + list_of_num_turns.append(torch.tensor([len(data_point['turn_data'])], dtype=torch.int64)) + + # case 3: prompt key is presented; no response key + elif 'prompt' in data[0]: + list_of_prompts = [item['prompt'] for item in data] + list_of_prompt_len = [item['prompt_len'] for item in data] + list_of_prompt_ids = [item['prompt_id'] for item in data] + + # dealing with input_ids if it not empty. batch, padd, and truncate based on max_seq_len + if len(list_of_input_ids) > 0: + batch_input_ids = ref_collate_fn(list_of_input_ids)['input_ids'] + attention_masks = torch.logical_not(torch.eq(batch_input_ids, tokenizer.pad_token_id)).to(torch.int64) + # truncate if length of the batch exceeds max_seq_len + batch_max_seq_len = batch_input_ids.shape[1] + if batch_max_seq_len > max_seq_len: + batch_input_ids = batch_input_ids[:,:max_seq_len] + attention_masks = attention_masks[:,:max_seq_len] + # pad eos token on the sequence that is truncated + for i in range(batch_input_ids.shape[0]): + if batch_input_ids[i,-1] != tokenizer.eos_token_id and batch_input_ids[i,-1] != tokenizer.pad_token_id: + batch_input_ids[i,-1] = tokenizer.eos_token_id + + sequence_lens = torch.sum(attention_masks, dim = -1) + prompt_lens = torch.cat(list_of_prompt_len) + + # Add sequence_id tracking (like offline_dataset_collate_fn) + sequence_id = [] + for i in range(batch_input_ids.shape[0]): + cur_seq_len = int(sequence_lens[i].item()) + pad_len = int(batch_input_ids.shape[1] - cur_seq_len) + cur_sequence_id = torch.tensor([0] * cur_seq_len + [-1] * pad_len) + sequence_id.append(cur_sequence_id) + + return_dict = { + 'input_ids': batch_input_ids, + 'attention_mask': attention_masks, + 'sequence_len': sequence_lens, + 'prompt_len': prompt_lens, + 'sequence_id': torch.stack(sequence_id), + } + + if 'mask' in data[0]: # check if additional mask is provided, if so process it and add it to the return dict + masks = [] + for i in range(batch_input_ids.shape[0]): + mask_i = data[i]['mask'] + if len(mask_i) < len(batch_input_ids[i]): # input_id got right padded + all_zeros = torch.zeros(len(batch_input_ids[i])) + all_zeros[0:len(mask_i)] = mask_i + mask_i = all_zeros + else: # input_ids got truncated + mask_i = mask_i[0:len(batch_input_ids[i])] + masks.append(mask_i) + masks = torch.stack(masks) + return_dict['mask'] = masks + + if len(list_of_prompts) > 0: # dealing with prompts if present + tokenizer.padding_side = 'left' # switch to left padding for prompts + return_dict['prompt'] = ref_collate_fn(list_of_prompts)['input_ids'] + prompt_attention_mask = torch.logical_not(torch.eq(return_dict['prompt'], tokenizer.pad_token_id)).to(torch.int64) + return_dict['prompt_attention_mask'] = prompt_attention_mask + return_dict['prompt_id'] = torch.cat(list_of_prompt_ids) + return_dict['prompt_len'] = torch.cat(list_of_prompt_len) + + # this is the case where we have turn level data and messages are not flatterned + if len(list_of_num_turns) > 0: + assert 'turn_data' in data[0], "turn_data must be present if num_turns is present" + return_dict['num_turns'] = torch.cat(list_of_num_turns) + print("#### test: print num_turns ####") + print(return_dict['num_turns']) + print("#### test: done printing num_turns ####") + assert return_dict['input_ids'].shape[0] == torch.sum(return_dict['num_turns']), "input_ids and num_turns must have the same length" + + + if 'reward' in data[0]: + return_dict['reward'] = torch.cat([item['reward'] for item in data]) + if 'bonus' in data[0]: + return_dict['bonus'] = torch.cat([item['bonus'] for item in data]) + if 'vstar_rewards' in data[0]: + return_dict['vstar_rewards'] = torch.stack([item['vstar_rewards'] for item in data]) + if 'vstar_bonus' in data[0]: + return_dict['vstar_bonus'] = torch.stack([item['vstar_bonus'] for item in data]) + if "verified_answer" in data[0]: + return_dict['verified_answer'] = list(utils.flatten([item['verified_answer'] for item in data])) + + return return_dict + + + +class RLStreamingDataset(StreamingDataset): + """Dataloader for streaming in RL data.""" + + def __init__(self, + max_seq_len: int, + tokenizer: PreTrainedTokenizer, + chat_template: Optional[str] = None, + chat_template_path: Optional[str] = None, + tools: Optional[list[dict[str, Any]]] = None, + tools_path: Optional[str] = None, + flatten_messages: bool = False, + **kwargs: Any): + super().__init__(**kwargs) + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + self.flatten_messages = flatten_messages + + # Handle chat template (priority: file path > direct template > default) + if chat_template_path is not None: + # Load template from file + import os + + # Convert to absolute path for clarity + abs_template_path = os.path.abspath(chat_template_path) + + if not os.path.exists(abs_template_path): + raise FileNotFoundError(f"Chat template file not found: {chat_template_path} (resolved to: {abs_template_path})") + + with open(abs_template_path, 'r', encoding='utf-8') as f: + self.chat_template = f.read().strip() + log.info(f"Loaded chat template from: {abs_template_path}") + # Apply it to the tokenizer + self.tokenizer.chat_template = self.chat_template + + elif chat_template is not None: + # Use direct template string + self.chat_template = chat_template + # Apply it to the tokenizer + self.tokenizer.chat_template = chat_template + + else: + # Use tokenizer's default chat template + self.chat_template = getattr(tokenizer, 'chat_template', None) + + print(f"Using chat template: {self.chat_template}") + + # Handle tools (priority: file path > direct tools > None) + self.tools = [] + if tools_path is not None: + # Load tools from JSONL file (one JSON object per line) + import json + import os + + abs_tools_path = os.path.abspath(tools_path) + if not os.path.exists(abs_tools_path): + raise FileNotFoundError(f"Tools file not found: {tools_path} (resolved to: {abs_tools_path})") + + with open(abs_tools_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: # Skip empty lines + continue + try: + tool = json.loads(line) + if not isinstance(tool, dict): + raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") + self.tools.append(tool) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") + + log.info(f"Loaded {len(self.tools)} tools from JSONL file: {abs_tools_path}") + + elif tools is not None: + # Use direct tools list + if not isinstance(tools, list): + raise ValueError(f"Tools must be a list, but got {type(tools)}") + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") + + self.tools = tools + log.info(f"Using {len(self.tools)} tools provided directly") + + print(f"Using {len(self.tools)} tools, and tools are {self.tools}") + + def _convert_messages_to_turn_wise_data(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + #convert mesage to turn wise data + turn_data: list[dict[str, Any]] = [] + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == 'assistant': + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] # this makes sure that it ends with special generation token + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" # pyright: ignore[reportIndexIssue] + + input_ids = history_assistant + prompt_len = len(history) + sequence_len = len(input_ids) + + turn_data.append({ + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + }) + + return turn_data # list of dict + + def _convert_messages_to_traj_wise_data(self, messages: list[dict[str, Any]]) -> dict[str, Any]: + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + mask = [] + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == "assistant": + history = self.tokenizer.apply_chat_template(messages[0:i], return_tensors='pt', add_generation_prompt=True, tokenize=True, tools=self.tools)[0] + history_assistant = self.tokenizer.apply_chat_template(messages[0:i+1], return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + generation_len = len(history_assistant) - len(history) + current_mask = [0]*len(history) + [1]*generation_len + current_mask[0:len(mask)] = mask + mask = current_mask + + input_ids = self.tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + assert len(input_ids) == len(mask), f"Input ids and mask must have the same length" + return_dict = { + 'input_ids': input_ids, + 'mask': torch.tensor(mask, dtype=torch.int64), + } + + return return_dict # dict + + + def __getitem__(self, idx: int) -> dict[str, Any]: + sample = super().__getitem__(idx) + + return_dict: dict[str, Any] = {} + + prompt_id = None + prompt = None + prompt_len = None + input_ids = None + sequence_len = None + mask = None + turn_data = None # list[dict[str, Any]] = [] + + # case 0: just contains prompt. This is for online RL setting + if 'prompt' in sample and 'response' not in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + prompt = torch.from_numpy(sample['prompt']) + prompt_id = idx + prompt_len = len(prompt) + + # return dict for case 0: + return_dict = { + 'prompt': prompt, + 'prompt_id': prompt_id, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + } + + # case 1: prompt + response, we assume both are tokenized ndarray; this is for standard single turn offline rl + elif 'prompt' in sample and 'response' in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + assert isinstance(sample['response'], np.ndarray), f"Response must be a numpy array, but got {type(sample['response'])}" + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt_len = len(torch.from_numpy(sample['prompt'])) + sequence_len = len(input_ids) + + # return dict for case 1: + return_dict = { + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns + elif 'input' in sample and 'mask' in sample: + print('------case 2: input + mask------') + + assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" + assert isinstance(sample['mask'], np.ndarray), f"Mask must be a numpy array, but got {type(sample['mask'])}" + + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) + + prompt_len = 0 + sequence_len = len(input_ids) + + # return dict for case 2: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 3: for multi-turn data, and sample['messages] contains a list of messages in text + elif 'messages' in sample: + messages = sample['messages'] + if self.flatten_messages is False: + turn_data = self._convert_messages_to_turn_wise_data(messages) # list of dict, one dict per assistant turn + # return dict for case 3.a: + return_dict = { + 'turn_data': turn_data, + } + else: + traj_data = self._convert_messages_to_traj_wise_data(messages) # dict, flatten the message into a single trajectory + input_ids = traj_data['input_ids'] + mask = traj_data['mask'] + prompt_len = 0 + sequence_len = len(input_ids) + # return dict for case 3.b: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, # Already converted to tensor in _convert_messages_to_traj_wise_data + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + else: + raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") + + # now add additional keys + if 'reward' in sample: + return_dict['reward'] = torch.tensor([sample['reward']]) + if 'bonus' in sample: + return_dict['bonus'] = torch.tensor([sample['bonus']]) + if 'vstar_rewards' in sample: + assert isinstance(sample['vstar_rewards'], np.ndarray), f"Vstar rewards must be a numpy array, but got {type(sample['vstar_rewards'])}" + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) + if 'vstar_bonus' in sample: + assert isinstance(sample['vstar_bonus'], np.ndarray), f"Vstar bonus must be a numpy array, but got {type(sample['vstar_bonus'])}" + return_dict['vstar_bonus'] = torch.from_numpy(sample['vstar_bonus']) + if 'verified_answer' in sample: + assert isinstance(sample['verified_answer'], str), f"Verified answer must be a string, but got {type(sample['verified_answer'])}" + return_dict['verified_answer'] = sample['verified_answer'] + + return return_dict \ No newline at end of file diff --git a/compose_rl/metrics/offline_learning_metrics.py b/compose_rl/metrics/offline_learning_metrics.py new file mode 100644 index 00000000..2eac04ef --- /dev/null +++ b/compose_rl/metrics/offline_learning_metrics.py @@ -0,0 +1,156 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from torchmetrics import Metric + + +class TestTotalLossMetric(Metric): + """Metric for tracking total training loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'total' in batch: + self.loss_sum += batch['total'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestImplicitRewardsLossMetric(Metric): + """Metric for tracking implicit rewards loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'implicit_rewards' in batch: + self.loss_sum += batch['implicit_rewards'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestKLDivergenceLossMetric(Metric): + """Metric for tracking KL divergence losses.""" + + full_state_update = False + + def __init__(self, kl_type: str = 'reverse', dist_sync_on_step: bool = False, **kwargs: Any): + """Initialize KL divergence loss metric. + + Args: + kl_type: Type of KL divergence to track ('reverse' or 'forward') + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.kl_type = kl_type + self.loss_key = f'{kl_type}_kl' + + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if self.loss_key in batch: + self.loss_sum += batch[self.loss_key].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestEstimatedRewardLossMetric(Metric): + """Metric for tracking estimated reward loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + # print("keys in batch: ", batch.keys()) + # print("updating in estimated reward loss metric") + if 'estimated_reward' in batch: + self.loss_sum += batch['estimated_reward'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) + + +class TestSequenceEntropiesLossMetric(Metric): + """Metric for tracking sequence entropies loss.""" + + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False, **kwargs: Any): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state( + 'loss_sum', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) + self.add_state( + 'count', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + + def update(self, batch: dict, output_logits: torch.Tensor): + + if 'sequence_entropies' in batch: + self.loss_sum += batch['sequence_entropies'].detach().cpu().item() + self.count += 1 + + def compute(self): + return self.loss_sum / self.count if self.count > 0 else torch.tensor(0.0) \ No newline at end of file diff --git a/example_tools.jsonl b/example_tools.jsonl new file mode 100644 index 00000000..a8122875 --- /dev/null +++ b/example_tools.jsonl @@ -0,0 +1 @@ +{"type": "function", "function": {"name": "vector_search", "description": "Search for similar documents in a vector index using text queries.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Text query to search for similar documents"}}, "required": ["query"]}}} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f13f1efe..fbc0434b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ hf_classifier_rm = "compose_rl.algorithms.reward_modeling:ComposerHFClassifierRe hf_causal_classifier_rm = "compose_rl.algorithms.reward_modeling:ComposerHFCausalClassifierRewardModel" mpt_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerMPTPairwiseOfflinePolicyLM" hf_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerHFPairwiseOfflinePolicyLM" +mpt_offline_lm = "compose_rl.algorithms.offline:ComposerMPTOfflinePolicyLM" +hf_offline_lm = "compose_rl.algorithms.offline:ComposerHFOfflinePolicyLM" mpt_actor_critic_lm = "compose_rl.algorithms.online:ComposerMPTPolicyLM" hf_actor_critic_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM" hf_critic_free_lm = "compose_rl.algorithms.online:ComposerHFCriticFreePolicyLM" @@ -60,10 +62,14 @@ pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader" finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader" prompt = "compose_rl.data:build_prompt_dataloader" messages = "compose_rl.data:build_messages_dataloader" +offline = "compose_rl.data:build_offline_dataloader" +unified_rl = "compose_rl.data:build_rl_dataloader" [project.entry-points."llmfoundry_callbacks_with_config"] offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback" +pairwise_offline_rl = "compose_rl.algorithms.offline:PairwiseReferencePolicyCallback" on_policy_rl = "compose_rl.algorithms.online:OnPolicyCallback" + # Backwards Compatibility dpo = "compose_rl.algorithms.offline:ReferencePolicyCallback" ppo = "compose_rl.algorithms.online:OnPolicyCallback" diff --git a/test_rl_data_comprehensive.py b/test_rl_data_comprehensive.py new file mode 100644 index 00000000..d55d6468 --- /dev/null +++ b/test_rl_data_comprehensive.py @@ -0,0 +1,652 @@ +#!/usr/bin/env python3 +""" +Comprehensive test cases for RLStreamingDataset +Self-contained with mocked dependencies to avoid import issues. +""" + +import json +import os +import tempfile +import torch +import numpy as np +from typing import Any, Dict, List, Optional, Union +from unittest.mock import MagicMock + + +# Mock StreamingDataset base class +class MockStreamingDataset: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __getitem__(self, idx): + # Override in actual usage + pass + + +# Mock PreTrainedTokenizer +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + self.eos_token_id = 2 + self.padding_side = 'right' + self.chat_template = None + + def apply_chat_template(self, messages, tokenize=False, tools=None, add_generation_prompt=False, return_tensors=None): + """Mock chat template application""" + # Simulate tokenization by converting to simple token IDs + if not tokenize: + return "mocked_template_string" + + # Create mock token sequences based on message content + total_tokens = [] + + for msg in messages: + content = msg.get('content', '') or '' + role = msg.get('role', 'unknown') + + # Add role-specific prefix tokens + if role == 'system': + total_tokens.extend([1001, 1002]) # system prefix tokens + elif role == 'user': + total_tokens.extend([1003, 1004]) # user prefix tokens + elif role == 'assistant': + total_tokens.extend([1005, 1006]) # assistant prefix tokens + + # Simple tokenization: each word becomes a token ID + if content: + words = content.split() + token_ids = [hash(word) % 800 + 100 for word in words] # Mock content token IDs + total_tokens.extend(token_ids) + + # Add role-specific suffix tokens + total_tokens.append(1010) # end of message token + + # Add special tokens based on generation prompt + if add_generation_prompt: + total_tokens.extend([1005, 1006]) # Add assistant prefix for generation + else: + # Only add EOS if this is the final completion (no generation prompt) + if messages and messages[-1].get('role') == 'assistant': + total_tokens.append(self.eos_token_id) + + result_tensor = torch.tensor(total_tokens) + + if return_tensors == 'pt': + return result_tensor.unsqueeze(0) # Batch dimension + return result_tensor + + +# Copy the core RLStreamingDataset logic (simplified) +class RLStreamingDataset(MockStreamingDataset): + """Dataloader for streaming in RL data.""" + + def __init__(self, + max_seq_len: int, + tokenizer, + chat_template: Optional[str] = None, + chat_template_path: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tools_path: Optional[str] = None, + flatten_messages: bool = False, + **kwargs: Any): + super().__init__(**kwargs) + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + self.flatten_messages = flatten_messages + + # Handle chat template (priority: file path > direct template > default) + if chat_template_path is not None: + # Load template from file + abs_template_path = os.path.abspath(chat_template_path) + if not os.path.exists(abs_template_path): + raise FileNotFoundError(f"Chat template file not found: {chat_template_path} (resolved to: {abs_template_path})") + + with open(abs_template_path, 'r', encoding='utf-8') as f: + self.chat_template = f.read().strip() + # Apply it to the tokenizer + self.tokenizer.chat_template = self.chat_template + elif chat_template is not None: + # Use direct template string + self.chat_template = chat_template + self.tokenizer.chat_template = chat_template + else: + # Use tokenizer's default chat template + self.chat_template = getattr(tokenizer, 'chat_template', None) + + # Handle tools (priority: file path > direct tools > None) + self.tools = [] + if tools_path is not None: + # Load tools from JSONL file (one JSON object per line) + abs_tools_path = os.path.abspath(tools_path) + if not os.path.exists(abs_tools_path): + raise FileNotFoundError(f"Tools file not found: {tools_path} (resolved to: {abs_tools_path})") + + with open(abs_tools_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: # Skip empty lines + continue + try: + tool = json.loads(line) + if not isinstance(tool, dict): + raise ValueError(f"Tool on line {line_num} must be a dictionary, but got {type(tool)}") + self.tools.append(tool) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_num} in {abs_tools_path}: {e}") + + elif tools is not None: + # Use direct tools list + if not isinstance(tools, list): + raise ValueError(f"Tools must be a list, but got {type(tools)}") + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + raise ValueError(f"Tool {i} must be a dictionary, but got {type(tool)}") + + self.tools = tools + + def _convert_messages_to_turn_wise_data(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert messages to turn wise data""" + turn_data: List[Dict[str, Any]] = [] + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == 'assistant': + history = self.tokenizer.apply_chat_template(messages[:i], tokenize=True, tools=self.tools, add_generation_prompt=True, return_tensors='pt')[0] + history_assistant = self.tokenizer.apply_chat_template(messages[:i+1], tokenize=True, tools=self.tools, add_generation_prompt=False, return_tensors='pt')[0] + assert torch.allclose(history_assistant[:len(history)], history, atol=1e-5), f"History assistant must be the same as history" + + input_ids = history_assistant + prompt_len = len(history) + sequence_len = len(input_ids) + + turn_data.append({ + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + }) + + return turn_data + + def _convert_messages_to_traj_wise_data(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: + """Convert messages to trajectory wise data""" + assert isinstance(messages, list), f"Messages must be a list, but got {type(messages)}" + mask = [] + for i in range(len(messages)): + message = messages[i] + assert isinstance(message, dict), f"Message must be a dictionary, but got {type(message)}" + if message['role'] == "assistant": + history = self.tokenizer.apply_chat_template(messages[0:i], return_tensors='pt', add_generation_prompt=True, tokenize=True, tools=self.tools)[0] + history_assistant = self.tokenizer.apply_chat_template(messages[0:i+1], return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + generation_len = len(history_assistant) - len(history) + current_mask = [0]*len(history) + [1]*generation_len + current_mask[0:len(mask)] = mask + mask = current_mask + + input_ids = self.tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=False, tokenize=True, tools=self.tools)[0] + assert len(input_ids) == len(mask), f"Input ids and mask must have the same length, got {len(input_ids)} vs {len(mask)}" + return_dict = { + 'input_ids': input_ids, + 'mask': torch.tensor(mask, dtype=torch.int64), + } + return return_dict + + def __getitem__(self, idx: int) -> Dict[str, Any]: + # Mock the parent's __getitem__ to return test data + if hasattr(self, '_test_samples') and idx < len(self._test_samples): + sample = self._test_samples[idx] + else: + raise IndexError(f"Index {idx} out of range") + + return_dict: Dict[str, Any] = {} + + prompt_id = None + prompt = None + prompt_len = None + input_ids = None + sequence_len = None + mask = None + turn_data = None + + # case 0: just contains prompt. This is for online RL setting + if 'prompt' in sample and 'response' not in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + prompt = torch.from_numpy(sample['prompt']) + prompt_id = idx + prompt_len = len(prompt) + + # return dict for case 0: + return_dict = { + 'prompt': prompt, + 'prompt_id': prompt_id, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + } + + # case 1: prompt + response, we assume both are tokenized ndarray; this is for standard single turn offline rl + elif 'prompt' in sample and 'response' in sample: + assert isinstance(sample['prompt'], np.ndarray), f"Prompt must be a numpy array, but got {type(sample['prompt'])}" + assert isinstance(sample['response'], np.ndarray), f"Response must be a numpy array, but got {type(sample['response'])}" + input_ids = np.concatenate([sample['prompt'], sample['response']]) + input_ids = torch.from_numpy(input_ids[:self.max_seq_len]) + prompt_len = len(torch.from_numpy(sample['prompt'])) + sequence_len = len(input_ids) + + # return dict for case 1: + return_dict = { + 'input_ids': input_ids, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 2: input + mask, this is can be for single turn or multi-turn offline RL. mask is used to mask out non-assistant turns + elif 'input' in sample and 'mask' in sample: + assert isinstance(sample['input'], np.ndarray), f"Input must be a numpy array, but got {type(sample['input'])}" + assert isinstance(sample['mask'], np.ndarray), f"Mask must be a numpy array, but got {type(sample['mask'])}" + + input_ids = torch.from_numpy(sample['input']).to(torch.int64) + mask = torch.from_numpy(sample['mask']).to(torch.int64) + + prompt_len = 0 + sequence_len = len(input_ids) + + # return dict for case 2: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + # case 3: for multi-turn data, and sample['messages'] contains a list of messages in text + elif 'messages' in sample: + messages = sample['messages'] + if self.flatten_messages is False: + turn_data = self._convert_messages_to_turn_wise_data(messages) # list of dict, one dict per assistant turn + # return dict for case 3.a: + return_dict = { + 'turn_data': turn_data, + } + else: + traj_data = self._convert_messages_to_traj_wise_data(messages) # dict, flatten the message into a single trajectory + input_ids = traj_data['input_ids'] + mask = traj_data['mask'] + prompt_len = 0 + sequence_len = len(input_ids) + # return dict for case 3.b: + return_dict = { + 'input_ids': input_ids, + 'mask': mask, # Already converted to tensor in _convert_messages_to_traj_wise_data + 'prompt_len': torch.tensor([prompt_len], dtype=torch.int64), + 'sequence_len': torch.tensor([sequence_len], dtype=torch.int64), + } + + else: + raise ValueError(f"Sample must contain 'prompt', 'prompt'+'response', 'input'+'mask', or 'messages', but got keys: {list(sample.keys())}") + + # now add additional keys + if 'reward' in sample: + return_dict['reward'] = torch.tensor([sample['reward']]) + if 'bonus' in sample: + return_dict['bonus'] = torch.tensor([sample['bonus']]) + if 'vstar_rewards' in sample: + assert isinstance(sample['vstar_rewards'], np.ndarray), f"Vstar rewards must be a numpy array, but got {type(sample['vstar_rewards'])}" + return_dict['vstar_rewards'] = torch.from_numpy(sample['vstar_rewards']) + if 'vstar_bonus' in sample: + assert isinstance(sample['vstar_bonus'], np.ndarray), f"Vstar bonus must be a numpy array, but got {type(sample['vstar_bonus'])}" + return_dict['vstar_bonus'] = torch.from_numpy(sample['vstar_bonus']) + if 'verified_answer' in sample: + assert isinstance(sample['verified_answer'], str), f"Verified answer must be a string, but got {type(sample['verified_answer'])}" + return_dict['verified_answer'] = sample['verified_answer'] + + return return_dict + + +# Test cases +def test_case_0_prompt_only(): + """Test case 0: prompt only (online RL)""" + print("๐Ÿงช Testing Case 0: Prompt only...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'prompt': np.array([10, 11, 12, 13, 14]), + 'reward': 0.5 + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'prompt' in result + assert 'prompt_id' in result + assert 'prompt_len' in result + assert 'reward' in result + assert isinstance(result['prompt'], torch.Tensor) + assert result['prompt_id'] == 0 + assert result['prompt_len'].item() == 5 + assert result['reward'].item() == 0.5 + + print("โœ… Case 0 passed!") + + +def test_case_1_prompt_response(): + """Test case 1: prompt + response (single turn offline RL)""" + print("๐Ÿงช Testing Case 1: Prompt + Response...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'prompt': np.array([10, 11, 12]), + 'response': np.array([20, 21, 22, 23]), + 'reward': 1.0 + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert 'reward' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert result['prompt_len'].item() == 3 + assert result['sequence_len'].item() == 7 # 3 + 4 = 7 + assert result['reward'].item() == 1.0 + + # Check concatenation + expected_input_ids = torch.tensor([10, 11, 12, 20, 21, 22, 23]) + assert torch.equal(result['input_ids'], expected_input_ids) + + print("โœ… Case 1 passed!") + + +def test_case_2_input_mask(): + """Test case 2: input + mask (multi-turn with mask)""" + print("๐Ÿงช Testing Case 2: Input + Mask...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data + test_sample = { + 'input': np.array([10, 11, 12, 20, 21, 22]), + 'mask': np.array([0, 0, 0, 1, 1, 1]), # Only last 3 tokens are assistant + 'vstar_rewards': np.array([0.1, 0.2, 0.3]) + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'mask' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert 'vstar_rewards' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert isinstance(result['mask'], torch.Tensor) + assert result['prompt_len'].item() == 0 + assert result['sequence_len'].item() == 6 + + # Check mask values + expected_mask = torch.tensor([0, 0, 0, 1, 1, 1]) + assert torch.equal(result['mask'], expected_mask) + + print("โœ… Case 2 passed!") + + +def test_case_3a_messages_turn_wise(): + """Test case 3a: messages with turn-wise processing""" + print("๐Ÿงช Testing Case 3a: Messages (Turn-wise)...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, flatten_messages=False) + + # Mock test data + test_sample = { + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi there! How can I help you?'}, + {'role': 'user', 'content': 'Tell me a joke.'}, + {'role': 'assistant', 'content': 'Why did the chicken cross the road?'} + ] + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'turn_data' in result + assert isinstance(result['turn_data'], list) + assert len(result['turn_data']) == 2 # Two assistant turns + + # Check first turn + turn1 = result['turn_data'][0] + assert 'input_ids' in turn1 + assert 'prompt_len' in turn1 + assert 'sequence_len' in turn1 + assert isinstance(turn1['input_ids'], torch.Tensor) + assert turn1['prompt_len'].item() > 0 + assert turn1['sequence_len'].item() > 0 + + # Check second turn + turn2 = result['turn_data'][1] + assert 'input_ids' in turn2 + assert 'prompt_len' in turn2 + assert 'sequence_len' in turn2 + assert isinstance(turn2['input_ids'], torch.Tensor) + + print("โœ… Case 3a passed!") + + +def test_case_3b_messages_trajectory_wise(): + """Test case 3b: messages with trajectory-wise processing""" + print("๐Ÿงช Testing Case 3b: Messages (Trajectory-wise)...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, flatten_messages=True) + + # Mock test data + test_sample = { + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi there!'}, + {'role': 'user', 'content': 'Thanks.'}, + {'role': 'assistant', 'content': 'Welcome!'} + ] + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'input_ids' in result + assert 'mask' in result + assert 'prompt_len' in result + assert 'sequence_len' in result + assert isinstance(result['input_ids'], torch.Tensor) + assert isinstance(result['mask'], torch.Tensor) + assert result['prompt_len'].item() == 0 + assert result['sequence_len'].item() > 0 + + # Check mask - should have 1s for assistant tokens, 0s for others + assert len(result['mask']) == len(result['input_ids']) + assert torch.sum(result['mask']).item() > 0 # Some tokens should be marked as assistant + + print("โœ… Case 3b passed!") + + +def test_tools_from_list(): + """Test tools loading from direct list""" + print("๐Ÿงช Testing Tools from List...") + + tokenizer = MockTokenizer() + test_tools = [ + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {}} + } + } + ] + + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools=test_tools) + + # Assertions + assert len(dataset.tools) == 1 + assert dataset.tools[0]["type"] == "function" + assert dataset.tools[0]["function"]["name"] == "test_tool" + + print("โœ… Tools from list passed!") + + +def test_tools_from_jsonl_file(): + """Test tools loading from JSONL file""" + print("๐Ÿงช Testing Tools from JSONL file...") + + # Create temporary JSONL file + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: + tool1 = {"type": "function", "function": {"name": "tool1", "description": "First tool"}} + tool2 = {"type": "function", "function": {"name": "tool2", "description": "Second tool"}} + f.write(json.dumps(tool1) + '\n') + f.write(json.dumps(tool2) + '\n') + temp_file = f.name + + try: + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools_path=temp_file) + + # Assertions + assert len(dataset.tools) == 2 + assert dataset.tools[0]["function"]["name"] == "tool1" + assert dataset.tools[1]["function"]["name"] == "tool2" + + print("โœ… Tools from JSONL file passed!") + + finally: + os.unlink(temp_file) + + +def test_chat_template_from_string(): + """Test chat template from direct string""" + print("๐Ÿงช Testing Chat Template from String...") + + tokenizer = MockTokenizer() + test_template = "Custom template: {{ content }}" + + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, chat_template=test_template) + + # Assertions + assert dataset.chat_template == test_template + assert dataset.tokenizer.chat_template == test_template + + print("โœ… Chat template from string passed!") + + +def test_error_cases(): + """Test error handling""" + print("๐Ÿงช Testing Error Cases...") + + tokenizer = MockTokenizer() + + # Test invalid tools + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools="invalid") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Tools must be a list" in str(e) + + # Test invalid tool in list + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools=["invalid"]) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Tool 0 must be a dictionary" in str(e) + + # Test non-existent tools file + try: + RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer, tools_path="/non/existent/file.jsonl") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError as e: + assert "Tools file not found" in str(e) + + # Test invalid sample format + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + dataset._test_samples = [{'invalid': 'data'}] + + try: + dataset[0] + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Sample must contain" in str(e) + + print("โœ… Error cases passed!") + + +def test_additional_fields(): + """Test additional fields like bonus, verified_answer""" + print("๐Ÿงช Testing Additional Fields...") + + tokenizer = MockTokenizer() + dataset = RLStreamingDataset(max_seq_len=128, tokenizer=tokenizer) + + # Mock test data with additional fields + test_sample = { + 'prompt': np.array([10, 11, 12]), + 'response': np.array([20, 21]), + 'bonus': 2.5, + 'vstar_bonus': np.array([0.1, 0.2]), + 'verified_answer': 'This is verified' + } + dataset._test_samples = [test_sample] + + result = dataset[0] + + # Assertions + assert 'bonus' in result + assert 'vstar_bonus' in result + assert 'verified_answer' in result + assert result['bonus'].item() == 2.5 + assert isinstance(result['vstar_bonus'], torch.Tensor) + assert result['verified_answer'] == 'This is verified' + + print("โœ… Additional fields passed!") + + +def run_all_tests(): + """Run all test cases""" + print("๐Ÿš€ Running RLStreamingDataset Tests...\n") + + try: + test_case_0_prompt_only() + test_case_1_prompt_response() + test_case_2_input_mask() + test_case_3a_messages_turn_wise() + test_case_3b_messages_trajectory_wise() + test_tools_from_list() + test_tools_from_jsonl_file() + test_chat_template_from_string() + test_additional_fields() + test_variable_length_vstar_fields() + test_error_cases() + + print("\n๐ŸŽ‰ All tests passed! RLStreamingDataset is working correctly.") + + except Exception as e: + print(f"\nโŒ Test failed: {e}") + raise + + +if __name__ == "__main__": + run_all_tests() diff --git a/tests/test_offline.py b/tests/test_offline.py index 2838700b..fe8132a6 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -21,7 +21,8 @@ ComposerHFPairwiseOfflinePolicyLM, ComposerMPTPairwiseOfflinePolicyLM, ) -from compose_rl.algorithms.offline.callback import ReferencePolicyCallback +from compose_rl.algorithms.offline.callback import \ + PairwiseReferencePolicyCallback from compose_rl.data import pairwise_preference_dataset_collate_fn from tests.common import PairwisePreference, world_size @@ -64,7 +65,7 @@ def test_load_checkpoint_with_offline_callback( train_config = { 'model': model_config, } - reference_policy_callback = ReferencePolicyCallback( + reference_policy_callback = PairwiseReferencePolicyCallback( train_config=train_config, ) @@ -125,7 +126,7 @@ def test_reference_policy_callback_forward( 'fsdp_config': {}, 'seed': 17, } - callback = ReferencePolicyCallback(train_config=train_config) + callback = PairwiseReferencePolicyCallback(train_config=train_config) Trainer( model=model, callbacks=callback, @@ -209,7 +210,7 @@ def test_train( trainer = Trainer( model=model, train_dataloader=dataloader, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='1ep', ) @@ -295,7 +296,7 @@ def test_checkpoint_reloading( model=model, train_dataloader=dataloader, loggers=in_memory_logger, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='8ba', autoresume=True, @@ -317,7 +318,7 @@ def test_checkpoint_reloading( model=model, train_dataloader=dataloader, loggers=in_memory_logger, - callbacks=ReferencePolicyCallback(train_config=train_config), + callbacks=PairwiseReferencePolicyCallback(train_config=train_config), parallelism_config={'fsdp': fsdp_config}, max_duration='8ba', save_overwrite=True, diff --git a/yamls/local_dpo.yaml b/yamls/local_dpo.yaml index c142dde2..94de3507 100644 --- a/yamls/local_dpo.yaml +++ b/yamls/local_dpo.yaml @@ -7,20 +7,20 @@ model: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: meta-llama/Llama-3.1-8B-Instruct + pretrained_model_name_or_path: Qwen/Qwen2.5-7B loggers: mlflow: - experiment_name: brandon_dpo_test + experiment_name: wensun_dpo_test callbacks: - offline_rl: {} + pairwise_offline_rl: {} lr_monitor: {} speed_monitor: window_size: 10 memory_monitor: {} # hf_checkpointer: - # save_folder: # TODO: insert save path for huggingface checkpoints + # save_folder: # TODO: insert save path for huggingface checkpoints # save_interval: 1ep optimizer: @@ -39,7 +39,7 @@ scheduler: t_warmup: 0.1dur tokenizer: - name: meta-llama/Llama-3.1-8B-Instruct + name: Qwen/Qwen2.5-7B kwargs: model_max_length: ${max_seq_len} trust_remote_code: true @@ -59,7 +59,7 @@ fsdp_config: sharding_strategy: FULL_SHARD activation_cpu_offload: false -max_seq_len: 2048 +max_seq_len: 4096 save_folder: /tmp/dpo_model # TODO: update for a proper save path dist_timeout: 600 max_duration: 1ep @@ -68,8 +68,8 @@ progress_bar: false train_loader: name: pairwise_preference dataset: - # local: # TODO: insert local dataset path - # remote: # TODO: insert remote dataset path if applicable + # local: # TODO: insert local dataset path + # remote: # TODO: insert remote dataset path if applicable split: train shuffle: true diff --git a/yamls/offline_apo.yaml b/yamls/offline_apo.yaml new file mode 100644 index 00000000..8ba7cde3 --- /dev/null +++ b/yamls/offline_apo.yaml @@ -0,0 +1,179 @@ +name: apo-single-stream-openr1 +#name: apo-single-stream-math + +image: mosaicml/dle:nightly-latest +scheduling: + priority: high + max_retries: 0 + preemptible: true + retry_on_system_failure: false + +compute: + gpus: 16 #32 #16 + cluster: r5z2p1 + +#run_name: offline_apo_math_qwen +run_name: offline_apo_openr1_traj_wise_stream_deepseek + +parameters: + seed: 7338 + model: + name: hf_offline_lm + beta1: 1 + beta2: 0.01 + eta: 0.5 + loss_type: apo_critic + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B #deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + #Qwen/Qwen3-4B-Instruct-2507 + #/tmp/local_qwen + #Qwen/Qwen3-4B + #Qwen/Qwen2.5-Coder-32B-Instruct + #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + #Qwen/Qwen2.5-7B + + loggers: + mlflow: + tracking_uri: databricks + experiment_name: offline_apo_openr1_single_resposne + #/Users/j.chang@databricks.com/mlflow_experiments/offline_apo_single_response_openr1 + + callbacks: + offline_rl: {} + lr_monitor: {} + scheduled_gc: + batch_interval: 2000 + speed_monitor: + window_size: 10 + memory_monitor: {} + hf_checkpointer: + overwrite: true + precision: bfloat16 + save_folder: s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/{run_name} + #s3://data-force-one-datasets/mosaicml-internal-checkpoints/jchang/models/hf/{run_name} + save_interval: 1ep + + optimizer: + lr: 1.0e-06 #7.0e-07 + eps: 1.0e-10 + name: decoupled_adamw + betas: + - 0.9 + - 0.95 + weight_decay: 1.0e-8 + precision: amp_bf16 + scheduler: + name: cosine_with_warmup + alpha_f: 0.05 + t_warmup: 0.02dur + #name: constant_with_warmup + #t_warmup: 0.01dur + + tokenizer: + name: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B #deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + #Qwen/Qwen3-4B-Instruct-2507 + #Qwen/Qwen2.5-7B + #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + kwargs: + trust_remote_code: true + padding_side: right + + algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1 + eval_first: false + + fsdp_config: + verbose: false + cpu_offload: false + mixed_precision: PURE + state_dict_type: sharded + forward_prefetch: true + backward_prefetch: BACKWARD_PRE + limit_all_gathers: true + sharding_strategy: FULL_SHARD + activation_cpu_offload: false + activation_checkpointing: true + activation_checkpointing_reentrant: false + + max_seq_len: 16384 #65536 #35000 #16384 #65K does not work due to oom. + dist_timeout: 3600 + max_duration: 1ep + autoresume: true + progress_bar: false + train_loader: + name: unified_rl + dataset: + local: /tmp/dataset/ + split: train + remote: dbfs:/Volumes/datasets/wensun/data/offline_apo/traj_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v3_medium_0.5/ + #dbfs:/Volumes/datasets/wensun/data/offline_apo/step_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ + # dbfs:/Volumes/datasets/wensun/data/offline_apo/traj_wise_openr1/DeepSeek-R1-Distill-Qwen-7B/v3_medium_0.5/ + #dbfs:/Volumes/datasets/jchang/data/offline_apo/apo_openr1/DeepSeek-R1-Distill-Qwen-7B/v1_0.5/ + shuffle: true + max_seq_len: ${max_seq_len} + flatten_messages: true # Enable trajectory-wise flattened processing for messages + tools_path: ../../../compose-rl/example_tools.jsonl + shuffle_seed: 7338 + download_retry: 4 + download_timeout: 600 + drop_last: false + num_workers: 8 + save_folder: s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/mpt/{run_name} + #s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/{run_name} + save_overwrite: True + eval_interval: 1ep + save_interval: 1ep + log_to_console: true + python_log_level: debug + load_weights_only: true + console_log_interval: 5ba + device_eval_batch_size: 1 + eval_subset_num_batches: -1 + global_train_batch_size: 256 + device_train_microbatch_size: 1 + + + variables: + reference_model: + name: hf_offline_lm + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + + auxiliary_model: # Add your second model here + name: hf_offline_lm + pretrained: true + init_device: mixed + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B # Different model or same + num_bins: 32 # Number of bins for distributional value function + +integrations: +- integration_type: git_repo + path: custom-llm-foundry + git_repo: mosaicml/llm-foundry + git_branch: main + pip_install: . --no-deps +- integration_type: git_repo + path: compose-rl + git_repo: databricks/Compose-RL + git_branch: wensun/offline_apo + pip_install: .[gpu] --no-deps +env_variables: + AWS_PROFILE: data-force-one +command: |- + # Run llm foundry train + + #mkdir -p /tmp/local_qwen + #s5cmd --numworkers 32 cp s3://data-force-one-datasets/mosaicml-internal-checkpoints/wensun/models/offline_apo_tool_call-1-1.0-0.01-1e-06-4-256-4b_f_t/huggingface/ba689/* /tmp/local_qwen + + cd custom-llm-foundry/scripts/train/ + composer train.py /mnt/config/parameters.yaml