Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ There are many options for HydraGNN; the dataset and model type are particularly
important:
- `["Verbosity"]["level"]`: `0`, `1`, `2`, `3`, `4`
- `["Dataset"]["name"]`: `CuAu_32atoms`, `FePt_32atoms`, `FeSi_1024atoms`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`, `DimeNet`

### Citations
"HydraGNN: Distributed PyTorch implementation of multi-headed graph convolutional neural networks", Copyright ID#: 81929619
Expand Down
6 changes: 5 additions & 1 deletion examples/csce/train_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ def __getitem__(self, idx):
% (len(trainset), len(valset), len(testset))
)

(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"]
)

Expand Down
6 changes: 5 additions & 1 deletion examples/eam/eam.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def info(*args, logtype="info", sep=" "):
% (len(trainset), len(valset), len(testset))
)

(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"]
)
timer.stop()
Expand Down
12 changes: 4 additions & 8 deletions examples/ising_model/create_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


def write_to_file(total_energy, atomic_features, count_config, dir):

numpy_string_total_value = np.array2string(total_energy)

filetxt = numpy_string_total_value
Expand Down Expand Up @@ -40,7 +39,7 @@ def E_dimensionless(config, L, spin_function, scale_spin):
spin[x, y, z] = spin_function(config[x, y, z])

count_pos = 0
number_nodes = L ** 3
number_nodes = L**3
positions = np.zeros((number_nodes, 3))
atomic_features = np.zeros((number_nodes, 5))
for x in range(L):
Expand Down Expand Up @@ -76,18 +75,16 @@ def E_dimensionless(config, L, spin_function, scale_spin):
def create_dataset(
L, histogram_cutoff, dir, spin_function=lambda x: x, scale_spin=False
):

count_config = 0

for num_downs in tqdm(range(0, L ** 3)):

primal_configuration = np.ones((L ** 3,))
for num_downs in tqdm(range(0, L**3)):
primal_configuration = np.ones((L**3,))
for down in range(0, num_downs):
primal_configuration[down] = -1.0

# If the current composition has a total number of possible configurations above
# the hard cutoff threshold, a random configurational subset is picked
if scipy.special.binom(L ** 3, num_downs) > histogram_cutoff:
if scipy.special.binom(L**3, num_downs) > histogram_cutoff:
for num_config in range(0, histogram_cutoff):
config = np.random.permutation(primal_configuration)
config = np.reshape(config, (L, L, L))
Expand Down Expand Up @@ -115,7 +112,6 @@ def create_dataset(


if __name__ == "__main__":

dir = os.path.join(os.path.dirname(__file__), "../../dataset/ising_model")
if os.path.exists(dir):
shutil.rmtree(dir)
Expand Down
13 changes: 8 additions & 5 deletions examples/ising_model/train_ising.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@


def write_to_file(total_energy, atomic_features, count_config, dir, prefix):

numpy_string_total_value = np.array2string(total_energy)

filetxt = numpy_string_total_value
Expand All @@ -67,21 +66,21 @@ def create_dataset_mpi(
comm_size = comm.Get_size()

count_config = 0
rx = list(nsplit(range(0, L ** 3), comm_size))[rank]
rx = list(nsplit(range(0, L**3), comm_size))[rank]
info("rx", rx.start, rx.stop)

for num_downs in iterate_tqdm(
range(rx.start, rx.stop), verbosity_level=2, desc="Creating dataset"
):
prefix = "output_%d_" % num_downs

primal_configuration = np.ones((L ** 3,))
primal_configuration = np.ones((L**3,))
for down in range(0, num_downs):
primal_configuration[down] = -1.0

# If the current composition has a total number of possible configurations above
# the hard cutoff threshold, a random configurational subset is picked
if scipy.special.binom(L ** 3, num_downs) > histogram_cutoff:
if scipy.special.binom(L**3, num_downs) > histogram_cutoff:
for num_config in range(0, histogram_cutoff):
config = np.random.permutation(primal_configuration)
config = np.reshape(config, (L, L, L))
Expand Down Expand Up @@ -288,7 +287,11 @@ def info(*args, logtype="info", sep=" "):
% (len(trainset), len(valset), len(testset))
)

(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"]
)
timer.stop()
Expand Down
6 changes: 5 additions & 1 deletion examples/lsms/lsms.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def info(*args, logtype="info", sep=" "):
% (len(trainset), len(valset), len(testset))
)

(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"]
)
timer.stop()
Expand Down
7 changes: 6 additions & 1 deletion examples/md17/md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import hydragnn


# Update each sample prior to loading.
def md17_pre_transform(data):
# Set descriptor as element type.
Expand Down Expand Up @@ -68,7 +69,11 @@ def md17_pre_filter(data):
train, val, test = hydragnn.preprocess.split_dataset(
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False
)
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"]
)

Expand Down
6 changes: 5 additions & 1 deletion examples/ogb/train_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,11 @@ def __getitem__(self, idx):
% (len(trainset), len(valset), len(testset))
)

(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"]
)

Expand Down
7 changes: 6 additions & 1 deletion examples/qm9/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import hydragnn


# Update each sample prior to loading.
def qm9_pre_transform(data):
# Set descriptor as element type.
Expand Down Expand Up @@ -59,7 +60,11 @@ def qm9_pre_filter(data):
train, val, test = hydragnn.preprocess.split_dataset(
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False
)
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"]
)

Expand Down
172 changes: 172 additions & 0 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
DimeNet
========
Directional message passing neural network
for molecular graphs. The convolutional
layer uses spherical and radial basis
functions to perform message passing.

In particular this message passing layer
relies on the angle formed by the triplet
of incomming and outgoing messages.

The three key components of this network are
outlined below. In particular, the convolutional
network that is used for the message passing
the triplet function that generates to/from
information for angular values, and finally
the radial basis embedding that is used to
include radial basis information.

"""
from typing import Callable, Tuple
from torch_geometric.typing import SparseTensor

import torch
from torch import Tensor
from torch.nn import SiLU

from torch_geometric.nn import Linear, Sequential
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
InteractionBlock,
SphericalBasisLayer,
OutputBlock,
)
from .Base import Base


class DIMEStack(Base):
"""
Generates angles, distances, to/from indices, radial basis
functions and spherical basis functions for learning.
"""

def __init__(
self,
num_bilinear,
num_radial,
num_spherical,
radius,
envelope_exponent,
num_before_skip,
num_after_skip,
*args,
**kwargs
):
self.num_bilinear = num_bilinear
self.num_radial = num_radial
self.num_spherical = num_spherical
self.num_before_skip = num_before_skip
self.num_after_skip = num_after_skip
self.radius = radius

super().__init__(*args, **kwargs)

self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent)
self.sbf = SphericalBasisLayer(
num_spherical, num_radial, radius, envelope_exponent
)


pass


def get_conv(self, input_dim, output_dim):
emb = EmbeddingBlock(self.num_radial, input_dim, act=SiLU())
inter = InteractionBlock(
hidden_channels=self.hidden_dim,
num_bilinear=self.num_bilinear,
num_spherical=self.num_spherical,
num_radial=self.num_radial,
num_before_skip=self.num_before_skip,
num_after_skip=self.num_after_skip,
act=SiLU(),
)
dec = OutputBlock(self.num_radial, self.hidden_dim, output_dim, 1, SiLU())
return Sequential('x, rbf, sbf, i, j, idx_kj, idx_ji', [
(emb, 'x, rbf, i, j -> x1'),
(inter,'x1, rbf, sbf, idx_kj, idx_ji -> x2'),
(dec,'x2, rbf, i -> c'),
])

def _conv_args(self, data):
assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
data.edge_index, num_nodes=data.x.size(0)
)
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = data.pos[idx_i]
pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)

conv_args = {"rbf":rbf, "sbf":sbf, "i": i, "j":j, "idx_kj":idx_kj, "idx_ji":idx_ji}

return conv_args


"""
Triplets
---------
Generates to/from edge_indices for
angle generating purposes.

"""


def triplets(
edge_index: Tensor,
num_nodes: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
row, col = edge_index # j->i

value = torch.arange(row.size(0), device=row.device)
adj_t = SparseTensor(
row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes)
)
adj_t_row = adj_t[row]
num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

# Node indices (k->j->i) for triplets.
idx_i = col.repeat_interleave(num_triplets)
idx_j = row.repeat_interleave(num_triplets)
idx_k = adj_t_row.storage.col()
mask = idx_i != idx_k # Remove i == k triplets.
idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

# Edge indices (k-j, j->i) for triplets.
idx_kj = adj_t_row.storage.value()[mask]
idx_ji = adj_t_row.storage.row()[mask]

return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


class EmbeddingBlock(torch.nn.Module):
def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
super().__init__()
self.act = act

# self.emb = Embedding(95, hidden_channels) # Atomic Embeddings are handles by Hydra
self.lin_rbf = Linear(num_radial, hidden_channels)
self.lin = Linear(3 * hidden_channels, hidden_channels)

self.reset_parameters()

def reset_parameters(self):
# self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
self.lin_rbf.reset_parameters()
self.lin.reset_parameters()

def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
# x = self.emb(x)
rbf = self.act(self.lin_rbf(rbf))
return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))
1 change: 0 additions & 1 deletion hydragnn/models/PNAStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
*args,
**kwargs,
):

self.aggregators = ["mean", "min", "max", "std"]
self.scalers = [
"identity",
Expand Down
Loading