From 1c5fd7102b1b6e33b62406cead341ac72adc63b4 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 13:33:28 +0000 Subject: [PATCH 1/5] feat: add band gap and magnetic moments prediction heads Adds two new prediction heads to the SevenNet model: - `predict_atomic_bandgap`: predicts an intensive scalar per atom, followed by an `AtomReduce` (mean pooled) layer to predict a global band gap. - `predict_magmoms`: predicts a scalar magnetic moment per atom. Both of these branch off the final hidden node features (`irreps_x`) just prior to the total energy `init_feature_reduce`. Changes include: - `sevenn/_keys.py`: Define new keys for `BANDGAP` and `MAGMOMS` and associated training toggles. - `sevenn/nn/linear.py`: Update `AtomReduce` to support `reduce="mean"`. - `sevenn/model_build.py`: Append the two new prediction heads (mapping `irreps_x` to `1x0e`). - `sevenn/train/loss.py`: Implement `BandGapLoss` (intensive) and `MagmomsLoss` (extensive node-scalar) and selectively add them to the configured loss functions. - `sevenn/calculator.py`: Expose `bandgap` and `magmoms` in `SevenNetCalculator`'s output results dictionary if they exist in the model output. - `tests/unit_tests/test_model.py`: Adjust expected parameter counts due to the new linear layers. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/_keys.py | 11 +++++ sevenn/calculator.py | 9 +++- sevenn/model_build.py | 24 ++++++++++ sevenn/nn/linear.py | 13 +++++- sevenn/train/loss.py | 85 ++++++++++++++++++++++++++++++++++ tests/unit_tests/test_model.py | 30 ++++++------ 6 files changed, 155 insertions(+), 17 deletions(-) diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 0c9af7b7..4dadb54b 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -49,6 +49,8 @@ ENERGY: Final[str] = 'total_energy' # (1) FORCE: Final[str] = 'force_of_atoms' # (N, 3) STRESS: Final[str] = 'stress' # (6) +BANDGAP: Final[str] = 'bandgap' # (1) +MAGMOMS: Final[str] = 'magmoms' # (N, 1) # This is for training, per atom scale. SCALED_ENERGY: Final[str] = 'scaled_total_energy' @@ -58,6 +60,11 @@ ATOMIC_ENERGY: Final[str] = 'atomic_energy' PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' +ATOMIC_BANDGAP: Final[str] = 'atomic_bandgap' +PRED_BANDGAP: Final[str] = 'inferred_bandgap' + +PRED_MAGMOMS: Final[str] = 'inferred_magmoms' + PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' @@ -124,6 +131,10 @@ TRAIN_SHUFFLE = 'train_shuffle' IS_TRAIN_STRESS = 'is_train_stress' +IS_TRAIN_BANDGAP = 'is_train_bandgap' +IS_TRAIN_MAGMOMS = 'is_train_magmoms' +BANDGAP_WEIGHT = 'bandgap_loss_weight' +MAGMOMS_WEIGHT = 'magmoms_loss_weight' CONTINUE = 'continue' CHECKPOINT = 'checkpoint' diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 2f4a3d59..4f7e421b 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -207,7 +207,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation ) # Store results - return { + results = { 'free_energy': energy, 'energy': energy, 'energies': atomic_energies, @@ -215,6 +215,13 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: 'stress': stress, 'num_edges': output[KEY.EDGE_IDX].shape[1], } + if KEY.PRED_BANDGAP in output: + results['bandgap'] = output[KEY.PRED_BANDGAP].detach().cpu().item() + if KEY.PRED_MAGMOMS in output: + results['magmoms'] = ( + output[KEY.PRED_MAGMOMS].detach().cpu().numpy()[:num_atoms].flatten() + ) + return results def calculate(self, atoms=None, properties=None, system_changes=all_changes): is_ts_type = isinstance(self.model, torch_script_type) diff --git a/sevenn/model_build.py b/sevenn/model_build.py index c548c34e..2f22cef1 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -599,6 +599,30 @@ def build_E3_equivariant_model( layers.update(interaction_builder(**param_interaction_block)) irreps_x = irreps_out + layers.update( + { + 'predict_atomic_bandgap': IrrepsLinear( + irreps_in=irreps_x, + irreps_out=Irreps('1x0e'), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.ATOMIC_BANDGAP, + biases=use_bias_in_linear, + ), + 'reduce_total_bandgap': AtomReduce( + data_key_in=KEY.ATOMIC_BANDGAP, + data_key_out=KEY.PRED_BANDGAP, + reduce='mean', + ), + 'predict_magmoms': IrrepsLinear( + irreps_in=irreps_x, + irreps_out=Irreps('1x0e'), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.PRED_MAGMOMS, + biases=use_bias_in_linear, + ), + } + ) + layers.update(init_feature_reduce(config, irreps_x)) # type: ignore layers.update( diff --git a/sevenn/nn/linear.py b/sevenn/nn/linear.py index 543c6a7e..46f39a4c 100644 --- a/sevenn/nn/linear.py +++ b/sevenn/nn/linear.py @@ -134,9 +134,20 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: device=src.device, ) output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') + if self.reduce == 'mean': + counts = torch.zeros(size, dtype=src.dtype, device=src.device) + ones = torch.ones_like(src) + counts.scatter_reduce_(0, data[KEY.BATCH], ones, reduce='sum') + counts = counts.clamp(min=1.0) + output = output / counts data[self.key_output] = output * self.constant else: - data[self.key_output] = torch.sum(data[self.key_input]) * self.constant + if self.reduce == 'mean': + v = torch.mean(data[self.key_input]) * self.constant + data[self.key_output] = v + else: + v = torch.sum(data[self.key_input]) * self.constant + data[self.key_output] = v return data diff --git a/sevenn/train/loss.py b/sevenn/train/loss.py index a6f8a769..8c68dd38 100644 --- a/sevenn/train/loss.py +++ b/sevenn/train/loss.py @@ -201,6 +201,85 @@ def _preprocess( return pred, ref, w_tensor +class BandGapLoss(LossDefinition): + """ + Loss for intensive band gap + """ + + def __init__( + self, + name: str = 'BandGap', + unit: str = 'eV', + criterion: Optional[Callable] = None, + ref_key: str = KEY.BANDGAP, + pred_key: str = KEY.PRED_BANDGAP, + **kwargs, + ) -> None: + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = torch.repeat_interleave(weight, 1) + + return pred, ref, w_tensor + + +class MagmomsLoss(LossDefinition): + """ + Loss for magnetic moments (node-level scalar) + """ + + def __init__( + self, + name: str = 'Magmoms', + unit: str = 'mu_B', + criterion: Optional[Callable] = None, + ref_key: str = KEY.MAGMOMS, + pred_key: str = KEY.PRED_MAGMOMS, + **kwargs, + ) -> None: + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = weight[batch_data[KEY.BATCH]] + w_tensor = torch.repeat_interleave(w_tensor, 1) + + return pred, ref, w_tensor + + def get_loss_functions_from_config( config: Dict[str, Any], ) -> List[Tuple[LossDefinition, float]]: @@ -222,6 +301,12 @@ def get_loss_functions_from_config( loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) if config[KEY.IS_TRAIN_STRESS]: loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) + if config.get(KEY.IS_TRAIN_BANDGAP, False): + w = config.get(KEY.BANDGAP_WEIGHT, 1.0) + loss_functions.append((BandGapLoss(**commons), w)) + if config.get(KEY.IS_TRAIN_MAGMOMS, False): + w = config.get(KEY.MAGMOMS_WEIGHT, 1.0) + loss_functions.append((MagmomsLoss(**commons), w)) for loss_function, _ in loss_functions: # why do these? if loss_function.criterion is None: diff --git a/tests/unit_tests/test_model.py b/tests/unit_tests/test_model.py index d75976f8..944809ed 100644 --- a/tests/unit_tests/test_model.py +++ b/tests/unit_tests/test_model.py @@ -162,16 +162,16 @@ def test_batch(): _n_param_tests = [ - ({}, 20642), - ({'train_denominator': True}, 20642 + 3), - ({'train_shift_scale': True}, 20642 + 2), - ({'shift': [1.0] * 4}, 20642), - ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8), - ({'num_convolution_layer': 4}, 33458), - ({'lmax': 3}, 26866), - ({'channel': 2}, 16883), - ({'is_parity': False}, 20386), - ({'self_connection_type': 'linear'}, 20114), + ({}, 20642 + 8), + ({'train_denominator': True}, 20642 + 8 + 3), + ({'train_shift_scale': True}, 20642 + 8 + 2), + ({'shift': [1.0] * 4}, 20642 + 8), + ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8 + 8), + ({'num_convolution_layer': 4}, 33458 + 8), + ({'lmax': 3}, 26866 + 8), + ({'channel': 2}, 16883 + 4), + ({'is_parity': False}, 20386 + 8), + ({'self_connection_type': 'linear'}, 20114 + 8), ] @@ -183,11 +183,11 @@ def test_num_params(cf, ref): _n_modal_param_tests = [ - ({}, 20642), - ({'use_modal_node_embedding': True}, 20642 + 8), - ({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3), - ({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)), - ({'use_modal_output_block': True}, 20642 + 2 * 4 / 2), + ({}, 20642 + 8), + ({'use_modal_node_embedding': True}, 20642 + 8 + 8), + ({'use_modal_self_inter_intro': True}, 20642 + 8 + 2 * 4 * 3), + ({'use_modal_self_inter_outro': True}, 20642 + 8 + 2 * (12 + 20 + 4)), + ({'use_modal_output_block': True}, 20642 + 8 + 2 * 4 / 2), ] From b9a9bc8f522dc5d19e2f509fe0d54b01e5c91f64 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 13:46:07 +0000 Subject: [PATCH 2/5] fix: allow loading older checkpoints without missing prediction head keys Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`. It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/checkpoint.py | 3 ++- tests/unit_tests/test_calculator.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sevenn/checkpoint.py b/sevenn/checkpoint.py index e0422ec2..e0cef452 100644 --- a/sevenn/checkpoint.py +++ b/sevenn/checkpoint.py @@ -347,7 +347,8 @@ def build_model( state_dict = compat.patch_state_dict_if_old( self.model_state_dict, self.config, model ) - missing, not_used = model.load_state_dict(state_dict, strict=True) + missing, not_used = model.load_state_dict(state_dict, strict=False) + missing = [m for m in missing if not (m.startswith('predict_atomic_bandgap.') or m.startswith('predict_magmoms.') or m.startswith('reduce_total_bandgap.'))] assert len(missing) == 0, f'Missing keys: {missing}' if len(not_used) > 0: warnings.warn(f'Some keys are not used: {not_used}', UserWarning) diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index 9b19308d..d4e3a7f6 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -122,6 +122,8 @@ def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): res_script = copy.copy(atoms_pbc.calc.results) for k in res_cp: + if k in ['bandgap', 'magmoms']: + continue assert np.allclose(res_cp[k], res_script[k], rtol=1e-4, atol=1e-4) @@ -141,6 +143,8 @@ def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): res_script = copy.copy(atoms_pbc.calc.results) for k in res_cp: + if k in ['bandgap', 'magmoms']: + continue assert np.allclose(res_cp[k], res_script[k]) From a011813388f6f58e4d2fd00ed5f61cc68b190bea Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 13:49:07 +0000 Subject: [PATCH 3/5] style: fix flake8 E501 line too long in sevenn/checkpoint.py The line added in `sevenn/checkpoint.py` to filter missing state_dict keys exceeded the 85 character limit enforced by flake8, which resulted in a `prek` (pre-commit) failure in CI. This commit reformats the list comprehension to satisfy the line length limit. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/checkpoint.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sevenn/checkpoint.py b/sevenn/checkpoint.py index e0cef452..4f29c139 100644 --- a/sevenn/checkpoint.py +++ b/sevenn/checkpoint.py @@ -348,7 +348,15 @@ def build_model( self.model_state_dict, self.config, model ) missing, not_used = model.load_state_dict(state_dict, strict=False) - missing = [m for m in missing if not (m.startswith('predict_atomic_bandgap.') or m.startswith('predict_magmoms.') or m.startswith('reduce_total_bandgap.'))] + missing = [ + m + for m in missing + if not ( + m.startswith('predict_atomic_bandgap.') + or m.startswith('predict_magmoms.') + or m.startswith('reduce_total_bandgap.') + ) + ] assert len(missing) == 0, f'Missing keys: {missing}' if len(not_used) > 0: warnings.warn(f'Some keys are not used: {not_used}', UserWarning) From 4a1713246c56bb37f476bff3fa85f237fb20a0ae Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 15:17:54 +0000 Subject: [PATCH 4/5] fix: allow loading older checkpoints without missing prediction head keys Older models do not have `predict_atomic_bandgap`, `reduce_total_bandgap`, and `predict_magmoms` in their state dictionaries. This commit updates `sevenn/checkpoint.py` to gracefully load checkpoints by ignoring these missing keys during `load_state_dict(strict=False)`. It also updates `tests/unit_tests/test_calculator.py` to disregard `bandgap` and `magmoms` keys when comparing output results between instances versus checkpoints/deployed models, as older deployments and torchscript models won't have those keys populated in their results dicts while the instantiated models will have these branches. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> From 37e171a678a9a65abd7170bae395d4e226005439 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:50:30 +0000 Subject: [PATCH 5/5] feat: native support for band_gap and magnetic_moments in dataloaders Adds native dataloader support for extracting `.info['bandgap']` and `.arrays['magmoms']` from ASE Atoms objects into the SevenNet training pipeline `KEY.BANDGAP` and `KEY.MAGMOMS` fields. - Modifies `_set_atoms_y` and `atoms_to_graph` in `sevenn/train/dataload.py` to optionally pull these properties out of `info` and `arrays`. - Updates `run_stat` loops inside `sevenn/train/atoms_dataset.py` and `sevenn/train/graph_dataset.py` to ensure dataset statistics calculation ignores empty lists or skips successfully. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/train/atoms_dataset.py | 13 +++++++++++-- sevenn/train/dataload.py | 12 ++++++++++++ sevenn/train/graph_dataset.py | 19 +++++++++++++------ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/sevenn/train/atoms_dataset.py b/sevenn/train/atoms_dataset.py index d5ffe153..dde4bbfa 100644 --- a/sevenn/train/atoms_dataset.py +++ b/sevenn/train/atoms_dataset.py @@ -182,7 +182,10 @@ def run_stat(self) -> None: """ if self._scanned is True: return # statistics already computed - y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] + y_keys: List[str] = [ + KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS, + KEY.BANDGAP, KEY.MAGMOMS + ] natoms_counter = Counter() composition = np.zeros((len(self), NUM_UNIV_ELEMENT)) stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys} @@ -202,9 +205,15 @@ def run_stat(self) -> None: dct['_array'].append(atoms.arrays['y_force'].reshape(-1)) elif y == KEY.STRESS: dct['_array'].append(atoms.info['y_stress'].reshape(-1)) + elif y == KEY.BANDGAP and 'y_bandgap' in atoms.info: + dct['_array'].append(atoms.info['y_bandgap']) + elif y == KEY.MAGMOMS and 'y_magmoms' in atoms.arrays: + dct['_array'].append(atoms.arrays['y_magmoms'].reshape(-1)) for y, dct in stats.items(): - if y == KEY.FORCE: + if len(dct['_array']) == 0: + continue + if y in [KEY.FORCE, KEY.MAGMOMS]: array = np.concatenate(dct['_array']) else: array = np.array(dct['_array']).reshape(-1) diff --git a/sevenn/train/dataload.py b/sevenn/train/dataload.py index 545131ee..9062cfd6 100644 --- a/sevenn/train/dataload.py +++ b/sevenn/train/dataload.py @@ -206,6 +206,11 @@ def atoms_to_graph( KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), } + if 'y_bandgap' in atoms.info: + data[KEY.BANDGAP] = _correct_scalar(atoms.info['y_bandgap']) + if 'y_magmoms' in atoms.arrays: + data[KEY.MAGMOMS] = atoms.arrays['y_magmoms'] + if with_shift: data[KEY.CELL_SHIFT] = shift data[KEY.CELL] = cell @@ -302,6 +307,8 @@ def _set_atoms_y( energy_key: Optional[str] = None, force_key: Optional[str] = None, stress_key: Optional[str] = None, + bandgap_key: Optional[str] = 'bandgap', + magmoms_key: Optional[str] = 'magmoms', ) -> List[ase.Atoms]: """ Define how SevenNet reads ASE.atoms object for its y label @@ -345,6 +352,11 @@ def _set_atoms_y( else: atoms.info['y_stress'] = from_calc['stress'] + if bandgap_key and bandgap_key in atoms.info: + atoms.info['y_bandgap'] = atoms.info[bandgap_key] + if magmoms_key and magmoms_key in atoms.arrays: + atoms.arrays['y_magmoms'] = atoms.arrays[magmoms_key] + return atoms_list diff --git a/sevenn/train/graph_dataset.py b/sevenn/train/graph_dataset.py index 224e6e22..b83969d0 100644 --- a/sevenn/train/graph_dataset.py +++ b/sevenn/train/graph_dataset.py @@ -65,7 +65,10 @@ def _run_stat( """ Loop over dataset and init any statistics might need """ - y_keys = y_keys or [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] + y_keys = y_keys or [ + KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS, + KEY.BANDGAP, KEY.MAGMOMS + ] n_neigh = [] natoms_counter = Counter() composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) @@ -79,14 +82,18 @@ def _run_stat( composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) for y, dct in stats.items(): - dct['_array'].append( - graph[y].reshape( - -1, + if y in graph: + dct['_array'].append( + graph[y].reshape( + -1, + ) ) - ) stats.update({'num_neighbor': {'_array': n_neigh}}) - for y, dct in stats.items(): + for y, dct in list(stats.items()): + if not dct['_array']: + del stats[y] + continue array = torch.cat(dct['_array']) if array.dtype == torch.int64: # because of n_neigh array = array.to(torch.float)