Skip to content

Fix causality violation: use per-token weights instead of full-sequence mean pooling#3

Open
sippycoder wants to merge 1 commit into
mainfrom
claude/fix-causality-tensor-reshape-itfva
Open

Fix causality violation: use per-token weights instead of full-sequence mean pooling#3
sippycoder wants to merge 1 commit into
mainfrom
claude/fix-causality-tensor-reshape-itfva

Conversation

@sippycoder

Copy link
Copy Markdown
Contributor

Dynamic weights (H_pre, H_post, H_res) were computed from H.mean(dim=1), which
averages over all sequence positions. For autoregressive LLMs this leaks future
token information into the mixing weights applied at position t, breaking causality.

Fix: replace H.mean(dim=1).reshape(batch, ndim) with H.reshape(batchseq, n*dim)
so each token's weights are derived solely from its own hidden state. Weight shapes
change from (batch, n) to (batch, seq, n) throughout, matching the paper's intent
and reference implementations (tokenbender, VatsaDev).

Changes:

  • module.py / _torch_baseline.py: all three _compute_weights paths (static, fused,
    separate-projections) now produce per-position weights (batch, seq, n)
  • _kernels.py: stream_mix and add_residual forward kernels index weights by pid_bs
    (b*seq+s) instead of b
  • _backward.py: same index fix in backward kernels; gradient reductions now sum
    only over d_blocks (dim=2) to preserve the per-position (batch, seq, n) shape
  • ops.py / _torch_baseline.py: updated einsum signatures and docstrings

https://claude.ai/code/session_016YVdHfTQm3GA8aqcj8ws25

…ce mean pooling

Dynamic weights (H_pre, H_post, H_res) were computed from H.mean(dim=1), which
averages over all sequence positions. For autoregressive LLMs this leaks future
token information into the mixing weights applied at position t, breaking causality.

Fix: replace H.mean(dim=1).reshape(batch, n*dim) with H.reshape(batch*seq, n*dim)
so each token's weights are derived solely from its own hidden state. Weight shapes
change from (batch, n) to (batch, seq, n) throughout, matching the paper's intent
and reference implementations (tokenbender, VatsaDev).

Changes:
- module.py / _torch_baseline.py: all three _compute_weights paths (static, fused,
  separate-projections) now produce per-position weights (batch, seq, n)
- _kernels.py: stream_mix and add_residual forward kernels index weights by pid_bs
  (b*seq+s) instead of b
- _backward.py: same index fix in backward kernels; gradient reductions now sum
  only over d_blocks (dim=2) to preserve the per-position (batch, seq, n) shape
- ops.py / _torch_baseline.py: updated einsum signatures and docstrings

https://claude.ai/code/session_016YVdHfTQm3GA8aqcj8ws25
@sippycoder

Copy link
Copy Markdown
Contributor Author
  • Verify tests and numerical stability
  • Train sample runs and track layer norms for training stability

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants