Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def forward_backward_batch(
@abc.abstractmethod
def train_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> dict[str, float]:
Expand All @@ -372,9 +372,11 @@ def train_batch(

Parameters
----------
input_ : dict[str, Any]
The input data for model forward pass and the loss function.
Redundant entries are allowed.
input_ : list[dict[str, Any]] | dict[str, Any]
Input data for model forward pass and loss computation.
Preferred format is ``list[dict[str, Any]]`` (trajectory list).
Backward compatibility: a pre-batched ``dict[str, Any]`` is
also accepted.
loss_fn : Callable[..., torch.Tensor]
The loss function. For actor (is_critic=False), it receives
(logprobs, entropy, input_data). For critic (is_critic=True),
Expand All @@ -397,7 +399,7 @@ def train_batch(
@abc.abstractmethod
def eval_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
Expand All @@ -409,9 +411,11 @@ def eval_batch(

Parameters
----------
input_ : dict[str, Any]
The input data for model forward pass and the loss function.
Redundant entries are allowed.
input_ : list[dict[str, Any]] | dict[str, Any]
Input data for model forward pass and loss computation.
Preferred format is ``list[dict[str, Any]]`` (trajectory list).
Backward compatibility: a pre-batched ``dict[str, Any]`` is
also accepted.
loss_fn : Callable[..., torch.Tensor]
The loss function. For actor (is_critic=False), it receives
(logprobs, entropy, input_data). For critic (is_critic=True),
Expand All @@ -434,10 +438,10 @@ def eval_batch(
@abc.abstractmethod
def forward_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> torch.Tensor:
aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat,
) -> torch.Tensor | list[torch.Tensor]:
"""Run the forward pass or inference on the model.

Note
Expand All @@ -446,29 +450,34 @@ def forward_batch(

Parameters
----------
input_ : dict[str, Any]
The input data for model forward pass. Redundant entries are allowed.
input_ : list[dict[str, Any]] | dict[str, Any]
Input data for model forward pass. Redundant entries are allowed.
``list[dict[str, Any]]`` and pre-batched ``dict[str, Any]``
are both supported.
output_seqlens : list[int], optional
The desired output sequence lengths. If None, assumes that the output
has the same lengths as inputs, by default None.
aggregate_fn : Callable[[list[Any]], Any], optional
aggregate_fn : Callable[[list[torch.Tensor]], torch.Tensor], optional
A function to aggregate micro-batched outputs, by default torch.cat.
It should preserve batch dimension 0.

Returns
-------
Any
For actor (is_critic=False): logprobs tensor aggregated by `aggregate_fn`.
For critic (is_critic=True): values tensor aggregated by `aggregate_fn`.
torch.Tensor | list[torch.Tensor]
Batched tensor output for dict input.
Per-trajectory tensor list for list input.
For actor (is_critic=False), return logprobs tensors.
For critic (is_critic=True), return value tensors.
"""
raise NotImplementedError()

@torch.no_grad()
def forward(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> torch.Tensor:
aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat,
) -> torch.Tensor | list[torch.Tensor]:
return self.forward_batch(input_, output_seqlens, aggregate_fn)

@abc.abstractmethod
Expand Down
61 changes: 50 additions & 11 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@
MicroBatchItem,
MicroBatchList,
amend_position_ids,
concat_batch,
pack_tensor_dict,
pad_mb_list,
split_batch,
split_padded_tensor_dict_into_mb_list,
unsqueeze_mb_list,
)
Expand Down Expand Up @@ -615,15 +617,17 @@ def forward_backward_batch(

def train_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> dict[str, float]:
self._ensure_ready()
self.optimizer_zero_grad()

input_batched, _ = self._normalize_batch_input(input_)

# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
total_loss_weight = compute_total_loss_weight(
Expand Down Expand Up @@ -652,14 +656,16 @@ def process_output(
@torch.no_grad()
def eval_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
self._ensure_ready()

input_batched, _ = self._normalize_batch_input(input_)

# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
total_loss_weight = compute_total_loss_weight(
Expand Down Expand Up @@ -691,21 +697,33 @@ def process_output(
@torch.no_grad()
def forward_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> torch.Tensor:
aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat,
) -> torch.Tensor | list[torch.Tensor]:
self._ensure_ready()

input_batched, meta = self._normalize_batch_input(input_)

# Step 1: Prepare sequence lengths
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
if meta is not None:
assert isinstance(input_, list)
inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_]
if output_seqlens is not None and output_seqlens != inferred_seqlens:
raise ValueError(
f"output_seqlens mismatch for list input: "
f"given {output_seqlens}, "
f"inferred {inferred_seqlens} from attention_mask shapes."
)
output_seqlens = inferred_seqlens
cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"]
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
assert output_seqlens is not None
batch_size = len(output_seqlens)

# Step 2: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 3: Forward using process_output_fn callback, collecting results
outputs: list[torch.Tensor] = []
Expand All @@ -720,8 +738,15 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None:

# Step 4: Aggregate and reorder outputs
if self.enable_tree_training:
return merge_packed_tree_results(outputs, batch_size)
return reorder_and_pad_outputs(outputs, output_seqlens, mb_list, aggregate_fn)
result = merge_packed_tree_results(outputs, batch_size)
else:
result = reorder_and_pad_outputs(
outputs, output_seqlens, mb_list, aggregate_fn
)

if meta is None:
return result
return split_batch(result, meta)

def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=self.data_parallel_group)
Expand Down Expand Up @@ -992,6 +1017,20 @@ def _ensure_ready(self) -> None:
if self.parallel_helper.sp_size > 1:
set_ulysses_sequence_parallel_group(self.sp_group)

@staticmethod
def _normalize_batch_input(
input_: list[dict[str, Any]] | dict[str, Any],
) -> tuple[dict[str, Any], Any | None]:
"""Normalize list/dict batch input to a single batched dict.

Returns ``(batched_input, meta)`` where ``meta`` is non-None only when
input is list-based and can be used to split forward outputs back into
per-trajectory results.
"""
if isinstance(input_, list):
return concat_batch(input_)
return input_, None

def _get_model_name_parameters(
self, meta: WeightUpdateMeta
) -> Iterator[tuple[str, nn.Parameter]]:
Expand Down
48 changes: 38 additions & 10 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@
MicroBatchList,
amend_position_ids,
broadcast_tensor,
concat_batch,
pack_tensor_dict,
pad_mb_list,
split_batch,
split_padded_tensor_dict_into_mb_list,
unpad_logits,
)
Expand Down Expand Up @@ -729,15 +731,17 @@ def _process_output(input_, output_):

def train_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> dict[str, float]:
self._ensure_ready()
self.optimizer_zero_grad()

input_batched, _ = self._normalize_batch_input(input_)

# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
total_loss_weight = compute_total_loss_weight(
Expand Down Expand Up @@ -769,14 +773,16 @@ def process_output(
@torch.no_grad()
def eval_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
self._ensure_ready()

input_batched, _ = self._normalize_batch_input(input_)

# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
total_loss_weight = compute_total_loss_weight(
Expand Down Expand Up @@ -805,21 +811,33 @@ def process_output(
@torch.no_grad()
def forward_batch(
self,
input_: dict[str, Any],
input_: list[dict[str, Any]] | dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> torch.Tensor:
aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat,
) -> torch.Tensor | list[torch.Tensor]:
self._ensure_ready()

input_batched, meta = self._normalize_batch_input(input_)

# Step 1: Prepare sequence lengths
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
if meta is not None:
assert isinstance(input_, list)
inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_]
if output_seqlens is not None and output_seqlens != inferred_seqlens:
raise ValueError(
f"output_seqlens mismatch for list input: "
f"given {output_seqlens}, "
f"inferred {inferred_seqlens} from attention_mask shapes."
)
output_seqlens = inferred_seqlens
cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"]
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
assert output_seqlens is not None
batch_size = len(output_seqlens)

# Step 2: Prepare micro-batches
mb_list = self._prepare_mb_list(input_).to(self.device)
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 3: Forward using Megatron's pipeline function, collecting results
outputs: list[torch.Tensor] = []
Expand All @@ -845,7 +863,9 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None:
src_rank=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
)
return res
if meta is None:
return res
return split_batch(res, meta)

def export_stats(self) -> dict[str, float]:
data = stats_tracker.export_all(reduce_group=self.data_parallel_group)
Expand Down Expand Up @@ -1114,6 +1134,14 @@ def _check_rollout_engine_connected(self) -> None:
" before using rollout/update_weight methods."
)

@staticmethod
def _normalize_batch_input(
input_: list[dict[str, Any]] | dict[str, Any],
) -> tuple[dict[str, Any], Any | None]:
if isinstance(input_, list):
return concat_batch(input_)
return input_, None

def _ensure_ready(self) -> None:
if self.is_offload:
self.onload()
Expand Down
Loading
Loading