Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions .github/workflows/install-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ on:
branches: [main]
paths:
- 'pyproject.toml'
- 'pyproject.vllm.toml'
- 'uv.lock'
- 'uv.vllm.lock'
- 'areal/**'
- '.github/workflows/install-test.yml'
push:
branches: [main]
paths:
- 'pyproject.toml'
- 'pyproject.vllm.toml'
- 'uv.lock'
- 'uv.vllm.lock'
- 'areal/**'
- '.github/workflows/install-test.yml'
workflow_dispatch:
Expand All @@ -39,7 +43,9 @@ jobs:
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
cache-dependency-glob: 'uv.lock'
cache-dependency-glob: |
uv.lock
uv.vllm.lock

- name: Set up Python
run: uv python install ${{ matrix.python-version }}
Expand Down Expand Up @@ -93,15 +99,22 @@ jobs:
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
cache-dependency-glob: 'uv.lock'
cache-dependency-glob: |
uv.lock
uv.vllm.lock

- name: Set up Python
run: uv python install 3.12

- name: Install package with CUDA extras (excluding flash-attn)
# flash-attn requires CUDA toolkit for compilation, skip it in CI
# Test individual extras that have pre-built wheels
run: uv sync --extra ${{ matrix.variant }} --extra megatron --extra tms
run: |
if [ "${{ matrix.variant }}" = "vllm" ]; then
cp pyproject.vllm.toml pyproject.toml
cp uv.vllm.lock uv.lock
fi
uv sync --extra ${{ matrix.variant }} --extra megatron --extra tms

- name: Verify package import with CUDA extras
run: |
Expand Down
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ repos:
name: nbstripout - Strip notebook output
description: Strip output from Jupyter notebooks

# Check consistency between pyproject.toml variants (sglang vs vllm)
- repo: local
hooks:
- id: check-pyproject-consistency
name: Check pyproject.toml consistency
entry: python3 areal/tools/check_pyproject_consistency.py
language: system
files: ^pyproject(\.vllm)?\.toml$
pass_filenames: false
always_run: false
require_serial: true

# Generate CLI documentation
- repo: local
hooks:
Expand Down
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

```bash
# Environment
uv sync --extra cuda # CUDA + SGLang inference (default); for vLLM: --extra cuda-vllm
uv sync --extra cuda # CUDA + SGLang inference (default); for vLLM: cp pyproject.vllm.toml pyproject.toml && cp uv.vllm.lock uv.lock && uv sync --extra cuda
source .venv/bin/activate # activate venv BEFORE pre-commit or git commit if venv exists
pre-commit install --install-hooks # hooks: Ruff, clang-format, mdformat, nbstripout, conventional-commits
pre-commit run --all-files # lint + format everything
Expand Down
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ uv --version # Install: https://docs.astral.sh/uv/

# Sync dependencies
uv sync --extra cuda # CUDA + SGLang inference (default)
# For vLLM: cp pyproject.vllm.toml pyproject.toml && cp uv.vllm.lock uv.lock && uv sync --extra cuda
uv sync --group dev # Include dev/test packages
uv run python3 areal/tools/validate_installation.py # Validate installation

Expand Down
10 changes: 8 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,16 @@ RUN uv pip install nanobot-ai

# Install the project's dependencies (not the project itself)
# This adds packages without removing unlisted ones (like our C++ packages)
# VARIANT selects the inference backend (sglang or vllm)
# VARIANT selects the inference backend (sglang or vllm) via separate pyproject files
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv pip install --no-build-isolation -r pyproject.toml --extra cuda-train --extra ${VARIANT} --group dev
--mount=type=bind,source=pyproject.vllm.toml,target=pyproject.vllm.toml \
case "$VARIANT" in \
sglang) cp pyproject.toml /tmp/pyproject.toml ;; \
vllm) cp pyproject.vllm.toml /tmp/pyproject.toml ;; \
*) echo "Invalid VARIANT=$VARIANT (expected: sglang|vllm)" >&2; exit 1 ;; \
esac \
&& uv pip install --no-build-isolation -r /tmp/pyproject.toml --extra cuda --group dev

##############################################################
# STAGE 4: Misc fixes and final setup
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pip install uv
# (pick the wheel matching your Python version; see https://github.com/mjun0812/flash-attention-prebuild-wheels/releases)
uv pip install "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp312-cp312-linux_x86_64.whl"
uv sync --extra cuda # installs training packages + SGLang (default inference backend)
# For vLLM instead: cp pyproject.vllm.toml pyproject.toml && cp uv.vllm.lock uv.lock && uv sync --extra cuda
```

Our training scripts automatically download the required dataset (openai/gsm8k) and
Expand Down Expand Up @@ -277,8 +278,8 @@ pip install uv
uv pip install "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp312-cp312-linux_x86_64.whl"
# Use `--extra cuda` on Linux with CUDA (installs training packages + SGLang)
uv sync --extra cuda --group dev
# For vLLM instead (note: use torch2.10 flash-attn wheel):
# uv sync --extra cuda-vllm --group dev
# For vLLM instead:
# cp pyproject.vllm.toml pyproject.toml && cp uv.vllm.lock uv.lock && uv sync --extra cuda --group dev
# Or without CUDA support
# uv sync --group dev

Expand Down
4 changes: 3 additions & 1 deletion areal/engine/core/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def compute_total_loss_weight(
.clone()
.to(dtype=torch.float32)
)
assert total_weight != 0, "Total loss weight must be non-zero"
dist.all_reduce(total_weight, group=dp_group)
assert total_weight > 0, (
"Global total loss weight must be positive after all_reduce"
)
return total_weight


Expand Down
8 changes: 6 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,10 @@ def _compute_logprobs_and_loss(
loss_multiplier: float = 1.0,
) -> torch.Tensor:
"""Compute logprobs/entropy and return scaled loss."""
local_weight = loss_weight_fn(ctx.mb_input)
if local_weight == 0:
return logits.mean() * 0.0

if self.config.is_critic and self.enable_tree_training:
raise NotImplementedError(
"Tree training with critic model is not supported yet."
Expand All @@ -1669,7 +1673,7 @@ def _compute_logprobs_and_loss(
if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids:
# Return zero loss that maintains gradient connection to logits
# This ensures backward() works correctly for FSDP synchronization
return logits.sum() * 0.0
return logits.mean() * 0.0

# For tree training, use gather_packed_tree_vocab_stats to properly
# unpack vocab stats from tree structure back to per-sequence format.
Expand Down Expand Up @@ -1712,7 +1716,7 @@ def _compute_logprobs_and_loss(
values = values[: -ctx.pad_length]
loss = loss_fn(values, ctx.mb_input)

loss_scale = loss_weight_fn(ctx.mb_input) / total_loss_weight * loss_multiplier
loss_scale = local_weight / total_loss_weight * loss_multiplier
return loss * loss_scale

def _compute_forward_result(
Expand Down
17 changes: 14 additions & 3 deletions areal/engine/fsdp_utils/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,22 @@ def get_grad_norm_fp32(
norm_type = float(norm_type)
total_norm = 0.0

if not grads_for_norm:
return 0.0

device = current_platform.current_device()

if not grads_for_norm:
# Still participate in all_reduce with zero contribution so that
# ranks with grads don't hang waiting for this rank (e.g. LoRA frozen ranks).
total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device)
reduce_op = dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM
if data_parallel_group:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group)
if model_parallel_group is not None:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group)
total_norm = float(total_norm_cuda.item())
if norm_type != torch.inf and total_norm > 0:
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

if norm_type == torch.inf:
norms = [grad.abs().max() for grad in grads_for_norm]
total_norm = torch.max(torch.stack(norms)) if norms else 0.0
Expand Down
8 changes: 6 additions & 2 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,10 @@ def _compute_logprobs_and_loss(
total_loss_weight: torch.Tensor,
loss_multiplier: float = 1.0,
) -> torch.Tensor:
local_weight = loss_weight_fn(inputs)
if local_weight == 0:
return output.mean() * 0.0

if self.config.is_critic and self.enable_tree_training:
raise NotImplementedError(
"Tree training with critic model is not supported yet."
Expand All @@ -1658,7 +1662,7 @@ def _compute_logprobs_and_loss(
if trie_node is None or not trie_node.all_sequence_ids:
# Return zero loss that maintains gradient connection to output
# This ensures backward() works correctly for distributed synchronization
return output.sum() * 0.0
return output.mean() * 0.0

# For tree training, use gather_packed_tree_vocab_stats to properly
# unpack vocab stats from tree structure back to per-sequence format.
Expand Down Expand Up @@ -1699,7 +1703,7 @@ def _compute_logprobs_and_loss(
values = output.squeeze(-1)
loss = loss_fn(values, inputs)

loss_scale = loss_weight_fn(inputs) / total_loss_weight * loss_multiplier
loss_scale = local_weight / total_loss_weight * loss_multiplier
return loss * loss_scale

def _compute_forward_result(
Expand Down
5 changes: 5 additions & 0 deletions areal/engine/vllm_ext/areal_vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def areal_update_weight(request: UpdateWeightsRequest, raw_request: Reques
logger.info(f"API server starts areal_update_weight, {request.model_path}")
llm = raw_request.app.state.engine_client
await llm.pause_generation(wait_for_inflight_requests=False, clear_cache=True)
await llm.reset_mm_cache()
try:
ret_list = await llm.collective_rpc(
"areal_update_weights",
Expand All @@ -183,6 +184,7 @@ async def areal_update_weight_lora(
)
llm = raw_request.app.state.engine_client
await llm.pause_generation(wait_for_inflight_requests=False, clear_cache=True)
await llm.reset_mm_cache()

try:
ret_list = await llm.collective_rpc(
Expand All @@ -205,6 +207,7 @@ async def areal_update_weight_xccl(raw_request: Request):
logger.info("API server starts areal_update_weight_xccl")
llm = raw_request.app.state.engine_client
await llm.pause_generation(wait_for_inflight_requests=False, clear_cache=True)
await llm.reset_mm_cache()
try:
ret_list = await llm.collective_rpc("areal_update_weight_xccl")
finally:
Expand All @@ -219,6 +222,7 @@ async def areal_update_weight_lora_xccl(
logger.info("API server starts areal_update_weight_lora_xccl")
llm = raw_request.app.state.engine_client
await llm.pause_generation(wait_for_inflight_requests=False, clear_cache=True)
await llm.reset_mm_cache()

try:
ret_list = await llm.collective_rpc("areal_update_weight_lora_xccl")
Expand Down Expand Up @@ -310,6 +314,7 @@ async def areal_pause_generation(raw_request: Request):
wait_for_inflight_requests=False,
clear_cache=True,
)
await llm.reset_mm_cache()

return to_json_response(True, "Generation paused and all requests aborted")

Expand Down
51 changes: 46 additions & 5 deletions areal/engine/vllm_ext/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def areal_update_weight_lora_xccl(self):
f"LoRA adapter {lora_int_id} not found. Available: {adapter_ids}"
)

# Get the LoRA model
# Get the currently registered LoRA model (used for diagnostics).
lora_model = (
self.model_runner.lora_manager._adapter_manager._registered_adapters[
lora_int_id
Expand Down Expand Up @@ -211,12 +211,33 @@ def areal_update_weight_lora_xccl(self):

logger.info(f"Received {len(received_weights)} LoRA parameters via XCCL")

self.model_runner.lora_manager.remove_adapter(lora_int_id)

normalized_weights = {
k.replace("default.", ""): v for k, v in received_weights.items()
}

lora_partial_shard_key = (self.areal_lora_name, lora_int_id)

group_shards = self._lora_partial_shards.setdefault(
lora_partial_shard_key, {}
)
group_shards[self.areal_weight_meta_group_name] = normalized_weights
buffered_count = len(group_shards)

# Assumes that every registered weight update group contributes to the update cycle
if buffered_count < len(self.weight_update_groups):
logger.info(
"Buffered LoRA shard for "
f"{self.areal_lora_name}: group={self.areal_weight_meta_group_name}, "
f"buffered={buffered_count}/{len(self.weight_update_groups)} PP stages."
)
self.sync()
return True, "Success"

merged_weights: dict[str, torch.Tensor] = {}
for shard in group_shards.values():
merged_weights.update(shard)
self._lora_partial_shards.pop(lora_partial_shard_key, None)

peft_config = {
"r": self.areal_lora_rank,
"lora_alpha": self.areal_lora_alpha,
Expand All @@ -234,7 +255,7 @@ def areal_update_weight_lora_xccl(self):

new_lora_model = LoRAModel.from_lora_tensors(
lora_model_id=self.areal_lora_int_id,
tensors=normalized_weights,
tensors=merged_weights,
peft_helper=peft_helper,
device=self.model_runner.device,
dtype=self.model_runner.lora_manager.lora_config.lora_dtype,
Expand All @@ -244,13 +265,21 @@ def areal_update_weight_lora_xccl(self):
),
)

self.model_runner.lora_manager.remove_adapter(lora_int_id)

self.model_runner.lora_manager._adapter_manager._add_adapter(new_lora_model)
self.model_runner.lora_manager._adapter_manager.activate_adapter(
new_lora_model.id
)
logger.info(
f"Found LoRA model with {len(new_lora_model.loras)} LoRA modules"
f"Updated New LoRA model with {len(new_lora_model.loras)} LoRA modules "
f"from {len(merged_weights)} tensors across {len(self.weight_update_groups)} groups"
)
if len(new_lora_model.loras) != len(lora_model.loras):
logger.warning(
f"Number of modules in the new LoRA model ({len(new_lora_model.loras)}) "
f"does not match the old LoRA model ({len(lora_model.loras)})."
)

self.sync()
return True, "Success"
Expand All @@ -271,6 +300,18 @@ def areal_init_update_weight_group(
):
if not hasattr(self, "weight_update_groups"):
self.weight_update_groups: dict[str, dist.ProcessGroup] = {}

# This is required for buffering weights during lora weight update, as vLLM
# expects the partial PP shards to be buffered until all groups have sent their shards.
_is_vllm_lora_enabled = (
getattr(self.model_runner, "lora_manager", None) is not None
)
if _is_vllm_lora_enabled and not hasattr(self, "_lora_partial_shards"):
# (lora_name, lora_int_id) -> group_name -> normalized weight dict
self._lora_partial_shards: dict[
tuple[str, int], dict[str, dict[str, torch.Tensor]]
] = {}

try:
group = init_custom_process_group(
backend=backend,
Expand Down
8 changes: 6 additions & 2 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,10 +1154,14 @@ def _compute_logprobs_and_loss(
loss_multiplier: float = 1.0,
) -> torch.Tensor:
"""Compute logprobs/entropy and return scaled loss."""
local_weight = loss_weight_fn(ctx.mb_input)
if local_weight == 0:
return logits.mean() * 0.0

if not self.config.is_critic:
result = self._gather_actor_train_outputs(logits, ctx)
if result is None:
return logits.sum() * 0.0
return logits.mean() * 0.0
logprobs, entropy, vocab_min_logits, vocab_max_logits = result
loss = loss_fn(
logprobs,
Expand All @@ -1170,7 +1174,7 @@ def _compute_logprobs_and_loss(
values = self._gather_critic_output(logits, ctx)
loss = loss_fn(values, ctx.mb_input)

loss_scale = loss_weight_fn(ctx.mb_input) / total_loss_weight * loss_multiplier
loss_scale = local_weight / total_loss_weight * loss_multiplier
return loss * loss_scale

def _compute_forward_result(
Expand Down
Loading
Loading