From 18232585bf12bae2442d87ed6be17f36b5995385 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 11 Feb 2026 14:03:24 -0500 Subject: [PATCH 1/4] updated gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ed23fafd..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .cache/ docs/reference/* ./examples/MP/experiments -./examples/QM9s +./examples/QM9 *doctrees* /site From 5f89604f33950a50dc0c982b008ddeb281fa2fa0 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 11 Feb 2026 14:18:39 -0500 Subject: [PATCH 2/4] added lattice-aware modules --- .../configs/MP/config_resnet_lcn.yaml | 48 ++++ src/electrai/dataloader/dataset.py | 6 +- src/electrai/dataloader/utils.py | 12 +- src/electrai/lightning.py | 11 +- src/electrai/model/LCN.py | 230 ++++++++++++++++++ src/electrai/model/resnet_LCN.py | 172 +++++++++++++ src/electrai/model/utils.py | 141 +++++++++++ 7 files changed, 609 insertions(+), 11 deletions(-) create mode 100644 src/electrai/configs/MP/config_resnet_lcn.yaml create mode 100644 src/electrai/model/LCN.py create mode 100644 src/electrai/model/resnet_LCN.py create mode 100644 src/electrai/model/utils.py diff --git a/src/electrai/configs/MP/config_resnet_lcn.yaml b/src/electrai/configs/MP/config_resnet_lcn.yaml new file mode 100644 index 00000000..d1b94cae --- /dev/null +++ b/src/electrai/configs/MP/config_resnet_lcn.yaml @@ -0,0 +1,48 @@ +# Dataset / loader parameters +data: + _target_: src.electrai.dataloader.dataset.RhoRead + root: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/mp_filelist.txt + split_file: null #/scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/split.json + precision: f32 + batch_size: 1 + train_workers: 8 + val_workers: 2 + pin_memory: false + val_frac: 0.005 + drop_last: false + augmentation: false + random_seed: 42 + # downsample_label: 0 + # downsample_data: 0 + +# Model +model: + _target_: src.electrai.model.resnet_LCN.GeneratorResNet + n_residual_blocks: 32 + n_channels: 32 + kernel_size1: 5 + kernel_size2: 5 + normalize: True + use_checkpoint: False + use_lattice_conv: true + use_radial_embedding: true + num_gaussians: 500 + use_positional_embedding: true + pos_embed_dim: 500 + +# Training parameters +precision: 32 +epochs: 50 +lr: 0.01 +weight_decay: 0.0 +warmup_length: 1 +beta1: 0.9 +beta2: 0.99 + +# Weights and biases +wandb_mode: online +entity: PrinceOA +wb_pname: mp-experiment + +# checkpoints +ckpt_path: ./checkpoints diff --git a/src/electrai/dataloader/dataset.py b/src/electrai/dataloader/dataset.py index b3321143..78cf72df 100644 --- a/src/electrai/dataloader/dataset.py +++ b/src/electrai/dataloader/dataset.py @@ -107,7 +107,7 @@ def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs): else: raise ValueError("No filename found.") - self.category = Path(datapath).name.split("_")[0] # example: mp_filelist.txt + self.category = Path(datapath).name.split("_")[0] self.root = Path(datapath).parent self.member_list = member_list @@ -116,7 +116,7 @@ def __len__(self): def __getitem__(self, index): index = self.member_list[index] - data, label = utils.load_numpy_rho( + data, label, lattice = utils.load_numpy_rho( root=self.root, category=self.category, index=index, @@ -125,4 +125,4 @@ def __getitem__(self, index): ) data = data.unsqueeze(0) label = label.unsqueeze(0) - return {"data": data, "label": label, "index": index} + return {"data": data, "label": label, "index": index, "lattice": lattice} diff --git a/src/electrai/dataloader/utils.py b/src/electrai/dataloader/utils.py index 63f775ca..9db22cfc 100644 --- a/src/electrai/dataloader/utils.py +++ b/src/electrai/dataloader/utils.py @@ -25,22 +25,28 @@ def load_numpy_rho( """ root = Path(root) if category == "mp": - data, label = load_chgcar(root, index) + data, label, lattice = load_chgcar(root, index) elif category == "qm9": data, label = load_npy(root, index) data = torch.tensor(data, dtype=dtype_map[precision]) label = torch.tensor(label, dtype=dtype_map[precision]) + lattice = torch.tensor(lattice, dtype=dtype_map[precision]) + grid_shape = torch.tensor( + data.shape, dtype=dtype_map[precision], device=lattice.device + ) + lattice = lattice / grid_shape[:, None] if augmentation: data, label = rand_rotate([data, label]) - return data, label + return data, label, lattice def load_chgcar(root: str | bytes | os.PathLike, index: str): data = Chgcar.from_file(root / "data" / f"{index}.CHGCAR") label = Chgcar.from_file(root / "label" / f"{index}.CHGCAR") + lattice = data.structure.lattice.matrix data = data.data["total"] / data.structure.lattice.volume label = label.data["total"] / label.structure.lattice.volume - return data, label + return data, label, lattice def load_npy(root: str | bytes | os.PathLike, index: str): diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index d216b723..c9a86966 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -14,8 +14,8 @@ def __init__(self, cfg): self.model = instantiate(cfg.model) self.loss_fn = NormMAE() - def forward(self, x): - return self.model(x) + def forward(self, x, lattice_vectors=None): + return self.model(x, lattice_vectors) def training_step(self, batch): loss = self._loss_calculation(batch) @@ -39,15 +39,16 @@ def validation_step(self, batch): def _loss_calculation(self, batch): x = batch["data"] y = batch["label"] + A = batch["lattice"] if isinstance(x, list): losses = [] - for x_i, y_i in zip(x, y, strict=True): - pred = self(x_i.unsqueeze(0)) + for x_i, y_i, A_i in zip(x, y, strict=True): + pred = self(x_i.unsqueeze(0), A_i.unsqueeze(0)) loss = self.loss_fn(pred, y_i.unsqueeze(0)) losses.append(loss) loss = torch.stack(losses).mean() else: - pred = self(x) + pred = self(x, A) loss = self.loss_fn(pred, y) return loss diff --git a/src/electrai/model/LCN.py b/src/electrai/model/LCN.py new file mode 100644 index 00000000..bc6c8571 --- /dev/null +++ b/src/electrai/model/LCN.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from src.electrai.model.utils import ( + FourierPositionalEmbedding, + GaussianRadialBasis, + PositionalEmbedding, +) + + +class LatticeConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=False, + use_radial_embedding=False, + use_positional_embedding=False, + num_gaussians=16, + pos_embed_dim=16, + pos_embed_type="learnable", + r_max=5.0, + hidden_dim=64, + ): + super().__init__() + padding = kernel_size // 2 # - 1 + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 + ) + self.padding = padding if isinstance(padding, tuple) else (padding,) * 3 + self.padding_mode = padding_mode + self.stride = stride + self.dilation = dilation + self.use_lattice_conv = use_lattice_conv + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + padding=padding, + padding_mode=padding_mode, + stride=stride, + dilation=dilation, + bias=True, + ) + w = self.conv.weight.detach().clone() + del self.conv._parameters["weight"] + self.conv.register_buffer("weight", w) + self.use_radial_embedding = use_radial_embedding + self.use_positional_embedding = use_positional_embedding + + if use_lattice_conv: + if use_radial_embedding: + self.gaussian_smear = GaussianRadialBasis( + num_gaussians=num_gaussians, r_min=0.0, r_max=r_max, trainable=True + ) + + if use_positional_embedding: + if pos_embed_type == "learnable": + self.pos_embedding = PositionalEmbedding( + embed_dim=pos_embed_dim, max_kernel_size=max(self.kernel_size) + ) + elif pos_embed_type == "fourier": + self.pos_embedding = FourierPositionalEmbedding( + embed_dim=pos_embed_dim, max_freq=10 + ) + else: + raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}") + self.pos_embed_type = pos_embed_type + + if use_radial_embedding or use_positional_embedding: + input_size = 0 + if use_radial_embedding: + input_size += num_gaussians + if use_positional_embedding: + input_size += pos_embed_dim + self.filter_network = nn.Sequential( + nn.Linear(input_size, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim, in_channels * out_channels), + ) + + # if self.use_radial_embedding or self.use_positional_embedding: + # self.mix_weight = nn.Parameter(torch.tensor(0.1)) + + def _apply_padding(self, x): + if self.padding_mode == "circular" and any(p > 0 for p in self.padding): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode="circular") + elif self.padding_mode in ["zeros", "reflect", "replicate"] and any( + p > 0 for p in self.padding + ): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode=self.padding_mode) + else: + return x + + def compute_geometric_kernel(self, lattice_vectors): + if lattice_vectors.dim() == 2: + lattice_vectors = lattice_vectors.unsqueeze(0) + squeeze_batch = True + else: + squeeze_batch = False + + B = lattice_vectors.shape[0] + kz, ky, kx = self.kernel_size + device = lattice_vectors.device + + z = torch.arange(kz, device=device) - kz // 2 + y = torch.arange(ky, device=device) - ky // 2 + x = torch.arange(kx, device=device) - kx // 2 + + grid_z, grid_y, grid_x = torch.meshgrid(z, y, x, indexing="ij") + frac_coords = torch.stack([grid_z, grid_y, grid_x], dim=-1).float() + + cart_coords = torch.einsum("ijkl,bml->bijkm", frac_coords, lattice_vectors) + distances = torch.norm(cart_coords, dim=-1) + + if self.use_radial_embedding: + radial_features = self.gaussian_smear(distances) + radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") + if self.use_positional_embedding: + if self.pos_embed_type == "learnable": + pos_features = self.pos_embedding(frac_coords, self.kernel_size) + pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + else: + pos_features = self.pos_embedding(frac_coords) + pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + pos_flat = rearrange(pos_features, "b kz ky kx n -> b (kz ky kx) n") + + if self.use_radial_embedding and self.use_positional_embedding: + features = torch.cat([radial_flat, pos_flat], dim=-1) + elif self.use_radial_embedding: + features = torch.cat([radial_flat], dim=-1) + else: + features = torch.cat([pos_flat], dim=-1) + + kernel_flat = self.filter_network(features) + kernel = rearrange( + kernel_flat, + "b (kz ky kx) (o i) -> b o i kz ky kx", + o=self.out_channels, + i=self.in_channels, + kz=kz, + ky=ky, + kx=kx, + ) + + if squeeze_batch: + kernel = kernel.squeeze(0) + + return kernel + + def forward(self, x, lattice_vectors=None): + if not self.use_lattice_conv or ( + not self.use_radial_embedding and not self.use_positional_embedding + ): + return self.conv(x) + + B = x.shape[0] + if lattice_vectors.dim() == 2: + x_padded = self._apply_padding(x) + geometric_kernel = self.compute_geometric_kernel(lattice_vectors) + alpha = 1 # self.mix_weight + kernel = alpha * geometric_kernel + + return F.conv3d( + x_padded, + kernel, + self.conv.bias, + stride=self.stride, + padding=0, + dilation=self.dilation, + ) + + else: + geometric_kernels = self.compute_geometric_kernel(lattice_vectors) + alpha = 1 # 0.1 # self.mix_weight + # self.register_buffer("base_weight", w) # saved + moved with .to(device), not trained + + # base_kernel = self.conv.weight.unsqueeze(0) + kernels = alpha * geometric_kernels # + (1 - alpha) * base_kernel + x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) + x_grouped = self._apply_padding(x_grouped) + kernels_grouped = kernels.reshape( + B * self.out_channels, self.in_channels, *self.kernel_size + ) + out = F.conv3d( + x_grouped, + kernels_grouped, + bias=None, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=B, + ) + + out = out.reshape(B, self.out_channels, *out.shape[2:]) + + if self.conv.bias is not None: + out = out + self.conv.bias.view(1, -1, 1, 1, 1) + return out diff --git a/src/electrai/model/resnet_LCN.py b/src/electrai/model/resnet_LCN.py new file mode 100644 index 00000000..f8d15477 --- /dev/null +++ b/src/electrai/model/resnet_LCN.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +from src.electrai.model.LCN import LatticeConv3d +from torch.utils.checkpoint import checkpoint + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_features, + K=3, + use_checkpoint=True, + use_lattice_conv=False, + use_radial_embedding=False, + use_positional_embedding=False, + num_gaussians=16, + pos_embed_dim=16, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + + self.conv1 = LatticeConv3d( + in_features, + in_features, + kernel_size=K, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + ) + self.norm1 = nn.InstanceNorm3d(in_features) + self.act1 = nn.PReLU() + self.conv2 = LatticeConv3d( + in_features, + in_features, + kernel_size=K, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + ) + self.norm2 = nn.InstanceNorm3d(in_features) + + def forward(self, x, lattice_vectors=None): + if self.use_checkpoint and self.training: + return x + checkpoint( + self._forward, x, lattice_vectors, use_reentrant=False + ) + else: + return x + self._forward(x, lattice_vectors) + + def _forward(self, x, lattice_vectors): + out = self.conv1(x, lattice_vectors) + out = self.norm1(out) + out = self.act1(out) + out = self.conv2(out, lattice_vectors) + return self.norm2(out) + + +class GeneratorResNet(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + n_residual_blocks=16, + n_channels=64, + kernel_size1=5, + kernel_size2=3, + normalize=True, + use_checkpoint=True, + use_lattice_conv=False, + use_radial_embedding=False, + use_positional_embedding=False, + num_gaussians=16, + pos_embed_dim=16, + ): + super().__init__() + self.normalize = normalize + self.use_checkpoint = use_checkpoint + self.use_lattice_conv = use_lattice_conv + + # First layer + self.conv1 = LatticeConv3d( + in_channels, + n_channels, + kernel_size=kernel_size1, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + ) + self.act1 = nn.PReLU() + + # Residual blocks + res_blocks = [ + ResidualBlock( + n_channels, + K=kernel_size2, + use_checkpoint=use_checkpoint, + use_lattice_conv=use_lattice_conv, + ) + for _ in range(n_residual_blocks) + ] + self.res_blocks = nn.ModuleList(res_blocks) + + # Second conv layer post residual blocks + self.conv2 = LatticeConv3d( + n_channels, + n_channels, + kernel_size=kernel_size2, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + ) + self.norm = nn.InstanceNorm3d(n_channels) + + # Final output layer + self.conv3 = LatticeConv3d( + n_channels, + out_channels, + kernel_size=kernel_size1, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + ) + self.act2 = nn.ReLU() + + def forward(self, x, lattice_vectors): + if isinstance(x, torch.Tensor): + return self._forward(x, lattice_vectors) + return [self._forward(xi.unsqueeze(0), lattice_vectors).squeeze(0) for xi in x] + + def _forward(self, x, lattice_vectors=None): + out1 = self.conv1(x, lattice_vectors) + out1 = self.act1(out1) + out = out1 + for block in self.res_blocks: + out = block(out, lattice_vectors) + out2 = self.conv2(out, lattice_vectors) + out2 = self.norm(out2) + out = torch.add(out1, out2) + out = self.conv3(out, lattice_vectors) + out = self.act2(out) + + if self.normalize: + out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] + out = out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + return out diff --git a/src/electrai/model/utils.py b/src/electrai/model/utils.py new file mode 100644 index 00000000..60e1f922 --- /dev/null +++ b/src/electrai/model/utils.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + + +class GaussianRadialBasis(nn.Module): + def __init__(self, num_gaussians=50, r_min=0.0, r_max=5.0, trainable=False): + super().__init__() + self.num_gaussians = num_gaussians + self.r_min = r_min + self.r_max = r_max + + # Initialize Gaussian centers uniformly between r_min and r_max + centers = torch.linspace(r_min, r_max, num_gaussians) + + # Initialize widths based on spacing + # Common choice: width = distance between adjacent centers + spacing = (r_max - r_min) / (num_gaussians - 1) if num_gaussians > 1 else 1.0 + widths = torch.ones(num_gaussians) * spacing + + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, distances): + """ + Expand distances using Gaussian basis functions. + + Args: + distances: Tensor of shape [...] containing interatomic distances + + Returns: + Tensor of shape [..., num_gaussians] with Gaussian features + """ + # distances: [...] -> [..., 1] + # centers: [num_gaussians] -> [1, ..., 1, num_gaussians] + distances = distances.unsqueeze(-1) # [..., 1] + centers = self.centers.view( + *([1] * (distances.dim() - 1)), -1 + ) # [1, ..., num_gaussians] + widths = self.widths.view(*([1] * (distances.dim() - 1)), -1) + + # Gaussian RBF: exp(-(d - c)^2 / (2 * w^2)) + diff = distances - centers + gamma = 1.0 / (2 * widths**2) + return torch.exp(-gamma * diff**2) + + +class PositionalEmbedding(nn.Module): + """ + Learnable positional embedding for kernel positions. + Similar to positional encodings in Transformers but learnable. + """ + + def __init__(self, embed_dim=32, max_kernel_size=7): + super().__init__() + self.embed_dim = embed_dim + + # Learnable embeddings for each coordinate dimension + # Range: [-max_kernel_size//2, max_kernel_size//2] + self.z_embed = nn.Embedding(max_kernel_size, embed_dim) + self.y_embed = nn.Embedding(max_kernel_size, embed_dim) + self.x_embed = nn.Embedding(max_kernel_size, embed_dim) + + # Optional: learnable way to combine the three directions + self.combine = nn.Linear(3 * embed_dim, embed_dim) + + def forward(self, frac_coords, kernel_size): + """ + Args: + frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) + kernel_size: (kz, ky, kx) tuple + + Returns: + embeddings: [kz, ky, kx, embed_dim] + """ + kz, ky, kx = kernel_size + + # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) + z_idx = (frac_coords[..., 0] + kz // 2).long() + y_idx = (frac_coords[..., 1] + ky // 2).long() + x_idx = (frac_coords[..., 2] + kx // 2).long() + + # Get embeddings for each dimension + z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] + y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] + x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] + + # Combine (can use addition, concatenation + linear, etc.) + combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] + return self.combine(combined) # [kz, ky, kx, embed_dim] + + +class FourierPositionalEmbedding(nn.Module): + """ + Fourier features for positional encoding (non-learnable but more expressive). + Similar to what's used in NeRF. + """ + + def __init__(self, embed_dim=32, max_freq=10): + super().__init__() + self.embed_dim = embed_dim + + # Number of frequency bands + num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords + + # Logarithmically spaced frequencies + freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) + self.register_buffer("freq_bands", freq_bands) + + def forward(self, frac_coords): + """ + Args: + frac_coords: [..., 3] fractional coordinates + + Returns: + embeddings: [..., embed_dim] + """ + # frac_coords: [..., 3] + # freq_bands: [num_freqs] + + coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] + freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] + + # Compute sin and cos for each frequency and coordinate + angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] + + sin_features = torch.sin(angles) + cos_features = torch.cos(angles) + + # Concatenate all features + fourier_features = torch.cat( + [sin_features, cos_features], dim=-2 + ) # [..., 2*num_freqs, 3] + return fourier_features.reshape( + *frac_coords.shape[:-1], -1 + ) # [..., 6*num_freqs] From 13bee6df7eb55258b3a7cdaf07e4401cfdca6be4 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Tue, 17 Feb 2026 09:31:31 -0500 Subject: [PATCH 3/4] updated modules --- src/electrai/lightning.py | 69 ++++++++++++++++++++++++++++++++ src/electrai/model/LCN.py | 27 ++++++++++--- src/electrai/model/resnet_LCN.py | 7 ++++ 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index c9a86966..126f295a 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -27,6 +27,13 @@ def training_step(self, batch): on_epoch=True, sync_dist=False, ) + if hasattr(self.model, "conv1") and hasattr( + self.model.conv1, "last_debug_stats" + ): + stats = self.model.conv1.last_debug_stats + for key, values in stats.items(): + for metric, val in values.items(): + self.log(f"debug/{key}/{metric}", val, on_step=True, on_epoch=False) return loss def validation_step(self, batch): @@ -36,6 +43,68 @@ def validation_step(self, batch): ) return loss + # def _log_gaussian_params(self, prefix="train_"): + # for name, module in self.model.named_modules(): + # if isinstance(module, torch.nn.Module) and hasattr( + # module, "gaussian_smear" + # ): + # gaussian_smear = module.gaussian_smear + + # if hasattr(gaussian_smear, "centers"): + # centers = gaussian_smear.centers + # self.log( + # f"{prefix}gaussian/centers_mean", + # centers.mean(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/centers_std", + # centers.std(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/centers_min", + # centers.min(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/centers_max", + # centers.max(), + # on_step=True, + # on_epoch=True, + # ) + + # if hasattr(gaussian_smear, "widths"): + # widths = gaussian_smear.widths + # self.log( + # f"{prefix}gaussian/widths_mean", + # widths.mean(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/widths_std", + # widths.std(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/widths_min", + # widths.min(), + # on_step=True, + # on_epoch=True, + # ) + # self.log( + # f"{prefix}gaussian/widths_max", + # widths.max(), + # on_step=True, + # on_epoch=True, + # ) + # break + def _loss_calculation(self, batch): x = batch["data"] y = batch["label"] diff --git a/src/electrai/model/LCN.py b/src/electrai/model/LCN.py index bc6c8571..45fc6cad 100644 --- a/src/electrai/model/LCN.py +++ b/src/electrai/model/LCN.py @@ -23,6 +23,7 @@ def __init__( use_lattice_conv=False, use_radial_embedding=False, use_positional_embedding=False, + trainable_gaussian_params=False, num_gaussians=16, pos_embed_dim=16, pos_embed_type="learnable", @@ -41,6 +42,7 @@ def __init__( self.stride = stride self.dilation = dilation self.use_lattice_conv = use_lattice_conv + self.trainable_gaussian_params = trainable_gaussian_params self.conv = nn.Conv3d( in_channels, out_channels, @@ -60,7 +62,10 @@ def __init__( if use_lattice_conv: if use_radial_embedding: self.gaussian_smear = GaussianRadialBasis( - num_gaussians=num_gaussians, r_min=0.0, r_max=r_max, trainable=True + num_gaussians=num_gaussians, + r_min=0.0, + r_max=r_max, + trainable=self.trainable_gaussian_params, ) if use_positional_embedding: @@ -144,6 +149,15 @@ def compute_geometric_kernel(self, lattice_vectors): cart_coords = torch.einsum("ijkl,bml->bijkm", frac_coords, lattice_vectors) distances = torch.norm(cart_coords, dim=-1) + debug_stats = {} + debug_stats["distances"] = { + "min": distances.min().item(), + "max": distances.max().item(), + "mean": distances.mean().item(), + "has_nan": torch.isnan(distances).any().item(), + "has_inf": torch.isinf(distances).any().item(), + } + if self.use_radial_embedding: radial_features = self.gaussian_smear(distances) radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") @@ -177,7 +191,7 @@ def compute_geometric_kernel(self, lattice_vectors): if squeeze_batch: kernel = kernel.squeeze(0) - return kernel + return kernel, debug_stats def forward(self, x, lattice_vectors=None): if not self.use_lattice_conv or ( @@ -188,9 +202,9 @@ def forward(self, x, lattice_vectors=None): B = x.shape[0] if lattice_vectors.dim() == 2: x_padded = self._apply_padding(x) - geometric_kernel = self.compute_geometric_kernel(lattice_vectors) + geometric_kernels = self.compute_geometric_kernel(lattice_vectors) alpha = 1 # self.mix_weight - kernel = alpha * geometric_kernel + kernel = alpha * geometric_kernels return F.conv3d( x_padded, @@ -202,7 +216,10 @@ def forward(self, x, lattice_vectors=None): ) else: - geometric_kernels = self.compute_geometric_kernel(lattice_vectors) + geometric_kernels, debug_stats = self.compute_geometric_kernel( + lattice_vectors + ) + self.last_debug_stats = debug_stats alpha = 1 # 0.1 # self.mix_weight # self.register_buffer("base_weight", w) # saved + moved with .to(device), not trained diff --git a/src/electrai/model/resnet_LCN.py b/src/electrai/model/resnet_LCN.py index f8d15477..be1537b0 100644 --- a/src/electrai/model/resnet_LCN.py +++ b/src/electrai/model/resnet_LCN.py @@ -15,6 +15,7 @@ def __init__( use_lattice_conv=False, use_radial_embedding=False, use_positional_embedding=False, + trainable_gaussian_params=False, num_gaussians=16, pos_embed_dim=16, ): @@ -31,6 +32,7 @@ def __init__( use_lattice_conv=use_lattice_conv, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, ) @@ -46,6 +48,7 @@ def __init__( use_lattice_conv=use_lattice_conv, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, ) @@ -81,6 +84,7 @@ def __init__( use_lattice_conv=False, use_radial_embedding=False, use_positional_embedding=False, + trainable_gaussian_params=False, num_gaussians=16, pos_embed_dim=16, ): @@ -100,6 +104,7 @@ def __init__( use_lattice_conv=use_lattice_conv, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, ) @@ -128,6 +133,7 @@ def __init__( use_lattice_conv=use_lattice_conv, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, ) @@ -144,6 +150,7 @@ def __init__( use_lattice_conv=use_lattice_conv, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, ) From 7318af5436112513ce2a2655e213d4ca09808874 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Mon, 2 Mar 2026 22:06:16 -0500 Subject: [PATCH 4/4] added trial modules --- src/electrai/dataloader/utils.py | 13 +- src/electrai/entrypoints/train.py | 11 +- src/electrai/lightning.py | 115 +++---- src/electrai/model/LCN.py | 249 +++++++------- src/electrai/model/LCN_al.py | 269 +++++++++++++++ src/electrai/model/resnet_LCN.py | 29 +- src/electrai/model/resunet_LCN.py | 543 ++++++++++++++++++++++++++++++ src/electrai/model/utils.py | 269 ++++++++++----- src/electrai/model/utils_al.py | 362 ++++++++++++++++++++ 9 files changed, 1578 insertions(+), 282 deletions(-) create mode 100644 src/electrai/model/LCN_al.py create mode 100644 src/electrai/model/resunet_LCN.py create mode 100644 src/electrai/model/utils_al.py diff --git a/src/electrai/dataloader/utils.py b/src/electrai/dataloader/utils.py index 9db22cfc..87896594 100644 --- a/src/electrai/dataloader/utils.py +++ b/src/electrai/dataloader/utils.py @@ -31,12 +31,17 @@ def load_numpy_rho( data = torch.tensor(data, dtype=dtype_map[precision]) label = torch.tensor(label, dtype=dtype_map[precision]) lattice = torch.tensor(lattice, dtype=dtype_map[precision]) - grid_shape = torch.tensor( - data.shape, dtype=dtype_map[precision], device=lattice.device - ) - lattice = lattice / grid_shape[:, None] + # grid_shape = torch.tensor( + # data.shape, dtype=dtype_map[precision], device=lattice.device + # ) + # lattice = lattice / grid_shape[:, None] if augmentation: data, label = rand_rotate([data, label]) + # print("shapeeeeeeee", index, data.shape, label.shape) + data = data.permute(2, 1, 0) + label = label.permute(2, 1, 0) + # a, b, c = lattice[0], lattice[1], lattice[2] + # lattice = torch.stack([c, b, a], dim=0) # (z,y,x) return data, label, lattice diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index 235b3c84..b7ec947a 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -42,7 +42,7 @@ def train(args): from lightning.pytorch.loggers import WandbLogger wandb_logger = WandbLogger( - project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg) + project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg), name=cfg.run_name ) else: wandb_logger = None @@ -62,12 +62,19 @@ def train(args): # ----------------------------- # Trainer # ----------------------------- + local_world_size = int( + os.environ.get("LOCAL_WORLD_SIZE", torch.cuda.device_count()) + ) + world_size = int(os.environ.get("WORLD_SIZE", local_world_size)) + num_nodes = max(1, world_size // local_world_size) trainer = Trainer( max_epochs=int(cfg.epochs), logger=wandb_logger, callbacks=[checkpoint_cb, lr_monitor], accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, + devices="auto", + num_nodes=num_nodes, + strategy="ddp", precision=cfg.precision, log_every_n_steps=1, gradient_clip_val=getattr(cfg, "gradient_clip_value", 1.0), diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 126f295a..fb797362 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -3,6 +3,8 @@ import torch from hydra.utils import instantiate from lightning.pytorch import LightningModule +from lightning.pytorch.utilities import rank_zero_only +from src.electrai.model.LCN import LatticeConv3d from src.electrai.model.loss.charge import NormMAE @@ -17,6 +19,38 @@ def __init__(self, cfg): def forward(self, x, lattice_vectors=None): return self.model(x, lattice_vectors) + @rank_zero_only + def _collect_kernel_stats(self, target_layer: str | None = None) -> None: + # no need for trainer.is_global_zero now; rank_zero_only handles it + for name, module in self.model.named_modules(): + if not isinstance(module, LatticeConv3d) or not hasattr( + module, "kernel_stats" + ): + continue + if target_layer is not None and name != target_layer: + continue + + s = module.kernel_stats + if self.cfg.model["use_lattice_conv"]: + log_dict = { + f"kernels/{name}/alpha": s["alpha"], + f"kernels/{name}/ratio": s["ratio"], + f"kernels/{name}/geo_rms": s["geo_rms"], + f"kernels/{name}/base_rms": s["base_rms"], + } + else: + log_dict = {f"kernels/{name}/base_rms": s["base_rms"]} + + # IMPORTANT: let Lightning manage step + syncing + self.log_dict( + log_dict, + on_step=True, + on_epoch=False, + prog_bar=False, + logger=True, + sync_dist=False, # keep False if you're only logging on rank 0 + ) + def training_step(self, batch): loss = self._loss_calculation(batch) self.log( @@ -27,15 +61,11 @@ def training_step(self, batch): on_epoch=True, sync_dist=False, ) - if hasattr(self.model, "conv1") and hasattr( - self.model.conv1, "last_debug_stats" - ): - stats = self.model.conv1.last_debug_stats - for key, values in stats.items(): - for metric, val in values.items(): - self.log(f"debug/{key}/{metric}", val, on_step=True, on_epoch=False) return loss + def on_train_batch_end(self, outputs, batch, batch_idx): # noqa: ARG002 + self._collect_kernel_stats(target_layer="mid.0.conv1") + def validation_step(self, batch): loss = self._loss_calculation(batch) self.log( @@ -43,68 +73,6 @@ def validation_step(self, batch): ) return loss - # def _log_gaussian_params(self, prefix="train_"): - # for name, module in self.model.named_modules(): - # if isinstance(module, torch.nn.Module) and hasattr( - # module, "gaussian_smear" - # ): - # gaussian_smear = module.gaussian_smear - - # if hasattr(gaussian_smear, "centers"): - # centers = gaussian_smear.centers - # self.log( - # f"{prefix}gaussian/centers_mean", - # centers.mean(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/centers_std", - # centers.std(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/centers_min", - # centers.min(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/centers_max", - # centers.max(), - # on_step=True, - # on_epoch=True, - # ) - - # if hasattr(gaussian_smear, "widths"): - # widths = gaussian_smear.widths - # self.log( - # f"{prefix}gaussian/widths_mean", - # widths.mean(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/widths_std", - # widths.std(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/widths_min", - # widths.min(), - # on_step=True, - # on_epoch=True, - # ) - # self.log( - # f"{prefix}gaussian/widths_max", - # widths.max(), - # on_step=True, - # on_epoch=True, - # ) - # break - def _loss_calculation(self, batch): x = batch["data"] y = batch["label"] @@ -126,8 +94,17 @@ def configure_optimizers(self): self.model.parameters(), lr=float(self.cfg.lr), weight_decay=float(self.cfg.weight_decay), + betas=(getattr(self.cfg, "beta1", 0.9), getattr(self.cfg, "beta2", 0.999)), ) + # flat = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1) + # cos = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, T_max=self.cfg.epochs - 1, eta_min=1e-5 + # ) + + # scheduler = torch.optim.lr_scheduler.SequentialLR( + # optimizer, [flat, cos], milestones=[1] + # ) linsch = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-5, diff --git a/src/electrai/model/LCN.py b/src/electrai/model/LCN.py index 45fc6cad..a90947f1 100644 --- a/src/electrai/model/LCN.py +++ b/src/electrai/model/LCN.py @@ -4,34 +4,31 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from src.electrai.model.utils import ( - FourierPositionalEmbedding, - GaussianRadialBasis, - PositionalEmbedding, -) +from src.electrai.model.utils import CartesianFourierEmbedding, GaussianRadialBasis class LatticeConv3d(nn.Module): def __init__( self, - in_channels, - out_channels, + in_channels: int, + out_channels: int, kernel_size, - padding_mode="circular", - stride=1, - dilation=1, - use_lattice_conv=False, - use_radial_embedding=False, - use_positional_embedding=False, - trainable_gaussian_params=False, - num_gaussians=16, - pos_embed_dim=16, - pos_embed_type="learnable", - r_max=5.0, - hidden_dim=64, + padding_mode: str = "circular", + stride: int = 1, + dilation: int = 1, + use_lattice_conv: bool = False, + mix_weight: float = 0.1, + use_radial_embedding: bool = False, + use_positional_embedding: bool = False, + trainable_gaussian_params: bool = False, + num_gaussians: int = 16, + pos_embed_dim: int = 16, + r_max: float = 5.0, + hidden_dim: int = 64, ): super().__init__() - padding = kernel_size // 2 # - 1 + padding = kernel_size // 2 + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = ( @@ -41,8 +38,12 @@ def __init__( self.padding_mode = padding_mode self.stride = stride self.dilation = dilation + self.use_lattice_conv = use_lattice_conv + self.use_radial_embedding = use_radial_embedding + self.use_positional_embedding = use_positional_embedding self.trainable_gaussian_params = trainable_gaussian_params + self.conv = nn.Conv3d( in_channels, out_channels, @@ -53,11 +54,6 @@ def __init__( dilation=dilation, bias=True, ) - w = self.conv.weight.detach().clone() - del self.conv._parameters["weight"] - self.conv.register_buffer("weight", w) - self.use_radial_embedding = use_radial_embedding - self.use_positional_embedding = use_positional_embedding if use_lattice_conv: if use_radial_embedding: @@ -69,17 +65,7 @@ def __init__( ) if use_positional_embedding: - if pos_embed_type == "learnable": - self.pos_embedding = PositionalEmbedding( - embed_dim=pos_embed_dim, max_kernel_size=max(self.kernel_size) - ) - elif pos_embed_type == "fourier": - self.pos_embedding = FourierPositionalEmbedding( - embed_dim=pos_embed_dim, max_freq=10 - ) - else: - raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}") - self.pos_embed_type = pos_embed_type + self.pos_embedding = CartesianFourierEmbedding(num_freqs=60) # 6) if use_radial_embedding or use_positional_embedding: input_size = 0 @@ -87,22 +73,30 @@ def __init__( input_size += num_gaussians if use_positional_embedding: input_size += pos_embed_dim + self.filter_network = nn.Sequential( nn.Linear(input_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), - nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), - nn.Dropout(0.1), nn.Linear(hidden_dim, in_channels * out_channels), ) - # if self.use_radial_embedding or self.use_positional_embedding: - # self.mix_weight = nn.Parameter(torch.tensor(0.1)) + # Optional stabilization (uncomment if desired) + nn.init.zeros_(self.filter_network[-1].weight) + nn.init.zeros_(self.filter_network[-1].bias) + + # Learnable mixing weight (scalar) + # self.mix_weight = nn.Parameter(torch.tensor(float(mix_weight))) + # Alternative per-out-channel alpha: + # self.mix_weight = nn.Parameter(torch.full((out_channels,), float(mix_weight))) + self.mix_weight = nn.Parameter( + torch.full((out_channels,), float(mix_weight)) + ) - def _apply_padding(self, x): + def _apply_padding(self, x: torch.Tensor) -> torch.Tensor: if self.padding_mode == "circular" and any(p > 0 for p in self.padding): pad_3d = ( self.padding[2], @@ -113,7 +107,8 @@ def _apply_padding(self, x): self.padding[0], ) return F.pad(x, pad_3d, mode="circular") - elif self.padding_mode in ["zeros", "reflect", "replicate"] and any( + + if self.padding_mode in ["zeros", "reflect", "replicate"] and any( p > 0 for p in self.padding ): pad_3d = ( @@ -125,17 +120,21 @@ def _apply_padding(self, x): self.padding[0], ) return F.pad(x, pad_3d, mode=self.padding_mode) - else: - return x - def compute_geometric_kernel(self, lattice_vectors): + return x + + def compute_geometric_kernel(self, lattice_vectors: torch.Tensor) -> torch.Tensor: + """ + lattice_vectors: (B, 3, 3) or (3, 3) in voxel units (z,y,x ordering upstream) + returns: (B, out, in, kz, ky, kx) or (out, in, kz, ky, kx) if input was (3,3) + """ if lattice_vectors.dim() == 2: lattice_vectors = lattice_vectors.unsqueeze(0) squeeze_batch = True else: squeeze_batch = False - B = lattice_vectors.shape[0] + B = lattice_vectors.shape[0] # noqa: F841 kz, ky, kx = self.kernel_size device = lattice_vectors.device @@ -144,40 +143,31 @@ def compute_geometric_kernel(self, lattice_vectors): x = torch.arange(kx, device=device) - kx // 2 grid_z, grid_y, grid_x = torch.meshgrid(z, y, x, indexing="ij") - frac_coords = torch.stack([grid_z, grid_y, grid_x], dim=-1).float() - - cart_coords = torch.einsum("ijkl,bml->bijkm", frac_coords, lattice_vectors) - distances = torch.norm(cart_coords, dim=-1) + frac_coords = torch.stack( + [grid_z, grid_y, grid_x], dim=-1 + ).float() # (kz,ky,kx,3) - debug_stats = {} - debug_stats["distances"] = { - "min": distances.min().item(), - "max": distances.max().item(), - "mean": distances.mean().item(), - "has_nan": torch.isnan(distances).any().item(), - "has_inf": torch.isinf(distances).any().item(), - } + cart_coords = torch.einsum( + "ijkl,bml->bijkm", frac_coords, lattice_vectors + ) # (B,kz,ky,kx,3) + distances = torch.norm(cart_coords, dim=-1) # (B,kz,ky,kx) if self.use_radial_embedding: - radial_features = self.gaussian_smear(distances) + radial_features = self.gaussian_smear(distances) # (B,kz,ky,kx,Ng) radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") + if self.use_positional_embedding: - if self.pos_embed_type == "learnable": - pos_features = self.pos_embedding(frac_coords, self.kernel_size) - pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) - else: - pos_features = self.pos_embedding(frac_coords) - pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + pos_features = self.pos_embedding(cart_coords) # (B,kz,ky,kx,Np) pos_flat = rearrange(pos_features, "b kz ky kx n -> b (kz ky kx) n") if self.use_radial_embedding and self.use_positional_embedding: features = torch.cat([radial_flat, pos_flat], dim=-1) elif self.use_radial_embedding: - features = torch.cat([radial_flat], dim=-1) + features = radial_flat else: - features = torch.cat([pos_flat], dim=-1) + features = pos_flat - kernel_flat = self.filter_network(features) + kernel_flat = self.filter_network(features) # (B, kz*ky*kx, out*in) kernel = rearrange( kernel_flat, "b (kz ky kx) (o i) -> b o i kz ky kx", @@ -191,57 +181,88 @@ def compute_geometric_kernel(self, lattice_vectors): if squeeze_batch: kernel = kernel.squeeze(0) - return kernel, debug_stats + return kernel + + def forward( + self, x: torch.Tensor, lattice_vectors: torch.Tensor | None = None + ) -> torch.Tensor: + base_kernel = self.conv.weight.unsqueeze(0) # (1,out,in,kz,ky,kx) - def forward(self, x, lattice_vectors=None): - if not self.use_lattice_conv or ( - not self.use_radial_embedding and not self.use_positional_embedding + if (not self.use_lattice_conv) or ( + (not self.use_radial_embedding) and (not self.use_positional_embedding) ): + with torch.no_grad(): + b_rms = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = {"base_rms": b_rms.item()} return self.conv(x) - B = x.shape[0] - if lattice_vectors.dim() == 2: - x_padded = self._apply_padding(x) - geometric_kernels = self.compute_geometric_kernel(lattice_vectors) - alpha = 1 # self.mix_weight - kernel = alpha * geometric_kernels - - return F.conv3d( - x_padded, - kernel, - self.conv.bias, - stride=self.stride, - padding=0, - dilation=self.dilation, + if lattice_vectors is None: + raise ValueError( + "lattice_vectors must be provided when use_lattice_conv=True" ) - else: - geometric_kernels, debug_stats = self.compute_geometric_kernel( - lattice_vectors - ) - self.last_debug_stats = debug_stats - alpha = 1 # 0.1 # self.mix_weight - # self.register_buffer("base_weight", w) # saved + moved with .to(device), not trained - - # base_kernel = self.conv.weight.unsqueeze(0) - kernels = alpha * geometric_kernels # + (1 - alpha) * base_kernel - x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) - x_grouped = self._apply_padding(x_grouped) - kernels_grouped = kernels.reshape( - B * self.out_channels, self.in_channels, *self.kernel_size - ) - out = F.conv3d( - x_grouped, - kernels_grouped, - bias=None, - stride=self.stride, - padding=0, - dilation=self.dilation, - groups=B, - ) + D, H, W = x.shape[-3:] + + # lattice_vectors assumed (B,3,3) in Å; convert to voxel units and reorder to (z,y,x) + a = lattice_vectors[:, 0, :] + b = lattice_vectors[:, 1, :] + c = lattice_vectors[:, 2, :] + lv_voxel = torch.stack( + [c / D, b / H, a / W], dim=1 + ) # (B,3,3) in voxel units, z/y/x order + + B = x.shape[0] + geometric_kernels = self.compute_geometric_kernel( + lv_voxel + ) # (B,out,in,kz,ky,kx) + + # Optional global RMS match (comment out if you don't want this constraint) + # g_rms = geometric_kernels.pow(2).mean().sqrt().clamp(min=1e-8) + # b_rms = base_kernel.pow(2).mean().sqrt().clamp(min=1e-8) + # geometric_kernels = geometric_kernels * (b_rms / g_rms) + + # alpha = torch.sigmoid(self.mix_weight) # scalar + alpha = torch.sigmoid(self.mix_weight).view(1, self.out_channels, 1, 1, 1, 1) + + with torch.no_grad(): + g_rms2 = geometric_kernels.pow(2).mean().sqrt() + b_rms2 = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = { + "geo_rms": g_rms2.item(), + "base_rms": b_rms2.item(), + "ratio": (g_rms2 / (b_rms2 + 1e-8)).item(), + "alpha": float(alpha.mean().item()) + if alpha.numel() > 1 + else float(alpha.item()), + } + + # Current mixing rule (as in your snippet): + mod = torch.tanh(geometric_kernels) # bounded + kernels = base_kernel * (1 + alpha * mod) + # kernels = ( + # alpha * geometric_kernels + (1 - alpha) * base_kernel + # ) # (B,out,in,kz,ky,kx) + + x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) + x_grouped = self._apply_padding(x_grouped) + + kernels_grouped = kernels.reshape( + B * self.out_channels, self.in_channels, *self.kernel_size + ) + + out = F.conv3d( + x_grouped, + kernels_grouped, + bias=None, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=B, + ) + + out = out.reshape(B, self.out_channels, *out.shape[2:]) - out = out.reshape(B, self.out_channels, *out.shape[2:]) + if self.conv.bias is not None: + out = out + self.conv.bias.view(1, -1, 1, 1, 1) - if self.conv.bias is not None: - out = out + self.conv.bias.view(1, -1, 1, 1, 1) - return out + return out diff --git a/src/electrai/model/LCN_al.py b/src/electrai/model/LCN_al.py new file mode 100644 index 00000000..ee6ea473 --- /dev/null +++ b/src/electrai/model/LCN_al.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.electrai.model.utils import DistanceTriangleNet + + +class LatticeConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=False, + mix_weight=0.1, # noqa: ARG002 + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, # noqa: ARG002 + pos_embed_dim=16, # noqa: ARG002 + # pos_embed_type="learnable", + r_max=5.0, # noqa: ARG002 + hidden_dim=64, # noqa: ARG002 + ): + super().__init__() + padding = kernel_size // 2 # - 1 + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 + ) + self.padding = padding if isinstance(padding, tuple) else (padding,) * 3 + self.padding_mode = padding_mode + self.stride = stride + self.dilation = dilation + self.use_lattice_conv = use_lattice_conv + self.trainable_gaussian_params = trainable_gaussian_params + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + padding=padding, + padding_mode=padding_mode, + stride=stride, + dilation=dilation, + bias=True, + ) + self.use_radial_embedding = use_radial_embedding + self.use_positional_embedding = use_positional_embedding + + if use_lattice_conv: + self.dist_triangle = DistanceTriangleNet( + d=128, + n_heads=4, + n_blocks=2, + rbf_num=32, + r_max=16.0, + dropout=0.0, + in_channels=in_channels, + out_channels=out_channels, + ) + # if use_radial_embedding: + # self.gaussian_smear = GaussianRadialBasis( + # num_gaussians=num_gaussians, + # r_min=0.0, + # r_max=r_max, + # trainable=self.trainable_gaussian_params, + # ) + + # if use_positional_embedding: + # self.pos_embedding = CartesianFourierEmbedding(num_freqs=6) + # # if pos_embed_type == "learnable": + # # self.pos_embedding = PositionalEmbedding( + # # embed_dim=pos_embed_dim, max_kernel_size=max(self.kernel_size) + # # ) + # # elif pos_embed_type == "fourier": + # # self.pos_embedding = FourierPositionalEmbedding( + # # embed_dim=pos_embed_dim, max_freq=10 + # # ) + # # else: + # # raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}") + # # self.pos_embed_type = pos_embed_type + + # if use_radial_embedding or use_positional_embedding: + # input_size = 0 + # if use_radial_embedding: + # input_size += num_gaussians + # if use_positional_embedding: + # input_size += pos_embed_dim + # self.filter_network = nn.Sequential( + # nn.Linear(input_size, hidden_dim), + # nn.LayerNorm(hidden_dim), + # nn.SiLU(), + # # nn.Dropout(0.1), + # nn.Linear(hidden_dim, hidden_dim), + # nn.LayerNorm(hidden_dim), + # nn.SiLU(), + # # nn.Dropout(0.1), + # nn.Linear(hidden_dim, in_channels * out_channels), + # ) + # # nn.init.zeros_(self.filter_network[-1].weight) + # # nn.init.zeros_(self.filter_network[-1].bias) + + # # self.register_buffer("mix_weight", torch.tensor(float(mix_weight))) + # self.mix_weight = nn.Parameter(torch.tensor(float(mix_weight))) + # # self.mix_weight = nn.Parameter(torch.zeros(out_channels)) + + def _apply_padding(self, x): + if self.padding_mode == "circular" and any(p > 0 for p in self.padding): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode="circular") + elif self.padding_mode in ["zeros", "reflect", "replicate"] and any( + p > 0 for p in self.padding + ): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode=self.padding_mode) + else: + return x + + def compute_geometric_kernel(self, lattice_vectors): + if lattice_vectors.dim() == 2: + lattice_vectors = lattice_vectors.unsqueeze(0) + squeeze_batch = True + else: + squeeze_batch = False # noqa: F841 + + B = lattice_vectors.shape[0] + kz, ky, kx = self.kernel_size + device = lattice_vectors.device + + z = torch.arange(kz, device=device) - kz // 2 + y = torch.arange(ky, device=device) - ky // 2 + x = torch.arange(kx, device=device) - kx // 2 + + grid_z, grid_y, grid_x = torch.meshgrid(z, y, x, indexing="ij") + frac_coords = torch.stack([grid_z, grid_y, grid_x], dim=-1).float() + + cart_coords = torch.einsum("ijkl,bml->bijkm", frac_coords, lattice_vectors) + return cart_coords.reshape(B, -1, 3) + + # distances = torch.norm(cart_coords, dim=-1) + + # # debug_stats = {} + # # debug_stats["distances"] = { + # # "min": distances.min().item(), + # # "max": distances.max().item(), + # # "mean": distances.mean().item(), + # # "has_nan": torch.isnan(distances).any().item(), + # # "has_inf": torch.isinf(distances).any().item(), + # # } + + # if self.use_radial_embedding: + # radial_features = self.gaussian_smear(distances) + # radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") + # if self.use_positional_embedding: + # pos_features = self.pos_embedding(cart_coords) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # # if self.pos_embed_type == "learnable": + # # pos_features = self.pos_embedding(frac_coords, self.kernel_size) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # # else: + # # pos_features = self.pos_embedding(frac_coords) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # pos_flat = rearrange(pos_features, "b kz ky kx n -> b (kz ky kx) n") + + # if self.use_radial_embedding and self.use_positional_embedding: + # features = torch.cat([radial_flat, pos_flat], dim=-1) + # elif self.use_radial_embedding: + # features = torch.cat([radial_flat], dim=-1) + # else: + # features = torch.cat([pos_flat], dim=-1) + + # kernel_flat = self.filter_network(features) + # kernel = rearrange( + # kernel_flat, + # "b (kz ky kx) (o i) -> b o i kz ky kx", + # o=self.out_channels, + # i=self.in_channels, + # kz=kz, + # ky=ky, + # kx=kx, + # ) + # # kernel_norm = torch.linalg.vector_norm( + # # kernel, dim=(-4, -3, -2, -1), keepdim=True + # # ).clamp(min=1e-8) + # # kernel = ( + # # kernel / kernel_norm * (1.0 / (3 * self.in_channels * kz * ky * kx) ** 0.5) + # # ) + # # after kernel computed + # if squeeze_batch: + # kernel = kernel.squeeze(0) + + # return kernel + + def forward(self, x, lattice_vectors=None): + base_kernel = self.conv.weight.unsqueeze(0) + if not self.use_lattice_conv or ( + not self.use_radial_embedding and not self.use_positional_embedding + ): + with torch.no_grad(): + b_rms = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = {"base_rms": b_rms.item()} + return self.conv(x) + D, H, W = x.shape[-3:] + # scale = x.new_tensor([D, H, W]).view(1, 3, 1) # (1,3,1) + a = lattice_vectors[:, 0, :] + b = lattice_vectors[:, 1, :] + c = lattice_vectors[:, 2, :] + lv_voxel = torch.stack([c / D, b / H, a / W], dim=1) # (z,y,x) + # lv_voxel = lattice_vectors / scale # (B,3,3) + + B = x.shape[0] + # geometric_kernels = self.compute_geometric_kernel(lv_voxel) + # g_rms = geometric_kernels.pow(2).mean().sqrt().clamp(min=1e-8) + # b_rms = base_kernel.pow(2).mean().sqrt().clamp(min=1e-8) + # geometric_kernels = geometric_kernels * (b_rms / g_rms) + # alpha = torch.sigmoid(self.mix_weight) + + # with torch.no_grad(): + # g_rms = geometric_kernels.pow(2).mean().sqrt() + # b_rms = base_kernel.pow(2).mean().sqrt() + # self.kernel_stats = { + # "geo_rms": g_rms.item(), + # "base_rms": b_rms.item(), + # "ratio": (g_rms / (b_rms + 1e-8)).item(), + # "alpha": alpha.item(), + # } + + # kernels = alpha * geometric_kernels + base_kernel # (1 - alpha) * base_kernel + pos = self.compute_geometric_kernel(lv_voxel) + kernels = base_kernel + self.dist_triangle(pos, self.kernel_size[0]) + x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) + x_grouped = self._apply_padding(x_grouped) + kernels_grouped = kernels.reshape( + B * self.out_channels, self.in_channels, *self.kernel_size + ) + out = F.conv3d( + x_grouped, + kernels_grouped, + bias=None, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=B, + ) + + out = out.reshape(B, self.out_channels, *out.shape[2:]) + + if self.conv.bias is not None: + out = out + self.conv.bias.view(1, -1, 1, 1, 1) + return out diff --git a/src/electrai/model/resnet_LCN.py b/src/electrai/model/resnet_LCN.py index be1537b0..a5460029 100644 --- a/src/electrai/model/resnet_LCN.py +++ b/src/electrai/model/resnet_LCN.py @@ -13,11 +13,13 @@ def __init__( K=3, use_checkpoint=True, use_lattice_conv=False, + mix_weight=0.1, use_radial_embedding=False, use_positional_embedding=False, trainable_gaussian_params=False, num_gaussians=16, pos_embed_dim=16, + hidden_dim=64, ): super().__init__() self.use_checkpoint = use_checkpoint @@ -30,11 +32,13 @@ def __init__( stride=1, dilation=1, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) self.norm1 = nn.InstanceNorm3d(in_features) self.act1 = nn.PReLU() @@ -46,11 +50,13 @@ def __init__( stride=1, dilation=1, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) self.norm2 = nn.InstanceNorm3d(in_features) @@ -82,16 +88,17 @@ def __init__( normalize=True, use_checkpoint=True, use_lattice_conv=False, + mix_weight=0.1, use_radial_embedding=False, use_positional_embedding=False, trainable_gaussian_params=False, num_gaussians=16, pos_embed_dim=16, + hidden_dim=64, ): super().__init__() self.normalize = normalize self.use_checkpoint = use_checkpoint - self.use_lattice_conv = use_lattice_conv # First layer self.conv1 = LatticeConv3d( @@ -102,11 +109,13 @@ def __init__( stride=1, dilation=1, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) self.act1 = nn.PReLU() @@ -117,6 +126,13 @@ def __init__( K=kernel_size2, use_checkpoint=use_checkpoint, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) for _ in range(n_residual_blocks) ] @@ -131,11 +147,13 @@ def __init__( stride=1, dilation=1, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) self.norm = nn.InstanceNorm3d(n_channels) @@ -148,11 +166,13 @@ def __init__( stride=1, dilation=1, use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, use_radial_embedding=use_radial_embedding, use_positional_embedding=use_positional_embedding, trainable_gaussian_params=trainable_gaussian_params, num_gaussians=num_gaussians, pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, ) self.act2 = nn.ReLU() @@ -174,6 +194,9 @@ def _forward(self, x, lattice_vectors=None): out = self.act2(out) if self.normalize: - out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] - out = out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + # out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] + # out = out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + out_sum = torch.sum(out, axis=(-3, -2, -1), keepdim=True).clamp(min=1e-8) + x_sum = torch.sum(x, axis=(-3, -2, -1), keepdim=True) + out = (out / out_sum) * x_sum return out diff --git a/src/electrai/model/resunet_LCN.py b/src/electrai/model/resunet_LCN.py new file mode 100644 index 00000000..ebc036cc --- /dev/null +++ b/src/electrai/model/resunet_LCN.py @@ -0,0 +1,543 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.electrai.model.LCN import LatticeConv3d + + +class LatticeSequential(nn.Sequential): + def forward(self, x, lattice_vectors): + for module in self: + x = module(x, lattice_vectors) + return x + + +class ResBlock3D(nn.Module): + def __init__(self, cin, cout, k, **lcn_kwargs): + super().__init__() + + self.conv1 = LatticeConv3d( + cin, cout, kernel_size=k, padding_mode="circular", **lcn_kwargs + ) + self.norm1 = nn.InstanceNorm3d(cout) + self.act1 = nn.PReLU() + + self.conv2 = LatticeConv3d( + cout, cout, kernel_size=k, padding_mode="circular", **lcn_kwargs + ) + self.norm2 = nn.InstanceNorm3d(cout) + self.act2 = nn.PReLU() + + self.skip = ( + LatticeConv3d(cin, cout, kernel_size=1, **lcn_kwargs) + if cin != cout + else nn.Identity() + ) + + def forward(self, x, lattice_vectors): + h = self.act1(self.norm1(self.conv1(x, lattice_vectors))) + h = self.norm2(self.conv2(h, lattice_vectors)) + skip_out = ( + self.skip(x, lattice_vectors) + if isinstance(self.skip, LatticeConv3d) + else self.skip(x) + ) + return self.act2(h + skip_out) + + +class DownsampleBlock(nn.Module): + def __init__(self, cin, cout, **lcn_kwargs): + super().__init__() + self.conv = LatticeConv3d( + cin, cout, 3, stride=2, padding_mode="circular", **lcn_kwargs + ) + self.norm = nn.InstanceNorm3d(cout) + self.act = nn.PReLU() + + def forward(self, x, lattice_vectors): + return self.act(self.norm(self.conv(x, lattice_vectors))) + + +class PeriodicUpsampleConv3d(nn.Module): + def __init__(self, cin, cout, **lcn_kwargs): + super().__init__() + self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False) + self.conv = LatticeConv3d( + cin, + cout, + 3, + **lcn_kwargs, + # padding=1, + # padding_mode="circular", # , use_lattice_conv=False + ) + self.norm = nn.InstanceNorm3d(cout) + self.act = nn.PReLU() + + def forward(self, x, lattice_vectors): + x = F.pad(x, (1, 1, 1, 1, 1, 1), mode="circular") + x = self.up(x) + x = x[..., 2:-2, 2:-2, 2:-2] + x = self.conv(x, lattice_vectors) + x = self.norm(x) + return self.act(x) + + +class ResUNet3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + n_channels, + depth, + n_residual_blocks, + kernel_size, + use_lattice_conv=False, + mix_weight=0.1, + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, + pos_embed_dim=16, + hidden_dim=64, + ): + super().__init__() + + lcn_kwargs = { + "use_lattice_conv": use_lattice_conv, + "mix_weight": mix_weight, + "use_radial_embedding": use_radial_embedding, + "use_positional_embedding": use_positional_embedding, + "trainable_gaussian_params": trainable_gaussian_params, + "num_gaussians": num_gaussians, + "pos_embed_dim": pos_embed_dim, + "hidden_dim": hidden_dim, + } + + self.in_conv = ResBlock3D(in_channels, n_channels, kernel_size, **lcn_kwargs) + + # -------- Encoder -------- + self.enc_blocks = nn.ModuleList() + self.downs = nn.ModuleList() + + ch = n_channels + for _ in range(depth): + self.enc_blocks.append( + LatticeSequential( + *[ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(n_residual_blocks) + ] + ) + ) + self.downs.append(DownsampleBlock(ch, 2 * ch, **lcn_kwargs)) + ch *= 2 + + # -------- Bottleneck -------- + self.mid = LatticeSequential( + *[ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(2 * n_residual_blocks) + ] + ) + + # -------- Decoder -------- + self.ups = nn.ModuleList() + self.dec_blocks = nn.ModuleList() + # for _ in range(depth): + # self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2)) + # ch //= 2 + + # blocks = [ResBlock3D(2 * ch, ch, kernel_size, **lcn_kwargs)] + # blocks.extend( + # [ + # ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + # for _ in range(n_residual_blocks - 1) + # ] + # ) + # self.dec_blocks.append(LatticeSequential(*blocks)) + for _ in range(depth): + self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2, **lcn_kwargs)) + ch //= 2 + + blocks = [ResBlock3D(2 * ch, ch, kernel_size, **lcn_kwargs)] + blocks.extend( + [ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(n_residual_blocks - 1) + ] + ) + self.dec_blocks.append(LatticeSequential(*blocks)) + + # -------- Output -------- + # self.out_conv = nn.Conv3d(n_channels, out_channels, kernel_size=1) + self.out_conv = LatticeConv3d( + n_channels, out_channels, kernel_size=1, **lcn_kwargs + ) + + def forward(self, x, lattice_vectors): + skips = [] + out = self.in_conv(x, lattice_vectors) + + for enc, down in zip(self.enc_blocks, self.downs, strict=False): + out = enc(out, lattice_vectors) + skips.append(out) + out = down(out, lattice_vectors) + + out = self.mid(out, lattice_vectors) + + for up, dec in zip(self.ups, self.dec_blocks, strict=False): + out = up(out, lattice_vectors) + out = torch.cat([out, skips.pop()], dim=1) + out = dec(out, lattice_vectors) + + out = self.out_conv(out, lattice_vectors) + out = out / torch.sum(out, dim=(-3, -2, -1), keepdim=True) + return out * torch.sum(x, dim=(-3, -2, -1), keepdim=True) + + +# from __future__ import annotations + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from src.electrai.model.LCN import LatticeConv3d + + +# class ResBlock3D(nn.Module): +# def __init__( +# self, +# cin, +# cout, +# k, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.conv1 = LatticeConv3d( +# cin, +# cout, +# kernel_size=k, +# padding_mode="circular", +# padding=k // 2, +# stride=1, +# dilation=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# self.norm1 = nn.InstanceNorm3d(cout) +# self.act1 = nn.PReLU() +# self.conv2 = LatticeConv3d( +# cout, +# cout, +# kernel_size=k, +# padding_mode="circular", +# padding=k // 2, +# stride=1, +# dilation=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# self.norm2 = nn.InstanceNorm3d(cout) +# self.act2 = nn.PReLU() +# # self.conv1 = nn.Conv3d(cin, cout, k, padding=k // 2, padding_mode="circular") +# # self.norm1 = nn.InstanceNorm3d(cout) +# # self.act = nn.PReLU() +# # self.conv2 = nn.Conv3d(cout, cout, k, padding=k // 2, padding_mode="circular") +# # self.norm2 = nn.InstanceNorm3d(cout) + +# if cin != cout: +# self.skip = LatticeConv3d( +# cin, +# cout, +# kernel_size=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.skip = nn.Conv3d(cin, cout, 1) +# else: +# self.skip = nn.Identity() + +# def forward(self, x, lattice_vectors): +# h = self.act1(self.norm1(self.conv1(x, lattice_vectors))) +# h = self.norm2(self.conv2(h, lattice_vectors)) +# return self.act2(h + self.skip(x)) + + +# class ResUNet3D(nn.Module): +# def __init__( +# self, +# in_channels, +# out_channels, +# n_channels, +# depth, +# n_residual_blocks, +# kernel_size, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.in_conv = ResBlock3D( +# in_channels, +# n_channels, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.in_conv = ResBlock3D(in_channels, n_channels, kernel_size) + +# # -------- Encoder -------- +# self.enc_blocks = nn.ModuleList() +# self.downs = nn.ModuleList() + +# ch = n_channels +# for _ in range(depth): +# self.enc_blocks.append( +# nn.Sequential( +# *[ +# ResBlock3D( +# ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(n_residual_blocks) +# ] +# ) +# ) +# # self.enc_blocks.append( +# # nn.Sequential( +# # *[ResBlock3D(ch, ch, kernel_size,) for _ in range(n_residual_blocks)] +# # ) +# # ) +# self.downs.append( +# downsample( +# ch, +# 2 * ch, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# ) +# ch *= 2 + +# # -------- Bottleneck -------- +# self.mid = nn.Sequential( +# *[ +# ResBlock3D( +# ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(2 * n_residual_blocks) +# ] +# ) +# # self.mid = nn.Sequential( +# # *[ResBlock3D(ch, ch, kernel_size) for _ in range(2 * n_residual_blocks)] +# # ) + +# # -------- Decoder -------- +# self.ups = nn.ModuleList() +# self.dec_blocks = nn.ModuleList() + +# for _ in range(depth): +# self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2)) +# ch //= 2 +# self.dec_blocks.append( +# nn.Sequential( +# *[ +# ResBlock3D( +# 2 * ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(n_residual_blocks) +# ] +# ) +# ) +# # self.dec_blocks.append( +# # nn.Sequential( +# # *[ +# # ResBlock3D(2 * ch, ch, kernel_size) +# # for _ in range(n_residual_blocks) +# # ] +# # ) +# # ) + +# # -------- Output -------- +# self.out_conv = LatticeConv3d( +# n_channels, +# out_channels, +# kernel_size=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.out_conv = nn.Conv3d(n_channels, out_channels, kernel_size=1) + +# def forward(self, x, lattice_vectors): +# skips = [] +# out = self.in_conv(x, lattice_vectors) + +# for enc, down in zip(self.enc_blocks, self.downs, strict=False): +# out = enc(out) +# skips.append(out) +# out = down(out) +# out = self.mid(out) + +# for up, dec in zip(self.ups, self.dec_blocks, strict=False): +# out = up(out) +# out = torch.cat([out, skips.pop()], dim=1) +# out = dec(out) +# out = self.out_conv(out) +# out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] +# return out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + + +# class PeriodicUpsampleConv3d(nn.Module): +# def __init__( +# self, +# cin, +# cout, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False) +# self.conv = LatticeConv3d( +# cin, +# cout, +# 3, +# stride=2, +# padding=1, +# padding_mode="circular", +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.conv = nn.Conv3d(cin, cout, 3, padding=1, padding_mode="circular") +# self.norm = nn.InstanceNorm3d(cout) +# self.act = nn.PReLU() + +# def forward(self, x, lattice_vectors): +# x = F.pad(x, (1, 1, 1, 1, 1, 1), mode="circular") +# x = self.up(x) +# x = x[..., 2:-2, 2:-2, 2:-2] +# x = self.conv(x, lattice_vectors) +# x = self.norm(x) +# return self.act(x) + + +# def downsample( +# cin, +# cout, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# return nn.Sequential( +# LatticeConv3d( +# cin, +# cout, +# 3, +# stride=2, +# padding=1, +# padding_mode="circular", +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ), +# # nn.Conv3d(cin, cout, 3, stride=2, padding=1, padding_mode="circular"), +# nn.InstanceNorm3d(cout), +# nn.PReLU(), +# ) diff --git a/src/electrai/model/utils.py b/src/electrai/model/utils.py index 60e1f922..40644db2 100644 --- a/src/electrai/model/utils.py +++ b/src/electrai/model/utils.py @@ -5,17 +5,19 @@ class GaussianRadialBasis(nn.Module): - def __init__(self, num_gaussians=50, r_min=0.0, r_max=5.0, trainable=False): + def __init__( + self, + num_gaussians: int = 50, + r_min: float = 0.0, + r_max: float = 5.0, + trainable: bool = False, + ): super().__init__() self.num_gaussians = num_gaussians self.r_min = r_min self.r_max = r_max - # Initialize Gaussian centers uniformly between r_min and r_max centers = torch.linspace(r_min, r_max, num_gaussians) - - # Initialize widths based on spacing - # Common choice: width = distance between adjacent centers spacing = (r_max - r_min) / (num_gaussians - 1) if num_gaussians > 1 else 1.0 widths = torch.ones(num_gaussians) * spacing @@ -26,7 +28,7 @@ def __init__(self, num_gaussians=50, r_min=0.0, r_max=5.0, trainable=False): self.register_buffer("centers", centers) self.register_buffer("widths", widths) - def forward(self, distances): + def forward(self, distances: torch.Tensor) -> torch.Tensor: """ Expand distances using Gaussian basis functions. @@ -36,106 +38,193 @@ def forward(self, distances): Returns: Tensor of shape [..., num_gaussians] with Gaussian features """ - # distances: [...] -> [..., 1] - # centers: [num_gaussians] -> [1, ..., 1, num_gaussians] - distances = distances.unsqueeze(-1) # [..., 1] - centers = self.centers.view( - *([1] * (distances.dim() - 1)), -1 - ) # [1, ..., num_gaussians] + distances = distances.unsqueeze(-1) + centers = self.centers.view(*([1] * (distances.dim() - 1)), -1) widths = self.widths.view(*([1] * (distances.dim() - 1)), -1) - # Gaussian RBF: exp(-(d - c)^2 / (2 * w^2)) diff = distances - centers - gamma = 1.0 / (2 * widths**2) + gamma = 1.0 / (2.0 * widths**2) return torch.exp(-gamma * diff**2) -class PositionalEmbedding(nn.Module): +class CartesianFourierEmbedding(nn.Module): """ - Learnable positional embedding for kernel positions. - Similar to positional encodings in Transformers but learnable. + Fourier features of real displacement vectors (cartesian, in Å). + Uses a fixed physical scale (r_max) so features are comparable across samples. """ - def __init__(self, embed_dim=32, max_kernel_size=7): + def __init__( + self, + num_freqs: int = 6, + include_radius: bool = True, + r_max: float = 5.0, + freq_min: float = 0.5, + freq_max: float = 3.0, + ): super().__init__() - self.embed_dim = embed_dim - # Learnable embeddings for each coordinate dimension - # Range: [-max_kernel_size//2, max_kernel_size//2] - self.z_embed = nn.Embedding(max_kernel_size, embed_dim) - self.y_embed = nn.Embedding(max_kernel_size, embed_dim) - self.x_embed = nn.Embedding(max_kernel_size, embed_dim) + # Frequencies (wave numbers). Larger -> higher spatial variation. + freqs = torch.linspace(freq_min, freq_max, num_freqs) + self.register_buffer("freqs", freqs) - # Optional: learnable way to combine the three directions - self.combine = nn.Linear(3 * embed_dim, embed_dim) + self.include_radius = include_radius + self.r_max = float(r_max) - def forward(self, frac_coords, kernel_size): - """ - Args: - frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) - kernel_size: (kz, ky, kx) tuple + # 3 coords * num_freqs * (sin+cos) + optional (r, r^2) + self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) - Returns: - embeddings: [kz, ky, kx, embed_dim] - """ - kz, ky, kx = kernel_size - - # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) - z_idx = (frac_coords[..., 0] + kz // 2).long() - y_idx = (frac_coords[..., 1] + ky // 2).long() - x_idx = (frac_coords[..., 2] + kx // 2).long() - - # Get embeddings for each dimension - z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] - y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] - x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] - - # Combine (can use addition, concatenation + linear, etc.) - combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] - return self.combine(combined) # [kz, ky, kx, embed_dim] - - -class FourierPositionalEmbedding(nn.Module): - """ - Fourier features for positional encoding (non-learnable but more expressive). - Similar to what's used in NeRF. - """ - - def __init__(self, embed_dim=32, max_freq=10): - super().__init__() - self.embed_dim = embed_dim - - # Number of frequency bands - num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords - - # Logarithmically spaced frequencies - freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) - self.register_buffer("freq_bands", freq_bands) - - def forward(self, frac_coords): + def forward(self, cart_coords: torch.Tensor) -> torch.Tensor: """ Args: - frac_coords: [..., 3] fractional coordinates + cart_coords: [B, kz, ky, kx, 3] displacement vectors in Å Returns: - embeddings: [..., embed_dim] + features: [B, kz, ky, kx, out_dim] """ - # frac_coords: [..., 3] - # freq_bands: [num_freqs] - - coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] - freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] - - # Compute sin and cos for each frequency and coordinate - angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] - - sin_features = torch.sin(angles) - cos_features = torch.cos(angles) - - # Concatenate all features - fourier_features = torch.cat( - [sin_features, cos_features], dim=-2 - ) # [..., 2*num_freqs, 3] - return fourier_features.reshape( - *frac_coords.shape[:-1], -1 - ) # [..., 6*num_freqs] + # Fixed physical normalization: make coords dimensionless and stable. + v = cart_coords / (self.r_max + 1e-6) + + # [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] + angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + + sin = torch.sin(angles) + cos = torch.cos(angles) + + # concat sin/cos on frequency axis -> flatten to last dim + feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + + if self.include_radius: + r = torch.linalg.norm(v, dim=-1, keepdim=True) + feat = torch.cat([feat, r, r**2], dim=-1) + + return feat + + +# class CartesianFourierEmbedding(nn.Module): +# """ +# Fourier features of real displacement vector (cartesian). +# Much more meaningful than index or fractional embedding. +# """ + +# def __init__(self, num_freqs: int = 6, include_radius: bool = True): +# super().__init__() +# # Smooth low frequencies — not exponential like NeRF +# freqs = torch.linspace(0.5, 3.0, num_freqs) +# self.register_buffer("freqs", freqs) + +# self.include_radius = include_radius +# self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) + +# def forward(self, cart_coords: torch.Tensor) -> torch.Tensor: +# """ +# Args: +# cart_coords: [B, kz, ky, kx, 3] real displacement vectors (Å) + +# Returns: +# positional features: [B, kz, ky, kx, out_dim] +# """ +# # Normalize scale for stability (prevents huge lattice from exploding features) +# scale = ( +# cart_coords.norm(dim=-1, keepdim=True).mean(dim=(1, 2, 3, 4), keepdim=True) +# + 1e-6 +# ) +# v = cart_coords / scale + +# # Project onto frequencies: [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] +# angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + +# sin = torch.sin(angles) +# cos = torch.cos(angles) + +# # Concatenate sin/cos over frequency axis and flatten +# feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + +# if self.include_radius: +# r = torch.norm(v, dim=-1, keepdim=True) +# feat = torch.cat([feat, r, r**2], dim=-1) + +# return feat + + +# class PositionalEmbedding(nn.Module): +# """ +# Learnable positional embedding for kernel positions. +# Similar to positional encodings in Transformers but learnable. +# """ +# +# def __init__(self, embed_dim=32, max_kernel_size=7): +# super().__init__() +# self.embed_dim = embed_dim +# +# # Learnable embeddings for each coordinate dimension +# # Range: [-max_kernel_size//2, max_kernel_size//2] +# self.z_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.y_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.x_embed = nn.Embedding(max_kernel_size, embed_dim) +# +# # Optional: learnable way to combine the three directions +# self.combine = nn.Linear(3 * embed_dim, embed_dim) +# +# def forward(self, frac_coords, kernel_size): +# """ +# Args: +# frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) +# kernel_size: (kz, ky, kx) tuple +# +# Returns: +# embeddings: [kz, ky, kx, embed_dim] +# """ +# kz, ky, kx = kernel_size +# +# # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) +# z_idx = (frac_coords[..., 0] + kz // 2).long() +# y_idx = (frac_coords[..., 1] + ky // 2).long() +# x_idx = (frac_coords[..., 2] + kx // 2).long() +# +# # Get embeddings for each dimension +# z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] +# y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] +# x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] +# +# # Combine (can use addition, concatenation + linear, etc.) +# combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] +# return self.combine(combined) # [kz, ky, kx, embed_dim] +# +# +# class FourierPositionalEmbedding(nn.Module): +# """ +# Fourier features for positional encoding (non-learnable but more expressive). +# Similar to what's used in NeRF. +# """ +# +# def __init__(self, embed_dim=32, max_freq=10): +# super().__init__() +# self.embed_dim = embed_dim +# +# # Number of frequency bands +# num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords +# +# # Logarithmically spaced frequencies +# freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) +# self.register_buffer("freq_bands", freq_bands) +# +# def forward(self, frac_coords): +# """ +# Args: +# frac_coords: [..., 3] fractional coordinates +# +# Returns: +# embeddings: [..., embed_dim] +# """ +# coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] +# freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] +# +# # Compute sin and cos for each frequency and coordinate +# angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] +# +# sin_features = torch.sin(angles) +# cos_features = torch.cos(angles) +# +# # Concatenate all features +# fourier_features = torch.cat([sin_features, cos_features], dim=-2) # [..., 2*num_freqs, 3] +# return fourier_features.reshape(*frac_coords.shape[:-1], -1) # [..., 6*num_freqs] diff --git a/src/electrai/model/utils_al.py b/src/electrai/model/utils_al.py new file mode 100644 index 00000000..75a181e2 --- /dev/null +++ b/src/electrai/model/utils_al.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def pairwise_dist(R: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """ + R: (B, N, 3) or (N, 3) + returns D: (B, N, N) or (N, N) + """ + if R.ndim == 2: + R = R[None, ...] + squeeze = True + else: + squeeze = False + + diff = R[:, :, None, :] - R[:, None, :, :] # (B,N,N,3) + D = torch.linalg.norm(diff, dim=-1).clamp_min(eps) + if squeeze: + D = D[0] + return D + + +# 2) RBF embedder for distances (recommended) + + +class RBF(nn.Module): + def __init__(self, num=32, r_max=10.0, trainable=False): + super().__init__() + centers = torch.linspace(0.0, r_max, num) + widths = torch.full((num,), (r_max / num)) + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, D): + """ + D: (B,N,N) distances + returns: (B,N,N,num) + """ + # (B,N,N,1) - (num,) -> (B,N,N,num) + x = D[..., None] - self.centers + return torch.exp(-0.5 * (x / self.widths).pow(2)) + + +class TriangleAttention(nn.Module): + def __init__(self, d, n_heads=4, dropout=0.0, starting=True): + super().__init__() + assert d % n_heads == 0 + self.h = n_heads + self.dk = d // n_heads + self.starting = starting + + self.ln = nn.LayerNorm(d) + self.to_q = nn.Linear(d, d, bias=False) + self.to_k = nn.Linear(d, d, bias=False) + self.to_v = nn.Linear(d, d, bias=False) + self.to_out = nn.Linear(d, d, bias=False) + self.drop = nn.Dropout(dropout) + + def forward(self, P, mask=None): + """ + P: (B,N,N,d) + mask: (B,N) with 1 for real atoms, 0 for padding + """ + B, N, _, d = P.shape + x = self.ln(P) + + if self.starting: + # fixed i: update row i using keys from (i,k) + q = x.reshape(B * N, N, d) + k = x.reshape(B * N, N, d) + v = x.reshape(B * N, N, d) + else: + # fixed j: update col j using keys from (k,j) + xt = x.transpose(1, 2) + q = xt.reshape(B * N, N, d) + k = xt.reshape(B * N, N, d) + v = xt.reshape(B * N, N, d) + + q = self.to_q(q).view(B * N, N, self.h, self.dk).transpose(1, 2) # (BN,h,N,dk) + k = self.to_k(k).view(B * N, N, self.h, self.dk).transpose(1, 2) + v = self.to_v(v).view(B * N, N, self.h, self.dk).transpose(1, 2) + + attn = (q @ k.transpose(-1, -2)) / (self.dk**0.5) # (BN,h,N,N) + + if mask is not None: + kv_mask = mask.repeat_interleave(N, dim=0) # (BN,N) + attn = attn.masked_fill(kv_mask[:, None, None, :] == 0, float("-inf")) + + w = self.drop(F.softmax(attn, dim=-1)) + out = (w @ v).transpose(1, 2).contiguous().view(B * N, N, d) + out = self.to_out(out).view(B, N, N, d) + + if not self.starting: + out = out.transpose(1, 2) + + return P + out + + +class DistanceTriangleNet(nn.Module): + def __init__( + self, + d=128, + n_heads=4, + n_blocks=2, + rbf_num=32, + r_max=10.0, + dropout=0.0, + in_channels=32, + out_channels=32, + ): + super().__init__() + self.rbf = RBF(num=rbf_num, r_max=r_max, trainable=False) + + self.pair_in = nn.Sequential(nn.Linear(rbf_num, d), nn.SiLU(), nn.Linear(d, d)) + + self.blocks = nn.ModuleList([]) + for _ in range(n_blocks): + self.blocks.append( + nn.ModuleList( + [ + TriangleAttention( + d, n_heads=n_heads, dropout=dropout, starting=True + ), + TriangleAttention( + d, n_heads=n_heads, dropout=dropout, starting=False + ), + nn.LayerNorm(d), + nn.Sequential( + nn.Linear(d, 2 * d), nn.SiLU(), nn.Linear(2 * d, d) + ), + ] + ) + ) + + self.node_ln = nn.LayerNorm(d) + self.in_channels = in_channels + self.out_channels = out_channels + self.node_head = nn.Sequential( + nn.Linear(d, 2 * d), + nn.SiLU(), + nn.Linear(2 * d, self.in_channels * self.out_channels), + ) + + def forward(self, R, k: int = 5, mask=None): + """ + R: (B,N,3) coordinates + mask: (B,N) optional + returns y: (B,N) per-atom values + """ + N = R.shape[0] # noqa: F841 + D = pairwise_dist(R) # (B,N,N) + # Optional: zero-out padded interactions + if mask is not None: + m2 = mask[:, :, None] * mask[:, None, :] # (B,N,N) + D = D * m2 + (1.0 - m2) * D.max().detach() # keep padded far away + + P = self.pair_in(self.rbf(D)) # (B,N,N,d) + + for tri_s, tri_e, ln, ff in self.blocks: + P = tri_s(P, mask=mask) + P = tri_e(P, mask=mask) + P = P + ff(ln(P)) + + # pair -> node (row/col pooling) + row = P.mean(dim=2) # (B,N,d) + col = P.mean(dim=1) # (B,N,d) + H = 0.5 * (row + col) + + y = self.node_head(self.node_ln(H)) # (B,N) + y = rearrange( + y, + "B (k1 k2 k3) (o i)-> B o i k1 k2 k3", + k1=k, + k2=k, + k3=k, + o=self.out_channels, + i=self.in_channels, + ) + if mask is not None: + y = y * mask + return y + + +class GaussianRadialBasis(nn.Module): + def __init__(self, num_gaussians=50, r_min=0.0, r_max=5.0, trainable=False): + super().__init__() + self.num_gaussians = num_gaussians + self.r_min = r_min + self.r_max = r_max + centers = torch.linspace(r_min, r_max, num_gaussians) + spacing = (r_max - r_min) / (num_gaussians - 1) if num_gaussians > 1 else 1.0 + widths = torch.ones(num_gaussians) * spacing + + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, distances): + """ + Expand distances using Gaussian basis functions. + + Args: + distances: Tensor of shape [...] containing interatomic distances + + Returns: + Tensor of shape [..., num_gaussians] with Gaussian features + """ + distances = distances.unsqueeze(-1) + centers = self.centers.view(*([1] * (distances.dim() - 1)), -1) + widths = self.widths.view(*([1] * (distances.dim() - 1)), -1) + + diff = distances - centers + gamma = 1.0 / (2 * widths**2) + return torch.exp(-gamma * diff**2) + + +class CartesianFourierEmbedding(nn.Module): + """ + Fourier features of real displacement vector (cartesian). + Much more meaningful than index or fractional embedding. + """ + + def __init__(self, num_freqs=6, include_radius=True): + super().__init__() + + # smooth low frequencies — not exponential like NeRF + freqs = torch.linspace(0.5, 3.0, num_freqs) + self.register_buffer("freqs", freqs) + + self.include_radius = include_radius + self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) + + def forward(self, cart_coords): + """ + cart_coords: [B, kz, ky, kx, 3] real displacement vectors (Å) + returns: positional features + """ + + # normalize scale for stability + # prevents huge lattice from exploding features + scale = ( + cart_coords.norm(dim=-1, keepdim=True).mean(dim=(1, 2, 3, 4), keepdim=True) + + 1e-6 + ) + v = cart_coords / scale + + # project onto frequencies + # [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] + angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + + sin = torch.sin(angles) + cos = torch.cos(angles) + + feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + + if self.include_radius: + r = torch.norm(v, dim=-1, keepdim=True) + feat = torch.cat([feat, r, r**2], dim=-1) + + return feat + + +# class PositionalEmbedding(nn.Module): +# """ +# Learnable positional embedding for kernel positions. +# Similar to positional encodings in Transformers but learnable. +# """ + +# def __init__(self, embed_dim=32, max_kernel_size=7): +# super().__init__() +# self.embed_dim = embed_dim + +# # Learnable embeddings for each coordinate dimension +# # Range: [-max_kernel_size//2, max_kernel_size//2] +# self.z_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.y_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.x_embed = nn.Embedding(max_kernel_size, embed_dim) + +# # Optional: learnable way to combine the three directions +# self.combine = nn.Linear(3 * embed_dim, embed_dim) + +# def forward(self, frac_coords, kernel_size): +# """ +# Args: +# frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) +# kernel_size: (kz, ky, kx) tuple + +# Returns: +# embeddings: [kz, ky, kx, embed_dim] +# """ +# kz, ky, kx = kernel_size + +# # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) +# z_idx = (frac_coords[..., 0] + kz // 2).long() +# y_idx = (frac_coords[..., 1] + ky // 2).long() +# x_idx = (frac_coords[..., 2] + kx // 2).long() + +# # Get embeddings for each dimension +# z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] +# y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] +# x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] + +# # Combine (can use addition, concatenation + linear, etc.) +# combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] +# return self.combine(combined) # [kz, ky, kx, embed_dim] + + +# class FourierPositionalEmbedding(nn.Module): +# """ +# Fourier features for positional encoding (non-learnable but more expressive). +# Similar to what's used in NeRF. +# """ + +# def __init__(self, embed_dim=32, max_freq=10): +# super().__init__() +# self.embed_dim = embed_dim + +# # Number of frequency bands +# num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords + +# # Logarithmically spaced frequencies +# freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) +# self.register_buffer("freq_bands", freq_bands) + +# def forward(self, frac_coords): +# """ +# Args: +# frac_coords: [..., 3] fractional coordinates + +# Returns: +# embeddings: [..., embed_dim] +# """ +# # frac_coords: [..., 3] +# # freq_bands: [num_freqs] + +# coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] +# freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] + +# # Compute sin and cos for each frequency and coordinate +# angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] + +# sin_features = torch.sin(angles) +# cos_features = torch.cos(angles) + +# # Concatenate all features +# fourier_features = torch.cat( +# [sin_features, cos_features], dim=-2 +# ) # [..., 2*num_freqs, 3] +# return fourier_features.reshape( +# *frac_coords.shape[:-1], -1 +# ) # [..., 6*num_freqs]