From fd16685f204de3abe4d6753423bc56aadd5cfb7a Mon Sep 17 00:00:00 2001 From: Yiwen Shao Date: Wed, 25 Mar 2026 07:01:18 +0800 Subject: [PATCH] add hf lalm example --- examples/lalm/README.md | 171 +++++++++ examples/lalm/build_model.py | 212 +++++++++++ examples/lalm/configs/evaluate.yaml | 30 ++ examples/lalm/configs/train.yaml | 79 +++++ examples/lalm/configs/train_data_config.yaml | 8 + examples/lalm/configs/valid_data_config.yaml | 2 + examples/lalm/evaluate.py | 228 ++++++++++++ examples/lalm/lalm_core/__init__.py | 11 + examples/lalm/lalm_core/data_module.py | 223 ++++++++++++ examples/lalm/lalm_core/model/__init__.py | 5 + .../lalm_core/model/configuration_lalm.py | 107 ++++++ .../lalm/lalm_core/model/modeling_lalm.py | 332 ++++++++++++++++++ .../lalm/lalm_core/model/processing_lalm.py | 326 +++++++++++++++++ examples/lalm/lalm_core/trainer.py | 160 +++++++++ examples/lalm/prepare_conversation.py | 157 +++++++++ examples/lalm/results_utils.py | 275 +++++++++++++++ examples/lalm/scripts/build_model.sh | 36 ++ examples/lalm/scripts/evaluate.sh | 53 +++ examples/lalm/scripts/prepare_manifest.sh | 5 + examples/lalm/scripts/train.sh | 74 ++++ examples/lalm/train.py | 117 ++++++ src/auden/data/lhotse_datamodule.py | 15 +- src/auden/trainer/ddp_trainer.py | 53 ++- 23 files changed, 2666 insertions(+), 13 deletions(-) create mode 100644 examples/lalm/README.md create mode 100644 examples/lalm/build_model.py create mode 100644 examples/lalm/configs/evaluate.yaml create mode 100644 examples/lalm/configs/train.yaml create mode 100644 examples/lalm/configs/train_data_config.yaml create mode 100644 examples/lalm/configs/valid_data_config.yaml create mode 100644 examples/lalm/evaluate.py create mode 100644 examples/lalm/lalm_core/__init__.py create mode 100644 examples/lalm/lalm_core/data_module.py create mode 100644 examples/lalm/lalm_core/model/__init__.py create mode 100644 examples/lalm/lalm_core/model/configuration_lalm.py create mode 100644 examples/lalm/lalm_core/model/modeling_lalm.py create mode 100644 examples/lalm/lalm_core/model/processing_lalm.py create mode 100644 examples/lalm/lalm_core/trainer.py create mode 100644 examples/lalm/prepare_conversation.py create mode 100644 examples/lalm/results_utils.py create mode 100755 examples/lalm/scripts/build_model.sh create mode 100755 examples/lalm/scripts/evaluate.sh create mode 100755 examples/lalm/scripts/prepare_manifest.sh create mode 100755 examples/lalm/scripts/train.sh create mode 100644 examples/lalm/train.py diff --git a/examples/lalm/README.md b/examples/lalm/README.md new file mode 100644 index 0000000..b1301d9 --- /dev/null +++ b/examples/lalm/README.md @@ -0,0 +1,171 @@ +# LALM + +LALM is a multimodal language model that combines a pretrained audio encoder with a causal LLM via a learned projector. The model follows the standard HuggingFace `from_pretrained` / `save_pretrained` interface throughout. + +Supported audio encoders: Whisper, Qwen2.5-Omni audio encoder, Qwen3-Omni audio encoder. + +--- + +## Overview + +``` +audio waveform + │ + ▼ +Audio Encoder (frozen or fine-tuned) + │ packed features (C, T_total) + ▼ +LALMProjector (learned) + │ audio embeddings inserted into LLM token sequence + ▼ +Causal LLM (frozen or fine-tuned) + │ + ▼ +text output +``` + +Training follows a two-stage recipe: +- **Stage 1** — freeze encoder + LLM, train projector only (`frozen_modules: [audio_tower, language_model]`) +- **Stage 2** — unfreeze all, full fine-tuning (`frozen_modules: []`) + +--- + +## Step 1 — Build Model + +Assemble a LALM checkpoint from a pretrained LLM and audio encoder. This is a one-time step that produces a self-contained HF checkpoint. + +```bash +bash scripts/build_model.sh +``` + +Key options in `scripts/build_model.sh`: + +| Variable | Description | +|---|---| +| `llm` | Path to pretrained LLM (e.g. Qwen2.5-7B-Instruct) | +| `encoder` | Path to audio encoder (e.g. Qwen3-Omni, Qwen2.5-Omni, Whisper) | +| `projector_downsample_rate` | Frame concat rate in projector (higher = fewer LLM audio tokens) | +| `output_dir` | Where to save the assembled checkpoint | + +The assembled checkpoint can be loaded like any HF model: + +```python +from lalm_core.model import LALMForConditionalGeneration, LALMProcessor + +model = LALMForConditionalGeneration.from_pretrained(output_dir) +processor = LALMProcessor.from_pretrained(output_dir) +``` + +--- + +## Step 2 — Prepare Manifest + +Each training/evaluation sample needs a `conversation` field attached to its Lhotse cut. Run `prepare_conversation.py` once per dataset split. + +```bash +bash scripts/prepare_manifest.sh +``` + +Or run directly: + +```bash +python prepare_conversation.py \ + --input_manifest /path/to/cuts.jsonl.gz \ + --output_manifest data/train/cuts_conversation.jsonl.gz \ + --tokenizer /path/to/llm \ + --instruction "Please transcribe speech." +``` + +Key options: + +| Argument | Description | +|---|---| +| `--input_manifest` | Input Lhotse CutSet manifest | +| `--output_manifest` | Output manifest with `conversation` field added | +| `--tokenizer` | Tokenizer path (used to estimate token counts for batching) | +| `--instruction` | Optional user instruction appended to each sample | +| `--system` | Optional system prompt | + +The prepared manifest stores two fields per cut: +- `cut.conversation` — structured message list (OpenAI chat format) +- `cut.rendered_conversation` — rendered chat string used directly during training + +Register your prepared manifests in `configs/train_data_config.yaml` and `configs/valid_data_config.yaml`: + +```yaml +- name: aishell1 + manifest: data/train/aishell1_cuts_conversation.jsonl.gz +``` + +--- + +## Step 3 — Train + +```bash +bash scripts/train.sh +``` + +Multi-node example: + +```bash +NNODES=2 NODE_RANK=0 MASTER_ADDR= bash scripts/train.sh +``` + +Key options in `scripts/train.sh`: + +| Variable | Description | +|---|---| +| `model_dir` | Path to checkpoint from Step 1 | +| `frozen_modules` | Comma-separated modules to freeze (e.g. `audio_tower,language_model`) | +| `mixed_precision` | `bf16` or `fp16` | +| `exp_name` | Experiment name; checkpoints saved to `exp//` | + +Key training config options (override via command line or `configs/train.yaml`): + +| Config key | Description | +|---|---| +| `trainer.optimizer.lr` | Learning rate | +| `trainer.num_steps` | Total training steps | +| `trainer.grad_accum_steps` | Gradient accumulation steps | +| `trainer.valid_interval` | Validate every N steps | +| `trainer.save_every_n` | Save checkpoint every N validation intervals | +| `data.sampler.max_tokens` | Max LLM tokens per batch (controls batch size) | +| `data.feature` | Feature type: `whisper_v3_fbank` (128-dim) or `whisper_fbank` (80-dim) | + +Checkpoints are saved to `exp//checkpoint-{step}.pt`. The HF config and processor are saved once at the start of training to `exp//hf/`. + +To resume training, set `trainer.start_batch=` to the checkpoint step you want to resume from. + +--- + +## Step 4 — Evaluate + +```bash +bash scripts/evaluate.sh +``` + +Or run directly: + +```bash +python evaluate.py \ + exp_dir=./exp/my_experiment \ + checkpoint.iter=16000 \ + data.test_data_config=configs/valid_data_config.yaml \ + decoding_method=greedy_search +``` + +Key options: + +| Config key | Description | +|---|---| +| `exp_dir` | Experiment directory containing checkpoints | +| `checkpoint.iter` | Load `checkpoint-{iter}.pt` | +| `checkpoint.epoch` | Load `epoch-{epoch}.pt` | +| `checkpoint.model_dir` | Load directly from a HF model directory (skip export) | +| `decoding_method` | `greedy_search` or `beam_search` | +| `num_beams` | Beam size (used when `decoding_method=beam_search`) | +| `dtype` | Inference dtype: `fp16`, `bf16`, or `fp32` | + +On first run, the trainer checkpoint is exported to a HF checkpoint at `exp//export/iter-{iter}/` and reused on subsequent runs. + +Results are written to `exp///`. diff --git a/examples/lalm/build_model.py b/examples/lalm/build_model.py new file mode 100644 index 0000000..d8b5820 --- /dev/null +++ b/examples/lalm/build_model.py @@ -0,0 +1,212 @@ +""" +One-time script: assemble a LALM checkpoint from pretrained components. + +After running this script, load the model with standard HuggingFace APIs:: + + model = AutoModelForCausalLM.from_pretrained(output_dir) + processor = AutoProcessor.from_pretrained(output_dir) + +Usage:: + + python assemble_model.py \\ + --llm Qwen/Qwen2-7B-Instruct \\ + --encoder openai/whisper-large-v3 \\ + --output_dir ./lalm_checkpoint +""" + +import torch +from lalm_core.model import LALMConfig, LALMForConditionalGeneration, LALMProcessor +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForCausalLM, + AutoTokenizer, +) + + +def assemble_and_save( + llm_name: str, + encoder_name: str, + output_dir: str, + torch_dtype: torch.dtype = torch.bfloat16, + projector_downsample_rate: int = 4, +): + print(f"Loading LLM: {llm_name}") + llm = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch_dtype) + + print(f"Loading encoder: {encoder_name}") + encoder, audio_dim = _load_encoder(encoder_name, torch_dtype) + + # Build processor first to obtain the modality_token_id after + # the new <|audio|> token has been added to the vocabulary. + print("Building processor ...") + tokenizer = AutoTokenizer.from_pretrained(llm_name) + feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name) + processor = LALMProcessor( + feature_extractor, + tokenizer, + encoder_name=_get_encoder_model_type(encoder), + projector_downsample_rate=projector_downsample_rate, + ) + + # Resize LLM embeddings to cover the newly added audio token. + llm.resize_token_embeddings(len(processor.tokenizer)) + + text_dim = llm.config.hidden_size + + config = LALMConfig( + text_config=llm.config, + audio_config=encoder.config, + audio_dim=audio_dim, + text_dim=text_dim, + projector_downsample_rate=projector_downsample_rate, + audio_token_id=processor.tokenizer.convert_tokens_to_ids(processor.audio_token), + ) + + print("Assembling model ...") + model = LALMForConditionalGeneration( + config, + language_model=llm, + audio_tower=encoder, + ) + # Projector starts randomly initialised and is learned during training. + + print(f"Saving to {output_dir} ...") + model.save_pretrained(output_dir) + processor.save_pretrained(output_dir) + print("Done.") + + return model, processor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _load_encoder( + model_name_or_path: str, torch_dtype: torch.dtype +) -> tuple[torch.nn.Module, int]: + """ + Load an audio encoder and return ``(encoder_module, audio_dim)``. + + Do not rely on AutoModel here because some encoder checkpoints are not fully + registered for generic auto loading. Use model-family-specific APIs instead. + """ + hf_cfg = AutoConfig.from_pretrained(model_name_or_path) + model_type = getattr(hf_cfg, "model_type", "") + + if model_type == "whisper": + from transformers.models.whisper.modeling_whisper import WhisperModel + + try: + model = WhisperModel.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + encoder = model.encoder + except Exception: + # Some checkpoints may contain encoder-only weights. + from transformers.models.whisper.modeling_whisper import WhisperEncoder + + encoder = WhisperEncoder.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + return encoder, int(encoder.config.d_model) + + if model_type in {"qwen2_5_omni", "qwen2_5_omni_audio_encoder"}: + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder, + Qwen2_5OmniThinkerForConditionalGeneration, + ) + + if model_type == "qwen2_5_omni": + model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + encoder = model.audio_tower + else: + encoder = Qwen2_5OmniAudioEncoder.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + return encoder, int(encoder.config.output_dim) + + if model_type in {"qwen3_omni_moe", "qwen3_omni_moe_audio_encoder"}: + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, + Qwen3OmniMoeThinkerForConditionalGeneration, + ) + + if model_type == "qwen3_omni_moe": + model = Qwen3OmniMoeThinkerForConditionalGeneration.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + encoder = model.audio_tower + else: + encoder = Qwen3OmniMoeAudioEncoder.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + return encoder, int(encoder.config.output_dim) + + raise ValueError( + f"Unsupported encoder model_type={model_type!r}. " + "Supported values: 'whisper', 'qwen2_5_omni', " + "'qwen2_5_omni_audio_encoder', 'qwen3_omni_moe', " + "'qwen3_omni_moe_audio_encoder'." + ) + + +def _get_encoder_model_type(encoder) -> str: + hf_model_type = getattr(encoder.config, "model_type", "") + supported = { + "whisper", + "qwen2_5_omni_audio_encoder", + "qwen3_omni_moe_audio_encoder", + } + if hf_model_type not in supported: + known = ", ".join(f'"{k}"' for k in sorted(supported)) + raise ValueError( + f"Unsupported encoder model_type={hf_model_type!r}. " + f"Supported values: {known}." + ) + return hf_model_type + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Assemble LALM from pretrained components." + ) + parser.add_argument( + "--llm", required=True, help="HF model id or local path for the LLM" + ) + parser.add_argument( + "--encoder", + required=True, + help="HF model id or local path for the audio encoder", + ) + parser.add_argument( + "--output_dir", required=True, help="Where to save the assembled checkpoint" + ) + parser.add_argument( + "--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"] + ) + parser.add_argument("--projector_downsample_rate", default=4, type=int) + args = parser.parse_args() + + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + assemble_and_save( + llm_name=args.llm, + encoder_name=args.encoder, + output_dir=args.output_dir, + torch_dtype=dtype_map[args.dtype], + projector_downsample_rate=args.projector_downsample_rate, + ) diff --git a/examples/lalm/configs/evaluate.yaml b/examples/lalm/configs/evaluate.yaml new file mode 100644 index 0000000..7ebba9e --- /dev/null +++ b/examples/lalm/configs/evaluate.yaml @@ -0,0 +1,30 @@ +exp_dir: + +checkpoint: + # [Checkpoint Loading Priority] + # 1. If 'model_dir' exists and has weights, use it directly (from_pretrained). + # 2. Else if 'filename' is set, export from {exp_dir}/{filename}. + # 3. Else if 'iter' > 0, export from {exp_dir}/checkpoint-{iter}.pt. + # 4. Else if 'epoch' > 0, export from {exp_dir}/epoch-{epoch}.pt. + # Exported weights are cached in {exp_dir}/export/{suffix}/ for reuse. + model_dir: # e.g., "/path/to/exported_model" (skip export if set) + filename: # e.g., "checkpoint-5000.pt" + iter: 0 # e.g., 5000 + epoch: 0 # e.g., 3 + +# "greedy_search" or "beam_search" +decoding_method: greedy_search +num_beams: 4 # used only when decoding_method=beam_search +max_new_tokens: 200 + +# Mixed precision for inference: fp16, bf16, or fp32 +dtype: fp16 + +data: + test_data_config: configs/valid_data_config.yaml + max_duration: 1000 + num_workers: 4 + sampling_rate: 16000 + # Feature extraction settings (should match training) + on_the_fly_feats: true + feature: whisper_v3_fbank diff --git a/examples/lalm/configs/train.yaml b/examples/lalm/configs/train.yaml new file mode 100644 index 0000000..9448fd7 --- /dev/null +++ b/examples/lalm/configs/train.yaml @@ -0,0 +1,79 @@ +exp_dir: + +hydra: + run: + dir: ${exp_dir}/logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + +seed: 114514 + +model: + # Path to checkpoint assembled by examples/lalm/build_model.py + pretrained_model: + +trainer: + optimizer: + type: adamw + lr: 1.0e-3 + weight_decay: 0.0 + betas: [0.9, 0.95] + eps: 1.0e-8 + fused: true + + scheduler: + type: cosine + t_max: 10 + eta_min: 0.0 + + num_epochs: 10 + start_epoch: 1 + num_steps: 500000 + start_batch: 0 + + mixed_precision: fp16 + frozen_modules: [audio_tower, language_model] + + use_averaged_model: true + average_period: 100 + ema_decay: 0.9999 + + grad_accum_steps: 1 + max_grad_norm: 1.0 + + valid_interval: 1000 + save_every_n: 4 + keep_last_k: 5 + + log_interval: 50 + reset_interval: 200 + tensorboard: true + find_unused_parameters: false + +data: + # YAML files listing manifests. + train_data_config: configs/train_data_config.yaml + valid_data_config: configs/valid_data_config.yaml + + on_the_fly_feats: true + sampling_rate: 16000 + feature: whisper_v3_fbank + fault_tolerant: false + audio_token_rate: 12.5 + + min_duration: 0.5 + max_duration: 30.0 + + data_augmentation: + enable_spec_aug: true + enable_musan: false + enable_speed_perturb: false + musan: + + sampler: + num_buckets: 30 + max_tokens: 600 + max_duration: + shuffle: true + drop_last: true + + num_workers: 8 + use_infinite_dataset: true diff --git a/examples/lalm/configs/train_data_config.yaml b/examples/lalm/configs/train_data_config.yaml new file mode 100644 index 0000000..f84fe42 --- /dev/null +++ b/examples/lalm/configs/train_data_config.yaml @@ -0,0 +1,8 @@ +# - hours: 1000 +# manifest: /apdcephfs_cq12/share_302080740/data/asr_train_data/manifests/chinese/open_source/aishell2/aishell2_cuts.jsonl.gz +# weights: 1 +# source: patch +- hours: 150 + manifest: data/train/aishell1_cuts_conversation.jsonl.gz + weights: 1 + source: patch \ No newline at end of file diff --git a/examples/lalm/configs/valid_data_config.yaml b/examples/lalm/configs/valid_data_config.yaml new file mode 100644 index 0000000..db3b2f7 --- /dev/null +++ b/examples/lalm/configs/valid_data_config.yaml @@ -0,0 +1,2 @@ +- name: aishell + manifest: data/test/aishell_test_conversation.jsonl.gz diff --git a/examples/lalm/evaluate.py b/examples/lalm/evaluate.py new file mode 100644 index 0000000..4488412 --- /dev/null +++ b/examples/lalm/evaluate.py @@ -0,0 +1,228 @@ +"""Evaluation / decoding script for lalm. + +Assumes test manifests have been prepared by prepare_conversation.py. + +Usage:: + + python evaluate.py \\ + exp_dir=/path/to/exp \\ + checkpoint.iter=5000 \\ + data.test_data_config=configs/valid_data_config.yaml +""" + +import logging +import os +from collections import defaultdict +from pathlib import Path + +import hydra +import torch +import yaml +from lalm_core.model import LALMConfig, LALMForConditionalGeneration, LALMProcessor +from lhotse import CutSet, set_audio_duration_mismatch_tolerance +from lhotse.dataset import DynamicBucketingSampler +from omegaconf import DictConfig, OmegaConf +from results_utils import save_results +from transformers.modeling_utils import no_init_weights + +from auden.utils.text_normalization import text_normalization + +# --------------------------------------------------------------------------- +# Model loading (standard MLLM from_pretrained pattern) +# --------------------------------------------------------------------------- + + +def _has_hf_weights(model_dir: str) -> bool: + return any( + os.path.exists(os.path.join(model_dir, n)) + for n in ( + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", + ) + ) + + +def _resolve_checkpoint_path(cfg: DictConfig) -> tuple[str, str]: + ckpt_cfg = cfg.checkpoint + filename = ckpt_cfg.get("filename", None) + if filename: + path = ( + filename if os.path.isabs(filename) else os.path.join(cfg.exp_dir, filename) + ) + return path, Path(filename).stem + + iters = int(ckpt_cfg.get("iter", 0)) + epoch = int(ckpt_cfg.get("epoch", 0)) + if iters > 0: + return os.path.join(cfg.exp_dir, f"checkpoint-{iters}.pt"), f"iter-{iters}" + if epoch > 0: + return os.path.join(cfg.exp_dir, f"epoch-{epoch}.pt"), f"epoch-{epoch}" + raise ValueError( + "[evaluate] Specify checkpoint.filename, checkpoint.iter, or checkpoint.epoch." + ) + + +def prepare_model_dir(cfg: DictConfig) -> tuple[str, str]: + """Return (model_dir, results_suffix). + + If the target dir already has HF weights, use it directly (from_pretrained). + Otherwise export from the trainer checkpoint and save via save_pretrained, + caching to {exp_dir}/export/{suffix}/ for reuse. + """ + explicit_dir = cfg.checkpoint.get("model_dir", None) + checkpoint_path, results_suffix = _resolve_checkpoint_path(cfg) + + model_dir = explicit_dir or os.path.join(cfg.exp_dir, "export", results_suffix) + if explicit_dir: + results_suffix = Path(explicit_dir).name + + if _has_hf_weights(model_dir): + logging.info(f"[evaluate] Using existing model dir: {model_dir}") + return model_dir, results_suffix + + logging.info(f"[evaluate] Exporting {checkpoint_path} -> {model_dir}") + hf_dir = os.path.join(cfg.exp_dir, "hf") + if not os.path.isdir(hf_dir): + raise FileNotFoundError( + f"[evaluate] hf_dir not found: {hf_dir}. Run train.py first." + ) + + with no_init_weights(): + model = LALMForConditionalGeneration(LALMConfig.from_pretrained(hf_dir)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dict = checkpoint["model"] + if any(k.startswith("module.") for k in state_dict): + state_dict = {k[len("module.") :]: v for k, v in state_dict.items()} + model.load_state_dict(state_dict, strict=True) + + os.makedirs(model_dir, exist_ok=True) + model.save_pretrained(model_dir) + LALMProcessor.from_pretrained(hf_dir).save_pretrained(model_dir) + logging.info(f"[evaluate] Saved to {model_dir}") + del model, checkpoint, state_dict + return model_dir, results_suffix + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +_ASST_PREFIX = "<|im_start|>assistant\n" + + +def _strip_last_assistant(rendered: str) -> str: + """Remove the last assistant turn and return the prompt up to (and including) + the generation prefix '<|im_start|>assistant\\n'.""" + idx = rendered.rfind(_ASST_PREFIX) + if idx == -1: + return rendered + _ASST_PREFIX + return rendered[: idx + len(_ASST_PREFIX)] + + +@hydra.main(version_base=None, config_path="configs", config_name="evaluate") +@torch.no_grad() +def main(cfg: DictConfig): + logging.info("\n" + OmegaConf.to_yaml(cfg)) + set_audio_duration_mismatch_tolerance(0.1) + + model_dir, results_file_suffix = prepare_model_dir(cfg) + + dtype_name = cfg.get("dtype", "fp16") + dtype = ( + torch.float16 + if dtype_name == "fp16" + else torch.bfloat16 if dtype_name == "bf16" else torch.float32 + ) + + model = LALMForConditionalGeneration.from_pretrained(model_dir, torch_dtype=dtype) + processor = LALMProcessor.from_pretrained(model_dir) + # Decoder-only generation requires left-padding so every sample's last real + # token sits at the same position (L_max - 1) and the next generated token + # is correctly placed at L_max for all samples in the batch. + processor.tokenizer.padding_side = "left" + device = ( + torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu") + ) + model = model.to(device).eval() + logging.info( + f"[evaluate] {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params" + ) + + max_new_tokens = int(cfg.get("max_new_tokens", 200)) + generate_config = ( + dict(max_new_tokens=max_new_tokens, num_beams=1, do_sample=False) + if cfg.decoding_method == "greedy_search" + else dict( + max_new_tokens=max_new_tokens, + num_beams=int(cfg.get("num_beams", 4)), + do_sample=False, + ) + ) + + sampling_rate = int(cfg.data.get("sampling_rate", 16000)) + + with open(cfg.data.test_data_config) as f: + test_data_config = yaml.load(f, Loader=yaml.FullLoader) + + res_dir = Path(cfg.exp_dir) / cfg.decoding_method + os.makedirs(res_dir, exist_ok=True) + + for test_set in test_data_config: + logging.info(f"[evaluate] Test set: {test_set['name']}") + cutset = CutSet.from_file(test_set["manifest"]).resample(sampling_rate) + sampler = DynamicBucketingSampler( + cutset, max_duration=cfg.data.max_duration, shuffle=False + ) + results = defaultdict(list) + num_cuts = 0 + + for batch_idx, cuts in enumerate(sampler): + cuts = cuts.sort_by_duration(ascending=False) + + # Load raw audio (processor handles feature extraction + packing internally) + audios = [cut.load_audio()[0] for cut in cuts] + + # Build prompts from pre-rendered text (matches training exactly). + # Strip the last assistant turn and append the generation prefix. + texts = [_strip_last_assistant(cut.rendered_conversation) for cut in cuts] + + inputs = processor( + text=texts, + audio=audios, + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + ).to(device) + + output_ids = model.generate( + **inputs, + **generate_config, + ) + + hyps = processor.batch_decode(output_ids, skip_special_tokens=True) + refs = [cut.supervisions[0].text for cut in cuts] + cut_ids = [cut.id for cut in cuts] + + def norm(s): + return text_normalization( + s, + case="lower", + remove_diacritics=True, + simplified_chinese=True, + space_between_cjk=True, + ).split() + + for cut_id, ref, hyp in zip(cut_ids, refs, hyps): + results[cfg.decoding_method].append((cut_id, norm(ref), norm(hyp))) + + num_cuts += len(cut_ids) + if batch_idx % 50 == 0: + logging.info(f" batch {batch_idx}, cuts: {num_cuts}") + + save_results(res_dir, test_set["name"], results, suffix=results_file_suffix) + + +if __name__ == "__main__": + main() diff --git a/examples/lalm/lalm_core/__init__.py b/examples/lalm/lalm_core/__init__.py new file mode 100644 index 0000000..fad4569 --- /dev/null +++ b/examples/lalm/lalm_core/__init__.py @@ -0,0 +1,11 @@ +from .data_module import LALMDataModule +from .model import LALMConfig, LALMForConditionalGeneration, LALMProcessor +from .trainer import LALMTrainer + +__all__ = [ + "LALMConfig", + "LALMDataModule", + "LALMForConditionalGeneration", + "LALMProcessor", + "LALMTrainer", +] diff --git a/examples/lalm/lalm_core/data_module.py b/examples/lalm/lalm_core/data_module.py new file mode 100644 index 0000000..19d66d4 --- /dev/null +++ b/examples/lalm/lalm_core/data_module.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging + +import torch +import yaml +from lhotse import CutSet, set_audio_duration_mismatch_tolerance +from lhotse.dataset import DynamicBucketingSampler +from lhotse.dataset.sampling.base import TokenConstraint +from lhotse.workarounds import Hdf5MemoryIssueFix +from torch.utils.data import DataLoader + +from auden.data.lhotse_datamodule import BaseLhotseDatamodule, _SeedWorkers + + +def estimate_cut_tokens(cut, audio_token_rate: float): + """Estimate total tokens from duration and prepared num_text_tokens.""" + num_text_tokens = getattr(cut, "num_text_tokens", None) + if num_text_tokens is None: + raise ValueError( + f"Cut {getattr(cut, 'id', '')} missing cut.num_text_tokens. " + "Please run prepare_conversation.py first." + ) + + num_audio_tokens = int(round(float(cut.duration) * float(audio_token_rate))) + num_tokens = num_audio_tokens + int(num_text_tokens) + cut.num_tokens = num_tokens + return cut + + +class LALMDataset(torch.utils.data.Dataset): + def __init__( + self, + input_strategy, + processor, + cut_transforms=None, + input_transforms=None, + return_cuts: bool = False, + ): + self.input_strategy = input_strategy + self.processor = processor + self.cut_transforms = cut_transforms + self.input_transforms = input_transforms + self.return_cuts = return_cuts + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts) -> dict: + self.hdf5_fix.update() + + if self.cut_transforms is not None: + for transform in self.cut_transforms: + cuts = transform(cuts) + cuts = cuts.sort_by_duration(ascending=False) + + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + features, _, cuts = input_tpl + else: + features, _ = input_tpl + + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + feature_lens = supervision_intervals["num_frames"] + feature_lens = feature_lens.to(dtype=torch.long) + + # lhotse returns (N, T, C); processor expects padded (N, C, T_max) + if features.ndim == 3: + features = features.transpose(1, 2).contiguous() + + rendered_texts = [] + for cut in cuts: + rendered_text = getattr(cut, "rendered_conversation", None) + if rendered_text is None: + raise ValueError( + f"Cut {getattr(cut, 'id', '')} missing " + "cut.rendered_conversation. Please run prepare_conversation.py first." + ) + rendered_texts.append(rendered_text) + + inputs = self.processor( + text=rendered_texts, + audio_feature=(features, feature_lens), + prepare_labels=True, + return_tensors="pt", + padding=True, + ) + + batch = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "features": inputs["input_features"], + "feature_lens": inputs["feature_lens"], + "labels": inputs["labels"], + "batch_size": inputs["input_ids"].size(0), + } + + flat_cuts = [cut for cut in cuts for _ in cut.supervisions] + if self.return_cuts: + batch["cuts"] = flat_cuts + return batch + + def __len__(self) -> int: + return 0 + + +class LALMDataModule(BaseLhotseDatamodule): + def __init__(self, cfg, processor): + set_audio_duration_mismatch_tolerance(1) + self.processor = processor + super().__init__(cfg) + + def _filter_cutset(self, cutset: CutSet, split: str = "train") -> CutSet: + min_dur = float(self.cfg.get("min_duration", 0.5)) + max_dur = float(self.cfg.get("max_duration", 30.0)) + + def keep(c): + return min_dur <= c.duration <= max_dur + + audio_token_rate = float(self.cfg.get("audio_token_rate", 12.5)) + return cutset.filter(keep).map( + lambda cut: estimate_cut_tokens(cut, audio_token_rate=audio_token_rate) + ) + + def setup_train(self): + with open(self.cfg.train_data_config, "r", encoding="utf-8") as f: + train_data_config = yaml.load(f, Loader=yaml.FullLoader) + + train_cutset = self._build_train_mux_cutset(train_data_config) + train_cutset = self._filter_cutset(train_cutset, split="train") + + max_tokens = self.cfg.sampler.get("max_tokens", None) + max_duration = self.cfg.sampler.get("max_duration", None) + num_buckets = self.cfg.sampler.get("num_buckets", 30) + common = dict( + shuffle=self.cfg.sampler.shuffle, + num_buckets=num_buckets, + buffer_size=num_buckets * 2000, + shuffle_buffer_size=num_buckets * 5000, + drop_last=self.cfg.sampler.get("drop_last", True), + ) + + if max_tokens is not None: + logging.info( + f"[data] Train sampler: TokenConstraint max_tokens={max_tokens}" + ) + train_sampler = DynamicBucketingSampler( + train_cutset, + constraint=TokenConstraint(max_tokens=max_tokens), + **common, + ) + elif max_duration is not None: + logging.info(f"[data] Train sampler: max_duration={max_duration}s") + train_sampler = DynamicBucketingSampler( + train_cutset, max_duration=max_duration, **common + ) + else: + raise ValueError( + "sampler must set max_tokens (recommended) or max_duration" + ) + + train_dataset = LALMDataset( + input_strategy=self.input_strategy, + processor=self.processor, + cut_transforms=self.transforms, + input_transforms=self.input_transforms, + return_cuts=True, + ) + seed = torch.randint(0, 100_000, ()).item() + worker_init_fn = _SeedWorkers(seed) + self.train_dl = DataLoader( + train_dataset, + sampler=train_sampler, + batch_size=None, + num_workers=self.cfg.get("num_workers", 4), + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + def setup_valid(self): + valid_cfg_path = self.cfg.get("valid_data_config", None) + if not valid_cfg_path: + self.valid_dls = [] + self.valid_names = [] + return + + with open(valid_cfg_path, "r", encoding="utf-8") as f: + valid_data_config = yaml.load(f, Loader=yaml.FullLoader) + + self.valid_dls = [] + self.valid_names = [] + for valid_set in valid_data_config: + cutset = CutSet.from_file(valid_set["manifest"]).resample( + self.sampling_rate + ) + cutset = self._filter_cutset(cutset, split="valid") + valid_name = valid_set.get("name", "valid") + + max_tokens = self.cfg.sampler.get("max_tokens", None) + max_duration = self.cfg.sampler.get("max_duration", None) + if max_tokens is not None: + valid_sampler = DynamicBucketingSampler( + cutset, + constraint=TokenConstraint(max_tokens=max_tokens), + shuffle=False, + ) + else: + valid_sampler = DynamicBucketingSampler( + cutset, max_duration=max_duration, shuffle=False + ) + + valid_dataset = LALMDataset( + input_strategy=self.input_strategy, + processor=self.processor, + return_cuts=True, + ) + valid_dl = DataLoader( + valid_dataset, + sampler=valid_sampler, + batch_size=None, + num_workers=self.cfg.get("num_workers", 4), + persistent_workers=False, + ) + self.valid_names.append(valid_name) + self.valid_dls.append(valid_dl) diff --git a/examples/lalm/lalm_core/model/__init__.py b/examples/lalm/lalm_core/model/__init__.py new file mode 100644 index 0000000..571b497 --- /dev/null +++ b/examples/lalm/lalm_core/model/__init__.py @@ -0,0 +1,5 @@ +from .configuration_lalm import LALMConfig +from .modeling_lalm import LALMForConditionalGeneration +from .processing_lalm import LALMProcessor + +__all__ = ["LALMConfig", "LALMForConditionalGeneration", "LALMProcessor"] diff --git a/examples/lalm/lalm_core/model/configuration_lalm.py b/examples/lalm/lalm_core/model/configuration_lalm.py new file mode 100644 index 0000000..169ed4a --- /dev/null +++ b/examples/lalm/lalm_core/model/configuration_lalm.py @@ -0,0 +1,107 @@ +""" +LALM (Language-Audio Language Model) configuration. + +Combines a speech/audio encoder with a causal LLM and a linear projector. +Supports arbitrary encoder and LLM types via AutoConfig. +""" + +from __future__ import annotations + +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + +DEFAULT_TEXT_MODEL_TYPE = "qwen2" +DEFAULT_AUDIO_MODEL_TYPE = "whisper" + + +class LALMConfig(PretrainedConfig): + """ + Configuration for LALM (Language-Audio Language Model). + + LALM combines: + - An audio encoder (e.g. Whisper, Zipformer) + - A linear projector from audio dim to text dim + - A causal LLM (e.g. Qwen2, LLaMA) + + Both ``text_config`` and ``audio_config`` accept a dict (with ``model_type``) + or a ``PretrainedConfig`` subclass, so you can freely swap encoders and LLMs. + + Args: + text_config: LLM config. Defaults to Qwen2. + audio_config: Audio encoder config. Defaults to Whisper. + audio_dim: Hidden size output by the audio encoder (projector input dim). + text_dim: Hidden size of the LLM (projector output dim). + projector_downsample_rate: Number of consecutive encoder frames merged into + one token by the projector. Higher values = shorter audio sequences fed + to the LLM. Defaults to 4. + audio_token_id: Token id used as placeholder for audio embeddings. + """ + + model_type = "lalm" + + def __init__( + self, + text_config: Optional[Union[dict, PretrainedConfig]] = None, + audio_config: Optional[Union[dict, PretrainedConfig]] = None, + audio_dim: Optional[int] = None, + text_dim: Optional[int] = None, + projector_downsample_rate: int = 4, + audio_token_id: Optional[int] = None, + **kwargs: Any, + ) -> None: + self.text_config = _resolve_config(text_config, DEFAULT_TEXT_MODEL_TYPE) + self.audio_config = _resolve_config(audio_config, DEFAULT_AUDIO_MODEL_TYPE) + self.audio_dim = audio_dim + self.text_dim = text_dim + self.projector_downsample_rate = projector_downsample_rate + self.audio_token_id = audio_token_id + super().__init__(**kwargs) + + def to_dict(self) -> dict: + output = super().to_dict() + output["text_config"] = self.text_config.to_dict() + output["audio_config"] = self.audio_config.to_dict() + return output + + +def _resolve_config( + config: Optional[Union[dict, PretrainedConfig]], + default_model_type: str, +) -> PretrainedConfig: + """Convert a dict or None into a concrete PretrainedConfig.""" + if isinstance(config, PretrainedConfig): + return config + if isinstance(config, dict): + config = config.copy() + config.setdefault("model_type", default_model_type) + model_type = config["model_type"] + try: + return AutoConfig.for_model(**config) + except ValueError: + return _resolve_known_config(config, model_type) + return AutoConfig.for_model(default_model_type) + + +def _resolve_known_config(config: dict, model_type: str) -> PretrainedConfig: + """Handle HF config classes that exist but are not AutoConfig-registered.""" + if model_type == "whisper": + from transformers.models.whisper.configuration_whisper import WhisperConfig + + return WhisperConfig.from_dict(config) + + if model_type == "qwen2_5_omni_audio_encoder": + from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoderConfig, + ) + + return Qwen2_5OmniAudioEncoderConfig.from_dict(config) + + if model_type == "qwen3_omni_moe_audio_encoder": + from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoderConfig, + ) + + return Qwen3OmniMoeAudioEncoderConfig.from_dict(config) + + raise ValueError(f"Unrecognized model identifier: {model_type}.") diff --git a/examples/lalm/lalm_core/model/modeling_lalm.py b/examples/lalm/lalm_core/model/modeling_lalm.py new file mode 100644 index 0000000..ccb004f --- /dev/null +++ b/examples/lalm/lalm_core/model/modeling_lalm.py @@ -0,0 +1,332 @@ +import torch +import torch.nn as nn +from transformers import AutoModel, AutoModelForCausalLM, PreTrainedModel + +from .configuration_lalm import LALMConfig + + +class LALMProjector(nn.Module): + """ + Temporal downsampling projector that maps audio encoder outputs into the LLM's + embedding space. + + ``downsample_rate`` consecutive frames are concatenated along the feature axis + before the linear projection, reducing the sequence length fed to the LLM by + that factor. E.g. with Whisper (50 frames/sec) and ``downsample_rate=5``, the + LLM receives 10 audio tokens per second. + """ + + def __init__(self, audio_dim: int, text_dim: int, downsample_rate: int = 5): + super().__init__() + self.downsample_rate = downsample_rate + self.linear1 = nn.Linear(audio_dim * downsample_rate, text_dim) + self.act = nn.GELU() + self.linear2 = nn.Linear(text_dim, text_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, D = x.shape + remainder = T % self.downsample_rate + if remainder: + x = x[:, :-remainder, :] + x = x.reshape(B, T // self.downsample_rate, D * self.downsample_rate) + return self.linear2(self.act(self.linear1(x))) + + +class LALMForConditionalGeneration(PreTrainedModel): + """ + LALM: audio encoder + projector + causal LLM. + + During forward, audio features are encoded, projected, then inserted at + positions marked by ``modality_token_id`` in the input token sequence. + + Note on audio encoder: + ``audio_tower`` should be an *encoder-only* module whose output has a + ``last_hidden_state`` attribute (e.g. ``WhisperEncoder``, or + ``WhisperModel.encoder``). Full seq2seq models such as ``WhisperModel`` + return the *decoder*'s ``last_hidden_state`` and must be unwrapped first + (``assemble_model.py`` handles this automatically). + """ + + config_class = LALMConfig + + def __init__( + self, + config: LALMConfig, + language_model: nn.Module | None = None, + audio_tower: nn.Module | None = None, + ): + super().__init__(config) + self.language_model = ( + language_model + if language_model is not None + else AutoModelForCausalLM.from_config(config.text_config) + ) + self.audio_tower = ( + audio_tower + if audio_tower is not None + else self._build_audio_tower(config.audio_config) + ) + self.projector = LALMProjector( + config.audio_dim, config.text_dim, config.projector_downsample_rate + ) + + @staticmethod + def _build_audio_tower(audio_config): + """Build encoder explicitly to avoid relying on AutoModel registration.""" + model_type = getattr(audio_config, "model_type", "") + + if model_type == "whisper": + from transformers.models.whisper.modeling_whisper import WhisperEncoder + + return WhisperEncoder(audio_config) + + if model_type == "qwen2_5_omni_audio_encoder": + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder, + ) + + return Qwen2_5OmniAudioEncoder(audio_config) + + if model_type == "qwen3_omni_moe_audio_encoder": + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, + ) + + return Qwen3OmniMoeAudioEncoder(audio_config) + + return AutoModel.from_config(audio_config) + + def generate( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + input_features: torch.Tensor | None = None, + feature_lens: torch.Tensor | None = None, + **generate_kwargs, + ) -> torch.Tensor: + """Generate text auto-regressively, optionally conditioned on audio. + + Parameter names mirror the processor's output keys so that the standard + ``model.generate(**inputs, feature_lens=..., **gen_cfg)`` pattern works. + + Args: + input_ids: Token ids with audio-token placeholders, (N, L). + attention_mask: Padding mask, (N, L). + input_features: Packed audio features (C, T_total). + feature_lens: Actual frame counts per sample (N,). + **generate_kwargs: Forwarded to ``language_model.generate``. + + Returns: + Generated token ids, (N, T_gen). + """ + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if input_features is not None: + audio_param = next(self.audio_tower.parameters(), None) + if audio_param is not None and input_features.dtype != audio_param.dtype: + input_features = input_features.to(dtype=audio_param.dtype) + inputs_embeds = self._merge_input_ids_with_audio_features( + input_ids, inputs_embeds, input_features, feature_lens=feature_lens + ) + return self.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **generate_kwargs, + ) + + def forward( + self, + input_ids: torch.Tensor, + audio_features: torch.Tensor = None, + feature_lens: torch.Tensor = None, + attention_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, + labels: torch.Tensor = None, + **kwargs, + ): + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + if audio_features is not None: + inputs_embeds = self._merge_input_ids_with_audio_features( + input_ids, + inputs_embeds, + audio_features, + feature_lens=feature_lens, + ) + + if labels is not None and attention_mask is not None: + outputs, packed_labels = self._forward_packed( + inputs_embeds, attention_mask, labels, **kwargs + ) + outputs.packed_labels = packed_labels + return outputs + + return self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + labels=labels, + **kwargs, + ) + + def encode_audio( + self, + audio_features: torch.Tensor, + feature_lens: torch.Tensor = None, + ) -> torch.Tensor: + """Run the packed audio tower and return padded hidden states.""" + model_type = getattr(self.config.audio_config, "model_type", "") + if feature_lens is None: + raise ValueError("Packed-only LALM requires feature_lens.") + + feature_lens = feature_lens.to(device=audio_features.device, dtype=torch.long) + if model_type == "qwen2_5_omni_audio_encoder": + aftercnn_lens = (feature_lens - 1) // 2 + 1 + encoded = self.audio_tower( + input_features=audio_features, + feature_lens=feature_lens, + aftercnn_lens=aftercnn_lens, + ).last_hidden_state + output_lens = (aftercnn_lens - 2) // 2 + 1 + return self._pad_packed_audio_outputs(encoded, output_lens) + + if model_type == "qwen3_omni_moe_audio_encoder": + encoded = self.audio_tower( + input_features=audio_features, + feature_lens=feature_lens, + ).last_hidden_state + output_lens = self._get_audio_output_lengths(feature_lens) + return self._pad_packed_audio_outputs(encoded, output_lens) + + raise NotImplementedError( + "Packed-only LALM currently supports " + "'qwen2_5_omni_audio_encoder' and 'qwen3_omni_moe_audio_encoder'." + ) + + def _get_audio_output_lengths(self, feature_lens: torch.Tensor) -> torch.Tensor: + model_type = getattr(self.config.audio_config, "model_type", "") + + if model_type == "whisper": + return (feature_lens - 1) // 2 + 1 + + if model_type == "qwen2_5_omni_audio_encoder": + after_first = (feature_lens - 1) // 2 + 1 + return (after_first - 2) // 2 + 1 + + if model_type == "qwen3_omni_moe_audio_encoder": + remainder = feature_lens % 100 + feat = (remainder - 1) // 2 + 1 + return ((feat - 1) // 2 + 1 - 1) // 2 + 1 + (feature_lens // 100) * 13 + + raise ValueError(f"Unsupported audio encoder model_type: {model_type!r}") + + def _pad_packed_audio_outputs( + self, encoded: torch.Tensor, output_lens: torch.Tensor + ) -> torch.Tensor: + if encoded.ndim != 2: + return encoded + + lengths = [int(x) for x in output_lens.tolist()] + max_len = max(lengths) + padded = encoded.new_zeros((len(lengths), max_len, encoded.size(-1))) + chunks = encoded.split(lengths, dim=0) + for i, chunk in enumerate(chunks): + padded[i, : chunk.size(0)] = chunk + return padded + + def _forward_packed( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor, + **kwargs, + ): + device = inputs_embeds.device + + seq_lens = attention_mask.sum(dim=1).long() + segment_lens = seq_lens.tolist() + + valid = attention_mask.bool() + packed_embeds = inputs_embeds[valid].unsqueeze(0) + packed_labels = labels[valid].unsqueeze(0) + + T_total = packed_embeds.size(1) + ones = torch.ones(T_total, device=device, dtype=torch.long) + if seq_lens.size(0) > 1: + starts = seq_lens.cumsum(0)[:-1] + ones[starts] -= seq_lens[:-1] + position_ids = ones.cumsum(0).sub_(1).unsqueeze(0) + + attn_mask = _block_diagonal_causal_mask( + segment_lens, device=device, dtype=inputs_embeds.dtype + ) + outputs = self.language_model( + inputs_embeds=packed_embeds, + attention_mask=attn_mask, + position_ids=position_ids, + labels=packed_labels, + **kwargs, + ) + return outputs, packed_labels + + def _merge_input_ids_with_audio_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + audio_features: torch.Tensor, + feature_lens: torch.Tensor = None, + ) -> torch.Tensor: + """Replace audio token positions in inputs_embeds with projected audio embeddings.""" + audio_embeds = self.projector( + self.encode_audio( + audio_features, + feature_lens=feature_lens, + ) + ) + mask = input_ids == self.config.audio_token_id + if not mask.any(): + return inputs_embeds + + B, audio_len, D = audio_embeds.shape + if B != input_ids.shape[0]: + raise ValueError( + f"Batch size mismatch: input_ids batch={input_ids.shape[0]}, " + f"audio_embeds batch={B}." + ) + if audio_len == 0: + return inputs_embeds + + rank = mask.long().cumsum(dim=1) - 1 + valid = mask & (rank < audio_len) + gather_index = rank.clamp(min=0, max=audio_len - 1) + gathered_audio = audio_embeds.gather( + 1, gather_index.unsqueeze(-1).expand(-1, -1, D) + ).to(inputs_embeds.dtype) + + out = inputs_embeds.reshape(-1, D) + src = gathered_audio.reshape(-1, D) + flat_valid = valid.reshape(-1) + out[flat_valid] = src[flat_valid] + return out.view_as(inputs_embeds) + + +def _block_diagonal_causal_mask( + segment_lens: list[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + T = sum(segment_lens) + seg_ids = torch.repeat_interleave( + torch.arange(len(segment_lens), device=device), + torch.tensor(segment_lens, device=device), + ) + + q_pos = torch.arange(T, device=device).unsqueeze(1) + k_pos = torch.arange(T, device=device).unsqueeze(0) + q_seg = seg_ids.unsqueeze(1) + k_seg = seg_ids.unsqueeze(0) + + allow = (k_pos <= q_pos) & (k_seg == q_seg) + + mask = torch.zeros(1, 1, T, T, device=device, dtype=dtype) + mask.masked_fill_(~allow, float("-inf")) + return mask diff --git a/examples/lalm/lalm_core/model/processing_lalm.py b/examples/lalm/lalm_core/model/processing_lalm.py new file mode 100644 index 0000000..15a0633 --- /dev/null +++ b/examples/lalm/lalm_core/model/processing_lalm.py @@ -0,0 +1,326 @@ +import torch +from transformers.audio_utils import AudioInput +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin +from transformers.tokenization_utils_base import TextInput + +_ENCODER_LENGTH_FNS: dict[str, callable] = {} +_AUDIO_TOKEN_LENGTH_FNS: dict[str, callable] = {} + + +def _whisper_output_length(mel_frames: int) -> int: + return (mel_frames - 1) // 2 + 1 + + +def _qwen25_ae_output_length(mel_frames: int) -> int: + after_first = (mel_frames - 1) // 2 + 1 + return (after_first - 2) // 2 + 1 + + +def _qwen3_aut_output_length(mel_frames: int) -> int: + remainder = mel_frames % 100 + feat = (remainder - 1) // 2 + 1 + return ((feat - 1) // 2 + 1 - 1) // 2 + 1 + (mel_frames // 100) * 13 + + +_ENCODER_LENGTH_FNS = { + "whisper": _whisper_output_length, + "qwen2_5_omni_audio_encoder": _qwen25_ae_output_length, + "qwen3_omni_moe_audio_encoder": _qwen3_aut_output_length, +} + + +def make_audio_token_length_fn( + encoder_name: str, projector_downsample_rate: int = 4 +) -> callable: + """ + Return a ``(mel_frames: int) -> int`` function combining encoder downsampling + and projector frame-concat rate. + """ + if encoder_name not in _ENCODER_LENGTH_FNS: + raise ValueError( + f"Unknown encoder {encoder_name!r}. Known: {list(_ENCODER_LENGTH_FNS)}." + ) + if projector_downsample_rate <= 0: + raise ValueError("projector_downsample_rate must be a positive integer.") + name = f"{encoder_name}_ds{projector_downsample_rate}" + if name not in _AUDIO_TOKEN_LENGTH_FNS: + encoder_fn = _ENCODER_LENGTH_FNS[encoder_name] + ds = projector_downsample_rate + + def fn(mel_frames: int) -> int: + return encoder_fn(mel_frames) // ds + + fn.__name__ = name + _AUDIO_TOKEN_LENGTH_FNS[name] = fn + return _AUDIO_TOKEN_LENGTH_FNS[name] + + +class LALMProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": {"padding": False}, + "audio_kwargs": { + "sampling_rate": 16000, + "padding": True, + "truncation": False, + "return_attention_mask": True, + }, + } + + +class LALMProcessor(ProcessorMixin): + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + feature_extractor, + tokenizer, + encoder_name: str | None = None, + projector_downsample_rate: int = 4, + audio_token: str = "<|audio|>", + ): + self.audio_token = audio_token + self.encoder_name = encoder_name or "whisper" + self.projector_downsample_rate = projector_downsample_rate + if self.encoder_name not in _ENCODER_LENGTH_FNS: + raise ValueError( + f"Unknown encoder {self.encoder_name!r}. Known: {list(_ENCODER_LENGTH_FNS)}." + ) + if self.projector_downsample_rate <= 0: + raise ValueError("projector_downsample_rate must be a positive integer.") + + if self.audio_token not in tokenizer.get_vocab(): + tokenizer.add_special_tokens( + {"additional_special_tokens": [self.audio_token]} + ) + super().__init__(feature_extractor, tokenizer) + + @property + def audio_token_length_fn(self) -> callable: + return make_audio_token_length_fn( + self.encoder_name, + self.projector_downsample_rate, + ) + + def apply_chat_template( + self, conversations, chat_template=None, tokenize=False, **kwargs + ) -> str | list[str]: + is_batched = isinstance(conversations[0], list) + conversations = conversations if is_batched else [conversations] + results = [ + self.tokenizer.apply_chat_template( + self._normalize_conversation(conv), + tokenize=tokenize, + chat_template=chat_template, + **kwargs, + ) + for conv in conversations + ] + return results if is_batched else results[0] + + def _normalize_conversation(self, conversation: list) -> list: + out = [] + for turn in conversation: + content = turn["content"] + if isinstance(content, list): + parts = [] + for item in content: + if item["type"] == "audio": + parts.append(self.audio_token) + elif item["type"] == "text": + parts.append(item["text"]) + else: + raise ValueError( + f"Unknown content type {item.get('type')!r} in multimodal message." + ) + content = " ".join(parts) + out.append({**turn, "content": content}) + return out + + def __call__( + self, + text: TextInput | None = None, + audio: AudioInput | None = None, + audio_feature: tuple | None = None, + prepare_labels: bool = False, + **kwargs, + ) -> BatchFeature: + if text is None: + raise ValueError("You need to specify a `text` input to process.") + if audio is not None and audio_feature is not None: + raise ValueError("Provide either `audio` or `audio_feature`, not both.") + + output_kwargs = self._merge_kwargs( + LALMProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if audio is not None: + out = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + # feature_lens = number of non-padded mel frames per sample + feature_lens = out.attention_mask.sum(dim=-1).long() + feature_lens_list = feature_lens.tolist() + # Pack padded (N, C, T_max) → flat (C, T_total) as the audio encoders expect + packed_features = self._pack_from_padded(out.input_features, feature_lens) + audio_lengths = iter( + self.audio_token_length_fn(int(length)) for length in feature_lens_list + ) + audio_inputs = { + "input_features": packed_features, + "feature_lens": feature_lens, + } + elif audio_feature is not None: + features, feature_lens = audio_feature + if torch.is_tensor(feature_lens): + feature_lens = feature_lens.to(dtype=torch.long) + feature_lens_list = feature_lens.tolist() + else: + feature_lens_list = [int(x) for x in feature_lens] + feature_lens = torch.tensor(feature_lens_list, dtype=torch.long) + + packed_features = self._pack_from_padded(features, feature_lens) + audio_lengths = iter( + self.audio_token_length_fn(int(length)) for length in feature_lens_list + ) + audio_inputs = { + "input_features": packed_features, + "feature_lens": feature_lens, + } + else: + audio_inputs = {} + audio_lengths = iter([]) + + if not isinstance(text, list): + text = [text] + + text = self.replace_multimodal_special_tokens(text, audio_lengths) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if prepare_labels: + text_inputs["labels"] = self._prepare_labels( + text_inputs["input_ids"], text_inputs["attention_mask"] + ) + + return BatchFeature( + data={**text_inputs, **audio_inputs}, + tensor_type=kwargs.get("return_tensors"), + ) + + def _prepare_labels(self, input_ids, attention_mask): + labels = torch.full_like(input_ids, -100) + + im_start_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>") + im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>") + assistant_header = self.tokenizer.encode( + "assistant\n", add_special_tokens=False + ) + header_len = len(assistant_header) + header_tensor = torch.tensor( + assistant_header, dtype=input_ids.dtype, device=input_ids.device + ) + + B, L = input_ids.shape + + for b in range(B): + ids = input_ids[b] # [L] + + # 1. Find positions where <|im_start|> is followed by "assistant\n". + # unfold(0, header_len, 1) on ids[1:] gives windows where + # windows[i] == ids[i+1 : i+1+header_len]. + im_start_mask = ids == im_start_id # [L] + if header_len > 0 and L > header_len: + windows = ids[1:].unfold(0, header_len, 1) # [L-header_len, header_len] + header_match = (windows == header_tensor).all(dim=-1) # [L-header_len] + header_match_full = torch.cat( + [header_match, header_match.new_zeros(header_len)] + ) # [L] + else: + header_match_full = ids.new_zeros(L, dtype=torch.bool) + + assistant_starts = (im_start_mask & header_match_full).nonzero( + as_tuple=True + )[0] # positions of matching <|im_start|> + if len(assistant_starts) == 0: + continue + + # response content starts right after "<|im_start|>assistant\n" + response_starts = (assistant_starts + 1 + header_len).clamp(max=L) + + # 2. For each response_start, find the first <|im_end|> at or after it. + # searchsorted on the sorted im_end_pos tensor replaces list.index(). + im_end_pos = (ids == im_end_id).nonzero(as_tuple=True)[0] # [M], sorted + if len(im_end_pos) > 0: + idx = torch.searchsorted(im_end_pos, response_starts) # [N] + has_end = idx < len(im_end_pos) + safe_idx = idx.clamp(max=len(im_end_pos) - 1) + # +1 to include the <|im_end|> token itself (matches original logic) + response_ends = torch.where( + has_end, + im_end_pos[safe_idx] + 1, + ids.new_full((), L), + ).clamp(max=L) + else: + response_ends = ids.new_full((len(response_starts),), L) + + # 3. Fill [response_start, response_end) ranges using the cumsum trick: + # +1 at each start, -1 at each end → cumsum > 0 marks label positions. + signal = ids.new_zeros(L + 1) + signal.scatter_add_( + 0, response_starts, torch.ones_like(response_starts) + ) + signal.scatter_add_( + 0, response_ends, -torch.ones_like(response_ends) + ) + label_mask = signal[:L].cumsum(0).bool() # [L] + + labels[b] = ids.masked_fill(~label_mask, -100) + + return labels + + @staticmethod + def _pack_from_padded( + features: torch.Tensor, feature_lens: torch.Tensor + ) -> torch.Tensor: + """Pack padded (N, C, T_max) into flat (C, T_total) expected by audio encoders.""" + # features: (N, C, T_max) — HF feature extractor output layout + chunks = [ + features[i, :, : int(feature_lens[i].item())] + for i in range(features.size(0)) + ] + return torch.cat(chunks, dim=1).contiguous() # (C, T_total) + + def replace_multimodal_special_tokens( + self, text: list[str], audio_lengths + ) -> list[str]: + processed = [] + for sample in text: + if self.audio_token not in sample: + processed.append(sample) + continue + + parts = sample.split(self.audio_token) + rebuilt = [parts[0]] + for i in range(1, len(parts)): + n = max(int(next(audio_lengths, 1)), 1) + rebuilt.append(self.audio_token * n) + rebuilt.append(parts[i]) + processed.append("".join(rebuilt)) + return processed + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self) -> list[str]: + return list( + dict.fromkeys( + self.tokenizer.model_input_names + + self.feature_extractor.model_input_names + + ["feature_lens"] + ) + ) diff --git a/examples/lalm/lalm_core/trainer.py b/examples/lalm/lalm_core/trainer.py new file mode 100644 index 0000000..6697fb2 --- /dev/null +++ b/examples/lalm/lalm_core/trainer.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import copy +import logging + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP + +from auden.trainer.ddp_trainer import BaseTrainer +from auden.utils.metric_tracker import MetricsTracker + + +class LALMTrainer(BaseTrainer): + + def __init__(self, cfg, model, data_module, rank=0, local_rank=0, world_size=1): + t = cfg.trainer + self._grad_accum_steps: int = int(getattr(t, "grad_accum_steps", 1)) + self._max_grad_norm: float = float(getattr(t, "max_grad_norm", 1.0)) + self._ema_decay: float = float(getattr(t, "ema_decay", 0.9999)) + self._accum_step: int = 0 + super().__init__(cfg, model, data_module, rank, local_rank, world_size) + + def setup_model(self, model: nn.Module): + """float32 EMA on rank-0 (not float64), find_unused_parameters=False.""" + if self.rank == 0: + model_avg = copy.deepcopy(model).to(torch.float32).to("cpu") + else: + model_avg = None + + model = model.to(self.device) + + if self.world_size > 1: + model = DDP( + model, + device_ids=[self.local_rank], + find_unused_parameters=self.cfg.trainer.get( + "find_unused_parameters", False + ), + ) + + num_param = sum(p.numel() for p in model.parameters()) / 1e6 + num_trainable = ( + sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 + ) + logging.info( + f"Parameters: {num_param:.2f}M total, {num_trainable:.2f}M trainable" + ) + return model, model_avg + + def _maybe_update_model_average(self): + """EMA over trainable params only — skips frozen weights to avoid copying + the entire model (e.g. 7B LLM backbone) from GPU to CPU every N steps.""" + if ( + self.rank != 0 + or not self.cfg.trainer.get("use_averaged_model", False) + or self.global_step == 0 + or self.global_step % self.cfg.trainer.average_period != 0 + ): + return + + model_cur = self.model.module if isinstance(self.model, DDP) else self.model + decay = self._ema_decay + with torch.no_grad(): + for (_, avg_p), (_, cur_p) in zip( + self.model_avg.named_parameters(), + model_cur.named_parameters(), + ): + if not cur_p.requires_grad: + continue + avg_p.data.mul_(decay).add_(cur_p.data.float().cpu(), alpha=1.0 - decay) + + def _forward_backward_optimize(self, batch: dict): + """zero_grad at start of accumulation window, grad accumulation, grad clipping.""" + amp_dtype = ( + torch.float16 + if self.mixed_precision == "fp16" + else torch.bfloat16 if self.mixed_precision == "bf16" else None + ) + + if self._accum_step == 0: + self.optimizer.zero_grad(set_to_none=True) + + with torch.amp.autocast("cuda", enabled=amp_dtype is not None, dtype=amp_dtype): + loss, batch_metrics = self._forward_one_batch(batch, is_training=True) + + scaled_loss = loss / self._grad_accum_steps + self.scaler.scale(scaled_loss).backward() + + self._accum_step += 1 + + if self._accum_step >= self._grad_accum_steps: + self._accum_step = 0 + + self.scaler.unscale_(self.optimizer) + + if self._max_grad_norm > 0: + nn.utils.clip_grad_norm_( + (p for p in self.model.parameters() if p.grad is not None), + self._max_grad_norm, + ) + + self.scheduler.step_batch(self.global_step) + self.scaler.step(self.optimizer) + self.scaler.update() + + return loss, batch_metrics + + def _forward_one_batch(self, batch, is_training=True): + device = self.device + input_ids = batch["input_ids"].to(device, non_blocking=True) + attention_mask = batch["attention_mask"].to(device, non_blocking=True) + labels = batch["labels"].to(device, non_blocking=True) + audio_features = batch["features"].to(device, non_blocking=True) + feature_lens = batch["feature_lens"].to(device=device, dtype=torch.long, non_blocking=True) + + model_ref = self.model.module if isinstance(self.model, DDP) else self.model + audio_param = next(model_ref.audio_tower.parameters(), None) + if audio_param is not None and audio_features.dtype != audio_param.dtype: + audio_features = audio_features.to(dtype=audio_param.dtype) + + with torch.set_grad_enabled(is_training): + outputs = self.model( + input_ids=input_ids, + audio_features=audio_features, + feature_lens=feature_lens, + attention_mask=attention_mask, + labels=labels, + ) + loss = outputs.loss + logits = outputs.logits + packed_labels = getattr(outputs, "packed_labels", labels) + + info = MetricsTracker() + B = int(input_ids.size(0)) + info.set_value("samples", B, normalization="sum") + info.set_value( + "tokens", int((attention_mask > 0).sum().item()), normalization="sum" + ) + info.set_value( + "loss", float(loss.detach().cpu().item()), normalization="sample_avg" + ) + info.set_value( + "acc", + float(self._token_accuracy(logits, packed_labels).detach().cpu().item()), + normalization="sample_avg", + ) + return loss, info + + @staticmethod + def _token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + valid = shift_labels != -100 + n = valid.sum() + if n == 0: + return logits.new_zeros(()) + preds = shift_logits.argmax(-1)[valid] + correct = (preds == shift_labels[valid]).sum() + return correct.float() / n diff --git a/examples/lalm/prepare_conversation.py b/examples/lalm/prepare_conversation.py new file mode 100644 index 0000000..dca3fcc --- /dev/null +++ b/examples/lalm/prepare_conversation.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +"""Offline prepare cut.conversation for LALM training. + +This script reads a Lhotse CutSet manifest, builds a conversation object for each +cut, writes it to ``cut.custom["conversation"]`` (accessible as ``cut.conversation``), +and saves a new manifest. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from lhotse import CutSet +from transformers import AutoTokenizer + +from auden.utils.text_normalization import text_normalization + + +def _build_conversation(cut, instruction=None, system=None): + response = cut.supervisions[0].text # set it to your own response field + response = text_normalization( + response, + case="lower", + space_between_cjk=False, + remove_diacritics=True, + remove_symbols=True, + remove_in_parenthesis=True, + remove_in_brackets=True, + ) + + audio_source = cut.recording.sources[0].source + + messages: list[dict] = [] + if system: + messages.append({"role": "system", "content": system}) + + user_content = [{"type": "audio", "audio": audio_source}] + if instruction: + user_content.append({"type": "text", "text": str(instruction)}) + messages.append({"role": "user", "content": user_content}) + messages.append({"role": "assistant", "content": response}) + return messages + + +def _render_conversation( + conversation: list[dict], audio_token: str = "<|audio|>" +) -> str: + """Render conversation to Qwen-style chat text with im tags. + + Example: + conversation = [ + { + "role": "user", + "content": [{"type": "audio"}, {"type": "text", "text": "Please transcribe."}], + }, + {"role": "assistant", "content": "Trading has almost stalled."}, + ] + + rendered_text = + <|im_start|>user + <|audio|>Please transcribe.<|im_end|> + <|im_start|>assistant + Trading has almost stalled.<|im_end|> + """ + chunks = [] + for turn in conversation: + role = turn["role"] + content = turn["content"] + if isinstance(content, list): + parts = [] + for item in content: + item_type = item.get("type") + if item_type == "audio": + parts.append(audio_token) + elif item_type == "text": + parts.append(str(item.get("text", ""))) + else: + raise ValueError(f"Unknown content type: {item_type!r}") + rendered_content = "".join(parts) + else: + rendered_content = str(content) + chunks.append(f"<|im_start|>{role}\n{rendered_content}<|im_end|>\n") + return "".join(chunks) + + +def _estimate_text_tokens(text: str, tokenizer) -> int: + """Estimate text tokens with a real tokenizer.""" + text = text.strip() + if not text: + return 0 + return len(tokenizer(text, add_special_tokens=False).input_ids) + + +def prepare_cut(cut, tokenizer, instruction=None, system=None): + if cut.custom is None: + cut.custom = {} + conversation = _build_conversation(cut, instruction, system) + rendered_conversation = _render_conversation(conversation) + cut.custom["conversation"] = conversation + cut.custom["rendered_conversation"] = rendered_conversation + cut.custom["num_text_tokens"] = _estimate_text_tokens( + rendered_conversation, + tokenizer, + ) + return cut + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare cut.conversation offline for a CutSet." + ) + parser.add_argument( + "--input_manifest", + required=True, + help="Input CutSet manifest path (e.g. *.jsonl.gz).", + ) + parser.add_argument( + "--output_manifest", + required=True, + help="Output CutSet manifest path.", + ) + parser.add_argument( + "--instruction", + help="Additional user prompt", + ) + parser.add_argument( + "--system", + help="System for the conversation.", + ) + parser.add_argument( + "--tokenizer", + required=True, + help="Tokenizer name or path used to estimate num_text_tokens.", + ) + args = parser.parse_args() + + out_path = Path(args.output_manifest) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + + cuts = CutSet.from_file(args.input_manifest) + prepared = cuts.map( + lambda c: prepare_cut( + c, + tokenizer=tokenizer, + instruction=args.instruction, + system=args.system, + ) + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + prepared.to_file(str(out_path)) + print(f"Saved prepared CutSet to: {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/lalm/results_utils.py b/examples/lalm/results_utils.py new file mode 100644 index 0000000..eed8966 --- /dev/null +++ b/examples/lalm/results_utils.py @@ -0,0 +1,275 @@ +"""Utilities for saving ASR decoding results and computing error stats (WER/CER). + +Most of this file is adapted from Icefall's utilities with minor adjustments: +@https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py +""" + +import logging +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import kaldialign + + +def store_transcripts( + filename: Path, + texts: Iterable[Tuple[str, Union[str, List[str]], Union[str, List[str]]]], + char_level: bool = False, +) -> None: + """Save predicted results and reference transcripts to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + If it is a multi-talker ASR system, the ref and hyp may also be lists of + strings. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf8") as f: + for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def save_results( + res_dir: Path, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + suffix: Optional[str] = None, +) -> None: + """Save recognition outputs and write WER summaries for different settings. + + Parameters + ---------- + res_dir : Path + Directory to write result files into. + test_set_name : str + Name of the test set (used in filenames and logs). + results_dict : Dict[str, List[Tuple[str, List[str], List[str]]]] + Mapping from setting name (e.g., 'greedy_search') to a list of tuples + (cut_id, ref_tokens, hyp_tokens). + suffix : Optional[str] + Optional suffix added to filenames to distinguish checkpoints. + """ + test_set_wers: Dict[str, float] = {} + suffix_str = f"-{suffix}" if suffix else "" + for key, results in results_dict.items(): + recog_path = res_dir / f"recogs-{test_set_name}-{key}{suffix_str}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + # store_transcripts_and_timestamps(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}-{key}{suffix_str}.txt" + # results = [r[:3] for r in results] + with open(errs_filename, "w", encoding="utf8") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = res_dir / f"wer-summary-{test_set_name}{suffix_str}.txt" + with open(errs_info, "w", encoding="utf8") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + wer_info = res_dir / f"wer-summary-all{suffix_str}.txt" + if not os.path.exists(wer_info): + with open(wer_info, "w", encoding="utf8") as f: + print("dataset\tsettings\tWER", file=f) + with open(wer_info, "a+", encoding="utf8") as f: + for key, val in test_set_wers: + print("{}\t{}\t{}".format(test_set_name, key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], + enable_log: bool = True, + compute_CER: bool = False, + sclite_mode: bool = False, +) -> float: + """Write statistics based on predicted results and reference transcripts. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + A list of (cut_id, ref_tokens, hyp_tokens), where token lists are words + or subwords depending on your tokenizer/normalization. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) diff --git a/examples/lalm/scripts/build_model.sh b/examples/lalm/scripts/build_model.sh new file mode 100755 index 0000000..deb44bb --- /dev/null +++ b/examples/lalm/scripts/build_model.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Assemble a LALM checkpoint from pretrained LLM + audio encoder components. +# +# Run once before training. Saves a self-contained HF checkpoint to $output_dir. +# After this, load the model with: +# AutoModelForCausalLM.from_pretrained(output_dir) +# LALMProcessor.from_pretrained(output_dir) +# +# Usage: +# cd examples/lalm && bash scripts/build_model.sh + +set -euo pipefail + +# ── Components ──────────────────────────────────────────────────────────────── +llm=/apdcephfs_cq12/share_302080740/model/Qwen2.5-7B-Instruct +encoder=/apdcephfs_cq12/share_302080740/model/Qwen3-Omni-30B-A3B-Instruct + +# ── Downsampling ────────────────────────────────────────────────────────────── +# projector_downsample_rate: frame concat in LALMProjector (higher = fewer LLM tokens) +projector_downsample_rate=1 + +# ── Output ──────────────────────────────────────────────────────────────────── +output_dir=./models/aut_qwen2_7b_ds${projector_downsample_rate} + +echo "LLM: $llm" +echo "Encoder: $encoder" +echo "Output: $output_dir" + +python build_model.py \ + --llm "$llm" \ + --encoder "$encoder" \ + --output_dir "$output_dir" \ + --dtype bfloat16 \ + --projector_downsample_rate "$projector_downsample_rate" + +echo "Done. Checkpoint saved to $output_dir" diff --git a/examples/lalm/scripts/evaluate.sh b/examples/lalm/scripts/evaluate.sh new file mode 100755 index 0000000..7cdec3b --- /dev/null +++ b/examples/lalm/scripts/evaluate.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# LALM evaluation script (single GPU). +# +# Usage (from examples/lalm/): +# bash scripts/evaluate.sh +# +# Override checkpoint via env vars, e.g.: +# ITER=5000 bash scripts/evaluate.sh +# EPOCH=3 bash scripts/evaluate.sh + +set -euo pipefail + +# ── Environment ─────────────────────────────────────────────────────────────── +export TOKENIZERS_PARALLELISM=false +export TRANSFORMERS_NO_ADVISORY_WARNINGS=1 +export PYTHONPATH=/apdcephfs_cq10/share_1603164/user/yiwenyshao/lhotse:${PYTHONPATH:-} + +# ── GPU ─────────────────────────────────────────────────────────────────────── +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} + +# ── Experiment ──────────────────────────────────────────────────────────────── +exp_name="lalm_8gpus_aut_qwen2_ds4" +exp_dir="./exp/$exp_name" + +# ── Checkpoint ──────────────────────────────────────────────────────────────── +# Set ITER or EPOCH to select a checkpoint; leave both 0 to specify a filename. +iter=16000 +epoch=${EPOCH:-0} +# filename="" # uncomment to load a specific file, e.g. "checkpoint-5000.pt" + +# ── Decoding ────────────────────────────────────────────────────────────────── +decoding_method="greedy_search" # greedy_search | beam_search +max_new_tokens=200 +dtype="fp16" + +# ── Data ────────────────────────────────────────────────────────────────────── +test_data_config="configs/valid_data_config.yaml" +max_duration=1000 + +echo "========================================================" +echo " Exp: ${exp_dir}" +echo " Iter: ${iter} | Epoch: ${epoch}" +echo " Decoding: ${decoding_method}" +echo "========================================================" + +python evaluate.py \ + exp_dir="$exp_dir" \ + checkpoint.iter="$iter" \ + decoding_method="$decoding_method" \ + max_new_tokens="$max_new_tokens" \ + dtype="$dtype" \ + data.test_data_config="$test_data_config" \ + data.max_duration="$max_duration" diff --git a/examples/lalm/scripts/prepare_manifest.sh b/examples/lalm/scripts/prepare_manifest.sh new file mode 100755 index 0000000..4e18ed0 --- /dev/null +++ b/examples/lalm/scripts/prepare_manifest.sh @@ -0,0 +1,5 @@ +python prepare_conversation.py \ + --input_manifest /apdcephfs_cq12/share_302080740/data/asr_train_data/manifests/chinese/open_source/aishell1/aishell1_cuts.jsonl.gz \ + --output_manifest data/train/aishell1_cuts_conversation.jsonl.gz \ + --tokenizer /apdcephfs_cq12/share_302080740/model/Qwen2-7B-Instruct \ + --instruction "Please transcribe speech." \ \ No newline at end of file diff --git a/examples/lalm/scripts/train.sh b/examples/lalm/scripts/train.sh new file mode 100755 index 0000000..1fd0545 --- /dev/null +++ b/examples/lalm/scripts/train.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# LALM training script. +# +# Single-node multi-GPU: +# bash scripts/train.sh +# +# Multi-node (set NNODES / NODE_RANK / MASTER_ADDR / MASTER_PORT externally): +# NNODES=2 NODE_RANK=0 MASTER_ADDR= bash scripts/train.sh +# +# Assumes the current working directory is examples/lalm: +# bash scripts/train.sh + +set -euo pipefail + +# ── Environment ─────────────────────────────────────────────────────────────── +export TOKENIZERS_PARALLELISM=false +export TRANSFORMERS_NO_ADVISORY_WARNINGS=1 +export PYTHONPATH=/apdcephfs_cq10/share_1603164/user/yiwenyshao/lhotse:${PYTHONPATH:-} + +# ── GPU / node ──────────────────────────────────────────────────────────────── +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} +# export CUDA_VISIBLE_DEVICES=0 +num_gpus=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}') +num_nodes=${NNODES:-1} +node_rank=${NODE_RANK:-0} +master_addr=${MASTER_ADDR:-127.0.0.1} +master_port=${MASTER_PORT:-29500} + +# ── Model ───────────────────────────────────────────────────────────────────── +# Point to the checkpoint created by scripts/build_model.sh. +model_dir="./models/aut_qwen2_7b_ds1" +mixed_precision=bf16 + +# Modules to freeze: "audio_tower" for stage-1 (projector-only training). +# Set to "" to unfreeze all (stage-2, full fine-tuning). +frozen_modules="audio_tower,language_model" + +# ── Experiment ──────────────────────────────────────────────────────────────── +exp_name=aut_qwen2_7b_ds1_asr +exp_dir="./exp/$exp_name" + +echo "========================================================" +echo " Nodes: ${num_nodes} | GPUs/node: ${num_gpus}" +echo " Master: ${master_addr}:${master_port}" +echo " Model: ${model_dir}" +echo " Exp: ${exp_dir}" +echo "========================================================" + +torchrun \ + --nnodes="$num_nodes" \ + --nproc_per_node="$num_gpus" \ + --node_rank="$node_rank" \ + --master_addr="$master_addr" \ + --master_port="$master_port" \ + train.py \ + exp_dir="$exp_dir" \ + model.pretrained_model="$model_dir" \ + trainer.mixed_precision="$mixed_precision" \ + trainer.frozen_modules="[$frozen_modules]" \ + trainer.optimizer.lr=1e-3 \ + trainer.optimizer.weight_decay=0 \ + trainer.scheduler.type=cosine \ + trainer.use_averaged_model=true \ + trainer.valid_interval=1000 \ + trainer.save_every_n=4 \ + trainer.keep_last_k=5 \ + trainer.log_interval=50 \ + data.feature=whisper_v3_fbank \ + data.audio_token_rate=12.5 \ + data.min_duration=0.5 \ + data.max_duration=30.0 \ + data.use_infinite_dataset=true \ + data.num_workers=8 \ + data.sampler.max_tokens=2000 diff --git a/examples/lalm/train.py b/examples/lalm/train.py new file mode 100644 index 0000000..a60b1f5 --- /dev/null +++ b/examples/lalm/train.py @@ -0,0 +1,117 @@ +import logging +import os + +import hydra +import torch +import torch.distributed as dist +from lalm_core import LALMDataModule, LALMTrainer +from lalm_core.model import LALMConfig, LALMForConditionalGeneration, LALMProcessor +from lhotse.utils import fix_random_seed +from omegaconf import DictConfig, OmegaConf + + +def _resolve_dtype(name: str | None) -> torch.dtype: + if name == "bf16": + return torch.bfloat16 + if name == "fp16": + return torch.float16 + return torch.float32 + + +def _freeze_modules(model: torch.nn.Module, names: list[str]) -> None: + for name in names: + module = model + for part in name.split("."): + module = getattr(module, part, None) + if module is None: + logging.warning(f"[train] Module not found, skip freeze: {name}") + break + if module is not None: + for p in module.parameters(): + p.requires_grad = False + + +@hydra.main(version_base=None, config_path="configs", config_name="train") +def main(cfg: DictConfig): + logging.info("\n" + OmegaConf.to_yaml(cfg)) + + # 1) Fix random seed + if "seed" in cfg: + fix_random_seed(cfg.seed) + + # 2) Gather torchrun environment variables + rank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", init_method="env://") + + if cfg.get("exp_dir"): + os.makedirs(cfg.exp_dir, exist_ok=True) + + dtype = _resolve_dtype(cfg.trainer.get("mixed_precision", None)) + start_batch = int(cfg.trainer.get("start_batch", 0)) + start_epoch = int(cfg.trainer.get("start_epoch", 0)) + hf_dir = os.path.join(cfg.exp_dir, "hf") if cfg.get("exp_dir") else None + + # 3) Match BaseTrainer semantics: + # - fresh start when start_batch == 0 and start_epoch <= 1 + # - resume from step checkpoint when start_batch > 0 + # - resume from epoch checkpoint when start_epoch > 1 + fresh_start = start_batch == 0 and start_epoch == 1 + if not fresh_start: + if not hf_dir or not os.path.isdir(hf_dir): + raise FileNotFoundError( + f"[train] Resume requested (start_batch={start_batch}, " + f"start_epoch={start_epoch}) but hf_dir not found: {hf_dir}" + ) + logging.info( + f"[train] Resume mode: init model from config in {hf_dir}; " + "trainer will restore checkpoint weights." + ) + model_cfg = LALMConfig.from_pretrained(hf_dir) + from transformers.modeling_utils import no_init_weights + + with no_init_weights(): + model = LALMForConditionalGeneration(model_cfg) + model = model.to(dtype) + processor = LALMProcessor.from_pretrained(hf_dir) + else: + model_dir = cfg.model.pretrained_model + logging.info(f"[train] Fresh start: loading model from {model_dir}") + model = LALMForConditionalGeneration.from_pretrained( + model_dir, torch_dtype=dtype + ) + processor = LALMProcessor.from_pretrained(model_dir) + if rank == 0: + os.makedirs(hf_dir, exist_ok=True) + model.config.save_pretrained(hf_dir) + processor.save_pretrained(hf_dir) + logging.info(f"[train] Saved config + processor to {hf_dir}") + + # 4) Freeze requested modules + frozen = list(cfg.trainer.get("frozen_modules", []) or []) + if frozen: + _freeze_modules(model, frozen) + + # 5) Initialize data module + data_module = LALMDataModule(cfg.data, processor) + + # 6) Create trainer and run + trainer = LALMTrainer( + cfg, + model, + data_module, + rank=rank, + local_rank=local_rank, + world_size=world_size, + ) + trainer.run() + + if world_size > 1: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/src/auden/data/lhotse_datamodule.py b/src/auden/data/lhotse_datamodule.py index 965bf3d..4164f2f 100644 --- a/src/auden/data/lhotse_datamodule.py +++ b/src/auden/data/lhotse_datamodule.py @@ -113,21 +113,26 @@ def _setup_feature_extraction(self): feature_type = self.cfg.get("feature", "fbank") if feature_type == "fbank": self.input_strategy = OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) + Fbank(FbankConfig(num_mel_bins=80)), + fault_tolerant=self.cfg.get("fault_tolerant", False), ) logging.info("Using default kaldi-fbank") elif feature_type == "whisper_fbank": self.input_strategy = OnTheFlyFeatures( - WhisperFbank(WhisperFbankConfig(num_filters=80)) + WhisperFbank(WhisperFbankConfig(num_filters=80)), + fault_tolerant=self.cfg.get("fault_tolerant", False), ) logging.info("Using Whisper fbank (80 dims)") elif feature_type == "whisper_v3_fbank": self.input_strategy = OnTheFlyFeatures( - WhisperFbank(WhisperFbankConfig(num_filters=128)) + WhisperFbank(WhisperFbankConfig(num_filters=128)), + fault_tolerant=self.cfg.get("fault_tolerant", False), ) logging.info("Using Whisper v3 fbank (128 dims)") elif feature_type == "wav": - self.input_strategy = AudioSamples() + self.input_strategy = AudioSamples( + fault_tolerant=self.cfg.get("fault_tolerant", False) + ) logging.info("Using raw waveform") else: self.input_strategy = PrecomputedFeatures() @@ -277,7 +282,7 @@ def _build_train_sampler(self, train_cutset): train_sampler = DynamicBucketingSampler( train_cutset, max_duration=self.cfg.sampler.max_duration, - max_cuts=getattr(self.cfg.sampler, 'max_cuts', None), + max_cuts=getattr(self.cfg.sampler, "max_cuts", None), shuffle=self.cfg.sampler.shuffle, num_buckets=self.cfg.sampler.num_buckets, buffer_size=self.cfg.sampler.num_buckets * 2000, diff --git a/src/auden/trainer/ddp_trainer.py b/src/auden/trainer/ddp_trainer.py index f417b31..c5bb50e 100644 --- a/src/auden/trainer/ddp_trainer.py +++ b/src/auden/trainer/ddp_trainer.py @@ -88,7 +88,18 @@ def __init__(self, cfg, model, data_module, rank=0, local_rank=0, world_size=1): self.local_rank = local_rank self.world_size = world_size self.device = torch.device("cuda", local_rank) - self.use_fp16 = cfg.trainer.use_fp16 + mixed_precision = None + if "mixed_precision" in cfg.trainer and cfg.trainer.mixed_precision is not None: + mixed_precision = str(cfg.trainer.mixed_precision).lower() + elif "use_fp16" in cfg.trainer: # deprecated, for backward compatibility + mixed_precision = "fp16" if bool(cfg.trainer.use_fp16) else None + + if mixed_precision not in (None, "fp16", "bf16"): + raise ValueError( + f"Invalid mixed_precision: {mixed_precision}. " + "Expected one of: None, 'fp16', 'bf16'." + ) + self.mixed_precision = mixed_precision self.global_step = cfg.trainer.start_batch self.tb_writer = None if self.rank == 0 and cfg.trainer.tensorboard: @@ -101,7 +112,9 @@ def __init__(self, cfg, model, data_module, rank=0, local_rank=0, world_size=1): self.model, self.model_avg = self.setup_model(model) # optimizer and scheduler - self.scaler = torch.amp.GradScaler("cuda", enabled=self.use_fp16) + self.scaler = torch.amp.GradScaler( + "cuda", enabled=self.mixed_precision == "fp16" + ) self.optimizer = self.build_optimizer(self.model) self.scheduler = self.build_scheduler(self.optimizer) @@ -281,6 +294,13 @@ def __init__(self, torch_scheduler, update_on="epoch"): def get_last_lr(self): return self._last_lr + def state_dict(self): + return self._sch.state_dict() + + def load_state_dict(self, state_dict): + self._sch.load_state_dict(state_dict) + self._last_lr = [g["lr"] for g in self._sch.optimizer.param_groups] + def step_batch(self, batch: int | None = None): if self._update_on == "batch": self._sch.step() @@ -483,7 +503,14 @@ def train_one_epoch(self, epoch: int): batch_idx = self.global_step self.global_step += 1 - batch_size = batch["inputs"].size(0) + if "batch_size" in batch: + batch_size = batch["batch_size"] + elif "inputs" in batch: + batch_size = batch["inputs"].size(0) + else: + raise ValueError( + f"Batch does not contain 'inputs' or 'features': {batch}" + ) loss, batch_metrics = self._forward_backward_optimize(batch) @@ -663,7 +690,16 @@ def _forward_backward_optimize(self, batch): This method uses automatic mixed precision when self.use_fp16 is True. The gradient scaler is used to prevent gradient underflow in FP16 training. """ - with torch.amp.autocast("cuda", enabled=self.use_fp16): + amp_dtype = ( + torch.float16 + if self.mixed_precision == "fp16" + else torch.bfloat16 if self.mixed_precision == "bf16" else None + ) + with torch.amp.autocast( + "cuda", + enabled=self.mixed_precision is not None, + dtype=amp_dtype, + ): loss, batch_metrics = self._forward_one_batch(batch=batch, is_training=True) # Backprop and optimization step @@ -768,7 +804,7 @@ def _maybe_rescale_grad_amp(self, batch_idx: int): If the scale becomes extremely small (< 1e-5), training will be terminated as this indicates severe numerical instability. """ - if not self.use_fp16 or batch_idx % 100 != 0: + if self.mixed_precision != "fp16" or batch_idx % 100 != 0: return cur_scale = self.scaler.get_scale() @@ -809,19 +845,20 @@ def _maybe_log_training_status( return cur_lr = max(self.scheduler.get_last_lr()) - cur_grad_scale = self.scaler.get_scale() if self.use_fp16 else 1.0 + use_fp16 = self.mixed_precision == "fp16" + cur_grad_scale = self.scaler.get_scale() if use_fp16 else 1.0 logging.info( f"Epoch {epoch}, " f"batch {batch_idx}, info[{batch_metrics}], " f"tot_info[{total_metrics}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {cur_grad_scale}" if self.use_fp16 else "") + + (f"grad_scale: {cur_grad_scale}" if use_fp16 else "") ) if self.tb_writer is not None: self.tb_writer.add_scalar("train/learning_rate", cur_lr, self.global_step) - if self.use_fp16: + if use_fp16: self.tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, self.global_step )