Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions examples/lalm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# LALM

LALM is a multimodal language model that combines a pretrained audio encoder with a causal LLM via a learned projector. The model follows the standard HuggingFace `from_pretrained` / `save_pretrained` interface throughout.

Supported audio encoders: Whisper, Qwen2.5-Omni audio encoder, Qwen3-Omni audio encoder.

---

## Overview

```
audio waveform
Audio Encoder (frozen or fine-tuned)
│ packed features (C, T_total)
LALMProjector (learned)
│ audio embeddings inserted into LLM token sequence
Causal LLM (frozen or fine-tuned)
text output
```

Training follows a two-stage recipe:
- **Stage 1** — freeze encoder + LLM, train projector only (`frozen_modules: [audio_tower, language_model]`)
- **Stage 2** — unfreeze all, full fine-tuning (`frozen_modules: []`)

---

## Step 1 — Build Model

Assemble a LALM checkpoint from a pretrained LLM and audio encoder. This is a one-time step that produces a self-contained HF checkpoint.

```bash
bash scripts/build_model.sh
```

Key options in `scripts/build_model.sh`:

| Variable | Description |
|---|---|
| `llm` | Path to pretrained LLM (e.g. Qwen2.5-7B-Instruct) |
| `encoder` | Path to audio encoder (e.g. Qwen3-Omni, Qwen2.5-Omni, Whisper) |
| `projector_downsample_rate` | Frame concat rate in projector (higher = fewer LLM audio tokens) |
| `output_dir` | Where to save the assembled checkpoint |

The assembled checkpoint can be loaded like any HF model:

```python
from lalm_core.model import LALMForConditionalGeneration, LALMProcessor

model = LALMForConditionalGeneration.from_pretrained(output_dir)
processor = LALMProcessor.from_pretrained(output_dir)
```

---

## Step 2 — Prepare Manifest

Each training/evaluation sample needs a `conversation` field attached to its Lhotse cut. Run `prepare_conversation.py` once per dataset split.

```bash
bash scripts/prepare_manifest.sh
```

Or run directly:

```bash
python prepare_conversation.py \
--input_manifest /path/to/cuts.jsonl.gz \
--output_manifest data/train/cuts_conversation.jsonl.gz \
--tokenizer /path/to/llm \
--instruction "Please transcribe speech."
```

Key options:

| Argument | Description |
|---|---|
| `--input_manifest` | Input Lhotse CutSet manifest |
| `--output_manifest` | Output manifest with `conversation` field added |
| `--tokenizer` | Tokenizer path (used to estimate token counts for batching) |
| `--instruction` | Optional user instruction appended to each sample |
| `--system` | Optional system prompt |

The prepared manifest stores two fields per cut:
- `cut.conversation` — structured message list (OpenAI chat format)
- `cut.rendered_conversation` — rendered chat string used directly during training

Register your prepared manifests in `configs/train_data_config.yaml` and `configs/valid_data_config.yaml`:

```yaml
- name: aishell1
manifest: data/train/aishell1_cuts_conversation.jsonl.gz
```

---

## Step 3 — Train

```bash
bash scripts/train.sh
```

Multi-node example:

```bash
NNODES=2 NODE_RANK=0 MASTER_ADDR=<host0> bash scripts/train.sh
```

Key options in `scripts/train.sh`:

| Variable | Description |
|---|---|
| `model_dir` | Path to checkpoint from Step 1 |
| `frozen_modules` | Comma-separated modules to freeze (e.g. `audio_tower,language_model`) |
| `mixed_precision` | `bf16` or `fp16` |
| `exp_name` | Experiment name; checkpoints saved to `exp/<exp_name>/` |

Key training config options (override via command line or `configs/train.yaml`):

| Config key | Description |
|---|---|
| `trainer.optimizer.lr` | Learning rate |
| `trainer.num_steps` | Total training steps |
| `trainer.grad_accum_steps` | Gradient accumulation steps |
| `trainer.valid_interval` | Validate every N steps |
| `trainer.save_every_n` | Save checkpoint every N validation intervals |
| `data.sampler.max_tokens` | Max LLM tokens per batch (controls batch size) |
| `data.feature` | Feature type: `whisper_v3_fbank` (128-dim) or `whisper_fbank` (80-dim) |

Checkpoints are saved to `exp/<exp_name>/checkpoint-{step}.pt`. The HF config and processor are saved once at the start of training to `exp/<exp_name>/hf/`.

To resume training, set `trainer.start_batch=<step>` to the checkpoint step you want to resume from.

---

## Step 4 — Evaluate

```bash
bash scripts/evaluate.sh
```

Or run directly:

```bash
python evaluate.py \
exp_dir=./exp/my_experiment \
checkpoint.iter=16000 \
data.test_data_config=configs/valid_data_config.yaml \
decoding_method=greedy_search
```

Key options:

| Config key | Description |
|---|---|
| `exp_dir` | Experiment directory containing checkpoints |
| `checkpoint.iter` | Load `checkpoint-{iter}.pt` |
| `checkpoint.epoch` | Load `epoch-{epoch}.pt` |
| `checkpoint.model_dir` | Load directly from a HF model directory (skip export) |
| `decoding_method` | `greedy_search` or `beam_search` |
| `num_beams` | Beam size (used when `decoding_method=beam_search`) |
| `dtype` | Inference dtype: `fp16`, `bf16`, or `fp32` |

On first run, the trainer checkpoint is exported to a HF checkpoint at `exp/<exp_dir>/export/iter-{iter}/` and reused on subsequent runs.

Results are written to `exp/<exp_dir>/<decoding_method>/`.
212 changes: 212 additions & 0 deletions examples/lalm/build_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""
One-time script: assemble a LALM checkpoint from pretrained components.

After running this script, load the model with standard HuggingFace APIs::

model = AutoModelForCausalLM.from_pretrained(output_dir)
processor = AutoProcessor.from_pretrained(output_dir)

Usage::

python assemble_model.py \\
--llm Qwen/Qwen2-7B-Instruct \\
--encoder openai/whisper-large-v3 \\
--output_dir ./lalm_checkpoint
"""

import torch
from lalm_core.model import LALMConfig, LALMForConditionalGeneration, LALMProcessor
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoTokenizer,
)


def assemble_and_save(
llm_name: str,
encoder_name: str,
output_dir: str,
torch_dtype: torch.dtype = torch.bfloat16,
projector_downsample_rate: int = 4,
):
print(f"Loading LLM: {llm_name}")
llm = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch_dtype)

print(f"Loading encoder: {encoder_name}")
encoder, audio_dim = _load_encoder(encoder_name, torch_dtype)

# Build processor first to obtain the modality_token_id after
# the new <|audio|> token has been added to the vocabulary.
print("Building processor ...")
tokenizer = AutoTokenizer.from_pretrained(llm_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
processor = LALMProcessor(
feature_extractor,
tokenizer,
encoder_name=_get_encoder_model_type(encoder),
projector_downsample_rate=projector_downsample_rate,
)

# Resize LLM embeddings to cover the newly added audio token.
llm.resize_token_embeddings(len(processor.tokenizer))

text_dim = llm.config.hidden_size

config = LALMConfig(
text_config=llm.config,
audio_config=encoder.config,
audio_dim=audio_dim,
text_dim=text_dim,
projector_downsample_rate=projector_downsample_rate,
audio_token_id=processor.tokenizer.convert_tokens_to_ids(processor.audio_token),
)

print("Assembling model ...")
model = LALMForConditionalGeneration(
config,
language_model=llm,
audio_tower=encoder,
)
# Projector starts randomly initialised and is learned during training.

print(f"Saving to {output_dir} ...")
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print("Done.")

return model, processor


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _load_encoder(
model_name_or_path: str, torch_dtype: torch.dtype
) -> tuple[torch.nn.Module, int]:
"""
Load an audio encoder and return ``(encoder_module, audio_dim)``.

Do not rely on AutoModel here because some encoder checkpoints are not fully
registered for generic auto loading. Use model-family-specific APIs instead.
"""
hf_cfg = AutoConfig.from_pretrained(model_name_or_path)
model_type = getattr(hf_cfg, "model_type", "")

if model_type == "whisper":
from transformers.models.whisper.modeling_whisper import WhisperModel

try:
model = WhisperModel.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
encoder = model.encoder
except Exception:
# Some checkpoints may contain encoder-only weights.
from transformers.models.whisper.modeling_whisper import WhisperEncoder

encoder = WhisperEncoder.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
return encoder, int(encoder.config.d_model)

if model_type in {"qwen2_5_omni", "qwen2_5_omni_audio_encoder"}:
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
Qwen2_5OmniAudioEncoder,
Qwen2_5OmniThinkerForConditionalGeneration,
)

if model_type == "qwen2_5_omni":
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
encoder = model.audio_tower
else:
encoder = Qwen2_5OmniAudioEncoder.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
return encoder, int(encoder.config.output_dim)

if model_type in {"qwen3_omni_moe", "qwen3_omni_moe_audio_encoder"}:
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeAudioEncoder,
Qwen3OmniMoeThinkerForConditionalGeneration,
)

if model_type == "qwen3_omni_moe":
model = Qwen3OmniMoeThinkerForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
encoder = model.audio_tower
else:
encoder = Qwen3OmniMoeAudioEncoder.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype
)
return encoder, int(encoder.config.output_dim)

raise ValueError(
f"Unsupported encoder model_type={model_type!r}. "
"Supported values: 'whisper', 'qwen2_5_omni', "
"'qwen2_5_omni_audio_encoder', 'qwen3_omni_moe', "
"'qwen3_omni_moe_audio_encoder'."
)


def _get_encoder_model_type(encoder) -> str:
hf_model_type = getattr(encoder.config, "model_type", "")
supported = {
"whisper",
"qwen2_5_omni_audio_encoder",
"qwen3_omni_moe_audio_encoder",
}
if hf_model_type not in supported:
known = ", ".join(f'"{k}"' for k in sorted(supported))
raise ValueError(
f"Unsupported encoder model_type={hf_model_type!r}. "
f"Supported values: {known}."
)
return hf_model_type


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(
description="Assemble LALM from pretrained components."
)
parser.add_argument(
"--llm", required=True, help="HF model id or local path for the LLM"
)
parser.add_argument(
"--encoder",
required=True,
help="HF model id or local path for the audio encoder",
)
parser.add_argument(
"--output_dir", required=True, help="Where to save the assembled checkpoint"
)
parser.add_argument(
"--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"]
)
parser.add_argument("--projector_downsample_rate", default=4, type=int)
args = parser.parse_args()

dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
assemble_and_save(
llm_name=args.llm,
encoder_name=args.encoder,
output_dir=args.output_dir,
torch_dtype=dtype_map[args.dtype],
projector_downsample_rate=args.projector_downsample_rate,
)
Loading