feat: replace action tokenizer with windowed attention#16
feat: replace action tokenizer with windowed attention#16imitation-alpha wants to merge 1 commit intoAlmondGod:mainfrom
Conversation
93ed906 to
05765b0
Compare
|
this looks great! can you train a working world model to confirm the impact of the change? |
|
Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change?
|
yes, I'll add in a readme pr section specifying necessary criteria |
Testing Results: Windowed Attention vs Mean-Pool+ConcatTested the full 3-stage pipeline (Video Tokenizer → Latent Actions → Dynamics) on PicoDoom dataset (17,935 frames, 30% preload), batch_size=16, on CPU (M4 Pro 64GB Ram). Ran both 1K steps and 10K steps per stage to evaluate short and long training behavior. Stage 1: Video Tokenizer (identical model, 0.14M params)Both branches converge identically (~0.006 at 10K steps), as expected since this stage is unchanged. Stage 2: Latent Actions
Stage 3: Dynamics (0.17M params)
Training Curves (10K steps)Video Tokenizer (both branches): Latent Actions (PR / Main): Dynamics (PR / Main): SummaryThe windowed attention replacement converges faster early on (21% lower dynamics loss at 1K steps) but both approaches reach similar quality given enough training (~1.2% difference at 10K steps). The PR adds ~45% latency per step in the latent actions stage due to the attention computation. The full pipeline trains end-to-end without issues on both branches. Environment: Mac M4 Pro 64GB, CPU-only, PicoDoom dataset, Python 3.13, PyTorch 2.10.0 |
Summary
This PR replaces the "mean pool + concat" mechanism in the
LatentActionsEncoderwith a "length-2 windowed attention + mean" mechanism. This change aims to better capture temporal dependencies between adjacent frames during action tokenization.Changes
models/latent_actions.py:SpatialAttentionfrommodels.st_transformer.LatentActionsEncoderto useSpatialAttentionon concatenated windows of current and next frames.Verification
scripts/verify_latent_actions.py- deleted after verification).Notes
LatentActionsEncodercheckpoints.