Skip to content

feat: replace action tokenizer with windowed attention#16

Open
imitation-alpha wants to merge 1 commit intoAlmondGod:mainfrom
imitation-alpha:feature/action-tokenizer-window-attention
Open

feat: replace action tokenizer with windowed attention#16
imitation-alpha wants to merge 1 commit intoAlmondGod:mainfrom
imitation-alpha:feature/action-tokenizer-window-attention

Conversation

@imitation-alpha
Copy link

Summary

This PR replaces the "mean pool + concat" mechanism in the LatentActionsEncoder with a "length-2 windowed attention + mean" mechanism. This change aims to better capture temporal dependencies between adjacent frames during action tokenization.

Changes

  • Modified models/latent_actions.py:
    • Imported SpatialAttention from models.st_transformer.
    • Updated LatentActionsEncoder to use SpatialAttention on concatenated windows of current and next frames.
    • Removed the old mean pooling and concatenation logic.

Verification

  • Verified the implementation with a synthetic test script (scripts/verify_latent_actions.py - deleted after verification).
  • Confirmed that the model processes input frames and produces output actions with the correct dimensions.
  • Loss calculation works as expected.

Notes

  • This is a breaking change for LatentActionsEncoder checkpoints.

@imitation-alpha imitation-alpha force-pushed the feature/action-tokenizer-window-attention branch from 93ed906 to 05765b0 Compare November 29, 2025 06:09
@AlmondGod
Copy link
Owner

this looks great! can you train a working world model to confirm the impact of the change?

@NewJerseyStyle
Copy link

Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change?
For example:

  • How to confirm it does not get worse, monitor steps used to converge?
  • How to confirm it gets better, monitor the loss of the model?

@AlmondGod
Copy link
Owner

Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change? For example:

  • How to confirm it does not get worse, monitor steps used to converge?
  • How to confirm it gets better, monitor the loss of the model?

yes, I'll add in a readme pr section specifying necessary criteria

@imitation-alpha
Copy link
Author

Testing Results: Windowed Attention vs Mean-Pool+Concat

Tested the full 3-stage pipeline (Video Tokenizer → Latent Actions → Dynamics) on PicoDoom dataset (17,935 frames, 30% preload), batch_size=16, on CPU (M4 Pro 64GB Ram). Ran both 1K steps and 10K steps per stage to evaluate short and long training behavior.

Stage 1: Video Tokenizer (identical model, 0.14M params)

Both branches converge identically (~0.006 at 10K steps), as expected since this stage is unchanged.

Stage 2: Latent Actions

PR (windowed attn, 78K params) Main (mean-pool+concat, 74K params)
1K steps Loss: 0.041, Codebook: 75%, Enc Var: 0.101 Loss: 0.031, Codebook: 50%, Enc Var: 0.024
10K steps Loss: 0.028, Codebook: 50%, Enc Var: 0.054 Loss: 0.027, Codebook: 75%, Enc Var: 0.138
Speed ~3.8 it/s ~5.5 it/s
  • At 1K steps, PR shows higher codebook usage (75% vs 50%) and 4x higher encoder variance
  • At 10K steps, both converge to similar loss (~0.027-0.028) and both achieve 100% codebook usage at some point during training
  • Main branch is ~45% faster per step

Stage 3: Dynamics (0.17M params)

Steps PR (windowed attn) Main (mean-pool+concat) Difference
1K 2.848 3.594 PR 21% lower
10K 4.009 3.960 Main 1.2% lower
  • At 1K steps, PR shows 21% lower dynamics loss — faster early convergence
  • At 10K steps, the gap closes to ~1.2% with main slightly ahead — both approaches converge to similar quality

Training Curves (10K steps)

Video Tokenizer (both branches):

Step     0 → 1K → 2K → 3K → 4K → 5K → 6K → 7K → 8K → 9K
Loss  0.31  0.036 0.011 0.009 0.008 0.007 0.006 0.006 0.006 0.006

Latent Actions (PR / Main):

Step      0        1K       5K       9K
Loss   1.23/1.21  0.030/0.027  0.023/0.030  0.028/0.027
Cdbk   50%/25%    50%/50%      75%/100%     50%/75%
EncVar 0.00/0.00  0.035/0.017  0.224/0.207  0.054/0.138

Dynamics (PR / Main):

Step      0        1K       3K       5K       7K       9K
Loss   7.03/7.09  5.29/5.30  4.41/4.39  4.26/4.22  4.12/4.07  4.01/3.96

Summary

The windowed attention replacement converges faster early on (21% lower dynamics loss at 1K steps) but both approaches reach similar quality given enough training (~1.2% difference at 10K steps). The PR adds ~45% latency per step in the latent actions stage due to the attention computation. The full pipeline trains end-to-end without issues on both branches.

Environment: Mac M4 Pro 64GB, CPU-only, PicoDoom dataset, Python 3.13, PyTorch 2.10.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants