Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 76 additions & 49 deletions tools/ace-qwen3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,22 +529,22 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
}

// Batched decode loop, partial LM head: only project [TOKEN_IM_END..V)
Timer t_decode;
int V_eff = V - TOKEN_IM_END; // 65559 vs 217204
std::vector<float> logits_cond((size_t) V_eff * N);
std::vector<float> logits_uncond((size_t) V_eff * N);
std::vector<int> tokens(N);
Timer t_decode;
int V_eff = V - TOKEN_IM_END;

// CFG: single forward with 2*N (cond + uncond)
int N2 = use_cfg ? 2 * N : N;
std::vector<int> tokens_2n(N2), sets_2n(N2);
std::vector<float> logits_2n((size_t) V_eff * N2);
if (use_cfg) {
for (int i = 0; i < N; i++) {
sets_2n[i] = cond_sets[i];
sets_2n[N + i] = uncond_sets[i];
}
}
// Pre-allocate batched arrays for the maximum possible size (N or 2*N for CFG)
int max_N2 = use_cfg ? 2 * N : N;
std::vector<int> batch_tokens(max_N2);
std::vector<int> batch_sets(max_N2);
std::vector<float> batch_logits((size_t) V_eff * max_N2);

// This array maps the compact "active" index back to the original sequence index (0 to N-1)
std::vector<int> active_to_orig(N);

// Tiny array for CPU sampling (EOS token + Audio Codes) to prevent sorting 150,000 text logits
int audio_code_offset = AUDIO_CODE_BASE - TOKEN_IM_END;
int compact_V = AUDIO_CODE_COUNT + 1;
std::vector<float> compact_logits(compact_V);

int n_active = N;
for (int i = 0; i < N; i++) {
Expand All @@ -554,58 +554,85 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
}

for (int step = 0; step < max_tokens && n_active > 0; step++) {
// Collect tokens (done sequences feed their last token, result ignored)
for (int i = 0; i < N; i++) {
tokens[i] = seqs[i].last_token;
}
int current_active = 0;

if (use_cfg) {
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
for (int i = 0; i < N; i++) {
tokens_2n[i] = tokens[i];
tokens_2n[N + i] = tokens[i];
// 1. DYNAMIC COMPACTION: Loop through all N sequences, but only gather the active ones!
for (int i = 0; i < N; i++) {
if (!seqs[i].done) {
active_to_orig[current_active] = i; // Remember that this slot belongs to sequence 'i'

if (use_cfg) {
// Place the Cond token/set in the first half
batch_tokens[current_active] = seqs[i].last_token;
batch_sets[current_active] = cond_sets[i];

// Place the Uncond token/set exactly n_active elements later
batch_tokens[n_active + current_active] = seqs[i].last_token;
batch_sets[n_active + current_active] = uncond_sets[i];
} else {
batch_tokens[current_active] = seqs[i].last_token;
batch_sets[current_active] = cond_sets[i];
}
current_active++;
}
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data(), TOKEN_IM_END, V_eff);
memcpy(logits_cond.data(), logits_2n.data(), (size_t) V_eff * N * sizeof(float));
memcpy(logits_uncond.data(), logits_2n.data() + (size_t) V_eff * N, (size_t) V_eff * N * sizeof(float));
} else {
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data(), TOKEN_IM_END, V_eff);
}

// Per-sequence: CFG combine + sample (logits are [V_eff] starting at TOKEN_IM_END)
for (int i = 0; i < N; i++) {
if (seqs[i].done) {
continue;
}
// 2. FORWARD PASS: GPU only computes attention for n_active sequences
int actual_batch_size = use_cfg ? (2 * n_active) : n_active;
qw3lm_forward_batch(m, batch_tokens.data(), batch_sets.data(), actual_batch_size, batch_logits.data(),
TOKEN_IM_END, V_eff);

// 3. TARGETED CFG & LOGIT EXTRACTION
for (int a = 0; a < n_active; a++) {
int orig_i = active_to_orig[a]; // Map back to original sequence object

// Pointer to the conditional logits for THIS active sequence
float * lc = batch_logits.data() + (size_t) a * V_eff;

float * lc = logits_cond.data() + (size_t) i * V_eff;
if (use_cfg) {
float * lu = logits_uncond.data() + (size_t) i * V_eff;
for (int v = 0; v < V_eff; v++) {
lc[v] = lu[v] + cfg_scale * (lc[v] - lu[v]);
// Pointer to the unconditional logits (offset by n_active)
float * lu = batch_logits.data() + (size_t) (n_active + a) * V_eff;

// Targeted CFG Math: Only apply it to EOS + Audio Codes. Skip the 150,000 text tokens!
lc[0] = lu[0] + cfg_scale * (lc[0] - lu[0]); // EOS token
for (int c = 0; c < AUDIO_CODE_COUNT; c++) {
int idx = audio_code_offset + c;
lc[idx] = lu[idx] + cfg_scale * (lc[idx] - lu[idx]);
}
}

// Mask the 24-token gap: indices 1..AUDIO_CODE_BASE-TOKEN_IM_END-1
// (index 0 = TOKEN_IM_END = EOS, index 24+ = audio codes)
for (int v = 1; v < AUDIO_CODE_BASE - TOKEN_IM_END; v++) {
lc[v] = -1e9f;
// Extract ONLY the valid target tokens into the tiny compact array
compact_logits[0] = lc[0];
for (int c = 0; c < AUDIO_CODE_COUNT; c++) {
compact_logits[c + 1] = lc[audio_code_offset + c];
}
int tok = sample_top_k_p(lc, V_eff, temperature, top_p, top_k, seqs[i].rng) + TOKEN_IM_END;
seqs[i].last_token = tok;

// CPU samples instantly because it only has to sort ~2049 items instead of 150,000+
int compact_tok =
sample_top_k_p(compact_logits.data(), compact_V, temperature, top_p, top_k, seqs[orig_i].rng);

// Map the sampled index back to global vocabulary ID
int tok = (compact_tok == 0) ? TOKEN_IM_END : (AUDIO_CODE_BASE + compact_tok - 1);

seqs[orig_i].last_token = tok;

if (tok == TOKEN_IM_END) {
seqs[i].done = true;
n_active--;
} else if (tok >= AUDIO_CODE_BASE && tok < AUDIO_CODE_BASE + AUDIO_CODE_COUNT) {
seqs[i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
seqs[orig_i].done = true;
} else {
seqs[orig_i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
}
}

int total_codes = 0;
// 4. UPDATE ACTIVE COUNT for the next loop iteration
int next_active_count = 0;
int total_codes = 0;
for (int i = 0; i < N; i++) {
if (!seqs[i].done) {
next_active_count++;
}
total_codes += (int) seqs[i].audio_codes.size();
}
n_active = next_active_count;

if ((step + 1) % 50 == 0) {
double elapsed = t_decode.ms() / 1000.0;
Expand Down
Loading