Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
ENERGY: Final[str] = 'total_energy' # (1)
FORCE: Final[str] = 'force_of_atoms' # (N, 3)
STRESS: Final[str] = 'stress' # (6)
BANDGAP: Final[str] = 'bandgap' # (1)
MAGMOMS: Final[str] = 'magmoms' # (N, 1)

# This is for training, per atom scale.
SCALED_ENERGY: Final[str] = 'scaled_total_energy'
Expand All @@ -58,6 +60,11 @@
ATOMIC_ENERGY: Final[str] = 'atomic_energy'
PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy'

ATOMIC_BANDGAP: Final[str] = 'atomic_bandgap'
PRED_BANDGAP: Final[str] = 'inferred_bandgap'

PRED_MAGMOMS: Final[str] = 'inferred_magmoms'

PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy'
PER_ATOM_ENERGY: Final[str] = 'per_atom_energy'

Expand Down Expand Up @@ -124,6 +131,10 @@
TRAIN_SHUFFLE = 'train_shuffle'

IS_TRAIN_STRESS = 'is_train_stress'
IS_TRAIN_BANDGAP = 'is_train_bandgap'
IS_TRAIN_MAGMOMS = 'is_train_magmoms'
BANDGAP_WEIGHT = 'bandgap_loss_weight'
MAGMOMS_WEIGHT = 'magmoms_loss_weight'

CONTINUE = 'continue'
CHECKPOINT = 'checkpoint'
Expand Down
9 changes: 8 additions & 1 deletion sevenn/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,21 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]:
.numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation
)
# Store results
return {
results = {
'free_energy': energy,
'energy': energy,
'energies': atomic_energies,
'forces': forces,
'stress': stress,
'num_edges': output[KEY.EDGE_IDX].shape[1],
}
if KEY.PRED_BANDGAP in output:
results['bandgap'] = output[KEY.PRED_BANDGAP].detach().cpu().item()
if KEY.PRED_MAGMOMS in output:
results['magmoms'] = (
output[KEY.PRED_MAGMOMS].detach().cpu().numpy()[:num_atoms].flatten()
)
return results

def calculate(self, atoms=None, properties=None, system_changes=all_changes):
is_ts_type = isinstance(self.model, torch_script_type)
Expand Down
11 changes: 10 additions & 1 deletion sevenn/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,16 @@ def build_model(
state_dict = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model
)
missing, not_used = model.load_state_dict(state_dict, strict=True)
missing, not_used = model.load_state_dict(state_dict, strict=False)
missing = [
m
for m in missing
if not (
m.startswith('predict_atomic_bandgap.')
or m.startswith('predict_magmoms.')
or m.startswith('reduce_total_bandgap.')
)
]
assert len(missing) == 0, f'Missing keys: {missing}'
if len(not_used) > 0:
warnings.warn(f'Some keys are not used: {not_used}', UserWarning)
Expand Down
24 changes: 24 additions & 0 deletions sevenn/model_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,30 @@ def build_E3_equivariant_model(
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out

layers.update(
{
'predict_atomic_bandgap': IrrepsLinear(
irreps_in=irreps_x,
irreps_out=Irreps('1x0e'),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.ATOMIC_BANDGAP,
biases=use_bias_in_linear,
),
'reduce_total_bandgap': AtomReduce(
data_key_in=KEY.ATOMIC_BANDGAP,
data_key_out=KEY.PRED_BANDGAP,
reduce='mean',
),
'predict_magmoms': IrrepsLinear(
irreps_in=irreps_x,
irreps_out=Irreps('1x0e'),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.PRED_MAGMOMS,
biases=use_bias_in_linear,
),
}
)

layers.update(init_feature_reduce(config, irreps_x)) # type: ignore

layers.update(
Expand Down
13 changes: 12 additions & 1 deletion sevenn/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,20 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
device=src.device,
)
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
if self.reduce == 'mean':
counts = torch.zeros(size, dtype=src.dtype, device=src.device)
ones = torch.ones_like(src)
counts.scatter_reduce_(0, data[KEY.BATCH], ones, reduce='sum')
counts = counts.clamp(min=1.0)
output = output / counts
data[self.key_output] = output * self.constant
else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
if self.reduce == 'mean':
v = torch.mean(data[self.key_input]) * self.constant
data[self.key_output] = v
else:
v = torch.sum(data[self.key_input]) * self.constant
data[self.key_output] = v

return data

Expand Down
13 changes: 11 additions & 2 deletions sevenn/train/atoms_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def run_stat(self) -> None:
"""
if self._scanned is True:
return # statistics already computed
y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS]
y_keys: List[str] = [
KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS,
KEY.BANDGAP, KEY.MAGMOMS
]
natoms_counter = Counter()
composition = np.zeros((len(self), NUM_UNIV_ELEMENT))
stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys}
Expand All @@ -202,9 +205,15 @@ def run_stat(self) -> None:
dct['_array'].append(atoms.arrays['y_force'].reshape(-1))
elif y == KEY.STRESS:
dct['_array'].append(atoms.info['y_stress'].reshape(-1))
elif y == KEY.BANDGAP and 'y_bandgap' in atoms.info:
dct['_array'].append(atoms.info['y_bandgap'])
elif y == KEY.MAGMOMS and 'y_magmoms' in atoms.arrays:
dct['_array'].append(atoms.arrays['y_magmoms'].reshape(-1))

for y, dct in stats.items():
if y == KEY.FORCE:
if len(dct['_array']) == 0:
continue
if y in [KEY.FORCE, KEY.MAGMOMS]:
array = np.concatenate(dct['_array'])
else:
array = np.array(dct['_array']).reshape(-1)
Expand Down
12 changes: 12 additions & 0 deletions sevenn/train/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def atoms_to_graph(
KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)),
}

if 'y_bandgap' in atoms.info:
data[KEY.BANDGAP] = _correct_scalar(atoms.info['y_bandgap'])
if 'y_magmoms' in atoms.arrays:
data[KEY.MAGMOMS] = atoms.arrays['y_magmoms']

if with_shift:
data[KEY.CELL_SHIFT] = shift
data[KEY.CELL] = cell
Expand Down Expand Up @@ -302,6 +307,8 @@ def _set_atoms_y(
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
bandgap_key: Optional[str] = 'bandgap',
magmoms_key: Optional[str] = 'magmoms',
) -> List[ase.Atoms]:
"""
Define how SevenNet reads ASE.atoms object for its y label
Expand Down Expand Up @@ -345,6 +352,11 @@ def _set_atoms_y(
else:
atoms.info['y_stress'] = from_calc['stress']

if bandgap_key and bandgap_key in atoms.info:
atoms.info['y_bandgap'] = atoms.info[bandgap_key]
if magmoms_key and magmoms_key in atoms.arrays:
atoms.arrays['y_magmoms'] = atoms.arrays[magmoms_key]

return atoms_list


Expand Down
19 changes: 13 additions & 6 deletions sevenn/train/graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def _run_stat(
"""
Loop over dataset and init any statistics might need
"""
y_keys = y_keys or [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS]
y_keys = y_keys or [
KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS,
KEY.BANDGAP, KEY.MAGMOMS
]
n_neigh = []
natoms_counter = Counter()
composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT))
Expand All @@ -79,14 +82,18 @@ def _run_stat(
composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT)
n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1])
for y, dct in stats.items():
dct['_array'].append(
graph[y].reshape(
-1,
if y in graph:
dct['_array'].append(
graph[y].reshape(
-1,
)
)
)

stats.update({'num_neighbor': {'_array': n_neigh}})
for y, dct in stats.items():
for y, dct in list(stats.items()):
if not dct['_array']:
del stats[y]
continue
array = torch.cat(dct['_array'])
if array.dtype == torch.int64: # because of n_neigh
array = array.to(torch.float)
Expand Down
85 changes: 85 additions & 0 deletions sevenn/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,85 @@ def _preprocess(
return pred, ref, w_tensor


class BandGapLoss(LossDefinition):
"""
Loss for intensive band gap
"""

def __init__(
self,
name: str = 'BandGap',
unit: str = 'eV',
criterion: Optional[Callable] = None,
ref_key: str = KEY.BANDGAP,
pred_key: str = KEY.PRED_BANDGAP,
**kwargs,
) -> None:
super().__init__(
name=name,
unit=unit,
criterion=criterion,
ref_key=ref_key,
pred_key=pred_key,
**kwargs,
)

def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str)
pred = torch.reshape(batch_data[self.pred_key], (-1,))
ref = torch.reshape(batch_data[self.ref_key], (-1,))
w_tensor = None

if self.use_weight:
loss_type = self.name.lower()
weight = batch_data[KEY.DATA_WEIGHT][loss_type]
w_tensor = torch.repeat_interleave(weight, 1)

return pred, ref, w_tensor


class MagmomsLoss(LossDefinition):
"""
Loss for magnetic moments (node-level scalar)
"""

def __init__(
self,
name: str = 'Magmoms',
unit: str = 'mu_B',
criterion: Optional[Callable] = None,
ref_key: str = KEY.MAGMOMS,
pred_key: str = KEY.PRED_MAGMOMS,
**kwargs,
) -> None:
super().__init__(
name=name,
unit=unit,
criterion=criterion,
ref_key=ref_key,
pred_key=pred_key,
**kwargs,
)

def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str)
pred = torch.reshape(batch_data[self.pred_key], (-1,))
ref = torch.reshape(batch_data[self.ref_key], (-1,))
w_tensor = None

if self.use_weight:
loss_type = self.name.lower()
weight = batch_data[KEY.DATA_WEIGHT][loss_type]
w_tensor = weight[batch_data[KEY.BATCH]]
w_tensor = torch.repeat_interleave(w_tensor, 1)

return pred, ref, w_tensor


def get_loss_functions_from_config(
config: Dict[str, Any],
) -> List[Tuple[LossDefinition, float]]:
Expand All @@ -222,6 +301,12 @@ def get_loss_functions_from_config(
loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT]))
if config[KEY.IS_TRAIN_STRESS]:
loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT]))
if config.get(KEY.IS_TRAIN_BANDGAP, False):
w = config.get(KEY.BANDGAP_WEIGHT, 1.0)
loss_functions.append((BandGapLoss(**commons), w))
if config.get(KEY.IS_TRAIN_MAGMOMS, False):
w = config.get(KEY.MAGMOMS_WEIGHT, 1.0)
loss_functions.append((MagmomsLoss(**commons), w))

for loss_function, _ in loss_functions: # why do these?
if loss_function.criterion is None:
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc):
res_script = copy.copy(atoms_pbc.calc.results)

for k in res_cp:
if k in ['bandgap', 'magmoms']:
continue
assert np.allclose(res_cp[k], res_script[k], rtol=1e-4, atol=1e-4)


Expand All @@ -141,6 +143,8 @@ def test_sevennet_0_cal_as_instance_consistency(atoms_pbc):
res_script = copy.copy(atoms_pbc.calc.results)

for k in res_cp:
if k in ['bandgap', 'magmoms']:
continue
assert np.allclose(res_cp[k], res_script[k])


Expand Down
30 changes: 15 additions & 15 deletions tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,16 @@ def test_batch():


_n_param_tests = [
({}, 20642),
({'train_denominator': True}, 20642 + 3),
({'train_shift_scale': True}, 20642 + 2),
({'shift': [1.0] * 4}, 20642),
({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8),
({'num_convolution_layer': 4}, 33458),
({'lmax': 3}, 26866),
({'channel': 2}, 16883),
({'is_parity': False}, 20386),
({'self_connection_type': 'linear'}, 20114),
({}, 20642 + 8),
({'train_denominator': True}, 20642 + 8 + 3),
({'train_shift_scale': True}, 20642 + 8 + 2),
({'shift': [1.0] * 4}, 20642 + 8),
({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8 + 8),
({'num_convolution_layer': 4}, 33458 + 8),
({'lmax': 3}, 26866 + 8),
({'channel': 2}, 16883 + 4),
({'is_parity': False}, 20386 + 8),
({'self_connection_type': 'linear'}, 20114 + 8),
]


Expand All @@ -183,11 +183,11 @@ def test_num_params(cf, ref):


_n_modal_param_tests = [
({}, 20642),
({'use_modal_node_embedding': True}, 20642 + 8),
({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3),
({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)),
({'use_modal_output_block': True}, 20642 + 2 * 4 / 2),
({}, 20642 + 8),
({'use_modal_node_embedding': True}, 20642 + 8 + 8),
({'use_modal_self_inter_intro': True}, 20642 + 8 + 2 * 4 * 3),
({'use_modal_self_inter_outro': True}, 20642 + 8 + 2 * (12 + 20 + 4)),
({'use_modal_output_block': True}, 20642 + 8 + 2 * 4 / 2),
]


Expand Down
Loading