Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def data_defaults(config):
KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True,
KEY.IS_TRAIN_BEC: False,
KEY.TRAIN_BEC_FROM: 'last',
KEY.TRAIN_SHUFFLE: True,
KEY.ERROR_RECORD: [
['Energy', 'RMSE'],
Expand Down Expand Up @@ -317,6 +318,7 @@ def data_defaults(config):
KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool,
KEY.IS_TRAIN_BEC: bool,
KEY.TRAIN_BEC_FROM: str,
KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str,
Expand All @@ -336,6 +338,12 @@ def train_defaults(config):
if KEY.IS_TRAIN_BEC not in config:
config[KEY.IS_TRAIN_BEC] = defaults[KEY.IS_TRAIN_BEC]

if KEY.TRAIN_BEC_FROM not in config:
config[KEY.TRAIN_BEC_FROM] = defaults[KEY.TRAIN_BEC_FROM]

if config[KEY.TRAIN_BEC_FROM] not in ['last', 'penultimate']:
raise ValueError(f"train_bec_from must be 'last' or 'penultimate', got {config[KEY.TRAIN_BEC_FROM]}")

# Automatically add BEC metrics if enabled and default err record
if config[KEY.IS_TRAIN_BEC]:
# If the user didn't explicitly provide an ERROR_RECORD, or if they provided
Expand Down
1 change: 1 addition & 0 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@

IS_TRAIN_STRESS = 'is_train_stress'
IS_TRAIN_BEC = 'is_train_bec'
TRAIN_BEC_FROM = 'train_bec_from'

CONTINUE = 'continue'
CHECKPOINT = 'checkpoint'
Expand Down
63 changes: 48 additions & 15 deletions sevenn/model_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,17 @@ def build_E3_equivariant_model(
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False

is_train_bec = config.get(KEY.IS_TRAIN_BEC, False)
train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last')

if t == num_convolution_layer - 1:
# If training BEC, we need vectors/tensors to survive the last layer
if config.get(KEY.IS_TRAIN_BEC, False):
# We need at least L=1 and L=2 for vectors and tensors.
if is_train_bec and train_bec_from == 'last':
# If training BEC from the last layer, we need vectors/tensors to survive
lmax_node = max(lmax_node, 2)
parity_mode = 'full'
else:
# Normal scalar output for final energy prediction
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
Expand Down Expand Up @@ -606,19 +610,48 @@ def build_E3_equivariant_model(
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out

if config.get(KEY.IS_TRAIN_BEC, False):
irreps_in_bec = irreps_x
layers.update(
{
'predict_bec': IrrepsLinear(
irreps_in_bec,
Irreps('1x0e+1x1e+1x2e'),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES,
biases=config[KEY.USE_BIAS_IN_LINEAR],
is_train_bec = config.get(KEY.IS_TRAIN_BEC, False)
train_bec_from = config.get(KEY.TRAIN_BEC_FROM, 'last')

# Attach BEC head immediately after the specified convolution block
# num_convolution_layer >= 2 is assumed for 'penultimate'
if is_train_bec:
attach_bec = False
if train_bec_from == 'last' and t == num_convolution_layer - 1:
attach_bec = True
elif train_bec_from == 'penultimate' and t == num_convolution_layer - 2:
attach_bec = True

if attach_bec:
layers.update(
{
'predict_bec': IrrepsLinear(
irreps_x,
Irreps('1x0e+1x1e+1x2e'),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES,
biases=config[KEY.USE_BIAS_IN_LINEAR],
)
}
)
}
)

# If attached at penultimate, we must strip L>0 features from node features
# before the final layer. This prevents FlashTP from crashing (which happens
# when input has L>0 but output is forced to L=0) and saves significant memory.
if train_bec_from == 'penultimate':
irreps_x_scalar = Irreps([(mul, ir) for mul, ir in irreps_x if ir.l == 0])
layers.update(
{
'filter_scalar_features': IrrepsLinear(
irreps_x,
irreps_x_scalar,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
)
}
)
irreps_x = irreps_x_scalar

layers.update(init_feature_reduce(config, irreps_x)) # type: ignore

Expand Down