perf: optimize Phase 2 batch generation with dynamic compaction by 3-12%#20
Conversation
📝 WalkthroughWalkthroughThe pull request refactors the Batched Phase 2 logic in ace-qwen3.cpp, replacing a per-step two-pass forward strategy with dynamic, compacted batching. It introduces active-to-original mapping, compact logits extraction, and CPU-side sampling to reduce GPU compute and memory footprint. Changes
Sequence Diagram(s)sequenceDiagram
participant Batch as Dynamic Batch<br/>(Phase 2)
participant GPU as GPU Forward<br/>Pass
participant Logits as Logits<br/>Processing
participant CPU as CPU Sampling
participant Map as ID Mapper
rect rgba(100, 150, 200, 0.5)
Note over Batch,Map: New Optimized Flow
Batch->>Batch: Collect active sequences<br/>via active_to_orig mapping
Batch->>GPU: Forward pass on compact<br/>batch (actual_batch_size)
GPU-->>Logits: Return logits tensor
Logits->>Logits: Extract compact subset<br/>(EOS + audio codes)
alt CFG Enabled
Logits->>Logits: Apply CFG scale to<br/>compact logits only
end
Logits->>CPU: Pass compact_logits<br/>to CPU
CPU->>CPU: Sample from compact<br/>vocabulary
CPU-->>Map: Local sampled IDs
Map->>Map: Map local IDs to<br/>global audio codes
Map-->>Batch: Global token IDs
Batch->>Batch: Update sequences &<br/>compute next n_active
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tools/ace-qwen3.cpp (1)
619-623: Keepn_activeandtotal_codesincremental in the hot path.This loop already knows when a sequence flips to
done, so the extra fullfor (i = 0; i < N; ++i)pass to rebuildnext_active_countis avoidable.total_codesis also recomputed on every step even though it is only emitted every 50 steps.Also applies to: 626-635
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/ace-qwen3.cpp` around lines 619 - 623, The loop that checks tok against TOKEN_IM_END should update the running counters in-place instead of recomputing them later: when seqs[orig_i].done transitions to true, decrement n_active immediately (and adjust any next_active_count tracking used later); when pushing an audio code (seqs[orig_i].audio_codes.push_back(tok - AUDIO_CODE_BASE)) increment total_codes immediately so you only recompute totals when needed for the 50-step emission; remove the subsequent full for-loop used to rebuild next_active_count/total_codes and ensure any logic that relied on that pass now reads the updated n_active and total_codes. Apply the same incremental updates in the analogous block around the code at lines 626-635.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tools/ace-qwen3.cpp`:
- Around line 619-623: The loop that checks tok against TOKEN_IM_END should
update the running counters in-place instead of recomputing them later: when
seqs[orig_i].done transitions to true, decrement n_active immediately (and
adjust any next_active_count tracking used later); when pushing an audio code
(seqs[orig_i].audio_codes.push_back(tok - AUDIO_CODE_BASE)) increment
total_codes immediately so you only recompute totals when needed for the 50-step
emission; remove the subsequent full for-loop used to rebuild
next_active_count/total_codes and ensure any logic that relied on that pass now
reads the updated n_active and total_codes. Apply the same incremental updates
in the analogous block around the code at lines 626-635.
|
That's interesting! I try it and merge it |
|
|
Bonus: compact sampling computes softmax over only the 2049 valid tokens (EOS + audio codes) instead of 65k, eliminating probability mass leakage to impossible text tokens and producing a sharper, more faithful distribution at every decode step |
…12% (#20) * perf: improve batch generation in step 1 by 3-12% * remove comments * remove comments
Tested with
--batch 4Summary by CodeRabbit