From 768b9bbfd95cf49ea1ea0e821893cedd03365e08 Mon Sep 17 00:00:00 2001 From: JustinBakerMath Date: Tue, 23 May 2023 17:36:22 +0000 Subject: [PATCH 1/3] dimenet baseline --- README.md | 2 +- hydragnn/models/DIMEStack.py | 239 +++++++++++++++++++++++++++++++++ hydragnn/models/create.py | 37 +++++ tests/inputs/ci.json | 6 + tests/inputs/ci_multihead.json | 6 + tests/test_graphs.py | 3 +- 6 files changed, 291 insertions(+), 2 deletions(-) create mode 100644 hydragnn/models/DIMEStack.py diff --git a/README.md b/README.md index 379b83954..2a7afcf6b 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ There are many options for HydraGNN; the dataset and model type are particularly important: - `["Verbosity"]["level"]`: `0`, `1`, `2`, `3`, `4` - `["Dataset"]["name"]`: `CuAu_32atoms`, `FePt_32atoms`, `FeSi_1024atoms` - - `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet` + - `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`, `DimeNet` ### Citations "HydraGNN: Distributed PyTorch implementation of multi-headed graph convolutional neural networks", Copyright ID#: 81929619 diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py new file mode 100644 index 000000000..644855774 --- /dev/null +++ b/hydragnn/models/DIMEStack.py @@ -0,0 +1,239 @@ +""" +DimeNet +======== +Directional message passing neural network +for molecular graphs. The convolutional +layer uses spherical and radial basis +functions to perform message passing. + +In particular this message passing layer +relies on the angle formed by the triplet +of incomming and outgoing messages. + +The three key components of this network are +outlined below. In particular, the convolutional +network that is used for the message passing +the triplet function that generates to/from +information for angular values, and finally +the radial basis embedding that is used to +include radial basis information. + +""" +from math import sqrt + +from typing import Callable, Optional, Tuple +from torch_geometric.typing import SparseTensor + +import torch +from torch import Tensor +from torch.nn import Embedding, Linear, SiLU + +from torch_geometric.nn.inits import glorot_orthogonal +from torch_geometric.nn.models.dimenet import BesselBasisLayer, SphericalBasisLayer, ResidualLayer +from torch_geometric.utils import scatter + +from .Base import Base + + +class DIMEStack(Base): + """ + Generates angles, distances, to/from indices, radial basis + functions and spherical basis functions for learning. + """ + def __init__(self, + num_bilinear, + num_radial, + num_spherical, + radius, + envelope_exponent, + num_before_skip, + num_after_skip, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.radius = radius + + self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent) + self.sbf = SphericalBasisLayer(num_spherical, num_radial, radius, + envelope_exponent) + + self.interact = Interaction( + hidden_channels=self.hidden_dim, + num_bilinear=num_bilinear, + num_spherical=num_spherical, + num_radial=num_radial, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + ) + + def _conv_args(self, data): + conv_args = {"edge_index": data.edge_index} + assert data.pos is not None, 'DimeNet requires node positions (data.pos) to be set.' + conv_args.update({"pos": data.pos}) + return conv_args + + + def forward(self, z, pos, edge_index): + + z = z.to(torch.long) + # edge_index = radius_graph(pos, r=self.radius, batch=batch, + # max_num_neighbors=self.max_num_neighbors) + i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(edge_index, num_nodes=z.size(0)) + dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() + + # Calculate angles. + pos_i = pos[idx_i] + pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i + a = (pos_ji * pos_ki).sum(dim=-1) + b = torch.cross(pos_ji, pos_ki).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(dist) + sbf = self.sbf(dist, angle, idx_kj) + z = self.interact(z, rbf, sbf, idx_kj, idx_ji) + # z = z + output_block(x, rbf, i, num_nodes=pos.size(0)) + + return z + + +class Interaction(torch.nn.Module): + def __init__( + self, + hidden_channels: int, + num_bilinear: int, + num_spherical: int, + num_radial: int, + num_before_skip: int, + num_after_skip: int, + ): + super().__init__() + self.act = SiLU() + + self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) + self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, + bias=False) + + # Dense transformations of input messages. + self.lin_from = Linear(hidden_channels, hidden_channels) + self.lin_to = Linear(hidden_channels, hidden_channels) + + self.W = torch.nn.Parameter( + torch.Tensor(hidden_channels, num_bilinear, hidden_channels)) + + self.layers_before_skip = torch.nn.ModuleList([ + ResidualLayer(hidden_channels, SiLU()) for _ in range(num_before_skip) + ]) + self.lin = Linear(hidden_channels, hidden_channels) + self.layers_after_skip = torch.nn.ModuleList([ + ResidualLayer(hidden_channels, SiLU()) for _ in range(num_after_skip) + ]) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf.weight, scale=2.0) + glorot_orthogonal(self.lin_sbf.weight, scale=2.0) + glorot_orthogonal(self.lin_from.weight, scale=2.0) + self.lin_from.bias.data.fill_(0) + glorot_orthogonal(self.lin_to.weight, scale=2.0) + self.lin_to.bias.data.fill_(0) + self.W.data.normal_(mean=0, std=2 / self.W.size(0)) + for res_layer in self.layers_before_skip: + res_layer.reset_parameters() + glorot_orthogonal(self.lin.weight, scale=2.0) + self.lin.bias.data.fill_(0) + for res_layer in self.layers_after_skip: + res_layer.reset_parameters() + + def forward(self, + x: Tensor, + radial_basis: Tensor, + spherical_basis: Tensor, + edge_index_from: Tensor, + edge_index_to: Tensor + ) -> Tensor: + + radial_basis = self.lin_rbf(radial_basis) + spherical_basis = self.lin_sbf(spherical_basis) + + x_kj = self.act(self.lin_from(x)) + x_kj = x_kj * radial_basis + x_kj = torch.einsum('wj,wl,ijl->wi', spherical_basis, x_kj[edge_index_from], self.W) + x_kj = scatter(x_kj, edge_index_to, dim=0, dim_size=x.size(0), reduce='sum') # message passing + + x_ji = self.act(self.lin_to(x)) + h = x_ji + x_kj # aggregates my learned message and my from messages to the next neighbor + + for layer in self.layers_before_skip: # this added resnet is not actually doing any message passing and is an interesting addition + h = layer(h) + h = self.act(self.lin(h)) + x # incorporates a residual connection to the input feature + for layer in self.layers_after_skip: + h = layer(h) + + return h + +""" +Triplets +--------- +Generates to/from edge_indices for +angle generating purposes. + +""" +def triplets( + edge_index: Tensor, + num_nodes: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + row, col = edge_index # j->i + + value = torch.arange(row.size(0), device=row.device) + adj_t = SparseTensor(row=col, col=row, value=value, + sparse_sizes=(num_nodes, num_nodes)) + adj_t_row = adj_t[row] + num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) + + # Node indices (k->j->i) for triplets. + idx_i = col.repeat_interleave(num_triplets) + idx_j = row.repeat_interleave(num_triplets) + idx_k = adj_t_row.storage.col() + mask = idx_i != idx_k # Remove i == k triplets. + idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] + + # Edge indices (k-j, j->i) for triplets. + idx_kj = adj_t_row.storage.value()[mask] + idx_ji = adj_t_row.storage.row()[mask] + + return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji + + +""" +EmbeddingBlock +--------------- +An embedding block that utilizes the +radial basis function and the to/from +information in the embedding by +concatentating the to/from nodes with +the radial basis functions. + +""" +class EmbeddingBlock(torch.nn.Module): + def __init__(self, num_radial: int, hidden_channels: int, act: Callable): + super().__init__() + self.act = act + + self.emb = Embedding(95, hidden_channels) + self.lin_rbf = Linear(num_radial, hidden_channels) + self.lin = Linear(3 * hidden_channels, hidden_channels) + + self.reset_parameters() + + def reset_parameters(self): + self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) + self.lin_rbf.reset_parameters() + self.lin.reset_parameters() + + def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor: + x = self.emb(x) + rbf = self.act(self.lin_rbf(rbf)) + return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) \ No newline at end of file diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 806f93199..7ccc529e4 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -20,6 +20,7 @@ from hydragnn.models.CGCNNStack import CGCNNStack from hydragnn.models.SAGEStack import SAGEStack from hydragnn.models.SCFStack import SCFStack +from hydragnn.models.DIMEStack import DIMEStack from hydragnn.utils.distributed import get_device from hydragnn.utils.print_utils import print_distributed @@ -72,6 +73,12 @@ def create_model( max_neighbours: int = None, edge_dim: int = None, pna_deg: torch.tensor = None, + num_before_skip: int = None, + num_after_skip: int = None, + num_bilinear: int = None, + num_radial: int = None, + num_spherical: int = None, + envelope_exponent: int = None, num_gaussians: int = None, num_filters: int = None, radius: float = None, @@ -206,6 +213,36 @@ def create_model( num_nodes=num_nodes, ) + elif model_type == "DimeNet": + assert num_bilinear is not None, "DimeNet requires num_bilinear input." + assert num_radial is not None, "DimeNet requires num_radial input." + assert num_spherical is not None, "DimeNet requires num_spherical input." + assert envelope_exponent is not None, "DimeNet requires envelope_exponent input." + assert num_before_skip is not None, "DimeNet requires num_before_skip input." + assert num_after_skip is not None, "DimeNet requires num_after_skip input." + assert radius is not None, "DimeNet requires radius input." + model = DIMEStack( + num_bilinear, + num_radial, + num_spherical, + radius, + envelope_exponent, + num_before_skip, + num_after_skip, + input_dim, + hidden_dim, + output_dim, + output_type, + output_heads, + loss_function_type, + max_neighbours=max_neighbours, + loss_weights=task_weights, + freeze_conv=freeze_conv, + initial_bias=initial_bias, + num_conv_layers=num_conv_layers, + num_nodes=num_nodes, + ) + else: raise ValueError("Unknown model_type: {0}".format(model_type)) diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index 3a141cb74..4d9040428 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -30,6 +30,12 @@ "max_neighbours": 100, "num_gaussians": 50, "num_filters": 126, + "num_before_skip": 1, + "num_after_skip": 1, + "num_bilinear": 2, + "num_radial": 2, + "num_spherical": 2, + "envelope_exponent": 5, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index aeb89f267..51f3d7993 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -28,6 +28,12 @@ "max_neighbours": 100, "num_gaussians": 50, "num_filters": 126, + "num_before_skip": 1, + "num_after_skip": 1, + "num_bilinear": 2, + "num_radial": 2, + "num_spherical": 2, + "envelope_exponent": 5, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 75ce792df..6adc7e6b4 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -131,6 +131,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False "GAT": [0.60, 0.70], "CGCNN": [0.50, 0.40], "SchNet": [0.20, 0.20], + "DimeNet": [0.20, 0.20], } if use_lengths and ("vector" not in ci_input): thresholds["CGCNN"] = [0.175, 0.175] @@ -173,7 +174,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False # Test across all models with both single/multihead @pytest.mark.parametrize( - "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet"] + "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet", "DimeNet"] ) @pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"]) def pytest_train_model(model_type, ci_input, overwrite_data=False): From b45ea5560a2fa8d0fcb4f115de8c026614299767 Mon Sep 17 00:00:00 2001 From: JustinBakerMath Date: Tue, 23 May 2023 19:01:27 +0000 Subject: [PATCH 2/3] black formatting --- examples/csce/train_gap.py | 6 +- examples/eam/eam.py | 6 +- examples/ising_model/create_configurations.py | 12 +-- examples/ising_model/train_ising.py | 13 ++- examples/lsms/lsms.py | 6 +- examples/md17/md17.py | 7 +- examples/ogb/train_gap.py | 6 +- examples/qm9/qm9.py | 7 +- hydragnn/models/DIMEStack.py | 95 ++++++++++++------- hydragnn/models/PNAStack.py | 1 - hydragnn/models/create.py | 4 +- hydragnn/postprocess/visualizer.py | 5 +- hydragnn/preprocess/cfg_raw_dataset_loader.py | 2 - .../compositional_data_splitting.py | 1 + hydragnn/preprocess/load_data.py | 2 - hydragnn/preprocess/raw_dataset_loader.py | 1 - .../preprocess/serialized_dataset_loader.py | 2 +- hydragnn/preprocess/utils.py | 1 + hydragnn/run_prediction.py | 2 - hydragnn/run_training.py | 2 - hydragnn/train/train_validate_test.py | 2 - hydragnn/utils/abstractrawdataset.py | 3 +- hydragnn/utils/atomicdescriptors.py | 2 - hydragnn/utils/cfgdataset.py | 2 - hydragnn/utils/distributed.py | 4 - hydragnn/utils/smiles_utils.py | 1 - hydragnn/utils/time_utils.py | 1 - hydragnn/utils/xyzdataset.py | 2 - setup.py | 1 + tests/deterministic_graph_data.py | 4 +- tests/test_config.py | 1 - tests/test_enthalpy.py | 1 - tests/test_graphs.py | 2 +- utils/lsms/compositional_histogram_cutoff.py | 1 - ...convert_total_energy_to_formation_gibbs.py | 3 - 35 files changed, 118 insertions(+), 93 deletions(-) diff --git a/examples/csce/train_gap.py b/examples/csce/train_gap.py index 360647b4f..ca8520cef 100644 --- a/examples/csce/train_gap.py +++ b/examples/csce/train_gap.py @@ -322,7 +322,11 @@ def __getitem__(self, idx): % (len(trainset), len(valset), len(testset)) ) - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + ( + train_loader, + val_loader, + test_loader, + ) = hydragnn.preprocess.create_dataloaders( trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] ) diff --git a/examples/eam/eam.py b/examples/eam/eam.py index 64a8b804e..ea8c902d5 100644 --- a/examples/eam/eam.py +++ b/examples/eam/eam.py @@ -165,7 +165,11 @@ def info(*args, logtype="info", sep=" "): % (len(trainset), len(valset), len(testset)) ) - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + ( + train_loader, + val_loader, + test_loader, + ) = hydragnn.preprocess.create_dataloaders( trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] ) timer.stop() diff --git a/examples/ising_model/create_configurations.py b/examples/ising_model/create_configurations.py index 16ac2c07b..140f7a354 100644 --- a/examples/ising_model/create_configurations.py +++ b/examples/ising_model/create_configurations.py @@ -8,7 +8,6 @@ def write_to_file(total_energy, atomic_features, count_config, dir): - numpy_string_total_value = np.array2string(total_energy) filetxt = numpy_string_total_value @@ -40,7 +39,7 @@ def E_dimensionless(config, L, spin_function, scale_spin): spin[x, y, z] = spin_function(config[x, y, z]) count_pos = 0 - number_nodes = L ** 3 + number_nodes = L**3 positions = np.zeros((number_nodes, 3)) atomic_features = np.zeros((number_nodes, 5)) for x in range(L): @@ -76,18 +75,16 @@ def E_dimensionless(config, L, spin_function, scale_spin): def create_dataset( L, histogram_cutoff, dir, spin_function=lambda x: x, scale_spin=False ): - count_config = 0 - for num_downs in tqdm(range(0, L ** 3)): - - primal_configuration = np.ones((L ** 3,)) + for num_downs in tqdm(range(0, L**3)): + primal_configuration = np.ones((L**3,)) for down in range(0, num_downs): primal_configuration[down] = -1.0 # If the current composition has a total number of possible configurations above # the hard cutoff threshold, a random configurational subset is picked - if scipy.special.binom(L ** 3, num_downs) > histogram_cutoff: + if scipy.special.binom(L**3, num_downs) > histogram_cutoff: for num_config in range(0, histogram_cutoff): config = np.random.permutation(primal_configuration) config = np.reshape(config, (L, L, L)) @@ -115,7 +112,6 @@ def create_dataset( if __name__ == "__main__": - dir = os.path.join(os.path.dirname(__file__), "../../dataset/ising_model") if os.path.exists(dir): shutil.rmtree(dir) diff --git a/examples/ising_model/train_ising.py b/examples/ising_model/train_ising.py index b4390d9ea..919db604b 100644 --- a/examples/ising_model/train_ising.py +++ b/examples/ising_model/train_ising.py @@ -43,7 +43,6 @@ def write_to_file(total_energy, atomic_features, count_config, dir, prefix): - numpy_string_total_value = np.array2string(total_energy) filetxt = numpy_string_total_value @@ -67,7 +66,7 @@ def create_dataset_mpi( comm_size = comm.Get_size() count_config = 0 - rx = list(nsplit(range(0, L ** 3), comm_size))[rank] + rx = list(nsplit(range(0, L**3), comm_size))[rank] info("rx", rx.start, rx.stop) for num_downs in iterate_tqdm( @@ -75,13 +74,13 @@ def create_dataset_mpi( ): prefix = "output_%d_" % num_downs - primal_configuration = np.ones((L ** 3,)) + primal_configuration = np.ones((L**3,)) for down in range(0, num_downs): primal_configuration[down] = -1.0 # If the current composition has a total number of possible configurations above # the hard cutoff threshold, a random configurational subset is picked - if scipy.special.binom(L ** 3, num_downs) > histogram_cutoff: + if scipy.special.binom(L**3, num_downs) > histogram_cutoff: for num_config in range(0, histogram_cutoff): config = np.random.permutation(primal_configuration) config = np.reshape(config, (L, L, L)) @@ -288,7 +287,11 @@ def info(*args, logtype="info", sep=" "): % (len(trainset), len(valset), len(testset)) ) - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + ( + train_loader, + val_loader, + test_loader, + ) = hydragnn.preprocess.create_dataloaders( trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] ) timer.stop() diff --git a/examples/lsms/lsms.py b/examples/lsms/lsms.py index 8d6654e1b..282c740b4 100644 --- a/examples/lsms/lsms.py +++ b/examples/lsms/lsms.py @@ -164,7 +164,11 @@ def info(*args, logtype="info", sep=" "): % (len(trainset), len(valset), len(testset)) ) - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + ( + train_loader, + val_loader, + test_loader, + ) = hydragnn.preprocess.create_dataloaders( trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] ) timer.stop() diff --git a/examples/md17/md17.py b/examples/md17/md17.py index 6e38a3b3d..bfcd80a38 100644 --- a/examples/md17/md17.py +++ b/examples/md17/md17.py @@ -11,6 +11,7 @@ import hydragnn + # Update each sample prior to loading. def md17_pre_transform(data): # Set descriptor as element type. @@ -68,7 +69,11 @@ def md17_pre_filter(data): train, val, test = hydragnn.preprocess.split_dataset( dataset, config["NeuralNetwork"]["Training"]["perc_train"], False ) -(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( +( + train_loader, + val_loader, + test_loader, +) = hydragnn.preprocess.create_dataloaders( train, val, test, config["NeuralNetwork"]["Training"]["batch_size"] ) diff --git a/examples/ogb/train_gap.py b/examples/ogb/train_gap.py index 47e360094..1214b66bc 100644 --- a/examples/ogb/train_gap.py +++ b/examples/ogb/train_gap.py @@ -334,7 +334,11 @@ def __getitem__(self, idx): % (len(trainset), len(valset), len(testset)) ) - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + ( + train_loader, + val_loader, + test_loader, + ) = hydragnn.preprocess.create_dataloaders( trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] ) diff --git a/examples/qm9/qm9.py b/examples/qm9/qm9.py index cd2943a29..e4f80b01b 100644 --- a/examples/qm9/qm9.py +++ b/examples/qm9/qm9.py @@ -11,6 +11,7 @@ import hydragnn + # Update each sample prior to loading. def qm9_pre_transform(data): # Set descriptor as element type. @@ -59,7 +60,11 @@ def qm9_pre_filter(data): train, val, test = hydragnn.preprocess.split_dataset( dataset, config["NeuralNetwork"]["Training"]["perc_train"], False ) -(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( +( + train_loader, + val_loader, + test_loader, +) = hydragnn.preprocess.create_dataloaders( train, val, test, config["NeuralNetwork"]["Training"]["batch_size"] ) diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py index 644855774..5ba5e5191 100644 --- a/hydragnn/models/DIMEStack.py +++ b/hydragnn/models/DIMEStack.py @@ -29,7 +29,11 @@ from torch.nn import Embedding, Linear, SiLU from torch_geometric.nn.inits import glorot_orthogonal -from torch_geometric.nn.models.dimenet import BesselBasisLayer, SphericalBasisLayer, ResidualLayer +from torch_geometric.nn.models.dimenet import ( + BesselBasisLayer, + SphericalBasisLayer, + ResidualLayer, +) from torch_geometric.utils import scatter from .Base import Base @@ -40,7 +44,9 @@ class DIMEStack(Base): Generates angles, distances, to/from indices, radial basis functions and spherical basis functions for learning. """ - def __init__(self, + + def __init__( + self, num_bilinear, num_radial, num_spherical, @@ -56,8 +62,9 @@ def __init__(self, self.radius = radius self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent) - self.sbf = SphericalBasisLayer(num_spherical, num_radial, radius, - envelope_exponent) + self.sbf = SphericalBasisLayer( + num_spherical, num_radial, radius, envelope_exponent + ) self.interact = Interaction( hidden_channels=self.hidden_dim, @@ -70,17 +77,19 @@ def __init__(self, def _conv_args(self, data): conv_args = {"edge_index": data.edge_index} - assert data.pos is not None, 'DimeNet requires node positions (data.pos) to be set.' + assert ( + data.pos is not None + ), "DimeNet requires node positions (data.pos) to be set." conv_args.update({"pos": data.pos}) return conv_args - def forward(self, z, pos, edge_index): - z = z.to(torch.long) # edge_index = radius_graph(pos, r=self.radius, batch=batch, # max_num_neighbors=self.max_num_neighbors) - i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(edge_index, num_nodes=z.size(0)) + i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( + edge_index, num_nodes=z.size(0) + ) dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. @@ -112,23 +121,23 @@ def __init__( self.act = SiLU() self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) - self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, - bias=False) + self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, bias=False) # Dense transformations of input messages. self.lin_from = Linear(hidden_channels, hidden_channels) self.lin_to = Linear(hidden_channels, hidden_channels) self.W = torch.nn.Parameter( - torch.Tensor(hidden_channels, num_bilinear, hidden_channels)) + torch.Tensor(hidden_channels, num_bilinear, hidden_channels) + ) - self.layers_before_skip = torch.nn.ModuleList([ - ResidualLayer(hidden_channels, SiLU()) for _ in range(num_before_skip) - ]) + self.layers_before_skip = torch.nn.ModuleList( + [ResidualLayer(hidden_channels, SiLU()) for _ in range(num_before_skip)] + ) self.lin = Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList([ - ResidualLayer(hidden_channels, SiLU()) for _ in range(num_after_skip) - ]) + self.layers_after_skip = torch.nn.ModuleList( + [ResidualLayer(hidden_channels, SiLU()) for _ in range(num_after_skip)] + ) self.reset_parameters() @@ -147,33 +156,46 @@ def reset_parameters(self): for res_layer in self.layers_after_skip: res_layer.reset_parameters() - def forward(self, - x: Tensor, - radial_basis: Tensor, - spherical_basis: Tensor, - edge_index_from: Tensor, - edge_index_to: Tensor - ) -> Tensor: - + def forward( + self, + x: Tensor, + radial_basis: Tensor, + spherical_basis: Tensor, + edge_index_from: Tensor, + edge_index_to: Tensor, + ) -> Tensor: radial_basis = self.lin_rbf(radial_basis) spherical_basis = self.lin_sbf(spherical_basis) x_kj = self.act(self.lin_from(x)) x_kj = x_kj * radial_basis - x_kj = torch.einsum('wj,wl,ijl->wi', spherical_basis, x_kj[edge_index_from], self.W) - x_kj = scatter(x_kj, edge_index_to, dim=0, dim_size=x.size(0), reduce='sum') # message passing + x_kj = torch.einsum( + "wj,wl,ijl->wi", spherical_basis, x_kj[edge_index_from], self.W + ) + x_kj = scatter( + x_kj, edge_index_to, dim=0, dim_size=x.size(0), reduce="sum" + ) # message passing x_ji = self.act(self.lin_to(x)) - h = x_ji + x_kj # aggregates my learned message and my from messages to the next neighbor - - for layer in self.layers_before_skip: # this added resnet is not actually doing any message passing and is an interesting addition + h = ( + x_ji + x_kj + ) # aggregates my learned message and my from messages to the next neighbor + + for ( + layer + ) in ( + self.layers_before_skip + ): # this added resnet is not actually doing any message passing and is an interesting addition h = layer(h) - h = self.act(self.lin(h)) + x # incorporates a residual connection to the input feature + h = ( + self.act(self.lin(h)) + x + ) # incorporates a residual connection to the input feature for layer in self.layers_after_skip: h = layer(h) return h + """ Triplets --------- @@ -181,6 +203,8 @@ def forward(self, angle generating purposes. """ + + def triplets( edge_index: Tensor, num_nodes: int, @@ -188,8 +212,9 @@ def triplets( row, col = edge_index # j->i value = torch.arange(row.size(0), device=row.device) - adj_t = SparseTensor(row=col, col=row, value=value, - sparse_sizes=(num_nodes, num_nodes)) + adj_t = SparseTensor( + row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) + ) adj_t_row = adj_t[row] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) @@ -217,6 +242,8 @@ def triplets( the radial basis functions. """ + + class EmbeddingBlock(torch.nn.Module): def __init__(self, num_radial: int, hidden_channels: int, act: Callable): super().__init__() @@ -236,4 +263,4 @@ def reset_parameters(self): def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor: x = self.emb(x) rbf = self.act(self.lin_rbf(rbf)) - return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) \ No newline at end of file + return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) diff --git a/hydragnn/models/PNAStack.py b/hydragnn/models/PNAStack.py index 427363f1d..a7f5353cb 100644 --- a/hydragnn/models/PNAStack.py +++ b/hydragnn/models/PNAStack.py @@ -24,7 +24,6 @@ def __init__( *args, **kwargs, ): - self.aggregators = ["mean", "min", "max", "std"] self.scalers = [ "identity", diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 7ccc529e4..84fe3df92 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -217,7 +217,9 @@ def create_model( assert num_bilinear is not None, "DimeNet requires num_bilinear input." assert num_radial is not None, "DimeNet requires num_radial input." assert num_spherical is not None, "DimeNet requires num_spherical input." - assert envelope_exponent is not None, "DimeNet requires envelope_exponent input." + assert ( + envelope_exponent is not None + ), "DimeNet requires envelope_exponent input." assert num_before_skip is not None, "DimeNet requires num_before_skip input." assert num_after_skip is not None, "DimeNet requires num_after_skip input." assert radius is not None, "DimeNet requires radius input." diff --git a/hydragnn/postprocess/visualizer.py b/hydragnn/postprocess/visualizer.py index 83ee4701f..d43498fb8 100644 --- a/hydragnn/postprocess/visualizer.py +++ b/hydragnn/postprocess/visualizer.py @@ -116,7 +116,6 @@ def __scatter_impl( y_label=None, xylim_equal=False, ): - ax.scatter(x, y, s=s, edgecolor="b", marker=marker, facecolor="none") ax.set_title(title + ", number of samples =" + str(len(x))) @@ -179,10 +178,10 @@ def create_plot_global_analysis( vsum_pred = [] for isamp in range(nshape[0]): vlen_true.append( - sqrt(sum([comp ** 2 for comp in true_values[isamp][:]])) + sqrt(sum([comp**2 for comp in true_values[isamp][:]])) ) vlen_pred.append( - sqrt(sum([comp ** 2 for comp in predicted_values[isamp][:]])) + sqrt(sum([comp**2 for comp in predicted_values[isamp][:]])) ) vsum_true.append(sum(true_values[isamp][:])) vsum_pred.append(sum(predicted_values[isamp][:])) diff --git a/hydragnn/preprocess/cfg_raw_dataset_loader.py b/hydragnn/preprocess/cfg_raw_dataset_loader.py index b5043abb1..977707b61 100644 --- a/hydragnn/preprocess/cfg_raw_dataset_loader.py +++ b/hydragnn/preprocess/cfg_raw_dataset_loader.py @@ -55,7 +55,6 @@ def __transform_CFG_input_to_data_object_base(self, filepath): """ if filepath.endswith(".cfg"): - data_object = self.__transform_ASE_object_to_data_object(filepath) return data_object @@ -64,7 +63,6 @@ def __transform_CFG_input_to_data_object_base(self, filepath): return None def __transform_ASE_object_to_data_object(self, filepath): - # FIXME: # this still assumes bulk modulus is specific to the CFG format. # To deal with multiple files across formats, one should generalize this function diff --git a/hydragnn/preprocess/compositional_data_splitting.py b/hydragnn/preprocess/compositional_data_splitting.py index 574c10dcf..89fdc5655 100644 --- a/hydragnn/preprocess/compositional_data_splitting.py +++ b/hydragnn/preprocess/compositional_data_splitting.py @@ -14,6 +14,7 @@ import torch import sklearn + # function to return key for any value def get_keys(dictionary, val): keys = [] diff --git a/hydragnn/preprocess/load_data.py b/hydragnn/preprocess/load_data.py index 27533fea4..642ee4712 100644 --- a/hydragnn/preprocess/load_data.py +++ b/hydragnn/preprocess/load_data.py @@ -225,7 +225,6 @@ def dataset_loading_and_splitting(config: {}): def create_dataloaders(trainset, valset, testset, batch_size): if dist.is_initialized(): - train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) val_sampler = torch.utils.data.distributed.DistributedSampler(valset) test_sampler = torch.utils.data.distributed.DistributedSampler(testset) @@ -267,7 +266,6 @@ def create_dataloaders(trainset, valset, testset, batch_size): ) else: - train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) val_loader = DataLoader( valset, diff --git a/hydragnn/preprocess/raw_dataset_loader.py b/hydragnn/preprocess/raw_dataset_loader.py index c0443bf2a..9be45ebe3 100644 --- a/hydragnn/preprocess/raw_dataset_loader.py +++ b/hydragnn/preprocess/raw_dataset_loader.py @@ -192,7 +192,6 @@ def scale_features_by_num_nodes(self, dataset): return dataset def normalize_dataset(self): - """Performs the normalization on Data objects and returns the normalized dataset.""" num_node_features = len(self.node_feature_dim) num_graph_features = len(self.graph_feature_dim) diff --git a/hydragnn/preprocess/serialized_dataset_loader.py b/hydragnn/preprocess/serialized_dataset_loader.py index 6d3028ea4..be9d8f94f 100644 --- a/hydragnn/preprocess/serialized_dataset_loader.py +++ b/hydragnn/preprocess/serialized_dataset_loader.py @@ -238,7 +238,7 @@ def __stratified_sampling(self, dataset: [Data], subsample_percentage: float): frequencies = sorted(frequencies[frequencies > 0].tolist()) category = 0 for index, frequency in enumerate(frequencies): - category += frequency * (100 ** index) + category += frequency * (100**index) dataset_categories.append(category) subsample_indices = [] diff --git a/hydragnn/preprocess/utils.py b/hydragnn/preprocess/utils.py index bf41d7246..ddc6cca77 100644 --- a/hydragnn/preprocess/utils.py +++ b/hydragnn/preprocess/utils.py @@ -17,6 +17,7 @@ import ase.neighborlist import os + ## This function can be slow if dataset is too large. Use with caution. ## Recommend to use check_if_graph_size_variable_dist def check_if_graph_size_variable(train_loader, val_loader, test_loader): diff --git a/hydragnn/run_prediction.py b/hydragnn/run_prediction.py index 0d997e085..2ff7aecf6 100755 --- a/hydragnn/run_prediction.py +++ b/hydragnn/run_prediction.py @@ -31,7 +31,6 @@ def run_prediction(config): @run_prediction.register def _(config_file: str): - with open(config_file, "r") as f: config = json.load(f) @@ -40,7 +39,6 @@ def _(config_file: str): @run_prediction.register def _(config: dict): - try: os.environ["SERIALIZED_DATA_PATH"] except: diff --git a/hydragnn/run_training.py b/hydragnn/run_training.py index ade8d1681..c7f2aa863 100644 --- a/hydragnn/run_training.py +++ b/hydragnn/run_training.py @@ -46,7 +46,6 @@ def run_training(config): @run_training.register def _(config_file: str): - with open(config_file, "r") as f: config = json.load(f) @@ -55,7 +54,6 @@ def _(config_file: str): @run_training.register def _(config: dict): - try: os.environ["SERIALIZED_DATA_PATH"] except: diff --git a/hydragnn/train/train_validate_test.py b/hydragnn/train/train_validate_test.py index 3e165ec68..0cc83d38f 100644 --- a/hydragnn/train/train_validate_test.py +++ b/hydragnn/train/train_validate_test.py @@ -373,7 +373,6 @@ def train( @torch.no_grad() def validate(loader, model, verbosity, reduce_ranks=True): - total_error = torch.tensor(0.0, device=get_device()) tasks_error = torch.zeros(model.module.num_heads, device=get_device()) num_samples_local = 0 @@ -398,7 +397,6 @@ def validate(loader, model, verbosity, reduce_ranks=True): @torch.no_grad() def test(loader, model, verbosity, reduce_ranks=True, return_samples=True): - total_error = torch.tensor(0.0, device=get_device()) tasks_error = torch.zeros(model.module.num_heads, device=get_device()) num_samples_local = 0 diff --git a/hydragnn/utils/abstractrawdataset.py b/hydragnn/utils/abstractrawdataset.py index edce58e56..f6aafdf2b 100644 --- a/hydragnn/utils/abstractrawdataset.py +++ b/hydragnn/utils/abstractrawdataset.py @@ -189,7 +189,6 @@ def __load_raw_data(self): self.__normalize_dataset() def __normalize_dataset(self): - """Performs the normalization on Data objects and returns the normalized dataset.""" num_node_features = len(self.node_feature_dim) num_graph_features = len(self.graph_feature_dim) @@ -434,7 +433,7 @@ def stratified_sampling(dataset: [Data], subsample_percentage: float, verbosity= frequencies = sorted(frequencies[frequencies > 0].tolist()) category = 0 for index, frequency in enumerate(frequencies): - category += frequency * (100 ** index) + category += frequency * (100**index) dataset_categories.append(category) subsample_indices = [] diff --git a/hydragnn/utils/atomicdescriptors.py b/hydragnn/utils/atomicdescriptors.py index 7c4d95035..b4b2d986b 100644 --- a/hydragnn/utils/atomicdescriptors.py +++ b/hydragnn/utils/atomicdescriptors.py @@ -124,7 +124,6 @@ def get_period(self, num_classes=-1): return torch.Tensor(period).reshape(len(self.element_types), -1) def __propertynormalize__(self, prop_list, prop_name): - None_elements = [ ele for ele, item in zip(self.element_types, prop_list) if item is None ] @@ -138,7 +137,6 @@ def __propertynormalize__(self, prop_list, prop_name): return [(item - minval) / (maxval - minval) for item in prop_list] def __realtocategorical__(self, prop_tensor, num_classes=10): - delval = (prop_tensor.max() - prop_tensor.min()) / num_classes categories = torch.minimum( (prop_tensor - prop_tensor.min()) / delval, torch.tensor([num_classes - 1]) diff --git a/hydragnn/utils/cfgdataset.py b/hydragnn/utils/cfgdataset.py index 5e7c59e7d..eecfd00e3 100644 --- a/hydragnn/utils/cfgdataset.py +++ b/hydragnn/utils/cfgdataset.py @@ -30,7 +30,6 @@ def __transform_CFG_input_to_data_object_base(self, filepath): """ if filepath.endswith(".cfg"): - data_object = self.__transform_ASE_object_to_data_object(filepath) return data_object @@ -39,7 +38,6 @@ def __transform_CFG_input_to_data_object_base(self, filepath): return None def __transform_ASE_object_to_data_object(self, filepath): - # FIXME: # this still assumes bulk modulus is specific to the CFG format. # To deal with multiple files across formats, one should generalize this function diff --git a/hydragnn/utils/distributed.py b/hydragnn/utils/distributed.py index 69117dcd8..5c16bedd6 100644 --- a/hydragnn/utils/distributed.py +++ b/hydragnn/utils/distributed.py @@ -163,14 +163,12 @@ def setup_ddp(): def get_device_list(): - available_gpus = [i for i in range(torch.cuda.device_count())] return available_gpus def get_device_name(use_gpu=True, rank_per_model=1, verbosity_level=0): - available_gpus = get_device_list() if not use_gpu or not available_gpus: print_distributed(verbosity_level, "Using CPU") @@ -203,12 +201,10 @@ def get_device_name(use_gpu=True, rank_per_model=1, verbosity_level=0): def get_device_from_name(name: str): - return torch.device(name) def get_device(use_gpu=True, rank_per_model=1, verbosity_level=0): - name = get_device_name(use_gpu, rank_per_model, verbosity_level) return get_device_from_name(name) diff --git a/hydragnn/utils/smiles_utils.py b/hydragnn/utils/smiles_utils.py index 30e32a719..dca24908e 100644 --- a/hydragnn/utils/smiles_utils.py +++ b/hydragnn/utils/smiles_utils.py @@ -33,7 +33,6 @@ def get_node_attribute_name(types): def generate_graphdata_from_smilestr(simlestr, ytarget, types, var_config=None): - ps = Chem.SmilesParserParams() ps.removeHs = False diff --git a/hydragnn/utils/time_utils.py b/hydragnn/utils/time_utils.py index f30bb9b11..1653ad319 100644 --- a/hydragnn/utils/time_utils.py +++ b/hydragnn/utils/time_utils.py @@ -93,7 +93,6 @@ def reset(self): def print_timers(verbosity): - world_size, world_rank = get_comm_size_and_rank() # With proper lever of verbosity >=1, the local timers will have different values per process diff --git a/hydragnn/utils/xyzdataset.py b/hydragnn/utils/xyzdataset.py index b7c89be30..612e8df80 100644 --- a/hydragnn/utils/xyzdataset.py +++ b/hydragnn/utils/xyzdataset.py @@ -31,7 +31,6 @@ def __transform_XYZ_input_to_data_object_base(self, filepath): """ if filepath.endswith(".xyz"): - data_object = self.__transform_XYZ_ASE_object_to_data_object(filepath) return data_object @@ -40,7 +39,6 @@ def __transform_XYZ_input_to_data_object_base(self, filepath): return None def __transform_XYZ_ASE_object_to_data_object(self, filepath): - # FIXME: # this still assumes bulk modulus is specific to the XYZ format. diff --git a/setup.py b/setup.py index ff0727869..e1580a0c4 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ # Note: setup() has access to cmd arguments of the setup.py script via sys.argv + # Utility function to read the README file. def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() diff --git a/tests/deterministic_graph_data.py b/tests/deterministic_graph_data.py index a383cc75d..0a8d95a38 100755 --- a/tests/deterministic_graph_data.py +++ b/tests/deterministic_graph_data.py @@ -130,8 +130,8 @@ def create_configuration( knn.fit(positions, node_feature) node_output_x = torch.Tensor(knn.predict(positions)) - node_output_x_square = node_output_x ** 2 + node_feature - node_output_x_cube = node_output_x ** 3 + node_output_x_square = node_output_x**2 + node_feature + node_output_x_cube = node_output_x**3 updated_table = torch.cat( ( diff --git a/tests/test_config.py b/tests/test_config.py index bac3caf7c..ff51a9b6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,7 +16,6 @@ @pytest.mark.parametrize("config_file", ["lsms/lsms.json"]) @pytest.mark.mpi_skip() def pytest_config(config_file): - config_file = os.path.join("examples", config_file) with open(config_file, "r") as f: config = json.load(f) diff --git a/tests/test_enthalpy.py b/tests/test_enthalpy.py index 4fd7ac04c..63dc9f8a2 100644 --- a/tests/test_enthalpy.py +++ b/tests/test_enthalpy.py @@ -19,7 +19,6 @@ def unittest_formation_enthalpy(): - dir = "dataset/unit_test_enthalpy" if not os.path.exists(dir): os.makedirs(dir) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 6adc7e6b4..0f8243e47 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -182,7 +182,7 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models -@pytest.mark.parametrize("model_type", ["PNA", "CGCNN", "SchNet"]) +@pytest.mark.parametrize("model_type", ["PNA", "CGCNN", "SchNet", "DimeNet"]) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) diff --git a/utils/lsms/compositional_histogram_cutoff.py b/utils/lsms/compositional_histogram_cutoff.py index 9bae71f7a..749422a54 100644 --- a/utils/lsms/compositional_histogram_cutoff.py +++ b/utils/lsms/compositional_histogram_cutoff.py @@ -41,7 +41,6 @@ def compositional_histogram_cutoff( comp_final = [] comp_all = np.zeros([num_bins]) for filename in tqdm(os.listdir(dir)): - path = os.path.join(dir, filename) # This is LSMS specific - it assumes only one header line and only atoms following. atoms = np.loadtxt(path, skiprows=1) diff --git a/utils/lsms/convert_total_energy_to_formation_gibbs.py b/utils/lsms/convert_total_energy_to_formation_gibbs.py index 620eeafbf..b2e6617f6 100644 --- a/utils/lsms/convert_total_energy_to_formation_gibbs.py +++ b/utils/lsms/convert_total_energy_to_formation_gibbs.py @@ -51,7 +51,6 @@ def convert_raw_data_energy_to_gibbs( # Search for the configurations with pure elements and store their total energy all_files = os.listdir(dir) for filename in tqdm(all_files): - path = os.path.join(dir, filename) total_energy, txt = read_file(path) atoms = np.loadtxt(txt[1:]) @@ -73,7 +72,6 @@ def convert_raw_data_energy_to_gibbs( # compute thermodynamic entropy # compute formation gibbs energy using formation enthalpy and thermodynamic entropy for fn, filename in enumerate(tqdm(all_files)): - path = os.path.join(dir, filename) total_energy_txt, txt = read_file(path) atoms = np.loadtxt(txt[1:]) @@ -143,7 +141,6 @@ def convert_raw_data_energy_to_gibbs( def compute_formation_enthalpy( path, elements_list, pure_elements_energy, total_energy, atoms ): - # FIXME: this currently works only for binary alloys elements, counts = np.unique(atoms[:, 0], return_counts=True) From afeb20fc030479b12caef02da76a9f3827f3fdfd Mon Sep 17 00:00:00 2001 From: JustinBakerMath Date: Tue, 23 May 2023 20:40:20 +0000 Subject: [PATCH 3/3] emb output temp --- hydragnn/models/DIMEStack.py | 178 ++++++++------------------------- hydragnn/models/create.py | 7 +- hydragnn/utils/config_utils.py | 14 +++ tests/test_graphs.py | 3 +- 4 files changed, 64 insertions(+), 138 deletions(-) diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py index 5ba5e5191..88d177a99 100644 --- a/hydragnn/models/DIMEStack.py +++ b/hydragnn/models/DIMEStack.py @@ -19,23 +19,20 @@ include radial basis information. """ -from math import sqrt - -from typing import Callable, Optional, Tuple +from typing import Callable, Tuple from torch_geometric.typing import SparseTensor import torch from torch import Tensor -from torch.nn import Embedding, Linear, SiLU +from torch.nn import SiLU -from torch_geometric.nn.inits import glorot_orthogonal +from torch_geometric.nn import Linear, Sequential from torch_geometric.nn.models.dimenet import ( BesselBasisLayer, + InteractionBlock, SphericalBasisLayer, - ResidualLayer, + OutputBlock, ) -from torch_geometric.utils import scatter - from .Base import Base @@ -57,143 +54,64 @@ def __init__( *args, **kwargs ): - super().__init__(*args, **kwargs) - + self.num_bilinear = num_bilinear + self.num_radial = num_radial + self.num_spherical = num_spherical + self.num_before_skip = num_before_skip + self.num_after_skip = num_after_skip self.radius = radius + super().__init__(*args, **kwargs) + self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent) self.sbf = SphericalBasisLayer( num_spherical, num_radial, radius, envelope_exponent ) - self.interact = Interaction( + + pass + + + def get_conv(self, input_dim, output_dim): + emb = EmbeddingBlock(self.num_radial, input_dim, act=SiLU()) + inter = InteractionBlock( hidden_channels=self.hidden_dim, - num_bilinear=num_bilinear, - num_spherical=num_spherical, - num_radial=num_radial, - num_before_skip=num_before_skip, - num_after_skip=num_after_skip, - ) + num_bilinear=self.num_bilinear, + num_spherical=self.num_spherical, + num_radial=self.num_radial, + num_before_skip=self.num_before_skip, + num_after_skip=self.num_after_skip, + act=SiLU(), + ) + dec = OutputBlock(self.num_radial, self.hidden_dim, output_dim, 1, SiLU()) + return Sequential('x, rbf, sbf, i, j, idx_kj, idx_ji', [ + (emb, 'x, rbf, i, j -> x1'), + (inter,'x1, rbf, sbf, idx_kj, idx_ji -> x2'), + (dec,'x2, rbf, i -> c'), + ]) def _conv_args(self, data): - conv_args = {"edge_index": data.edge_index} assert ( data.pos is not None ), "DimeNet requires node positions (data.pos) to be set." - conv_args.update({"pos": data.pos}) - return conv_args - - def forward(self, z, pos, edge_index): - z = z.to(torch.long) - # edge_index = radius_graph(pos, r=self.radius, batch=batch, - # max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( - edge_index, num_nodes=z.size(0) + data.edge_index, num_nodes=data.x.size(0) ) - dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() + dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. - pos_i = pos[idx_i] - pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i + pos_i = data.pos[idx_i] + pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) - z = self.interact(z, rbf, sbf, idx_kj, idx_ji) - # z = z + output_block(x, rbf, i, num_nodes=pos.size(0)) - - return z - - -class Interaction(torch.nn.Module): - def __init__( - self, - hidden_channels: int, - num_bilinear: int, - num_spherical: int, - num_radial: int, - num_before_skip: int, - num_after_skip: int, - ): - super().__init__() - self.act = SiLU() - self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) - self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, bias=False) + conv_args = {"rbf":rbf, "sbf":sbf, "i": i, "j":j, "idx_kj":idx_kj, "idx_ji":idx_ji} - # Dense transformations of input messages. - self.lin_from = Linear(hidden_channels, hidden_channels) - self.lin_to = Linear(hidden_channels, hidden_channels) - - self.W = torch.nn.Parameter( - torch.Tensor(hidden_channels, num_bilinear, hidden_channels) - ) - - self.layers_before_skip = torch.nn.ModuleList( - [ResidualLayer(hidden_channels, SiLU()) for _ in range(num_before_skip)] - ) - self.lin = Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList( - [ResidualLayer(hidden_channels, SiLU()) for _ in range(num_after_skip)] - ) - - self.reset_parameters() - - def reset_parameters(self): - glorot_orthogonal(self.lin_rbf.weight, scale=2.0) - glorot_orthogonal(self.lin_sbf.weight, scale=2.0) - glorot_orthogonal(self.lin_from.weight, scale=2.0) - self.lin_from.bias.data.fill_(0) - glorot_orthogonal(self.lin_to.weight, scale=2.0) - self.lin_to.bias.data.fill_(0) - self.W.data.normal_(mean=0, std=2 / self.W.size(0)) - for res_layer in self.layers_before_skip: - res_layer.reset_parameters() - glorot_orthogonal(self.lin.weight, scale=2.0) - self.lin.bias.data.fill_(0) - for res_layer in self.layers_after_skip: - res_layer.reset_parameters() - - def forward( - self, - x: Tensor, - radial_basis: Tensor, - spherical_basis: Tensor, - edge_index_from: Tensor, - edge_index_to: Tensor, - ) -> Tensor: - radial_basis = self.lin_rbf(radial_basis) - spherical_basis = self.lin_sbf(spherical_basis) - - x_kj = self.act(self.lin_from(x)) - x_kj = x_kj * radial_basis - x_kj = torch.einsum( - "wj,wl,ijl->wi", spherical_basis, x_kj[edge_index_from], self.W - ) - x_kj = scatter( - x_kj, edge_index_to, dim=0, dim_size=x.size(0), reduce="sum" - ) # message passing - - x_ji = self.act(self.lin_to(x)) - h = ( - x_ji + x_kj - ) # aggregates my learned message and my from messages to the next neighbor - - for ( - layer - ) in ( - self.layers_before_skip - ): # this added resnet is not actually doing any message passing and is an interesting addition - h = layer(h) - h = ( - self.act(self.lin(h)) + x - ) # incorporates a residual connection to the input feature - for layer in self.layers_after_skip: - h = layer(h) - - return h + return conv_args """ @@ -232,35 +150,23 @@ def triplets( return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji -""" -EmbeddingBlock ---------------- -An embedding block that utilizes the -radial basis function and the to/from -information in the embedding by -concatentating the to/from nodes with -the radial basis functions. - -""" - - class EmbeddingBlock(torch.nn.Module): def __init__(self, num_radial: int, hidden_channels: int, act: Callable): super().__init__() self.act = act - self.emb = Embedding(95, hidden_channels) + # self.emb = Embedding(95, hidden_channels) # Atomic Embeddings are handles by Hydra self.lin_rbf = Linear(num_radial, hidden_channels) self.lin = Linear(3 * hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): - self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) + # self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) self.lin_rbf.reset_parameters() self.lin.reset_parameters() def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor: - x = self.emb(x) + # x = self.emb(x) rbf = self.act(self.lin_rbf(rbf)) - return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) + return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) \ No newline at end of file diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 84fe3df92..3d4f966bc 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -48,6 +48,12 @@ def create_model_config( config["Architecture"]["max_neighbours"], config["Architecture"]["edge_dim"], config["Architecture"]["pna_deg"], + config["Architecture"]["num_before_skip"], + config["Architecture"]["num_after_skip"], + config["Architecture"]["num_bilinear"], + config["Architecture"]["num_radial"], + config["Architecture"]["num_spherical"], + config["Architecture"]["envelope_exponent"], config["Architecture"]["num_gaussians"], config["Architecture"]["num_filters"], config["Architecture"]["radius"], @@ -237,7 +243,6 @@ def create_model( output_type, output_heads, loss_function_type, - max_neighbours=max_neighbours, loss_weights=task_weights, freeze_conv=freeze_conv, initial_bias=initial_bias, diff --git a/hydragnn/utils/config_utils.py b/hydragnn/utils/config_utils.py index bafbf9a8e..5de7edbc3 100644 --- a/hydragnn/utils/config_utils.py +++ b/hydragnn/utils/config_utils.py @@ -59,10 +59,24 @@ def update_config(config, train_loader, val_loader, test_loader): if "radius" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["radius"] = None + # SchNet if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None if "num_filters" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_filters"] = None + # DimeNet + if "num_before_skip" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["num_before_skip"] = None + if "num_after_skip" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["num_after_skip"] = None + if "num_bilinear" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["num_bilinear"] = None + if "num_radial" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["num_radial"] = None + if "num_spherical" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["num_spherical"] = None + if "envelope_exponent" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["envelope_exponent"] = None config["NeuralNetwork"]["Architecture"] = update_config_edge_dim( config["NeuralNetwork"]["Architecture"] diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 0f8243e47..78f17c5d1 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -174,7 +174,8 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False # Test across all models with both single/multihead @pytest.mark.parametrize( - "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet", "DimeNet"] + # "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet", "DimeNet"] + "model_type", ["DimeNet"] ) @pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"]) def pytest_train_model(model_type, ci_input, overwrite_data=False):