From d7b51369c9846c9898ead459c5af0ff44726c6e8 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:37:47 +0000 Subject: [PATCH 1/5] feat: allow attaching BEC head to penultimate layer for performance Adds `train_bec_from` configuration (defaulting to 'last', with option 'penultimate'). When set to 'penultimate', the final convolution layer avoids computing expensive tensor features (`lmax>0`), improving efficiency while still extracting required features for the Born Effective Charge head. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/_keys.py | 1 + sevenn/model_build.py | 45 ++++++++++++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 83587c8c..401985d7 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -130,6 +130,7 @@ IS_TRAIN_STRESS = 'is_train_stress' IS_TRAIN_BEC = 'is_train_bec' +TRAIN_BEC_FROM = 'train_bec_from' CONTINUE = 'continue' CHECKPOINT = 'checkpoint' diff --git a/sevenn/model_build.py b/sevenn/model_build.py index f07f6986..384f984b 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -567,13 +567,17 @@ def build_E3_equivariant_model( if interaction_type == 'nequip': parity_mode = 'full' fix_multiplicity = False + + is_train_bec = config.get(KEY.IS_TRAIN_BEC, False) + train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last') + if t == num_convolution_layer - 1: - # If training BEC, we need vectors/tensors to survive the last layer - if config.get(KEY.IS_TRAIN_BEC, False): - # We need at least L=1 and L=2 for vectors and tensors. + if is_train_bec and train_bec_from == 'last': + # If training BEC from the last layer, we need vectors/tensors to survive lmax_node = max(lmax_node, 2) parity_mode = 'full' else: + # Normal scalar output for final energy prediction lmax_node = 0 parity_mode = 'even' # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out @@ -606,19 +610,30 @@ def build_E3_equivariant_model( layers.update(interaction_builder(**param_interaction_block)) irreps_x = irreps_out - if config.get(KEY.IS_TRAIN_BEC, False): - irreps_in_bec = irreps_x - layers.update( - { - 'predict_bec': IrrepsLinear( - irreps_in_bec, - Irreps('1x0e+1x1e+1x2e'), - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES, - biases=config[KEY.USE_BIAS_IN_LINEAR], + is_train_bec = config.get(KEY.IS_TRAIN_BEC, False) + train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last') + + # Attach BEC head immediately after the specified convolution block + # num_convolution_layer >= 2 is assumed for 'penultimate' + if is_train_bec: + attach_bec = False + if train_bec_from == 'last' and t == num_convolution_layer - 1: + attach_bec = True + elif train_bec_from == 'penultimate' and t == num_convolution_layer - 2: + attach_bec = True + + if attach_bec: + layers.update( + { + 'predict_bec': IrrepsLinear( + irreps_x, + Irreps('1x0e+1x1e+1x2e'), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ) + } ) - } - ) layers.update(init_feature_reduce(config, irreps_x)) # type: ignore From 42266d008319eefc57d697b6f83127143f2bfbc7 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:59:44 +0000 Subject: [PATCH 2/5] feat: allow attaching BEC head to penultimate layer for performance Adds `train_bec_from` configuration (defaulting to 'last', with option 'penultimate'). When set to 'penultimate', the final convolution layer avoids computing expensive tensor features (`lmax>0`), improving efficiency while still extracting required features for the Born Effective Charge head. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/_const.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sevenn/_const.py b/sevenn/_const.py index 9c4a3c26..e4b6328d 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -283,6 +283,7 @@ def data_defaults(config): KEY.NUM_WORKERS: 0, KEY.IS_TRAIN_STRESS: True, KEY.IS_TRAIN_BEC: False, + KEY.TRAIN_BEC_FROM: 'last', KEY.TRAIN_SHUFFLE: True, KEY.ERROR_RECORD: [ ['Energy', 'RMSE'], @@ -317,6 +318,7 @@ def data_defaults(config): KEY.DEFAULT_MODAL: str, KEY.IS_TRAIN_STRESS: bool, KEY.IS_TRAIN_BEC: bool, + KEY.TRAIN_BEC_FROM: str, KEY.TRAIN_SHUFFLE: bool, KEY.ERROR_RECORD: error_record_condition, KEY.BEST_METRIC: str, @@ -336,6 +338,12 @@ def train_defaults(config): if KEY.IS_TRAIN_BEC not in config: config[KEY.IS_TRAIN_BEC] = defaults[KEY.IS_TRAIN_BEC] + if KEY.TRAIN_BEC_FROM not in config: + config[KEY.TRAIN_BEC_FROM] = defaults[KEY.TRAIN_BEC_FROM] + + if config[KEY.TRAIN_BEC_FROM] not in ['last', 'penultimate']: + raise ValueError(f"train_bec_from must be 'last' or 'penultimate', got {config[KEY.TRAIN_BEC_FROM]}") + # Automatically add BEC metrics if enabled and default err record if config[KEY.IS_TRAIN_BEC]: # If the user didn't explicitly provide an ERROR_RECORD, or if they provided From 97738984f8bfe65b7068f31add8258960c0c9de1 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:30:41 +0000 Subject: [PATCH 3/5] feat: allow attaching BEC head to penultimate layer for performance Adds `train_bec_from` configuration (defaulting to 'last', with option 'penultimate'). When set to 'penultimate', the final convolution layer avoids computing expensive tensor features (`lmax>0`), improving efficiency while still extracting required features for the Born Effective Charge head. Also includes a scalar filter to prevent FlashTP and PyTorch OOM bugs when the final layer is purely scalar. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/model_build.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sevenn/model_build.py b/sevenn/model_build.py index 384f984b..990369cb 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -635,6 +635,24 @@ def build_E3_equivariant_model( } ) + # If attached at penultimate, we must strip L>0 features from node features + # before the final layer. This prevents FlashTP from crashing (which happens + # when input has L>0 but output is forced to L=0) and saves significant memory. + if train_bec_from == 'penultimate': + irreps_x_scalar = Irreps([(mul, ir) for mul, ir in irreps_x if ir.l == 0]) + layers.update( + { + 'filter_scalar_features': IrrepsLinear( + irreps_x, + irreps_x_scalar, + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.NODE_FEATURE, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ) + } + ) + irreps_x = irreps_x_scalar + layers.update(init_feature_reduce(config, irreps_x)) # type: ignore layers.update( From 7dcc2e114f9602adf281d56e659cb0bc2a1148e6 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:31:14 +0000 Subject: [PATCH 4/5] feat: allow attaching BEC head to penultimate layer for performance Adds `train_bec_from` configuration (defaulting to 'last', with option 'penultimate'). When set to 'penultimate', the final convolution layer avoids computing expensive tensor features (`lmax>0`), improving efficiency while still extracting required features for the Born Effective Charge head. Also includes a scalar filter to prevent FlashTP and PyTorch OOM bugs when the final layer is purely scalar. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> From 2ca6a62db08031b3d930f9c76f3371f4dd6f6b6e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:46:01 +0000 Subject: [PATCH 5/5] feat: allow attaching BEC head to penultimate layer for performance Adds `train_bec_from` configuration (defaulting to 'last', with option 'penultimate'). When set to 'penultimate', the final convolution layer avoids computing expensive tensor features (`lmax>0`), improving efficiency while still extracting required features for the Born Effective Charge head. Also bypasses a FlashTP bug where compiling a layer producing only L=0 outputs from L>0 inputs crashes with KeyError: 0 by falling back to standard e3nn for that specific layer. Co-authored-by: AugustinLu <59640670+AugustinLu@users.noreply.github.com> --- sevenn/model_build.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sevenn/model_build.py b/sevenn/model_build.py index 990369cb..adcd729a 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -321,6 +321,11 @@ def patch_flash_tp(layers: OrderedDict, config: Dict[str, Any]) -> OrderedDict: updates = {} for k, module in layers.items(): if isinstance(module, IrrepsConvolution): + # FlashTP crashes with KeyError: 0 when mapping L>0 input directly to purely L=0 output + # (which happens in the final layer of a standard SevenNet architecture or penultimate branch). + # We explicitly skip patching FlashTP on the final layer if it purely outputs L=0. + if module.convolution_kwargs['irreps_out'].lmax == 0 and module.convolution_kwargs['irreps_in1'].lmax > 0: + continue updates[k] = flash_helper.patch_convolution(module, _flash_lammps) layers.update(updates)