diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 09d86622f6..32cf7a10b6 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -360,7 +360,7 @@ def forward_backward_batch( @abc.abstractmethod def train_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> dict[str, float]: @@ -372,9 +372,11 @@ def train_batch( Parameters ---------- - input_ : dict[str, Any] - The input data for model forward pass and the loss function. - Redundant entries are allowed. + input_ : list[dict[str, Any]] | dict[str, Any] + Input data for model forward pass and loss computation. + Preferred format is ``list[dict[str, Any]]`` (trajectory list). + Backward compatibility: a pre-batched ``dict[str, Any]`` is + also accepted. loss_fn : Callable[..., torch.Tensor] The loss function. For actor (is_critic=False), it receives (logprobs, entropy, input_data). For critic (is_critic=True), @@ -397,7 +399,7 @@ def train_batch( @abc.abstractmethod def eval_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: @@ -409,9 +411,11 @@ def eval_batch( Parameters ---------- - input_ : dict[str, Any] - The input data for model forward pass and the loss function. - Redundant entries are allowed. + input_ : list[dict[str, Any]] | dict[str, Any] + Input data for model forward pass and loss computation. + Preferred format is ``list[dict[str, Any]]`` (trajectory list). + Backward compatibility: a pre-batched ``dict[str, Any]`` is + also accepted. loss_fn : Callable[..., torch.Tensor] The loss function. For actor (is_critic=False), it receives (logprobs, entropy, input_data). For critic (is_critic=True), @@ -434,10 +438,10 @@ def eval_batch( @abc.abstractmethod def forward_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], output_seqlens: list[int] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> torch.Tensor: + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: """Run the forward pass or inference on the model. Note @@ -446,29 +450,34 @@ def forward_batch( Parameters ---------- - input_ : dict[str, Any] - The input data for model forward pass. Redundant entries are allowed. + input_ : list[dict[str, Any]] | dict[str, Any] + Input data for model forward pass. Redundant entries are allowed. + ``list[dict[str, Any]]`` and pre-batched ``dict[str, Any]`` + are both supported. output_seqlens : list[int], optional The desired output sequence lengths. If None, assumes that the output has the same lengths as inputs, by default None. - aggregate_fn : Callable[[list[Any]], Any], optional + aggregate_fn : Callable[[list[torch.Tensor]], torch.Tensor], optional A function to aggregate micro-batched outputs, by default torch.cat. + It should preserve batch dimension 0. Returns ------- - Any - For actor (is_critic=False): logprobs tensor aggregated by `aggregate_fn`. - For critic (is_critic=True): values tensor aggregated by `aggregate_fn`. + torch.Tensor | list[torch.Tensor] + Batched tensor output for dict input. + Per-trajectory tensor list for list input. + For actor (is_critic=False), return logprobs tensors. + For critic (is_critic=True), return value tensors. """ raise NotImplementedError() @torch.no_grad() def forward( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], output_seqlens: list[int] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> torch.Tensor: + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: return self.forward_batch(input_, output_seqlens, aggregate_fn) @abc.abstractmethod diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 15c7e66308..d78c644e2f 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -112,8 +112,10 @@ MicroBatchItem, MicroBatchList, amend_position_ids, + concat_batch, pack_tensor_dict, pad_mb_list, + split_batch, split_padded_tensor_dict_into_mb_list, unsqueeze_mb_list, ) @@ -615,15 +617,17 @@ def forward_backward_batch( def train_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> dict[str, float]: self._ensure_ready() self.optimizer_zero_grad() + input_batched, _ = self._normalize_batch_input(input_) + # Step 1: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 2: Compute total loss weight total_loss_weight = compute_total_loss_weight( @@ -652,14 +656,16 @@ def process_output( @torch.no_grad() def eval_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: self._ensure_ready() + input_batched, _ = self._normalize_batch_input(input_) + # Step 1: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 2: Compute total loss weight total_loss_weight = compute_total_loss_weight( @@ -691,21 +697,33 @@ def process_output( @torch.no_grad() def forward_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], output_seqlens: list[int] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> torch.Tensor: + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: self._ensure_ready() + input_batched, meta = self._normalize_batch_input(input_) + # Step 1: Prepare sequence lengths - cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + if meta is not None: + assert isinstance(input_, list) + inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_] + if output_seqlens is not None and output_seqlens != inferred_seqlens: + raise ValueError( + f"output_seqlens mismatch for list input: " + f"given {output_seqlens}, " + f"inferred {inferred_seqlens} from attention_mask shapes." + ) + output_seqlens = inferred_seqlens + cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] if output_seqlens is None: output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() assert output_seqlens is not None batch_size = len(output_seqlens) # Step 2: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 3: Forward using process_output_fn callback, collecting results outputs: list[torch.Tensor] = [] @@ -720,8 +738,15 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None: # Step 4: Aggregate and reorder outputs if self.enable_tree_training: - return merge_packed_tree_results(outputs, batch_size) - return reorder_and_pad_outputs(outputs, output_seqlens, mb_list, aggregate_fn) + result = merge_packed_tree_results(outputs, batch_size) + else: + result = reorder_and_pad_outputs( + outputs, output_seqlens, mb_list, aggregate_fn + ) + + if meta is None: + return result + return split_batch(result, meta) def export_stats(self) -> dict[str, float]: return stats_tracker.export_all(reduce_group=self.data_parallel_group) @@ -992,6 +1017,20 @@ def _ensure_ready(self) -> None: if self.parallel_helper.sp_size > 1: set_ulysses_sequence_parallel_group(self.sp_group) + @staticmethod + def _normalize_batch_input( + input_: list[dict[str, Any]] | dict[str, Any], + ) -> tuple[dict[str, Any], Any | None]: + """Normalize list/dict batch input to a single batched dict. + + Returns ``(batched_input, meta)`` where ``meta`` is non-None only when + input is list-based and can be used to split forward outputs back into + per-trajectory results. + """ + if isinstance(input_, list): + return concat_batch(input_) + return input_, None + def _get_model_name_parameters( self, meta: WeightUpdateMeta ) -> Iterator[tuple[str, nn.Parameter]]: diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index eb486d8530..9eb88cb0d6 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -93,8 +93,10 @@ MicroBatchList, amend_position_ids, broadcast_tensor, + concat_batch, pack_tensor_dict, pad_mb_list, + split_batch, split_padded_tensor_dict_into_mb_list, unpad_logits, ) @@ -729,15 +731,17 @@ def _process_output(input_, output_): def train_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> dict[str, float]: self._ensure_ready() self.optimizer_zero_grad() + input_batched, _ = self._normalize_batch_input(input_) + # Step 1: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 2: Compute total loss weight total_loss_weight = compute_total_loss_weight( @@ -769,14 +773,16 @@ def process_output( @torch.no_grad() def eval_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: self._ensure_ready() + input_batched, _ = self._normalize_batch_input(input_) + # Step 1: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 2: Compute total loss weight total_loss_weight = compute_total_loss_weight( @@ -805,21 +811,33 @@ def process_output( @torch.no_grad() def forward_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], output_seqlens: list[int] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> torch.Tensor: + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: self._ensure_ready() + input_batched, meta = self._normalize_batch_input(input_) + # Step 1: Prepare sequence lengths - cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + if meta is not None: + assert isinstance(input_, list) + inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_] + if output_seqlens is not None and output_seqlens != inferred_seqlens: + raise ValueError( + f"output_seqlens mismatch for list input: " + f"given {output_seqlens}, " + f"inferred {inferred_seqlens} from attention_mask shapes." + ) + output_seqlens = inferred_seqlens + cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] if output_seqlens is None: output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() assert output_seqlens is not None batch_size = len(output_seqlens) # Step 2: Prepare micro-batches - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) # Step 3: Forward using Megatron's pipeline function, collecting results outputs: list[torch.Tensor] = [] @@ -845,7 +863,9 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: src_rank=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group(), ) - return res + if meta is None: + return res + return split_batch(res, meta) def export_stats(self) -> dict[str, float]: data = stats_tracker.export_all(reduce_group=self.data_parallel_group) @@ -1114,6 +1134,14 @@ def _check_rollout_engine_connected(self) -> None: " before using rollout/update_weight methods." ) + @staticmethod + def _normalize_batch_input( + input_: list[dict[str, Any]] | dict[str, Any], + ) -> tuple[dict[str, Any], Any | None]: + if isinstance(input_, list): + return concat_batch(input_) + return input_, None + def _ensure_ready(self) -> None: if self.is_offload: self.onload() diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 5a02eca139..a246b607be 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -91,8 +91,10 @@ MicroBatchList, amend_position_ids, broadcast_tensor, + concat_batch, pack_tensor_dict, pad_mb_list, + split_batch, split_padded_tensor_dict_into_mb_list, unsqueeze_mb_list, ) @@ -489,7 +491,7 @@ def forward_backward_batch( def train_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> dict[str, float]: @@ -497,7 +499,9 @@ def train_batch( assert self._initialized self.optimizer_zero_grad() - mb_list = self._prepare_mb_list(input_).to(self.device) + input_batched, _ = self._normalize_batch_input(input_) + + mb_list = self._prepare_mb_list(input_batched).to(self.device) total_loss_weight = compute_total_loss_weight( mb_list, loss_weight_fn, self.data_parallel_group @@ -523,14 +527,16 @@ def process_output( @torch.no_grad() def eval_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: """Evaluate on a batch of data.""" assert self._initialized - mb_list = self._prepare_mb_list(input_).to(self.device) + input_batched, _ = self._normalize_batch_input(input_) + + mb_list = self._prepare_mb_list(input_batched).to(self.device) total_loss_weight = compute_total_loss_weight( mb_list, loss_weight_fn, self.data_parallel_group @@ -563,20 +569,32 @@ def process_output( @torch.no_grad() def forward_batch( self, - input_: dict[str, Any], + input_: list[dict[str, Any]] | dict[str, Any], output_seqlens: list[int] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> torch.Tensor: + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: """Forward pass without gradient computation.""" assert self._initialized - cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + input_batched, meta = self._normalize_batch_input(input_) + + if meta is not None: + assert isinstance(input_, list) + inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_] + if output_seqlens is not None and output_seqlens != inferred_seqlens: + raise ValueError( + f"output_seqlens mismatch for list input: " + f"given {output_seqlens}, " + f"inferred {inferred_seqlens} from attention_mask shapes." + ) + output_seqlens = inferred_seqlens + cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] if output_seqlens is None: output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() assert output_seqlens is not None batch_size = len(output_seqlens) - mb_list = self._prepare_mb_list(input_).to(self.device) + mb_list = self._prepare_mb_list(input_batched).to(self.device) def process_output( logits: torch.Tensor, ctx_dict: dict[str, Any] @@ -606,7 +624,9 @@ def process_output( group=self.parallel_dims.get_group("pp"), ) assert res is not None - return res + if meta is None: + return res + return split_batch(res, meta) def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): """Connect to an inference engine for rollout.""" @@ -1067,6 +1087,14 @@ def _create_optimizer(self, ft_spec: FinetuneSpec): self.logger.info(f"Created optimizer in {time.perf_counter() - tik:.2f}s") + @staticmethod + def _normalize_batch_input( + input_: list[dict[str, Any]] | dict[str, Any], + ) -> tuple[dict[str, Any], Any | None]: + if isinstance(input_, list): + return concat_batch(input_) + return input_, None + def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: assert "attention_mask" in input_ and "input_ids" in input_ input_ = input_.copy()