Skip to content

mtmd: add granite-speech support (ibm-granite/granite-4.0-1b-speech)#22101

Open
ReinforcedKnowledge wants to merge 7 commits intoggml-org:masterfrom
ReinforcedKnowledge:add-granite-speech-support
Open

mtmd: add granite-speech support (ibm-granite/granite-4.0-1b-speech)#22101
ReinforcedKnowledge wants to merge 7 commits intoggml-org:masterfrom
ReinforcedKnowledge:add-granite-speech-support

Conversation

@ReinforcedKnowledge
Copy link
Copy Markdown

@ReinforcedKnowledge ReinforcedKnowledge commented Apr 19, 2026

Overview

Adds support for ibm-granite/granite-4.0-1b-speech.

  • Conformer encoder + QFormer projector (graph builder in granite-speech.cpp)
  • Audio preprocessor: log-mel spectrogram, dynamic range compression, frame stacking
  • GGUF converter: batch norm folding, K/V split, Conv1d reshape
  • Follows existing conformer/whisper patterns

Tested with greedy decoding on 30s/60s/120s/180s/360s clips, token-for-token match against HF transformers (following their script on the model card) for 30s and 60s. Too heavy for me to run for longer on HF but at 120s/180s there is noticeable degradation and at 360s it completely loops.

Test command:

ffmpeg -i input.wav -t 30 -ar 16000 -ac 1 test.wav

python convert_hf_to_gguf.py models/granite-4.0-1b-speech --outtype f16
python convert_hf_to_gguf.py models/granite-4.0-1b-speech --outtype f16 --mmproj

./build/bin/llama-mtmd-cli -m models/granite-4.0-1b-speech/granite-4.0-1B-speech-F16.gguf --mmproj models/granite-4.0-1b-speech/mmproj-granite-4.0-1b-speech-F16.gguf --audio test.wav -p "can you transcribe the speech into a written format?" --jinja --temp 0 -c 4096

Also test the UI:

./build/bin/llama-server -m models/granite-4.0-1b-speech/granite-4.0-1B-speech-F16.gguf --mmproj models/granite-4.0-1b-speech/mmproj-granite-4.0-1b-speech-F16.gguf --jinja -c 4096

Uploading an audio file and using the prompt above produces the same transcription as the CLI.

Notes: --jinja is required and the prompt "can you transcribe the speech into a written format?" is taken from the model card.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, used as a search tool to find relevant implementations for similar models and to find similar models that are unknown to me, for example I'd ask codex to find where and how the conformer calls log mel spectrogram. I copy pasted as much as possible from the existing architectures, when it made sense obviously. Also used codex for finding the different integration points across the multimodal part of the codebase, I'm familiar with the text-only part but not that much with the multimodal part before and how it's structured.

EDIT: Added the comment on testing the chat UI.

@ReinforcedKnowledge ReinforcedKnowledge requested review from a team and CISC as code owners April 19, 2026 01:24
@github-actions github-actions bot added examples python python script changes labels Apr 19, 2026
Conformer encoder with Shaw relative position encoding,
QFormer projector, log-mel spectrogram with frame stacking.

Encoder uses GLU gating, folded batch norm, and SSM depthwise
conv. QFormer compresses encoder output via windowed
cross-attention (window=15, queries=3) into the LLM embedding
space.

Audio preprocessing: reflect-padded STFT, 80-bin mel filterbank,
dynamic range compression, 2x frame stacking (80->160 mel).

GGUF converter handles batch norm folding at export time,
fused K/V split, and Conv1d weight reshaping.

Tested against HF transformers reference: token-for-token match
on 30s/60s audio clips with greedy decoding.
@ReinforcedKnowledge ReinforcedKnowledge force-pushed the add-granite-speech-support branch from f4c14e1 to 7b313dc Compare April 19, 2026 01:28
@taronaeo taronaeo requested a review from gabe-l-hart April 20, 2026 03:13
Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread tools/mtmd/clip-impl.h Outdated
Comment on lines +67 to +97
ggml_tensor * Q = build_mm(layer.q_w, normed);
ggml_tensor * K = build_mm(layer.k_w, normed);
ggml_tensor * V = build_mm(layer.v_w, normed);

Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, context_size, num_blocks);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, context_size, num_blocks);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, context_size, num_blocks);

ggml_tensor * Q_perm = ggml_permute(ctx0, Q, 0, 2, 1, 3);
ggml_tensor * K_perm = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));

ggml_tensor * kq = ggml_mul_mat(ctx0, K_perm, Q_perm);

// Shaw RPE: pos_emb ne[2]=1 broadcasts against Q ne[2]=num_blocks in mul_mat
ggml_tensor * pos_emb = ggml_get_rows(ctx0, layer.attn_rel_pos_emb, attn_dists);
pos_emb = ggml_reshape_3d(ctx0, pos_emb, d_head, context_size, context_size);
pos_emb = ggml_reshape_4d(ctx0, pos_emb, d_head, context_size, 1, context_size);

ggml_tensor * Q_shaw = ggml_permute(ctx0, Q, 0, 1, 3, 2);
ggml_tensor * pos_attn = ggml_mul_mat(ctx0, pos_emb, Q_shaw);
pos_attn = ggml_cont(ctx0, ggml_permute(ctx0, pos_attn, 0, 2, 3, 1));

ggml_tensor * scores = ggml_add(ctx0, kq, pos_attn);
ggml_tensor * attn_weights = ggml_soft_max_ext(ctx0, scores, attn_mask,
kq_scale, 0.0f);

ggml_tensor * V_perm = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
ggml_tensor * attn_out = ggml_mul_mat(ctx0, V_perm, attn_weights);

attn_out = ggml_permute(ctx0, attn_out, 0, 2, 1, 3);
attn_out = ggml_cont_2d(ctx0, attn_out, n_embd, padded_len);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use build_attn here? I suspect the only thing missing from build_attn was the mask, right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no :/ Shaw RPE injects pos_attn = mul_mat(pos_emb, Q) between the KQ product and softmax, and build_attn goes directly from mul_mat(k, q) to soft_max_ext with no hook for that. Flash attention path also fuses the whole thing into one op. But the QFormer attention already uses build_attn.

Comment thread tools/mtmd/clip-model.h Outdated
Comment thread tools/mtmd/clip-model.h Outdated
Comment thread tools/mtmd/clip-model.h Outdated
@gabe-l-hart
Copy link
Copy Markdown
Collaborator

@ReinforcedKnowledge Thanks for tackling this support! I'd been slowly working through Granite 3.3 Speech support, but had stalled out badly. I'll pull this down and give it a shot on both the new 4.0-based model and the older 3.2 and 3.3 models.

@gabe-l-hart
Copy link
Copy Markdown
Collaborator

gabe-l-hart commented Apr 20, 2026

One other feature of this model to be aware of: It uses a modality-conditional LoRA adapter to add the speech processing capabilities to the base LLM. This preserves the text-only capabilities of the model when run without speech inputs. Currently, adapter toggling must be done manually. I've worked on a branch to add automatic toggling for modality-conditional adapters, but was waiting on making more progress on Granite 4.0 Vision before submitting a PR since I haven't had a way to test it well.

We should not conflate this auto-toggling functionality with the basic model support, so no need to address it with this branch, but once the code is reviewed and fully working, I'll use this as a test for the auto-toggling as well.

🤦 Nope, I'm wrong here! The 3.x speech models used the conditional adapter while the 3.x vision models did not. It appears that this swapped for 4.0 (I didn't realize speech had dropped the conditional adapter for the HF release).

@gabe-l-hart
Copy link
Copy Markdown
Collaborator

Confirmed that this is working nicely for 4.0 with the embedded multilingual sample from the repo:

python convert_hf_to_gguf.py ~/models/ibm-granite/granite-4.0-1b-speech/ --outtype bf16
python convert_hf_to_gguf.py ~/models/ibm-granite/granite-4.0-1b-speech/ --outtype bf16 --mmproj
./build-rel/bin/llama-mtmd-cli -m ~/models/ibm-granite/granite-4.0-1b-speech/granite-4.0-1B-speech-BF16.gguf --mmproj ~/models/ibm-granite/granite-4.0-1b-speech/mmproj-granite-4.0-1b-speech-BF16.gguf --audio ~/models/ibm-granite/granite-4.0-1b-speech/multilingual_sample.wav -p "can you transcribe the speech into a written format?" --jinja --temp 0
for timothy was a spoiled cat and he allowed no one to interfere everybody waited upon him moving their chairs even for he was monarch of the hearth dinarzade la nuit suivante appela sa soeur quand il en fut temps si vous ne dormez pas ma soeur lui dit-elle je vous prie en attendant le jour qui paraîtra bientôt de continuer le conte du pêcheur
full logs
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 6.830 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 55662.79 MB
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected to use 3632 MiB of device memory vs. 53083 MiB of free device memory
llama_params_fit_impl: will leave 49451 >= 1024 MiB of free device memory, no changes needed
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 0.08 seconds
llama_model_load_from_file_impl: using device MTL0 (Apple M3 Max) (unknown id) - 53083 MiB free
llama_model_loader: loaded meta data with 40 key-value pairs and 363 tensors from /Users/ghart/models/ibm-granite/granite-4.0-1b-speech/granite-4.0-1B-speech-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = granite
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Granite 4.0 1b Speech
llama_model_loader: - kv   3:                           general.finetune str              = speech
llama_model_loader: - kv   4:                           general.basename str              = granite-4.0
llama_model_loader: - kv   5:                         general.size_label str              = 1B
llama_model_loader: - kv   6:                            general.license str              = apache-2.0
llama_model_loader: - kv   7:                   general.base_model.count u32              = 1
llama_model_loader: - kv   8:                  general.base_model.0.name str              = Granite 4.0 1b Base
llama_model_loader: - kv   9:          general.base_model.0.organization str              = Ibm Granite
llama_model_loader: - kv  10:              general.base_model.0.repo_url str              = https://huggingface.co/ibm-granite/gr...
llama_model_loader: - kv  11:                          general.languages arr[str,7]       = ["multilingual", "en", "fr", "de", "e...
llama_model_loader: - kv  12:                        granite.block_count u32              = 40
llama_model_loader: - kv  13:                     granite.context_length u32              = 4096
llama_model_loader: - kv  14:                   granite.embedding_length u32              = 2048
llama_model_loader: - kv  15:                granite.feed_forward_length u32              = 4096
llama_model_loader: - kv  16:               granite.attention.head_count u32              = 16
llama_model_loader: - kv  17:            granite.attention.head_count_kv u32              = 4
llama_model_loader: - kv  18:                     granite.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  19:   granite.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  20:                          general.file_type u32              = 32
llama_model_loader: - kv  21:                         granite.vocab_size u32              = 100353
llama_model_loader: - kv  22:               granite.rope.dimension_count u32              = 128
llama_model_loader: - kv  23:                    granite.attention.scale f32              = 0.007812
llama_model_loader: - kv  24:                    granite.embedding_scale f32              = 12.000000
llama_model_loader: - kv  25:                     granite.residual_scale f32              = 0.220000
llama_model_loader: - kv  26:                        granite.logit_scale f32              = 8.000000
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - kv  28:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  29:                         tokenizer.ggml.pre str              = dbrx
llama_model_loader: - kv  30:                      tokenizer.ggml.tokens arr[str,100353]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  31:                  tokenizer.ggml.token_type arr[i32,100353]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  32:                      tokenizer.ggml.merges arr[str,100000]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  33:                tokenizer.ggml.bos_token_id u32              = 100257
llama_model_loader: - kv  34:                tokenizer.ggml.eos_token_id u32              = 100257
llama_model_loader: - kv  35:            tokenizer.ggml.unknown_token_id u32              = 100269
llama_model_loader: - kv  36:            tokenizer.ggml.padding_token_id u32              = 100256
llama_model_loader: - kv  37:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  38:                    tokenizer.chat_template str              = {% for message in messages %}{% if me...
llama_model_loader: - kv  39:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - type  f32:   81 tensors
llama_model_loader: - type bf16:  282 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = BF16
print_info: file size   = 3.42 GiB (16.00 BPW) 
load: 69 unused tokens
load: printing all EOG tokens:
load:   - 100257 ('<|end_of_text|>')
load:   - 100261 ('<|fim_pad|>')
load: special tokens cache size = 97
load: token to piece cache size = 0.6152 MB
print_info: arch                  = granite
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 4096
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 40
print_info: n_head                = 16
print_info: n_head_kv             = 4
print_info: n_rot                 = 128
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 128
print_info: n_embd_head_v         = 128
print_info: n_gqa                 = 4
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-05
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 8.0e+00
print_info: f_attn_scale          = 7.8e-03
print_info: n_ff                  = 4096
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 0
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 4096
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = yes
print_info: model type            = 3B
print_info: model params          = 1.84 B
print_info: general.name          = Granite 4.0 1b Speech
print_info: f_embedding_scale     = 12.000000
print_info: f_residual_scale      = 0.220000
print_info: f_attention_scale     = 0.007812
print_info: n_ff_shexp            = 0
print_info: vocab type            = BPE
print_info: n_vocab               = 100353
print_info: n_merges              = 100000
print_info: BOS token             = 100257 '<|end_of_text|>'
print_info: EOS token             = 100257 '<|end_of_text|>'
print_info: EOT token             = 100257 '<|end_of_text|>'
print_info: UNK token             = 100269 '<|unk|>'
print_info: PAD token             = 100256 '<|pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 100258 '<|fim_prefix|>'
print_info: FIM SUF token         = 100260 '<|fim_suffix|>'
print_info: FIM MID token         = 100259 '<|fim_middle|>'
print_info: FIM PAD token         = 100261 '<|fim_pad|>'
print_info: EOG token             = 100257 '<|end_of_text|>'
print_info: EOG token             = 100261 '<|fim_pad|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 39 repeating layers to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   392.00 MiB
load_tensors:  MTL0_Mapped model buffer size =  3112.64 MiB
................................................................................
common_init_result: added <|end_of_text|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_seq     = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
llama_context:        CPU  output buffer size =     0.38 MiB
llama_kv_cache:       MTL0 KV buffer size =   320.00 MiB
llama_kv_cache: size =  320.00 MiB (  4096 cells,  40 layers,  1/1 seqs), K (f16):  160.00 MiB, V (f16):  160.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 128
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 128
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:       MTL0 compute buffer size =   200.00 MiB
sched_reserve:        CPU compute buffer size =    16.01 MiB
sched_reserve: graph nodes  = 1329
sched_reserve: graph splits = 2
sched_reserve: reserve took 9.59 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
mtmd_cli_context: chat template example:
USER: You are a helpful assistant
Hello
 ASSISTANT:Hi thereUSER: How are you?
 ASSISTANT:
clip_model_loader: model name:   Granite 4.0 1b Speech
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    559
clip_model_loader: n_kv:         23

clip_model_loader: has audio encoder
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
clip_ctx: CLIP using MTL0 backend
load_hparams: projector:          granite_speech
load_hparams: n_embd:             1024
load_hparams: n_head:             8
load_hparams: n_ff:               4096
load_hparams: n_layer:            16
load_hparams: ffn_op:             gelu_quick
load_hparams: projection_dim:     2048

--- audio hparams ---
load_hparams: n_mel_bins:         160
load_hparams: proj_stack_factor:  0
load_hparams: audio_chunk_len:    0
load_hparams: audio_sample_rate:  16000
load_hparams: audio_n_fft:        512
load_hparams: audio_window_len:   400
load_hparams: audio_hop_len:      160

load_hparams: model size:         1105.61 MiB
load_hparams: metadata size:      0.20 MiB
warmup: warmup with audio size = 3000
alloc_compute_meta:       MTL0 compute buffer size =   177.39 MiB
alloc_compute_meta:        CPU compute buffer size =     1.98 MiB
alloc_compute_meta: graph splits = 1, nodes = 1313
warmup: flash attention is enabled
init_audio: audio input is in experimental stage and may have reduced quality:
    https://github.com/ggml-org/llama.cpp/discussions/13759
main: loading model: /Users/ghart/models/ibm-granite/granite-4.0-1b-speech/granite-4.0-1B-speech-BF16.gguf
WARN: This is an experimental CLI for testing multimodal capability.
      For normal use cases, please use the standard llama-cli
encoding audio slice...
audio slice encoded in 255 ms
decoding audio batch 1/1, n_tokens_batch = 252
audio decoded (batch 1/1) in 44 ms

for timothy was a spoiled cat and he allowed no one to interfere everybody waited upon him moving their chairs even for he was monarch of the hearth dinarzade la nuit suivante appela sa soeur quand il en fut temps si vous ne dormez pas ma soeur lui dit-elle je vous prie en attendant le jour qui paraîtra bientôt de continuer le conte du pêcheur


llama_perf_context_print:        load time =    1567.53 ms
llama_perf_context_print: prompt eval time =     397.67 ms /   270 tokens (    1.47 ms per token,   678.95 tokens per second)
llama_perf_context_print:        eval time =     977.52 ms /    86 runs   (   11.37 ms per token,    87.98 tokens per second)
llama_perf_context_print:       total time =    1930.72 ms /   356 tokens
llama_perf_context_print:    graphs reused =         85
ggml_metal_free: deallocating
ggml_metal_free: deallocating

@gabe-l-hart
Copy link
Copy Markdown
Collaborator

gabe-l-hart commented Apr 20, 2026

This is also working nicely for 3.3-2b! Note that for that model, you do need the adapter (though interestingly it does seem to transcribe the english without the adapter before apparently translating the french to english).

Convert

python convert_hf_to_gguf.py ~/models/granite-speech-3.3-2b/ --outtype bf16
python convert_hf_to_gguf.py ~/models/granite-speech-3.3-2b/ --outtype bf16 --mmproj
python convert_lora_to_gguf.py ~/models/granite-speech-3.3-2b/ --outtype bf16

Run with adapter

./build-rel/bin/llama-mtmd-cli -m ~/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf --mmproj ~/models/granite-speech-3.3-2b/mmproj-granite-speech-3.3-2b-BF16.gguf --lora ~/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16-LoRA.gguf --audio ~/models/ibm-granite/granite-4.0-1b-speech/multilingual_sample.wav -p "can you transcribe the speech into a written format?" --jinja --temp 0
for timothy was a spoiled cat and he allowed no one to interfere everybody waited upon him moving their chairs even for he was monarch of the hearth dinarzade la nuit suivante appela sa soeur quand il en fut temps si vous ne dormez pas ma soeur lui dit-elle je vous prie en attendant le jour qui paraîtra bientôt de continuer le compte du pêcheur
full logs
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.011 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 55662.79 MB
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected to use 15460 MiB of device memory vs. 53083 MiB of free device memory
llama_params_fit_impl: will leave 37623 >= 1024 MiB of free device memory, no changes needed
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 0.05 seconds
llama_model_load_from_file_impl: using device MTL0 (Apple M3 Max) (unknown id) - 53083 MiB free
llama_model_loader: loaded meta data with 39 key-value pairs and 362 tensors from /Users/ghart/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = granite
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Granite Speech 3.3 2b
llama_model_loader: - kv   3:                           general.basename str              = granite-speech-3.3
llama_model_loader: - kv   4:                         general.size_label str              = 2B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv   7:                  general.base_model.0.name str              = Granite 3.3 2b Instruct
llama_model_loader: - kv   8:          general.base_model.0.organization str              = Ibm Granite
llama_model_loader: - kv   9:              general.base_model.0.repo_url str              = https://huggingface.co/ibm-granite/gr...
llama_model_loader: - kv  10:                          general.languages arr[str,1]       = ["multilingual"]
llama_model_loader: - kv  11:                        granite.block_count u32              = 40
llama_model_loader: - kv  12:                     granite.context_length u32              = 131072
llama_model_loader: - kv  13:                   granite.embedding_length u32              = 2048
llama_model_loader: - kv  14:                granite.feed_forward_length u32              = 8192
llama_model_loader: - kv  15:               granite.attention.head_count u32              = 32
llama_model_loader: - kv  16:            granite.attention.head_count_kv u32              = 8
llama_model_loader: - kv  17:                     granite.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  18:   granite.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  19:                          general.file_type u32              = 32
llama_model_loader: - kv  20:                         granite.vocab_size u32              = 49160
llama_model_loader: - kv  21:               granite.rope.dimension_count u32              = 64
llama_model_loader: - kv  22:                    granite.attention.scale f32              = 0.015625
llama_model_loader: - kv  23:                    granite.embedding_scale f32              = 12.000000
llama_model_loader: - kv  24:                     granite.residual_scale f32              = 0.220000
llama_model_loader: - kv  25:                        granite.logit_scale f32              = 8.000000
llama_model_loader: - kv  26:               general.quantization_version u32              = 2
llama_model_loader: - kv  27:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  28:                         tokenizer.ggml.pre str              = refact
llama_model_loader: - kv  29:                      tokenizer.ggml.tokens arr[str,49160]   = ["<|end_of_text|>", "<fim_prefix>", "...
llama_model_loader: - kv  30:                  tokenizer.ggml.token_type arr[i32,49160]   = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  31:                      tokenizer.ggml.merges arr[str,48891]   = ["Ġ Ġ", "ĠĠ ĠĠ", "ĠĠĠĠ ĠĠ...
llama_model_loader: - kv  32:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  33:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  34:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  35:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  36:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  37:                    tokenizer.chat_template str              = {# Alias tools -> available_tools #}\n...
llama_model_loader: - kv  38:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - type  f32:   81 tensors
llama_model_loader: - type bf16:  281 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = BF16
print_info: file size   = 4.72 GiB (16.00 BPW) 
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 0 ('<|end_of_text|>')
load:   - 4 ('<fim_pad>')
load:   - 18 ('<reponame>')
load: special tokens cache size = 27
load: token to piece cache size = 0.2827 MB
print_info: arch                  = granite
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 131072
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 40
print_info: n_head                = 32
print_info: n_head_kv             = 8
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 64
print_info: n_embd_head_v         = 64
print_info: n_gqa                 = 4
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-05
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 8.0e+00
print_info: f_attn_scale          = 1.6e-02
print_info: n_ff                  = 8192
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 0
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 131072
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = yes
print_info: model type            = 3B
print_info: model params          = 2.53 B
print_info: general.name          = Granite Speech 3.3 2b
print_info: f_embedding_scale     = 12.000000
print_info: f_residual_scale      = 0.220000
print_info: f_attention_scale     = 0.015625
print_info: n_ff_shexp            = 0
print_info: vocab type            = BPE
print_info: n_vocab               = 49160
print_info: n_merges              = 48891
print_info: BOS token             = 0 '<|end_of_text|>'
print_info: EOS token             = 0 '<|end_of_text|>'
print_info: EOT token             = 0 '<|end_of_text|>'
print_info: UNK token             = 0 '<|end_of_text|>'
print_info: PAD token             = 0 '<|end_of_text|>'
print_info: LF token              = 203 'Ċ'
print_info: FIM PRE token         = 1 '<fim_prefix>'
print_info: FIM SUF token         = 3 '<fim_suffix>'
print_info: FIM MID token         = 2 '<fim_middle>'
print_info: FIM PAD token         = 4 '<fim_pad>'
print_info: FIM REP token         = 18 '<reponame>'
print_info: EOG token             = 0 '<|end_of_text|>'
print_info: EOG token             = 4 '<fim_pad>'
print_info: EOG token             = 18 '<reponame>'
print_info: max token length      = 512
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 39 repeating layers to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   192.03 MiB
load_tensors:  MTL0_Mapped model buffer size =  4832.66 MiB
...............................................................................................
llama_adapter_lora_init_impl: loading lora adapter from '/Users/ghart/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16-LoRA.gguf' ...
llama_adapter_lora_init_impl: Dumping metadata keys/values.
llama_adapter_lora_init_impl: - kv   0:                       general.architecture str              = granite
llama_adapter_lora_init_impl: - kv   1:                               general.type str              = adapter
llama_adapter_lora_init_impl: - kv   2:                               adapter.type str              = lora
llama_adapter_lora_init_impl: - kv   3:                               general.name str              = Granite Speech 3.3 2b
llama_adapter_lora_init_impl: - kv   4:                           general.basename str              = granite-speech-3.3
llama_adapter_lora_init_impl: - kv   5:                         general.size_label str              = 2B
llama_adapter_lora_init_impl: - kv   6:                            general.license str              = apache-2.0
llama_adapter_lora_init_impl: - kv   7:                   general.base_model.count u32              = 1
llama_adapter_lora_init_impl: - kv   8:                  general.base_model.0.name str              = Granite 3.3 2b Instruct
llama_adapter_lora_init_impl: - kv   9:          general.base_model.0.organization str              = Ibm Granite
llama_adapter_lora_init_impl: - kv  10:              general.base_model.0.repo_url str              = https://huggingface.co/ibm-granite/gr...
llama_adapter_lora_init_impl: - kv  11:                          general.languages arr[str,1]       = ["multilingual"]
llama_adapter_lora_init_impl: - kv  12:                         adapter.lora.alpha f32              = 32.000000
llama_adapter_lora_init_impl: - kv  13:               general.quantization_version u32              = 2
llama_adapter_lora_init_impl: MTL0_Mapped LoRA buffer size =    32.50 MiB
llama_adapter_lora_init_impl: loaded 160 tensors from lora file
common_init_result: added <|end_of_text|> logit bias = -inf
common_init_result: added <fim_pad> logit bias = -inf
common_init_result: added <reponame> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
llama_context:        CPU  output buffer size =     0.19 MiB
llama_kv_cache:       MTL0 KV buffer size = 10240.00 MiB
llama_kv_cache: size = 10240.00 MiB (131072 cells,  40 layers,  1/1 seqs), K (f16): 5120.00 MiB, V (f16): 5120.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 64
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 64
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:       MTL0 compute buffer size =   388.01 MiB
sched_reserve:        CPU compute buffer size =   264.01 MiB
sched_reserve: graph nodes  = 1329
sched_reserve: graph splits = 2
sched_reserve: reserve took 76.84 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
sched_reserve: reserving ...
sched_reserve:       MTL0 compute buffer size =   388.01 MiB
sched_reserve:        CPU compute buffer size =   264.01 MiB
sched_reserve: graph nodes  = 1649
sched_reserve: graph splits = 2
sched_reserve: reserve took 104.49 ms, sched copies = 1
mtmd_cli_context: chat template example:
<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>
<|start_of_role|>user<|end_of_role|>How are you?<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>
clip_model_loader: model name:   Granite Speech 3.3 2b
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    559
clip_model_loader: n_kv:         23

clip_model_loader: has audio encoder
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
clip_ctx: CLIP using MTL0 backend
load_hparams: projector:          granite_speech
load_hparams: n_embd:             1024
load_hparams: n_head:             8
load_hparams: n_ff:               4096
load_hparams: n_layer:            16
load_hparams: ffn_op:             gelu_quick
load_hparams: projection_dim:     2048

--- audio hparams ---
load_hparams: n_mel_bins:         160
load_hparams: proj_stack_factor:  0
load_hparams: audio_chunk_len:    0
load_hparams: audio_sample_rate:  16000
load_hparams: audio_n_fft:        512
load_hparams: audio_window_len:   400
load_hparams: audio_hop_len:      160

load_hparams: model size:         1105.25 MiB
load_hparams: metadata size:      0.20 MiB
warmup: warmup with audio size = 3000
alloc_compute_meta:       MTL0 compute buffer size =   177.39 MiB
alloc_compute_meta:        CPU compute buffer size =     1.98 MiB
alloc_compute_meta: graph splits = 1, nodes = 1313
warmup: flash attention is enabled
init_audio: audio input is in experimental stage and may have reduced quality:
    https://github.com/ggml-org/llama.cpp/discussions/13759
main: loading model: /Users/ghart/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf
WARN: This is an experimental CLI for testing multimodal capability.
      For normal use cases, please use the standard llama-cli
encoding audio slice...
audio slice encoded in 196 ms
decoding audio batch 1/1, n_tokens_batch = 252
audio decoded (batch 1/1) in 5 ms

for timothy was a spoiled cat and he allowed no one to interfere everybody waited upon him moving their chairs even for he was monarch of the hearth dinarzade la nuit suivante appela sa soeur quand il en fut temps si vous ne dormez pas ma soeur lui dit-elle je vous prie en attendant le jour qui paraîtra bientôt de continuer le compte du pêcheur


llama_perf_context_print:        load time =    3556.35 ms
llama_perf_context_print: prompt eval time =     358.18 ms /   322 tokens (    1.11 ms per token,   898.98 tokens per second)
llama_perf_context_print:        eval time =    1688.09 ms /    99 runs   (   17.05 ms per token,    58.65 tokens per second)
llama_perf_context_print:       total time =    2966.16 ms /   421 tokens
llama_perf_context_print:    graphs reused =         98
ggml_metal_free: deallocating
ggml_metal_free: deallocating

Run without adapter

./build-rel/bin/llama-mtmd-cli -m ~/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf --mmproj ~/models/granite-speech-3.3-2b/mmproj-granite-speech-3.3-2b-BF16.gguf --audio ~/models/ibm-granite/granite-4.0-1b-speech/multilingual_sample.wav -p "can you transcribe the speech into a written format?" --jinja --temp 0
Sure, I'd be happy to transcribe the speech into written format. Here's the transcription:

---

For Timothy was a spoiled cat, and he allowed no one to interfere. Everybody waited upon him, moving their chairs even, for he was the monarch of the hearth.

The next night, Timothy's sister called him when he was still awake. "Sister," he said, "if you don't sleep, I beg you, wait until the day that will soon appear to continue the tale of the pecker."

---

This transcription maintains the original rhythm and tone of the text, preserving the poetic language and the sense of formality in the dialogue.
full logs
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.012 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal4  (5002)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 55662.79 MB
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected to use 15460 MiB of device memory vs. 53083 MiB of free device memory
llama_params_fit_impl: will leave 37623 >= 1024 MiB of free device memory, no changes needed
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 0.05 seconds
llama_model_load_from_file_impl: using device MTL0 (Apple M3 Max) (unknown id) - 53083 MiB free
llama_model_loader: loaded meta data with 39 key-value pairs and 362 tensors from /Users/ghart/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = granite
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Granite Speech 3.3 2b
llama_model_loader: - kv   3:                           general.basename str              = granite-speech-3.3
llama_model_loader: - kv   4:                         general.size_label str              = 2B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv   7:                  general.base_model.0.name str              = Granite 3.3 2b Instruct
llama_model_loader: - kv   8:          general.base_model.0.organization str              = Ibm Granite
llama_model_loader: - kv   9:              general.base_model.0.repo_url str              = https://huggingface.co/ibm-granite/gr...
llama_model_loader: - kv  10:                          general.languages arr[str,1]       = ["multilingual"]
llama_model_loader: - kv  11:                        granite.block_count u32              = 40
llama_model_loader: - kv  12:                     granite.context_length u32              = 131072
llama_model_loader: - kv  13:                   granite.embedding_length u32              = 2048
llama_model_loader: - kv  14:                granite.feed_forward_length u32              = 8192
llama_model_loader: - kv  15:               granite.attention.head_count u32              = 32
llama_model_loader: - kv  16:            granite.attention.head_count_kv u32              = 8
llama_model_loader: - kv  17:                     granite.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  18:   granite.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  19:                          general.file_type u32              = 32
llama_model_loader: - kv  20:                         granite.vocab_size u32              = 49160
llama_model_loader: - kv  21:               granite.rope.dimension_count u32              = 64
llama_model_loader: - kv  22:                    granite.attention.scale f32              = 0.015625
llama_model_loader: - kv  23:                    granite.embedding_scale f32              = 12.000000
llama_model_loader: - kv  24:                     granite.residual_scale f32              = 0.220000
llama_model_loader: - kv  25:                        granite.logit_scale f32              = 8.000000
llama_model_loader: - kv  26:               general.quantization_version u32              = 2
llama_model_loader: - kv  27:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  28:                         tokenizer.ggml.pre str              = refact
llama_model_loader: - kv  29:                      tokenizer.ggml.tokens arr[str,49160]   = ["<|end_of_text|>", "<fim_prefix>", "...
llama_model_loader: - kv  30:                  tokenizer.ggml.token_type arr[i32,49160]   = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  31:                      tokenizer.ggml.merges arr[str,48891]   = ["Ġ Ġ", "ĠĠ ĠĠ", "ĠĠĠĠ ĠĠ...
llama_model_loader: - kv  32:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  33:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  34:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  35:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  36:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  37:                    tokenizer.chat_template str              = {# Alias tools -> available_tools #}\n...
llama_model_loader: - kv  38:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - type  f32:   81 tensors
llama_model_loader: - type bf16:  281 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = BF16
print_info: file size   = 4.72 GiB (16.00 BPW) 
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 0 ('<|end_of_text|>')
load:   - 4 ('<fim_pad>')
load:   - 18 ('<reponame>')
load: special tokens cache size = 27
load: token to piece cache size = 0.2827 MB
print_info: arch                  = granite
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 131072
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 40
print_info: n_head                = 32
print_info: n_head_kv             = 8
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 64
print_info: n_embd_head_v         = 64
print_info: n_gqa                 = 4
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-05
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 8.0e+00
print_info: f_attn_scale          = 1.6e-02
print_info: n_ff                  = 8192
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 0
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 131072
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = yes
print_info: model type            = 3B
print_info: model params          = 2.53 B
print_info: general.name          = Granite Speech 3.3 2b
print_info: f_embedding_scale     = 12.000000
print_info: f_residual_scale      = 0.220000
print_info: f_attention_scale     = 0.015625
print_info: n_ff_shexp            = 0
print_info: vocab type            = BPE
print_info: n_vocab               = 49160
print_info: n_merges              = 48891
print_info: BOS token             = 0 '<|end_of_text|>'
print_info: EOS token             = 0 '<|end_of_text|>'
print_info: EOT token             = 0 '<|end_of_text|>'
print_info: UNK token             = 0 '<|end_of_text|>'
print_info: PAD token             = 0 '<|end_of_text|>'
print_info: LF token              = 203 'Ċ'
print_info: FIM PRE token         = 1 '<fim_prefix>'
print_info: FIM SUF token         = 3 '<fim_suffix>'
print_info: FIM MID token         = 2 '<fim_middle>'
print_info: FIM PAD token         = 4 '<fim_pad>'
print_info: FIM REP token         = 18 '<reponame>'
print_info: EOG token             = 0 '<|end_of_text|>'
print_info: EOG token             = 4 '<fim_pad>'
print_info: EOG token             = 18 '<reponame>'
print_info: max token length      = 512
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 39 repeating layers to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   192.03 MiB
load_tensors:  MTL0_Mapped model buffer size =  4832.66 MiB
...............................................................................................
common_init_result: added <|end_of_text|> logit bias = -inf
common_init_result: added <fim_pad> logit bias = -inf
common_init_result: added <reponame> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
llama_context:        CPU  output buffer size =     0.19 MiB
llama_kv_cache:       MTL0 KV buffer size = 10240.00 MiB
llama_kv_cache: size = 10240.00 MiB (131072 cells,  40 layers,  1/1 seqs), K (f16): 5120.00 MiB, V (f16): 5120.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 64
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 64
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:       MTL0 compute buffer size =   388.01 MiB
sched_reserve:        CPU compute buffer size =   264.01 MiB
sched_reserve: graph nodes  = 1329
sched_reserve: graph splits = 2
sched_reserve: reserve took 13.42 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
mtmd_cli_context: chat template example:
<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>
<|start_of_role|>user<|end_of_role|>How are you?<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>
clip_model_loader: model name:   Granite Speech 3.3 2b
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    559
clip_model_loader: n_kv:         23

clip_model_loader: has audio encoder
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: use fusion         = true
ggml_metal_init: use concurrency    = true
ggml_metal_init: use graph optimize = true
clip_ctx: CLIP using MTL0 backend
load_hparams: projector:          granite_speech
load_hparams: n_embd:             1024
load_hparams: n_head:             8
load_hparams: n_ff:               4096
load_hparams: n_layer:            16
load_hparams: ffn_op:             gelu_quick
load_hparams: projection_dim:     2048

--- audio hparams ---
load_hparams: n_mel_bins:         160
load_hparams: proj_stack_factor:  0
load_hparams: audio_chunk_len:    0
load_hparams: audio_sample_rate:  16000
load_hparams: audio_n_fft:        512
load_hparams: audio_window_len:   400
load_hparams: audio_hop_len:      160

load_hparams: model size:         1105.25 MiB
load_hparams: metadata size:      0.20 MiB
warmup: warmup with audio size = 3000
alloc_compute_meta:       MTL0 compute buffer size =   177.39 MiB
alloc_compute_meta:        CPU compute buffer size =     1.98 MiB
alloc_compute_meta: graph splits = 1, nodes = 1313
warmup: flash attention is enabled
init_audio: audio input is in experimental stage and may have reduced quality:
    https://github.com/ggml-org/llama.cpp/discussions/13759
main: loading model: /Users/ghart/models/granite-speech-3.3-2b/granite-speech-3.3-2B-BF16.gguf
WARN: This is an experimental CLI for testing multimodal capability.
      For normal use cases, please use the standard llama-cli
encoding audio slice...
audio slice encoded in 188 ms
decoding audio batch 1/1, n_tokens_batch = 252
audio decoded (batch 1/1) in 2 ms

Sure, I'd be happy to transcribe the speech into written format. Here's the transcription:

---

For Timothy was a spoiled cat, and he allowed no one to interfere. Everybody waited upon him, moving their chairs even, for he was the monarch of the hearth.

The next night, Timothy's sister called him when he was still awake. "Sister," he said, "if you don't sleep, I beg you, wait until the day that will soon appear to continue the tale of the pecker."

---

This transcription maintains the original rhythm and tone of the text, preserving the poetic language and the sense of formality in the dialogue.


llama_perf_context_print:        load time =     959.91 ms
llama_perf_context_print: prompt eval time =     344.32 ms /   322 tokens (    1.07 ms per token,   935.18 tokens per second)
llama_perf_context_print:        eval time =    2635.49 ms /   163 runs   (   16.17 ms per token,    61.85 tokens per second)
llama_perf_context_print:       total time =    3293.26 ms /   485 tokens
llama_perf_context_print:    graphs reused =        162
ggml_metal_free: deallocating
ggml_metal_free: deallocating

@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 20, 2026

@gabe-l-hart If I understand correctly, the model contains specific adapters for audio / vision input, and the adapter is only activated during prompt processing of the corresponding modality input, right?

IIRC there was also a discussion about having built-in LoRA adapter (because currently adapters are loaded as separated files, which is not very convenient in terms of UX). I don't remember exactly where was the discussion, but may worth re-visit it.

Copy link
Copy Markdown
Collaborator

@gabe-l-hart gabe-l-hart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you SO much for putting this together! It's been on my TODO list for a very long time and just hasn't made it to the top.

I've got a number of nitty questions about things that should maybe be hparams instead of being hard-coded as well as a few structural questions for @ngxson about any future plans for further model-specific modularity in the codebase. The only concrete change request (besides the naming conventions from @ngxson) is that you update the base GraniteModel in convert_hf_to_gguf.py rather than introducing a special text model for Granite Speech.

Comment thread convert_hf_to_gguf.py
Comment thread convert_hf_to_gguf.py
Comment thread convert_hf_to_gguf.py Outdated
Comment thread tools/mtmd/clip.cpp
} break;
case PROJECTOR_TYPE_GRANITE_SPEECH:
{
hparams.audio_chunk_len = 0;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I've been curious about these hard-coded values. These seem like properties of the model instance and not the model architecture and thus something that would make sense as hparam values in the GGUF for the specific model. Is there something I'm missing that explicitly links the projector architecture to these specific values? I know that the upstream transformers models hard-code them, but I would imagine it might make sense to proactively put them in the GGUF so that if in the future the architecture is reused with different values, we don't need a code-change and/or reconverted GGUFs to support it. The fields are already there in the internal clip_hparams (the ones being set here), so I think it would just be a matter of defining the string constants for conversion and then adding these as the default values in the convert_hf_to_gguf.py stack.

Comment thread tools/mtmd/clip.cpp Outdated
const int n_layer_orig = hparams.n_layer;
if (model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|| model.proj_type == PROJECTOR_TYPE_GRANITE_SPEECH) {
hparams.n_layer = 0; // these models do not use the generic layer structure
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused reading this comment (which I know you didn't add). It would be helpful to flesh out the comment to indicate that hparams.n_layer will be re-set below so this is just a workaround to skip the generic layer processing.

Also, it seems like this workaround is a bit of a hack. Would it make more sense to have bool skip_standard_layers = false; and then check it directly below before performing the default layer functionality? Or better yet, put a conditional around the default layer logic that makes it explicit that these architectures are skipping it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also agree that it's a bit of a hack, it doesn't sit comfortably with me and if I can get rid of it with minimal changes to the code it'd be great. Currently if someone reads they won't understand why we're doing const int n_layer_orig = hparams.n_layer; until, and if, they read the code down the line for granite.

I think the best is as you suggested, wrap the loop in a conditional and avoid any mutation and save/restore. On it!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed, let me know if it's good 😄

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson owns this area, so his opinion counts for a lot more than mine, but IMO, this clean-up is worth the slight scope-creep on the PR rather than perpetuating a pattern that is hard to maintain.

Comment thread tools/mtmd/clip.cpp Outdated
Comment thread tools/mtmd/clip.cpp
}
set_input_f32("pos_emb", pos_emb);
} break;
case PROJECTOR_TYPE_GRANITE_SPEECH:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for @ngxson: Is there any plan to break up clip.cpp so that this kind of model-specific code can live in a <model-name>.cpp file? Right now, it looks like the arch-specific files are only for graph building, but it seems like it could go a lot further to encode this sort of logic as well (this is probably a much bigger conversation that bleeds into the model-modularity conversation in the core as well).

Comment thread tools/mtmd/mtmd-audio.h
mtmd_audio_cache cache;
};

struct mtmd_audio_preprocessor_granite_speech : mtmd_audio_preprocessor {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question to @ngxson about the modularity plans. This also seems ripe for isolation.

Comment thread tools/mtmd/models/granite-speech.cpp Outdated
Comment thread tools/mtmd/models/granite-speech.cpp Outdated
@gabe-l-hart
Copy link
Copy Markdown
Collaborator

@gabe-l-hart If I understand correctly, the model contains specific adapters for audio / vision input, and the adapter is only activated during prompt processing of the corresponding modality input, right?

IIRC there was also a discussion about having built-in LoRA adapter (because currently adapters are loaded as separated files, which is not very convenient in terms of UX). I don't remember exactly where was the discussion, but may worth re-visit it.

Right, that's the goal of these modular models. I was clearly a bit confused thinking that 4.0 speech had kept the adapter separate like 3.3 did. I know that 4.0 vision did keep them separate. The ultimate goal is a single running model with modality-specific adapters that toggle on/off automatically allowing a single model to server all modalities without sacrificing the text quality for text-only.

Now that we've got this working for the 3.3 model, I'll use that as a testbed for my modality-conditional-adapter branch.

@gabe-l-hart
Copy link
Copy Markdown
Collaborator

@ngxson I've opened a PR for the modality-conditional adapter logic (#22184) and tested it using this branch and granite-speech-3.3-2b. It seems to be working like a charm!

@CISC
Copy link
Copy Markdown
Member

CISC commented Apr 20, 2026

IIRC there was also a discussion about having built-in LoRA adapter (because currently adapters are loaded as separated files, which is not very convenient in terms of UX). I don't remember exactly where was the discussion, but may worth re-visit it.

That would be #13693

@gabe-l-hart
Copy link
Copy Markdown
Collaborator

That would be #13693

Right! Thanks for the reminder. I'll look back over that and make sure I haven't duplicated anything

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants