diff --git a/sevenn/_const.py b/sevenn/_const.py index 9c4a3c26..e4b6328d 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -283,6 +283,7 @@ def data_defaults(config): KEY.NUM_WORKERS: 0, KEY.IS_TRAIN_STRESS: True, KEY.IS_TRAIN_BEC: False, + KEY.TRAIN_BEC_FROM: 'last', KEY.TRAIN_SHUFFLE: True, KEY.ERROR_RECORD: [ ['Energy', 'RMSE'], @@ -317,6 +318,7 @@ def data_defaults(config): KEY.DEFAULT_MODAL: str, KEY.IS_TRAIN_STRESS: bool, KEY.IS_TRAIN_BEC: bool, + KEY.TRAIN_BEC_FROM: str, KEY.TRAIN_SHUFFLE: bool, KEY.ERROR_RECORD: error_record_condition, KEY.BEST_METRIC: str, @@ -336,6 +338,12 @@ def train_defaults(config): if KEY.IS_TRAIN_BEC not in config: config[KEY.IS_TRAIN_BEC] = defaults[KEY.IS_TRAIN_BEC] + if KEY.TRAIN_BEC_FROM not in config: + config[KEY.TRAIN_BEC_FROM] = defaults[KEY.TRAIN_BEC_FROM] + + if config[KEY.TRAIN_BEC_FROM] not in ['last', 'penultimate']: + raise ValueError(f"train_bec_from must be 'last' or 'penultimate', got {config[KEY.TRAIN_BEC_FROM]}") + # Automatically add BEC metrics if enabled and default err record if config[KEY.IS_TRAIN_BEC]: # If the user didn't explicitly provide an ERROR_RECORD, or if they provided diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 83587c8c..401985d7 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -130,6 +130,7 @@ IS_TRAIN_STRESS = 'is_train_stress' IS_TRAIN_BEC = 'is_train_bec' +TRAIN_BEC_FROM = 'train_bec_from' CONTINUE = 'continue' CHECKPOINT = 'checkpoint' diff --git a/sevenn/model_build.py b/sevenn/model_build.py index f07f6986..990369cb 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -567,13 +567,17 @@ def build_E3_equivariant_model( if interaction_type == 'nequip': parity_mode = 'full' fix_multiplicity = False + + is_train_bec = config.get(KEY.IS_TRAIN_BEC, False) + train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last') + if t == num_convolution_layer - 1: - # If training BEC, we need vectors/tensors to survive the last layer - if config.get(KEY.IS_TRAIN_BEC, False): - # We need at least L=1 and L=2 for vectors and tensors. + if is_train_bec and train_bec_from == 'last': + # If training BEC from the last layer, we need vectors/tensors to survive lmax_node = max(lmax_node, 2) parity_mode = 'full' else: + # Normal scalar output for final energy prediction lmax_node = 0 parity_mode = 'even' # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out @@ -606,19 +610,48 @@ def build_E3_equivariant_model( layers.update(interaction_builder(**param_interaction_block)) irreps_x = irreps_out - if config.get(KEY.IS_TRAIN_BEC, False): - irreps_in_bec = irreps_x - layers.update( - { - 'predict_bec': IrrepsLinear( - irreps_in_bec, - Irreps('1x0e+1x1e+1x2e'), - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES, - biases=config[KEY.USE_BIAS_IN_LINEAR], + is_train_bec = config.get(KEY.IS_TRAIN_BEC, False) + train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last') + + # Attach BEC head immediately after the specified convolution block + # num_convolution_layer >= 2 is assumed for 'penultimate' + if is_train_bec: + attach_bec = False + if train_bec_from == 'last' and t == num_convolution_layer - 1: + attach_bec = True + elif train_bec_from == 'penultimate' and t == num_convolution_layer - 2: + attach_bec = True + + if attach_bec: + layers.update( + { + 'predict_bec': IrrepsLinear( + irreps_x, + Irreps('1x0e+1x1e+1x2e'), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ) + } ) - } - ) + + # If attached at penultimate, we must strip L>0 features from node features + # before the final layer. This prevents FlashTP from crashing (which happens + # when input has L>0 but output is forced to L=0) and saves significant memory. + if train_bec_from == 'penultimate': + irreps_x_scalar = Irreps([(mul, ir) for mul, ir in irreps_x if ir.l == 0]) + layers.update( + { + 'filter_scalar_features': IrrepsLinear( + irreps_x, + irreps_x_scalar, + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.NODE_FEATURE, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ) + } + ) + irreps_x = irreps_x_scalar layers.update(init_feature_reduce(config, irreps_x)) # type: ignore