Add Band Gap and Magnetic Moments Prediction Heads#3
Open
AugustinLu wants to merge 5 commits intomainfrom
Open
Add Band Gap and Magnetic Moments Prediction Heads#3AugustinLu wants to merge 5 commits intomainfrom
AugustinLu wants to merge 5 commits intomainfrom
Conversation
Adds two new prediction heads to the SevenNet model: - `predict_atomic_bandgap`: predicts an intensive scalar per atom, followed by an `AtomReduce` (mean pooled) layer to predict a global band gap. - `predict_magmoms`: predicts a scalar magnetic moment per atom. Both of these branch off the final hidden node features (`irreps_x`) just prior to the total energy `init_feature_reduce`. Changes include: - `sevenn/_keys.py`: Define new keys for `BANDGAP` and `MAGMOMS` and associated training toggles. - `sevenn/nn/linear.py`: Update `AtomReduce` to support `reduce="mean"`. - `sevenn/model_build.py`: Append the two new prediction heads (mapping `irreps_x` to `1x0e`). - `sevenn/train/loss.py`: Implement `BandGapLoss` (intensive) and `MagmomsLoss` (extensive node-scalar) and selectively add them to the configured loss functions. - `sevenn/calculator.py`: Expose `bandgap` and `magmoms` in `SevenNetCalculator`'s output results dictionary if they exist in the model output. - `tests/unit_tests/test_model.py`: Adjust expected parameter counts due to the new linear layers. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
…keys Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`. It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
The line added in `sevenn/checkpoint.py` to filter missing state_dict keys exceeded the 85 character limit enforced by flake8, which resulted in a `prek` (pre-commit) failure in CI. This commit reformats the list comprehension to satisfy the line length limit. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
…keys Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`. It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
Adds native dataloader support for extracting `.info['bandgap']` and `.arrays['magmoms']` from ASE Atoms objects into the SevenNet training pipeline `KEY.BANDGAP` and `KEY.MAGMOMS` fields. - Modifies `_set_atoms_y` and `atoms_to_graph` in `sevenn/train/dataload.py` to optionally pull these properties out of `info` and `arrays`. - Updates `run_stat` loops inside `sevenn/train/atoms_dataset.py` and `sevenn/train/graph_dataset.py` to ensure dataset statistics calculation ignores empty lists or skips successfully. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Added Band Gap (intensive scalar) and Magnetic Moments (node-level scalar) prediction heads alongside Energy and BEC branches. Both heads branch off the final hidden node features (
irreps_x). Also updatedAtomReduceto support mean pooling for intensive properties like Band Gap. UpdatedSevenNetCalculatorto conditionally extract these properties into ASEresultsdictionary. Modifiedget_loss_functions_from_configto appendBandGapLossandMagmomsLossif enabled in the training config. Fixed unit tests expected model parameter counts.PR created automatically by Jules for task 9256120961331400436 started by @AugustinLu