Skip to content

Incorporate GRIT model#22

Open
ttolhurst wants to merge 103 commits into
gridfm:mainfrom
ttolhurst:ft1125_incorporate_grit
Open

Incorporate GRIT model#22
ttolhurst wants to merge 103 commits into
gridfm:mainfrom
ttolhurst:ft1125_incorporate_grit

Conversation

@ttolhurst

Copy link
Copy Markdown
Member

Incorporation of GRIT from L. Ma's "Graph Inductive Biases in Transformers without Message Passing".

ttolhurst and others added 30 commits November 17, 2025 14:07
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com>
@romeokienzler

Copy link
Copy Markdown
Collaborator

Code Review Findings

  1. gridfm_graphkit/models/grit_transformer.pyGritHeteroAdapter.aggregate_pg reads self.grit.mask_value[0].item() to fill fully-masked buses. Since mask_value is registered as a learnable nn.Parameter when learn_mask=True, .item() detaches it: the learnable mask token receives no gradient through the only place it's consumed. learn_mask=True is silently a no-op.

  2. gridfm_graphkit/training/loss.pyPBELoss computes S_injection = torch.diag(V) @ Y_bus_conj @ V_conj. torch.diag(V) materializes a dense N×N matrix and the matmul densifies the sparse Y-bus; on case2000 this is O(N²) memory and defeats the sparse Y-bus construction directly above. Use V * (Y_bus_conj @ V_conj) instead.

  3. gridfm_graphkit/training/loss.pyPBELoss builds Y_diag = Y_diag + bus_orig[:, GS] + 1j * bus_orig[:, BS], pulling shunt admittance from the normalized input x_dict[\"bus\"]. If the normalizer scales GS/BS by a different factor than the Y-bus edge attributes YFF/YFT, the assembled Y-bus is physically inconsistent and the PBE residual is biased.

  4. gridfm_graphkit/datasets/rrwp.py — Walk-length semantics are inconsistent with add_identity: with add_identity=True the PE has walk_length powers [I, A, A², …]; with add_identity=False it has walk_length-1 powers. Downstream RRWPLinearNodeEncoder(emb_dim=ksteps) assumes one fixed dimension, so toggling add_identity produces a silent size mismatch.

  5. gridfm_graphkit/models/grit_layer.pytorch_scatter is imported under try/except and stored as None, but pyg_softmax/propagate_attention call scatter, scatter_max, scatter_add directly without a guard. A missing torch-scatter install raises TypeError: 'NoneType' is not callable instead of the explicit ImportError pattern used in rrwp_encoder._check_scatter.

  6. gridfm_graphkit/training/loss.pyMaskedGenMSE was changed from pred_dict[\"gen\"][mask_dict[\"gen\"][:, :(PG_H+1)]] to slicing pred/target first then masking. This is a behavior change for any config with output_gen_dim > 1 (the old form would broadcast/error, the new form computes a different quantity). There's no test asserting either path; confirm intent and add coverage.

  7. gridfm_graphkit/datasets/rrwp.pydeg = adj.sum(dim=1) followed by deg_inv = 1.0 / adj.sum(dim=1) recomputes the row sum; reuse deg. Minor but executes every PE computation (mitigated by caching).

  8. gridfm_graphkit/models/rrwp_encoder.pyRRWPLinearEdgeEncoder.__init__ accepts fill_value and uses it to build padding, but then unconditionally sets self.fill_value = 0.0. The stored attribute (and __repr__) always reports 0.0 regardless of the argument; either honor it or drop the parameter.

  9. Side noteGritHeteroAdapter.__init__ mutates the shared args object (args.model.input_dim, args.model.gt.dim_hidden, args.model.encoder.posenc_RWSE.kernel.times). Fine for a single-model run, but it's a footgun for any future code that constructs multiple models from the same config.

@romeokienzler

Copy link
Copy Markdown
Collaborator

@ttolhurst had a go with claude code, can you please have a look at the found issues and comment?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants