fix: enable sample_steps > num_gpus_dit with modulo cycling#12
fix: enable sample_steps > num_gpus_dit with modulo cycling#12imperatormk wants to merge 2 commits intoAlibaba-Quark:mainfrom
Conversation
Previously, multi-GPU inference crashed with KeyError when sample_steps exceeded num_gpus_dit (e.g., 8 steps with 4 DiT GPUs). The issue was 1:1 mapping of step i to rank i, leaving steps 4+ with no GPU. Fix: - Modulo cycling: step i runs on rank (i % num_gpus_dit) - Dynamic send/recv: cycle wraps to rank 0, final step sends to VAE - Scheduler index fix: set _step_index = i for correct timestep Tested with 8 and 16 steps on 4 DiT GPUs + 1 VAE GPU.
f62c1e7 to
5714523
Compare
Fixes lipsync desync when using sample_steps > num_gpus_dit. With modulo cycling, each GPU maintains its own KV cache. The denoising loop processes the same temporal block across all steps, with each step overwriting the same cache positions. When GPU 0 processes step 4, it still has stale KV values from step 0 instead of step 3's values. This commit adds _sync_kv_cache() which broadcasts the updated cache from the GPU that just completed a step to all other DiT GPUs.
|
Thanks a lot for the detailed investigation and the fix — enabling sample_steps > num_gpus_dit is definitely a meaningful and useful improvement. I tried merging your changes locally, but in my setup multi-GPU inference consistently hangs at the communication between step 1 → step 2 (progress stalls after finishing step 1). I’m still debugging whether this is a corner case in the current send/recv logic or something specific to my environment. One additional concern I have is that synchronizing the full KV cache across all DiT GPUs at every denoising step could introduce substantial communication overhead, which may significantly reduce multi-GPU parallel efficiency. While correctness-wise the fix makes sense, the performance trade-off might be non-trivial. That said, the flexibility to decouple sample_steps from num_gpus_dit is very valuable. An alternative direction could be to reserve multiple KV caches per GPU, or to adapt the pipeline to more GPUs (e.g., 8+1 for 8-step denoising) to avoid frequent cross-GPU KV synchronization. Thanks again for the contribution — I think this is an important direction to explore further. |
Problem
Multi-GPU inference crashes when
sample_steps > num_gpus_dit:--sample_steps 8with--num_gpus_dit 4causesKeyError: 'cond_shape'orKeyError: '5'Root cause: The denoising loop maps step
idirectly to ranki(if i != dist.get_rank(): continue), so steps 4+ have no GPU to run them.Solution
Implement modulo cycling so steps wrap around available GPUs:
Changes:
if i % num_gpus_dit != my_rankinstead ofif i != dist.get_rank()sample_scheduler._step_index = ifor correct timestep handlingWhy KV cache sync is needed
With modulo cycling, each GPU maintains its own KV cache. The denoising loop processes the same temporal block across all steps, with each step overwriting the same cache positions:
Without sync, GPU 0 at step 4 reads stale KV values from step 0 instead of step 3's values. This breaks temporal attention alignment, causing lipsync drift.
The fix broadcasts the updated KV cache after each step using
dist.send/recv.Testing
Tested on 5 GPUs (4 DiT + 1 VAE) with:
--sample_steps 8--sample_steps 16Video quality improved with more denoising steps as expected.