Skip to content

Worldmodel#1

Open
YassineYousfi wants to merge 16 commits into
mainfrom
worldmodel
Open

Worldmodel#1
YassineYousfi wants to merge 16 commits into
mainfrom
worldmodel

Conversation

@YassineYousfi

Copy link
Copy Markdown

No description provided.

YassineYousfi pushed a commit that referenced this pull request Jun 5, 2026
…nto canonicalize_graph_pass (pytorch#3488)

## Summary

Two related graph-pass changes in `graph_trainer`:

1. **Add `remove_b2b_transpose_pass`** — collapses back-to-back
`aten.t(aten.t(x))` transpose pairs. These appear in traced fwd+bwd
graphs from `F.linear` when `simple_fsdp` redistributes weight tensors.
Two consecutive `aten.t` form an involution (identical shape *and*
strides), so removing them is **bitwise numerics-preserving**. The pass
also handles chains (`t(t(t(x))) -> t(x)`) and keeps the inner transpose
when it still feeds other consumers.

2. **Bundle the no-op cleanup passes into a single
`canonicalize_graph_pass` entry** in `compile_time_passes`:
   - `remove_detach_pass`
   - `remove_identity_view_pass`
   - `remove_b2b_transpose_pass`
   - `remove_identity_slice_pass`
   - `normalize_view_ops_as_reshape`

`normalize_view_ops_as_reshape` moves from `passes.py` into
`remove_noop_passes.py` alongside the other graph-cleanup passes. The
sub-passes stay public so each is unit-tested in isolation.

## Why

The cleanup passes are all numerics-preserving local rewrites that ran
as four separate pass-list entries. Folding them into one
`canonicalize_graph_pass` keeps `compile_time_passes` readable and
groups them as one logical step, while the new b2b-transpose rewrite
removes redundant transpose pairs that `F.linear` + FSDP introduce.

## Behavior note

`--compile.disable_passes` now toggles `canonicalize_graph_pass` as a
whole rather than the individual sub-passes. No callsite
(configs/scripts/README) disabled them individually, so nothing breaks.

## Verification

Ran the llama3 debug model (FSDP=4, TP=2, `aot_fx_trace`,
`--compile.debug_graph_passes`). `remove_b2b_transpose_pass` logs its
removal count directly:

```
[rank0]:[titan] - root - INFO - Removed 129 back-to-back transpose node(s) from the graph
```

`canonicalize_graph_pass` is pass #1; its op-count diff confirms the
same 129 nodes (`t.default: 215 -> 86`):

```
nodes: 2460 -> 1851 (-609)
  t.default:        215 -> 86  (-129)   <- remove_b2b_transpose_pass
  detach.default:    40 -> 0   (-40)
  view.default:     634 -> 0   (-634)
  _unsafe_view:      43 -> 0   (-43)
  slice.Tensor:      40 -> 0   (-40)
  reshape.default:    0 -> 277 (+277)
```

3 training steps, loss `8.11353 -> 7.80330 -> 7.09355`.

## Test plan

```bash
pytest torchtitan/experiments/graph_trainer/tests/test_passes.py \
  -k "TestRemoveB2BTransposePass or TestCanonicalizeGraphPass or \
      TestNormalizeViewOpsAsReshape or TestRemoveDetachPass or \
      TestRemoveIdentityViewPass or TestRemoveIdentitySlicePass"
```

Added `TestRemoveB2BTransposePass` (pair removal + numerics, lone
transpose kept, inner transpose with another user kept, chain collapse,
no-op on transpose-free graphs) and `TestCanonicalizeGraphPass`
(end-to-end bundle). All pass; `pre-commit run --all-files` clean.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant