Skip to content

Add Band Gap and Magnetic Moments Prediction Heads#3

Open
AugustinLu wants to merge 5 commits intomainfrom
f-predict-gap-magmon-9256120961331400436
Open

Add Band Gap and Magnetic Moments Prediction Heads#3
AugustinLu wants to merge 5 commits intomainfrom
f-predict-gap-magmon-9256120961331400436

Conversation

@AugustinLu
Copy link
Copy Markdown
Owner

Added Band Gap (intensive scalar) and Magnetic Moments (node-level scalar) prediction heads alongside Energy and BEC branches. Both heads branch off the final hidden node features (irreps_x). Also updated AtomReduce to support mean pooling for intensive properties like Band Gap. Updated SevenNetCalculator to conditionally extract these properties into ASE results dictionary. Modified get_loss_functions_from_config to append BandGapLoss and MagmomsLoss if enabled in the training config. Fixed unit tests expected model parameter counts.


PR created automatically by Jules for task 9256120961331400436 started by @AugustinLu

google-labs-jules Bot and others added 5 commits April 11, 2026 13:33
Adds two new prediction heads to the SevenNet model:
- `predict_atomic_bandgap`: predicts an intensive scalar per atom, followed by an `AtomReduce` (mean pooled) layer to predict a global band gap.
- `predict_magmoms`: predicts a scalar magnetic moment per atom.

Both of these branch off the final hidden node features (`irreps_x`) just prior to the total energy `init_feature_reduce`.

Changes include:
- `sevenn/_keys.py`: Define new keys for `BANDGAP` and `MAGMOMS` and associated training toggles.
- `sevenn/nn/linear.py`: Update `AtomReduce` to support `reduce="mean"`.
- `sevenn/model_build.py`: Append the two new prediction heads (mapping `irreps_x` to `1x0e`).
- `sevenn/train/loss.py`: Implement `BandGapLoss` (intensive) and `MagmomsLoss` (extensive node-scalar) and selectively add them to the configured loss functions.
- `sevenn/calculator.py`: Expose `bandgap` and `magmoms` in `SevenNetCalculator`'s output results dictionary if they exist in the model output.
- `tests/unit_tests/test_model.py`: Adjust expected parameter counts due to the new linear layers.

Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
…keys

Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`.

It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches.

Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
The line added in `sevenn/checkpoint.py` to filter missing state_dict keys exceeded the 85 character limit enforced by flake8, which resulted in a `prek` (pre-commit) failure in CI. This commit reformats the list comprehension to satisfy the line length limit.

Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
…keys

Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`.

It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches.

Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
Adds native dataloader support for extracting `.info['bandgap']` and `.arrays['magmoms']` from ASE Atoms objects into the SevenNet training pipeline `KEY.BANDGAP` and `KEY.MAGMOMS` fields.

- Modifies `_set_atoms_y` and `atoms_to_graph` in `sevenn/train/dataload.py` to optionally pull these properties out of `info` and `arrays`.
- Updates `run_stat` loops inside `sevenn/train/atoms_dataset.py` and `sevenn/train/graph_dataset.py` to ensure dataset statistics calculation ignores empty lists or skips successfully.

Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant