Skip to content

Model graph message-passing aggregation ops so GNN layer outputs are typed instead of ⊤ #570

@khatchad

Description

@khatchad

Observed

The new testGcnCall fixture (a verbatim-vendored kyzhouhzau/NLPGNN GCN message-passing model) recovers the decorated GCNLayer.call input parameter node_embeddings concretely as (4, 8) float32, but the layer output locals (gc1/gc2 results) come back as ⊤ on both axes—unknown shape and unknown dtype.

Cause

The graph-convolution forward pass aggregates messages with tf.math.unsorted_segment_sum (and builds degree tensors with tf.scatter_nd, gathers with tf.gather) inside MessagePassing.propagate. These aggregation/scatter ops have no TensorGenerator, so the result classification is dropped entirely—not just the shape, but the dtype too.

Relation to Other Gaps

This is distinct from #530: there, built-in Dense/LSTM layer outputs keep their float32 dtype and lose only shape. Here the output loses both axes, because the underlying ops are entirely unmodeled rather than shape-imprecise. It may overlap with #491 (unregistered generator audit) if unsorted_segment_sum/scatter_nd are in that set; cross-link if so.

Scope

Model the message-passing aggregation ops (tf.math.unsorted_segment_{sum,max,mean}, tf.scatter_nd) for at least dtype passthrough, so GNN layer outputs are typed. Dtype is the load-bearing axis; shape can follow. Surfaced by testGcnCall, which documents the current ⊤ behavior and should be tightened once this lands.

Metadata

Metadata

Assignees

No one assigned
    No fields configured for Feature.

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions