diff --git a/.gitignore b/.gitignore index ed23fafd..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .cache/ docs/reference/* ./examples/MP/experiments -./examples/QM9s +./examples/QM9 *doctrees* /site diff --git a/src/electrai/configs/MP/config_resnet_lcn.yaml b/src/electrai/configs/MP/config_resnet_lcn.yaml new file mode 100644 index 00000000..d1b94cae --- /dev/null +++ b/src/electrai/configs/MP/config_resnet_lcn.yaml @@ -0,0 +1,48 @@ +# Dataset / loader parameters +data: + _target_: src.electrai.dataloader.dataset.RhoRead + root: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/mp_filelist.txt + split_file: null #/scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/split.json + precision: f32 + batch_size: 1 + train_workers: 8 + val_workers: 2 + pin_memory: false + val_frac: 0.005 + drop_last: false + augmentation: false + random_seed: 42 + # downsample_label: 0 + # downsample_data: 0 + +# Model +model: + _target_: src.electrai.model.resnet_LCN.GeneratorResNet + n_residual_blocks: 32 + n_channels: 32 + kernel_size1: 5 + kernel_size2: 5 + normalize: True + use_checkpoint: False + use_lattice_conv: true + use_radial_embedding: true + num_gaussians: 500 + use_positional_embedding: true + pos_embed_dim: 500 + +# Training parameters +precision: 32 +epochs: 50 +lr: 0.01 +weight_decay: 0.0 +warmup_length: 1 +beta1: 0.9 +beta2: 0.99 + +# Weights and biases +wandb_mode: online +entity: PrinceOA +wb_pname: mp-experiment + +# checkpoints +ckpt_path: ./checkpoints diff --git a/src/electrai/dataloader/dataset.py b/src/electrai/dataloader/dataset.py index b3321143..78cf72df 100644 --- a/src/electrai/dataloader/dataset.py +++ b/src/electrai/dataloader/dataset.py @@ -107,7 +107,7 @@ def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs): else: raise ValueError("No filename found.") - self.category = Path(datapath).name.split("_")[0] # example: mp_filelist.txt + self.category = Path(datapath).name.split("_")[0] self.root = Path(datapath).parent self.member_list = member_list @@ -116,7 +116,7 @@ def __len__(self): def __getitem__(self, index): index = self.member_list[index] - data, label = utils.load_numpy_rho( + data, label, lattice = utils.load_numpy_rho( root=self.root, category=self.category, index=index, @@ -125,4 +125,4 @@ def __getitem__(self, index): ) data = data.unsqueeze(0) label = label.unsqueeze(0) - return {"data": data, "label": label, "index": index} + return {"data": data, "label": label, "index": index, "lattice": lattice} diff --git a/src/electrai/dataloader/utils.py b/src/electrai/dataloader/utils.py index 63f775ca..87896594 100644 --- a/src/electrai/dataloader/utils.py +++ b/src/electrai/dataloader/utils.py @@ -25,22 +25,33 @@ def load_numpy_rho( """ root = Path(root) if category == "mp": - data, label = load_chgcar(root, index) + data, label, lattice = load_chgcar(root, index) elif category == "qm9": data, label = load_npy(root, index) data = torch.tensor(data, dtype=dtype_map[precision]) label = torch.tensor(label, dtype=dtype_map[precision]) + lattice = torch.tensor(lattice, dtype=dtype_map[precision]) + # grid_shape = torch.tensor( + # data.shape, dtype=dtype_map[precision], device=lattice.device + # ) + # lattice = lattice / grid_shape[:, None] if augmentation: data, label = rand_rotate([data, label]) - return data, label + # print("shapeeeeeeee", index, data.shape, label.shape) + data = data.permute(2, 1, 0) + label = label.permute(2, 1, 0) + # a, b, c = lattice[0], lattice[1], lattice[2] + # lattice = torch.stack([c, b, a], dim=0) # (z,y,x) + return data, label, lattice def load_chgcar(root: str | bytes | os.PathLike, index: str): data = Chgcar.from_file(root / "data" / f"{index}.CHGCAR") label = Chgcar.from_file(root / "label" / f"{index}.CHGCAR") + lattice = data.structure.lattice.matrix data = data.data["total"] / data.structure.lattice.volume label = label.data["total"] / label.structure.lattice.volume - return data, label + return data, label, lattice def load_npy(root: str | bytes | os.PathLike, index: str): diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index 235b3c84..b7ec947a 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -42,7 +42,7 @@ def train(args): from lightning.pytorch.loggers import WandbLogger wandb_logger = WandbLogger( - project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg) + project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg), name=cfg.run_name ) else: wandb_logger = None @@ -62,12 +62,19 @@ def train(args): # ----------------------------- # Trainer # ----------------------------- + local_world_size = int( + os.environ.get("LOCAL_WORLD_SIZE", torch.cuda.device_count()) + ) + world_size = int(os.environ.get("WORLD_SIZE", local_world_size)) + num_nodes = max(1, world_size // local_world_size) trainer = Trainer( max_epochs=int(cfg.epochs), logger=wandb_logger, callbacks=[checkpoint_cb, lr_monitor], accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, + devices="auto", + num_nodes=num_nodes, + strategy="ddp", precision=cfg.precision, log_every_n_steps=1, gradient_clip_val=getattr(cfg, "gradient_clip_value", 1.0), diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index d216b723..fb797362 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -3,6 +3,8 @@ import torch from hydra.utils import instantiate from lightning.pytorch import LightningModule +from lightning.pytorch.utilities import rank_zero_only +from src.electrai.model.LCN import LatticeConv3d from src.electrai.model.loss.charge import NormMAE @@ -14,8 +16,40 @@ def __init__(self, cfg): self.model = instantiate(cfg.model) self.loss_fn = NormMAE() - def forward(self, x): - return self.model(x) + def forward(self, x, lattice_vectors=None): + return self.model(x, lattice_vectors) + + @rank_zero_only + def _collect_kernel_stats(self, target_layer: str | None = None) -> None: + # no need for trainer.is_global_zero now; rank_zero_only handles it + for name, module in self.model.named_modules(): + if not isinstance(module, LatticeConv3d) or not hasattr( + module, "kernel_stats" + ): + continue + if target_layer is not None and name != target_layer: + continue + + s = module.kernel_stats + if self.cfg.model["use_lattice_conv"]: + log_dict = { + f"kernels/{name}/alpha": s["alpha"], + f"kernels/{name}/ratio": s["ratio"], + f"kernels/{name}/geo_rms": s["geo_rms"], + f"kernels/{name}/base_rms": s["base_rms"], + } + else: + log_dict = {f"kernels/{name}/base_rms": s["base_rms"]} + + # IMPORTANT: let Lightning manage step + syncing + self.log_dict( + log_dict, + on_step=True, + on_epoch=False, + prog_bar=False, + logger=True, + sync_dist=False, # keep False if you're only logging on rank 0 + ) def training_step(self, batch): loss = self._loss_calculation(batch) @@ -29,6 +63,9 @@ def training_step(self, batch): ) return loss + def on_train_batch_end(self, outputs, batch, batch_idx): # noqa: ARG002 + self._collect_kernel_stats(target_layer="mid.0.conv1") + def validation_step(self, batch): loss = self._loss_calculation(batch) self.log( @@ -39,15 +76,16 @@ def validation_step(self, batch): def _loss_calculation(self, batch): x = batch["data"] y = batch["label"] + A = batch["lattice"] if isinstance(x, list): losses = [] - for x_i, y_i in zip(x, y, strict=True): - pred = self(x_i.unsqueeze(0)) + for x_i, y_i, A_i in zip(x, y, strict=True): + pred = self(x_i.unsqueeze(0), A_i.unsqueeze(0)) loss = self.loss_fn(pred, y_i.unsqueeze(0)) losses.append(loss) loss = torch.stack(losses).mean() else: - pred = self(x) + pred = self(x, A) loss = self.loss_fn(pred, y) return loss @@ -56,8 +94,17 @@ def configure_optimizers(self): self.model.parameters(), lr=float(self.cfg.lr), weight_decay=float(self.cfg.weight_decay), + betas=(getattr(self.cfg, "beta1", 0.9), getattr(self.cfg, "beta2", 0.999)), ) + # flat = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1) + # cos = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, T_max=self.cfg.epochs - 1, eta_min=1e-5 + # ) + + # scheduler = torch.optim.lr_scheduler.SequentialLR( + # optimizer, [flat, cos], milestones=[1] + # ) linsch = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-5, diff --git a/src/electrai/model/LCN.py b/src/electrai/model/LCN.py new file mode 100644 index 00000000..a90947f1 --- /dev/null +++ b/src/electrai/model/LCN.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from src.electrai.model.utils import CartesianFourierEmbedding, GaussianRadialBasis + + +class LatticeConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + padding_mode: str = "circular", + stride: int = 1, + dilation: int = 1, + use_lattice_conv: bool = False, + mix_weight: float = 0.1, + use_radial_embedding: bool = False, + use_positional_embedding: bool = False, + trainable_gaussian_params: bool = False, + num_gaussians: int = 16, + pos_embed_dim: int = 16, + r_max: float = 5.0, + hidden_dim: int = 64, + ): + super().__init__() + padding = kernel_size // 2 + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 + ) + self.padding = padding if isinstance(padding, tuple) else (padding,) * 3 + self.padding_mode = padding_mode + self.stride = stride + self.dilation = dilation + + self.use_lattice_conv = use_lattice_conv + self.use_radial_embedding = use_radial_embedding + self.use_positional_embedding = use_positional_embedding + self.trainable_gaussian_params = trainable_gaussian_params + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + padding=padding, + padding_mode=padding_mode, + stride=stride, + dilation=dilation, + bias=True, + ) + + if use_lattice_conv: + if use_radial_embedding: + self.gaussian_smear = GaussianRadialBasis( + num_gaussians=num_gaussians, + r_min=0.0, + r_max=r_max, + trainable=self.trainable_gaussian_params, + ) + + if use_positional_embedding: + self.pos_embedding = CartesianFourierEmbedding(num_freqs=60) # 6) + + if use_radial_embedding or use_positional_embedding: + input_size = 0 + if use_radial_embedding: + input_size += num_gaussians + if use_positional_embedding: + input_size += pos_embed_dim + + self.filter_network = nn.Sequential( + nn.Linear(input_size, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, in_channels * out_channels), + ) + + # Optional stabilization (uncomment if desired) + nn.init.zeros_(self.filter_network[-1].weight) + nn.init.zeros_(self.filter_network[-1].bias) + + # Learnable mixing weight (scalar) + # self.mix_weight = nn.Parameter(torch.tensor(float(mix_weight))) + # Alternative per-out-channel alpha: + # self.mix_weight = nn.Parameter(torch.full((out_channels,), float(mix_weight))) + self.mix_weight = nn.Parameter( + torch.full((out_channels,), float(mix_weight)) + ) + + def _apply_padding(self, x: torch.Tensor) -> torch.Tensor: + if self.padding_mode == "circular" and any(p > 0 for p in self.padding): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode="circular") + + if self.padding_mode in ["zeros", "reflect", "replicate"] and any( + p > 0 for p in self.padding + ): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode=self.padding_mode) + + return x + + def compute_geometric_kernel(self, lattice_vectors: torch.Tensor) -> torch.Tensor: + """ + lattice_vectors: (B, 3, 3) or (3, 3) in voxel units (z,y,x ordering upstream) + returns: (B, out, in, kz, ky, kx) or (out, in, kz, ky, kx) if input was (3,3) + """ + if lattice_vectors.dim() == 2: + lattice_vectors = lattice_vectors.unsqueeze(0) + squeeze_batch = True + else: + squeeze_batch = False + + B = lattice_vectors.shape[0] # noqa: F841 + kz, ky, kx = self.kernel_size + device = lattice_vectors.device + + z = torch.arange(kz, device=device) - kz // 2 + y = torch.arange(ky, device=device) - ky // 2 + x = torch.arange(kx, device=device) - kx // 2 + + grid_z, grid_y, grid_x = torch.meshgrid(z, y, x, indexing="ij") + frac_coords = torch.stack( + [grid_z, grid_y, grid_x], dim=-1 + ).float() # (kz,ky,kx,3) + + cart_coords = torch.einsum( + "ijkl,bml->bijkm", frac_coords, lattice_vectors + ) # (B,kz,ky,kx,3) + distances = torch.norm(cart_coords, dim=-1) # (B,kz,ky,kx) + + if self.use_radial_embedding: + radial_features = self.gaussian_smear(distances) # (B,kz,ky,kx,Ng) + radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") + + if self.use_positional_embedding: + pos_features = self.pos_embedding(cart_coords) # (B,kz,ky,kx,Np) + pos_flat = rearrange(pos_features, "b kz ky kx n -> b (kz ky kx) n") + + if self.use_radial_embedding and self.use_positional_embedding: + features = torch.cat([radial_flat, pos_flat], dim=-1) + elif self.use_radial_embedding: + features = radial_flat + else: + features = pos_flat + + kernel_flat = self.filter_network(features) # (B, kz*ky*kx, out*in) + kernel = rearrange( + kernel_flat, + "b (kz ky kx) (o i) -> b o i kz ky kx", + o=self.out_channels, + i=self.in_channels, + kz=kz, + ky=ky, + kx=kx, + ) + + if squeeze_batch: + kernel = kernel.squeeze(0) + + return kernel + + def forward( + self, x: torch.Tensor, lattice_vectors: torch.Tensor | None = None + ) -> torch.Tensor: + base_kernel = self.conv.weight.unsqueeze(0) # (1,out,in,kz,ky,kx) + + if (not self.use_lattice_conv) or ( + (not self.use_radial_embedding) and (not self.use_positional_embedding) + ): + with torch.no_grad(): + b_rms = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = {"base_rms": b_rms.item()} + return self.conv(x) + + if lattice_vectors is None: + raise ValueError( + "lattice_vectors must be provided when use_lattice_conv=True" + ) + + D, H, W = x.shape[-3:] + + # lattice_vectors assumed (B,3,3) in Å; convert to voxel units and reorder to (z,y,x) + a = lattice_vectors[:, 0, :] + b = lattice_vectors[:, 1, :] + c = lattice_vectors[:, 2, :] + lv_voxel = torch.stack( + [c / D, b / H, a / W], dim=1 + ) # (B,3,3) in voxel units, z/y/x order + + B = x.shape[0] + geometric_kernels = self.compute_geometric_kernel( + lv_voxel + ) # (B,out,in,kz,ky,kx) + + # Optional global RMS match (comment out if you don't want this constraint) + # g_rms = geometric_kernels.pow(2).mean().sqrt().clamp(min=1e-8) + # b_rms = base_kernel.pow(2).mean().sqrt().clamp(min=1e-8) + # geometric_kernels = geometric_kernels * (b_rms / g_rms) + + # alpha = torch.sigmoid(self.mix_weight) # scalar + alpha = torch.sigmoid(self.mix_weight).view(1, self.out_channels, 1, 1, 1, 1) + + with torch.no_grad(): + g_rms2 = geometric_kernels.pow(2).mean().sqrt() + b_rms2 = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = { + "geo_rms": g_rms2.item(), + "base_rms": b_rms2.item(), + "ratio": (g_rms2 / (b_rms2 + 1e-8)).item(), + "alpha": float(alpha.mean().item()) + if alpha.numel() > 1 + else float(alpha.item()), + } + + # Current mixing rule (as in your snippet): + mod = torch.tanh(geometric_kernels) # bounded + kernels = base_kernel * (1 + alpha * mod) + # kernels = ( + # alpha * geometric_kernels + (1 - alpha) * base_kernel + # ) # (B,out,in,kz,ky,kx) + + x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) + x_grouped = self._apply_padding(x_grouped) + + kernels_grouped = kernels.reshape( + B * self.out_channels, self.in_channels, *self.kernel_size + ) + + out = F.conv3d( + x_grouped, + kernels_grouped, + bias=None, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=B, + ) + + out = out.reshape(B, self.out_channels, *out.shape[2:]) + + if self.conv.bias is not None: + out = out + self.conv.bias.view(1, -1, 1, 1, 1) + + return out diff --git a/src/electrai/model/LCN_al.py b/src/electrai/model/LCN_al.py new file mode 100644 index 00000000..ee6ea473 --- /dev/null +++ b/src/electrai/model/LCN_al.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.electrai.model.utils import DistanceTriangleNet + + +class LatticeConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=False, + mix_weight=0.1, # noqa: ARG002 + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, # noqa: ARG002 + pos_embed_dim=16, # noqa: ARG002 + # pos_embed_type="learnable", + r_max=5.0, # noqa: ARG002 + hidden_dim=64, # noqa: ARG002 + ): + super().__init__() + padding = kernel_size // 2 # - 1 + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 + ) + self.padding = padding if isinstance(padding, tuple) else (padding,) * 3 + self.padding_mode = padding_mode + self.stride = stride + self.dilation = dilation + self.use_lattice_conv = use_lattice_conv + self.trainable_gaussian_params = trainable_gaussian_params + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + padding=padding, + padding_mode=padding_mode, + stride=stride, + dilation=dilation, + bias=True, + ) + self.use_radial_embedding = use_radial_embedding + self.use_positional_embedding = use_positional_embedding + + if use_lattice_conv: + self.dist_triangle = DistanceTriangleNet( + d=128, + n_heads=4, + n_blocks=2, + rbf_num=32, + r_max=16.0, + dropout=0.0, + in_channels=in_channels, + out_channels=out_channels, + ) + # if use_radial_embedding: + # self.gaussian_smear = GaussianRadialBasis( + # num_gaussians=num_gaussians, + # r_min=0.0, + # r_max=r_max, + # trainable=self.trainable_gaussian_params, + # ) + + # if use_positional_embedding: + # self.pos_embedding = CartesianFourierEmbedding(num_freqs=6) + # # if pos_embed_type == "learnable": + # # self.pos_embedding = PositionalEmbedding( + # # embed_dim=pos_embed_dim, max_kernel_size=max(self.kernel_size) + # # ) + # # elif pos_embed_type == "fourier": + # # self.pos_embedding = FourierPositionalEmbedding( + # # embed_dim=pos_embed_dim, max_freq=10 + # # ) + # # else: + # # raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}") + # # self.pos_embed_type = pos_embed_type + + # if use_radial_embedding or use_positional_embedding: + # input_size = 0 + # if use_radial_embedding: + # input_size += num_gaussians + # if use_positional_embedding: + # input_size += pos_embed_dim + # self.filter_network = nn.Sequential( + # nn.Linear(input_size, hidden_dim), + # nn.LayerNorm(hidden_dim), + # nn.SiLU(), + # # nn.Dropout(0.1), + # nn.Linear(hidden_dim, hidden_dim), + # nn.LayerNorm(hidden_dim), + # nn.SiLU(), + # # nn.Dropout(0.1), + # nn.Linear(hidden_dim, in_channels * out_channels), + # ) + # # nn.init.zeros_(self.filter_network[-1].weight) + # # nn.init.zeros_(self.filter_network[-1].bias) + + # # self.register_buffer("mix_weight", torch.tensor(float(mix_weight))) + # self.mix_weight = nn.Parameter(torch.tensor(float(mix_weight))) + # # self.mix_weight = nn.Parameter(torch.zeros(out_channels)) + + def _apply_padding(self, x): + if self.padding_mode == "circular" and any(p > 0 for p in self.padding): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode="circular") + elif self.padding_mode in ["zeros", "reflect", "replicate"] and any( + p > 0 for p in self.padding + ): + pad_3d = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + return F.pad(x, pad_3d, mode=self.padding_mode) + else: + return x + + def compute_geometric_kernel(self, lattice_vectors): + if lattice_vectors.dim() == 2: + lattice_vectors = lattice_vectors.unsqueeze(0) + squeeze_batch = True + else: + squeeze_batch = False # noqa: F841 + + B = lattice_vectors.shape[0] + kz, ky, kx = self.kernel_size + device = lattice_vectors.device + + z = torch.arange(kz, device=device) - kz // 2 + y = torch.arange(ky, device=device) - ky // 2 + x = torch.arange(kx, device=device) - kx // 2 + + grid_z, grid_y, grid_x = torch.meshgrid(z, y, x, indexing="ij") + frac_coords = torch.stack([grid_z, grid_y, grid_x], dim=-1).float() + + cart_coords = torch.einsum("ijkl,bml->bijkm", frac_coords, lattice_vectors) + return cart_coords.reshape(B, -1, 3) + + # distances = torch.norm(cart_coords, dim=-1) + + # # debug_stats = {} + # # debug_stats["distances"] = { + # # "min": distances.min().item(), + # # "max": distances.max().item(), + # # "mean": distances.mean().item(), + # # "has_nan": torch.isnan(distances).any().item(), + # # "has_inf": torch.isinf(distances).any().item(), + # # } + + # if self.use_radial_embedding: + # radial_features = self.gaussian_smear(distances) + # radial_flat = rearrange(radial_features, "b kz ky kx n -> b (kz ky kx) n") + # if self.use_positional_embedding: + # pos_features = self.pos_embedding(cart_coords) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # # if self.pos_embed_type == "learnable": + # # pos_features = self.pos_embedding(frac_coords, self.kernel_size) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # # else: + # # pos_features = self.pos_embedding(frac_coords) + # # pos_features = pos_features.unsqueeze(0).expand(B, -1, -1, -1, -1) + # pos_flat = rearrange(pos_features, "b kz ky kx n -> b (kz ky kx) n") + + # if self.use_radial_embedding and self.use_positional_embedding: + # features = torch.cat([radial_flat, pos_flat], dim=-1) + # elif self.use_radial_embedding: + # features = torch.cat([radial_flat], dim=-1) + # else: + # features = torch.cat([pos_flat], dim=-1) + + # kernel_flat = self.filter_network(features) + # kernel = rearrange( + # kernel_flat, + # "b (kz ky kx) (o i) -> b o i kz ky kx", + # o=self.out_channels, + # i=self.in_channels, + # kz=kz, + # ky=ky, + # kx=kx, + # ) + # # kernel_norm = torch.linalg.vector_norm( + # # kernel, dim=(-4, -3, -2, -1), keepdim=True + # # ).clamp(min=1e-8) + # # kernel = ( + # # kernel / kernel_norm * (1.0 / (3 * self.in_channels * kz * ky * kx) ** 0.5) + # # ) + # # after kernel computed + # if squeeze_batch: + # kernel = kernel.squeeze(0) + + # return kernel + + def forward(self, x, lattice_vectors=None): + base_kernel = self.conv.weight.unsqueeze(0) + if not self.use_lattice_conv or ( + not self.use_radial_embedding and not self.use_positional_embedding + ): + with torch.no_grad(): + b_rms = base_kernel.pow(2).mean().sqrt() + self.kernel_stats = {"base_rms": b_rms.item()} + return self.conv(x) + D, H, W = x.shape[-3:] + # scale = x.new_tensor([D, H, W]).view(1, 3, 1) # (1,3,1) + a = lattice_vectors[:, 0, :] + b = lattice_vectors[:, 1, :] + c = lattice_vectors[:, 2, :] + lv_voxel = torch.stack([c / D, b / H, a / W], dim=1) # (z,y,x) + # lv_voxel = lattice_vectors / scale # (B,3,3) + + B = x.shape[0] + # geometric_kernels = self.compute_geometric_kernel(lv_voxel) + # g_rms = geometric_kernels.pow(2).mean().sqrt().clamp(min=1e-8) + # b_rms = base_kernel.pow(2).mean().sqrt().clamp(min=1e-8) + # geometric_kernels = geometric_kernels * (b_rms / g_rms) + # alpha = torch.sigmoid(self.mix_weight) + + # with torch.no_grad(): + # g_rms = geometric_kernels.pow(2).mean().sqrt() + # b_rms = base_kernel.pow(2).mean().sqrt() + # self.kernel_stats = { + # "geo_rms": g_rms.item(), + # "base_rms": b_rms.item(), + # "ratio": (g_rms / (b_rms + 1e-8)).item(), + # "alpha": alpha.item(), + # } + + # kernels = alpha * geometric_kernels + base_kernel # (1 - alpha) * base_kernel + pos = self.compute_geometric_kernel(lv_voxel) + kernels = base_kernel + self.dist_triangle(pos, self.kernel_size[0]) + x_grouped = x.reshape(1, B * self.in_channels, *x.shape[2:]) + x_grouped = self._apply_padding(x_grouped) + kernels_grouped = kernels.reshape( + B * self.out_channels, self.in_channels, *self.kernel_size + ) + out = F.conv3d( + x_grouped, + kernels_grouped, + bias=None, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=B, + ) + + out = out.reshape(B, self.out_channels, *out.shape[2:]) + + if self.conv.bias is not None: + out = out + self.conv.bias.view(1, -1, 1, 1, 1) + return out diff --git a/src/electrai/model/resnet_LCN.py b/src/electrai/model/resnet_LCN.py new file mode 100644 index 00000000..a5460029 --- /dev/null +++ b/src/electrai/model/resnet_LCN.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +from src.electrai.model.LCN import LatticeConv3d +from torch.utils.checkpoint import checkpoint + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_features, + K=3, + use_checkpoint=True, + use_lattice_conv=False, + mix_weight=0.1, + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, + pos_embed_dim=16, + hidden_dim=64, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + + self.conv1 = LatticeConv3d( + in_features, + in_features, + kernel_size=K, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + self.norm1 = nn.InstanceNorm3d(in_features) + self.act1 = nn.PReLU() + self.conv2 = LatticeConv3d( + in_features, + in_features, + kernel_size=K, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + self.norm2 = nn.InstanceNorm3d(in_features) + + def forward(self, x, lattice_vectors=None): + if self.use_checkpoint and self.training: + return x + checkpoint( + self._forward, x, lattice_vectors, use_reentrant=False + ) + else: + return x + self._forward(x, lattice_vectors) + + def _forward(self, x, lattice_vectors): + out = self.conv1(x, lattice_vectors) + out = self.norm1(out) + out = self.act1(out) + out = self.conv2(out, lattice_vectors) + return self.norm2(out) + + +class GeneratorResNet(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + n_residual_blocks=16, + n_channels=64, + kernel_size1=5, + kernel_size2=3, + normalize=True, + use_checkpoint=True, + use_lattice_conv=False, + mix_weight=0.1, + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, + pos_embed_dim=16, + hidden_dim=64, + ): + super().__init__() + self.normalize = normalize + self.use_checkpoint = use_checkpoint + + # First layer + self.conv1 = LatticeConv3d( + in_channels, + n_channels, + kernel_size=kernel_size1, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + self.act1 = nn.PReLU() + + # Residual blocks + res_blocks = [ + ResidualBlock( + n_channels, + K=kernel_size2, + use_checkpoint=use_checkpoint, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + for _ in range(n_residual_blocks) + ] + self.res_blocks = nn.ModuleList(res_blocks) + + # Second conv layer post residual blocks + self.conv2 = LatticeConv3d( + n_channels, + n_channels, + kernel_size=kernel_size2, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + self.norm = nn.InstanceNorm3d(n_channels) + + # Final output layer + self.conv3 = LatticeConv3d( + n_channels, + out_channels, + kernel_size=kernel_size1, + padding_mode="circular", + stride=1, + dilation=1, + use_lattice_conv=use_lattice_conv, + mix_weight=mix_weight, + use_radial_embedding=use_radial_embedding, + use_positional_embedding=use_positional_embedding, + trainable_gaussian_params=trainable_gaussian_params, + num_gaussians=num_gaussians, + pos_embed_dim=pos_embed_dim, + hidden_dim=hidden_dim, + ) + self.act2 = nn.ReLU() + + def forward(self, x, lattice_vectors): + if isinstance(x, torch.Tensor): + return self._forward(x, lattice_vectors) + return [self._forward(xi.unsqueeze(0), lattice_vectors).squeeze(0) for xi in x] + + def _forward(self, x, lattice_vectors=None): + out1 = self.conv1(x, lattice_vectors) + out1 = self.act1(out1) + out = out1 + for block in self.res_blocks: + out = block(out, lattice_vectors) + out2 = self.conv2(out, lattice_vectors) + out2 = self.norm(out2) + out = torch.add(out1, out2) + out = self.conv3(out, lattice_vectors) + out = self.act2(out) + + if self.normalize: + # out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] + # out = out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + out_sum = torch.sum(out, axis=(-3, -2, -1), keepdim=True).clamp(min=1e-8) + x_sum = torch.sum(x, axis=(-3, -2, -1), keepdim=True) + out = (out / out_sum) * x_sum + return out diff --git a/src/electrai/model/resunet_LCN.py b/src/electrai/model/resunet_LCN.py new file mode 100644 index 00000000..ebc036cc --- /dev/null +++ b/src/electrai/model/resunet_LCN.py @@ -0,0 +1,543 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.electrai.model.LCN import LatticeConv3d + + +class LatticeSequential(nn.Sequential): + def forward(self, x, lattice_vectors): + for module in self: + x = module(x, lattice_vectors) + return x + + +class ResBlock3D(nn.Module): + def __init__(self, cin, cout, k, **lcn_kwargs): + super().__init__() + + self.conv1 = LatticeConv3d( + cin, cout, kernel_size=k, padding_mode="circular", **lcn_kwargs + ) + self.norm1 = nn.InstanceNorm3d(cout) + self.act1 = nn.PReLU() + + self.conv2 = LatticeConv3d( + cout, cout, kernel_size=k, padding_mode="circular", **lcn_kwargs + ) + self.norm2 = nn.InstanceNorm3d(cout) + self.act2 = nn.PReLU() + + self.skip = ( + LatticeConv3d(cin, cout, kernel_size=1, **lcn_kwargs) + if cin != cout + else nn.Identity() + ) + + def forward(self, x, lattice_vectors): + h = self.act1(self.norm1(self.conv1(x, lattice_vectors))) + h = self.norm2(self.conv2(h, lattice_vectors)) + skip_out = ( + self.skip(x, lattice_vectors) + if isinstance(self.skip, LatticeConv3d) + else self.skip(x) + ) + return self.act2(h + skip_out) + + +class DownsampleBlock(nn.Module): + def __init__(self, cin, cout, **lcn_kwargs): + super().__init__() + self.conv = LatticeConv3d( + cin, cout, 3, stride=2, padding_mode="circular", **lcn_kwargs + ) + self.norm = nn.InstanceNorm3d(cout) + self.act = nn.PReLU() + + def forward(self, x, lattice_vectors): + return self.act(self.norm(self.conv(x, lattice_vectors))) + + +class PeriodicUpsampleConv3d(nn.Module): + def __init__(self, cin, cout, **lcn_kwargs): + super().__init__() + self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False) + self.conv = LatticeConv3d( + cin, + cout, + 3, + **lcn_kwargs, + # padding=1, + # padding_mode="circular", # , use_lattice_conv=False + ) + self.norm = nn.InstanceNorm3d(cout) + self.act = nn.PReLU() + + def forward(self, x, lattice_vectors): + x = F.pad(x, (1, 1, 1, 1, 1, 1), mode="circular") + x = self.up(x) + x = x[..., 2:-2, 2:-2, 2:-2] + x = self.conv(x, lattice_vectors) + x = self.norm(x) + return self.act(x) + + +class ResUNet3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + n_channels, + depth, + n_residual_blocks, + kernel_size, + use_lattice_conv=False, + mix_weight=0.1, + use_radial_embedding=False, + use_positional_embedding=False, + trainable_gaussian_params=False, + num_gaussians=16, + pos_embed_dim=16, + hidden_dim=64, + ): + super().__init__() + + lcn_kwargs = { + "use_lattice_conv": use_lattice_conv, + "mix_weight": mix_weight, + "use_radial_embedding": use_radial_embedding, + "use_positional_embedding": use_positional_embedding, + "trainable_gaussian_params": trainable_gaussian_params, + "num_gaussians": num_gaussians, + "pos_embed_dim": pos_embed_dim, + "hidden_dim": hidden_dim, + } + + self.in_conv = ResBlock3D(in_channels, n_channels, kernel_size, **lcn_kwargs) + + # -------- Encoder -------- + self.enc_blocks = nn.ModuleList() + self.downs = nn.ModuleList() + + ch = n_channels + for _ in range(depth): + self.enc_blocks.append( + LatticeSequential( + *[ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(n_residual_blocks) + ] + ) + ) + self.downs.append(DownsampleBlock(ch, 2 * ch, **lcn_kwargs)) + ch *= 2 + + # -------- Bottleneck -------- + self.mid = LatticeSequential( + *[ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(2 * n_residual_blocks) + ] + ) + + # -------- Decoder -------- + self.ups = nn.ModuleList() + self.dec_blocks = nn.ModuleList() + # for _ in range(depth): + # self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2)) + # ch //= 2 + + # blocks = [ResBlock3D(2 * ch, ch, kernel_size, **lcn_kwargs)] + # blocks.extend( + # [ + # ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + # for _ in range(n_residual_blocks - 1) + # ] + # ) + # self.dec_blocks.append(LatticeSequential(*blocks)) + for _ in range(depth): + self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2, **lcn_kwargs)) + ch //= 2 + + blocks = [ResBlock3D(2 * ch, ch, kernel_size, **lcn_kwargs)] + blocks.extend( + [ + ResBlock3D(ch, ch, kernel_size, **lcn_kwargs) + for _ in range(n_residual_blocks - 1) + ] + ) + self.dec_blocks.append(LatticeSequential(*blocks)) + + # -------- Output -------- + # self.out_conv = nn.Conv3d(n_channels, out_channels, kernel_size=1) + self.out_conv = LatticeConv3d( + n_channels, out_channels, kernel_size=1, **lcn_kwargs + ) + + def forward(self, x, lattice_vectors): + skips = [] + out = self.in_conv(x, lattice_vectors) + + for enc, down in zip(self.enc_blocks, self.downs, strict=False): + out = enc(out, lattice_vectors) + skips.append(out) + out = down(out, lattice_vectors) + + out = self.mid(out, lattice_vectors) + + for up, dec in zip(self.ups, self.dec_blocks, strict=False): + out = up(out, lattice_vectors) + out = torch.cat([out, skips.pop()], dim=1) + out = dec(out, lattice_vectors) + + out = self.out_conv(out, lattice_vectors) + out = out / torch.sum(out, dim=(-3, -2, -1), keepdim=True) + return out * torch.sum(x, dim=(-3, -2, -1), keepdim=True) + + +# from __future__ import annotations + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from src.electrai.model.LCN import LatticeConv3d + + +# class ResBlock3D(nn.Module): +# def __init__( +# self, +# cin, +# cout, +# k, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.conv1 = LatticeConv3d( +# cin, +# cout, +# kernel_size=k, +# padding_mode="circular", +# padding=k // 2, +# stride=1, +# dilation=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# self.norm1 = nn.InstanceNorm3d(cout) +# self.act1 = nn.PReLU() +# self.conv2 = LatticeConv3d( +# cout, +# cout, +# kernel_size=k, +# padding_mode="circular", +# padding=k // 2, +# stride=1, +# dilation=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# self.norm2 = nn.InstanceNorm3d(cout) +# self.act2 = nn.PReLU() +# # self.conv1 = nn.Conv3d(cin, cout, k, padding=k // 2, padding_mode="circular") +# # self.norm1 = nn.InstanceNorm3d(cout) +# # self.act = nn.PReLU() +# # self.conv2 = nn.Conv3d(cout, cout, k, padding=k // 2, padding_mode="circular") +# # self.norm2 = nn.InstanceNorm3d(cout) + +# if cin != cout: +# self.skip = LatticeConv3d( +# cin, +# cout, +# kernel_size=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.skip = nn.Conv3d(cin, cout, 1) +# else: +# self.skip = nn.Identity() + +# def forward(self, x, lattice_vectors): +# h = self.act1(self.norm1(self.conv1(x, lattice_vectors))) +# h = self.norm2(self.conv2(h, lattice_vectors)) +# return self.act2(h + self.skip(x)) + + +# class ResUNet3D(nn.Module): +# def __init__( +# self, +# in_channels, +# out_channels, +# n_channels, +# depth, +# n_residual_blocks, +# kernel_size, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.in_conv = ResBlock3D( +# in_channels, +# n_channels, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.in_conv = ResBlock3D(in_channels, n_channels, kernel_size) + +# # -------- Encoder -------- +# self.enc_blocks = nn.ModuleList() +# self.downs = nn.ModuleList() + +# ch = n_channels +# for _ in range(depth): +# self.enc_blocks.append( +# nn.Sequential( +# *[ +# ResBlock3D( +# ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(n_residual_blocks) +# ] +# ) +# ) +# # self.enc_blocks.append( +# # nn.Sequential( +# # *[ResBlock3D(ch, ch, kernel_size,) for _ in range(n_residual_blocks)] +# # ) +# # ) +# self.downs.append( +# downsample( +# ch, +# 2 * ch, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# ) +# ch *= 2 + +# # -------- Bottleneck -------- +# self.mid = nn.Sequential( +# *[ +# ResBlock3D( +# ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(2 * n_residual_blocks) +# ] +# ) +# # self.mid = nn.Sequential( +# # *[ResBlock3D(ch, ch, kernel_size) for _ in range(2 * n_residual_blocks)] +# # ) + +# # -------- Decoder -------- +# self.ups = nn.ModuleList() +# self.dec_blocks = nn.ModuleList() + +# for _ in range(depth): +# self.ups.append(PeriodicUpsampleConv3d(ch, ch // 2)) +# ch //= 2 +# self.dec_blocks.append( +# nn.Sequential( +# *[ +# ResBlock3D( +# 2 * ch, +# ch, +# kernel_size, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# for _ in range(n_residual_blocks) +# ] +# ) +# ) +# # self.dec_blocks.append( +# # nn.Sequential( +# # *[ +# # ResBlock3D(2 * ch, ch, kernel_size) +# # for _ in range(n_residual_blocks) +# # ] +# # ) +# # ) + +# # -------- Output -------- +# self.out_conv = LatticeConv3d( +# n_channels, +# out_channels, +# kernel_size=1, +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.out_conv = nn.Conv3d(n_channels, out_channels, kernel_size=1) + +# def forward(self, x, lattice_vectors): +# skips = [] +# out = self.in_conv(x, lattice_vectors) + +# for enc, down in zip(self.enc_blocks, self.downs, strict=False): +# out = enc(out) +# skips.append(out) +# out = down(out) +# out = self.mid(out) + +# for up, dec in zip(self.ups, self.dec_blocks, strict=False): +# out = up(out) +# out = torch.cat([out, skips.pop()], dim=1) +# out = dec(out) +# out = self.out_conv(out) +# out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] +# return out * torch.sum(x, axis=(-3, -2, -1))[..., None, None, None] + + +# class PeriodicUpsampleConv3d(nn.Module): +# def __init__( +# self, +# cin, +# cout, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# super().__init__() +# self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False) +# self.conv = LatticeConv3d( +# cin, +# cout, +# 3, +# stride=2, +# padding=1, +# padding_mode="circular", +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ) +# # self.conv = nn.Conv3d(cin, cout, 3, padding=1, padding_mode="circular") +# self.norm = nn.InstanceNorm3d(cout) +# self.act = nn.PReLU() + +# def forward(self, x, lattice_vectors): +# x = F.pad(x, (1, 1, 1, 1, 1, 1), mode="circular") +# x = self.up(x) +# x = x[..., 2:-2, 2:-2, 2:-2] +# x = self.conv(x, lattice_vectors) +# x = self.norm(x) +# return self.act(x) + + +# def downsample( +# cin, +# cout, +# use_lattice_conv=False, +# mix_weight=0.1, +# use_radial_embedding=False, +# use_positional_embedding=False, +# trainable_gaussian_params=False, +# num_gaussians=16, +# pos_embed_dim=16, +# hidden_dim=64, +# ): +# return nn.Sequential( +# LatticeConv3d( +# cin, +# cout, +# 3, +# stride=2, +# padding=1, +# padding_mode="circular", +# use_lattice_conv=use_lattice_conv, +# mix_weight=mix_weight, +# use_radial_embedding=use_radial_embedding, +# use_positional_embedding=use_positional_embedding, +# trainable_gaussian_params=trainable_gaussian_params, +# num_gaussians=num_gaussians, +# pos_embed_dim=pos_embed_dim, +# hidden_dim=hidden_dim, +# ), +# # nn.Conv3d(cin, cout, 3, stride=2, padding=1, padding_mode="circular"), +# nn.InstanceNorm3d(cout), +# nn.PReLU(), +# ) diff --git a/src/electrai/model/utils.py b/src/electrai/model/utils.py new file mode 100644 index 00000000..40644db2 --- /dev/null +++ b/src/electrai/model/utils.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + + +class GaussianRadialBasis(nn.Module): + def __init__( + self, + num_gaussians: int = 50, + r_min: float = 0.0, + r_max: float = 5.0, + trainable: bool = False, + ): + super().__init__() + self.num_gaussians = num_gaussians + self.r_min = r_min + self.r_max = r_max + + centers = torch.linspace(r_min, r_max, num_gaussians) + spacing = (r_max - r_min) / (num_gaussians - 1) if num_gaussians > 1 else 1.0 + widths = torch.ones(num_gaussians) * spacing + + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, distances: torch.Tensor) -> torch.Tensor: + """ + Expand distances using Gaussian basis functions. + + Args: + distances: Tensor of shape [...] containing interatomic distances + + Returns: + Tensor of shape [..., num_gaussians] with Gaussian features + """ + distances = distances.unsqueeze(-1) + centers = self.centers.view(*([1] * (distances.dim() - 1)), -1) + widths = self.widths.view(*([1] * (distances.dim() - 1)), -1) + + diff = distances - centers + gamma = 1.0 / (2.0 * widths**2) + return torch.exp(-gamma * diff**2) + + +class CartesianFourierEmbedding(nn.Module): + """ + Fourier features of real displacement vectors (cartesian, in Å). + Uses a fixed physical scale (r_max) so features are comparable across samples. + """ + + def __init__( + self, + num_freqs: int = 6, + include_radius: bool = True, + r_max: float = 5.0, + freq_min: float = 0.5, + freq_max: float = 3.0, + ): + super().__init__() + + # Frequencies (wave numbers). Larger -> higher spatial variation. + freqs = torch.linspace(freq_min, freq_max, num_freqs) + self.register_buffer("freqs", freqs) + + self.include_radius = include_radius + self.r_max = float(r_max) + + # 3 coords * num_freqs * (sin+cos) + optional (r, r^2) + self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) + + def forward(self, cart_coords: torch.Tensor) -> torch.Tensor: + """ + Args: + cart_coords: [B, kz, ky, kx, 3] displacement vectors in Å + + Returns: + features: [B, kz, ky, kx, out_dim] + """ + # Fixed physical normalization: make coords dimensionless and stable. + v = cart_coords / (self.r_max + 1e-6) + + # [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] + angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + + sin = torch.sin(angles) + cos = torch.cos(angles) + + # concat sin/cos on frequency axis -> flatten to last dim + feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + + if self.include_radius: + r = torch.linalg.norm(v, dim=-1, keepdim=True) + feat = torch.cat([feat, r, r**2], dim=-1) + + return feat + + +# class CartesianFourierEmbedding(nn.Module): +# """ +# Fourier features of real displacement vector (cartesian). +# Much more meaningful than index or fractional embedding. +# """ + +# def __init__(self, num_freqs: int = 6, include_radius: bool = True): +# super().__init__() +# # Smooth low frequencies — not exponential like NeRF +# freqs = torch.linspace(0.5, 3.0, num_freqs) +# self.register_buffer("freqs", freqs) + +# self.include_radius = include_radius +# self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) + +# def forward(self, cart_coords: torch.Tensor) -> torch.Tensor: +# """ +# Args: +# cart_coords: [B, kz, ky, kx, 3] real displacement vectors (Å) + +# Returns: +# positional features: [B, kz, ky, kx, out_dim] +# """ +# # Normalize scale for stability (prevents huge lattice from exploding features) +# scale = ( +# cart_coords.norm(dim=-1, keepdim=True).mean(dim=(1, 2, 3, 4), keepdim=True) +# + 1e-6 +# ) +# v = cart_coords / scale + +# # Project onto frequencies: [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] +# angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + +# sin = torch.sin(angles) +# cos = torch.cos(angles) + +# # Concatenate sin/cos over frequency axis and flatten +# feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + +# if self.include_radius: +# r = torch.norm(v, dim=-1, keepdim=True) +# feat = torch.cat([feat, r, r**2], dim=-1) + +# return feat + + +# class PositionalEmbedding(nn.Module): +# """ +# Learnable positional embedding for kernel positions. +# Similar to positional encodings in Transformers but learnable. +# """ +# +# def __init__(self, embed_dim=32, max_kernel_size=7): +# super().__init__() +# self.embed_dim = embed_dim +# +# # Learnable embeddings for each coordinate dimension +# # Range: [-max_kernel_size//2, max_kernel_size//2] +# self.z_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.y_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.x_embed = nn.Embedding(max_kernel_size, embed_dim) +# +# # Optional: learnable way to combine the three directions +# self.combine = nn.Linear(3 * embed_dim, embed_dim) +# +# def forward(self, frac_coords, kernel_size): +# """ +# Args: +# frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) +# kernel_size: (kz, ky, kx) tuple +# +# Returns: +# embeddings: [kz, ky, kx, embed_dim] +# """ +# kz, ky, kx = kernel_size +# +# # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) +# z_idx = (frac_coords[..., 0] + kz // 2).long() +# y_idx = (frac_coords[..., 1] + ky // 2).long() +# x_idx = (frac_coords[..., 2] + kx // 2).long() +# +# # Get embeddings for each dimension +# z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] +# y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] +# x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] +# +# # Combine (can use addition, concatenation + linear, etc.) +# combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] +# return self.combine(combined) # [kz, ky, kx, embed_dim] +# +# +# class FourierPositionalEmbedding(nn.Module): +# """ +# Fourier features for positional encoding (non-learnable but more expressive). +# Similar to what's used in NeRF. +# """ +# +# def __init__(self, embed_dim=32, max_freq=10): +# super().__init__() +# self.embed_dim = embed_dim +# +# # Number of frequency bands +# num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords +# +# # Logarithmically spaced frequencies +# freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) +# self.register_buffer("freq_bands", freq_bands) +# +# def forward(self, frac_coords): +# """ +# Args: +# frac_coords: [..., 3] fractional coordinates +# +# Returns: +# embeddings: [..., embed_dim] +# """ +# coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] +# freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] +# +# # Compute sin and cos for each frequency and coordinate +# angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] +# +# sin_features = torch.sin(angles) +# cos_features = torch.cos(angles) +# +# # Concatenate all features +# fourier_features = torch.cat([sin_features, cos_features], dim=-2) # [..., 2*num_freqs, 3] +# return fourier_features.reshape(*frac_coords.shape[:-1], -1) # [..., 6*num_freqs] diff --git a/src/electrai/model/utils_al.py b/src/electrai/model/utils_al.py new file mode 100644 index 00000000..75a181e2 --- /dev/null +++ b/src/electrai/model/utils_al.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def pairwise_dist(R: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """ + R: (B, N, 3) or (N, 3) + returns D: (B, N, N) or (N, N) + """ + if R.ndim == 2: + R = R[None, ...] + squeeze = True + else: + squeeze = False + + diff = R[:, :, None, :] - R[:, None, :, :] # (B,N,N,3) + D = torch.linalg.norm(diff, dim=-1).clamp_min(eps) + if squeeze: + D = D[0] + return D + + +# 2) RBF embedder for distances (recommended) + + +class RBF(nn.Module): + def __init__(self, num=32, r_max=10.0, trainable=False): + super().__init__() + centers = torch.linspace(0.0, r_max, num) + widths = torch.full((num,), (r_max / num)) + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, D): + """ + D: (B,N,N) distances + returns: (B,N,N,num) + """ + # (B,N,N,1) - (num,) -> (B,N,N,num) + x = D[..., None] - self.centers + return torch.exp(-0.5 * (x / self.widths).pow(2)) + + +class TriangleAttention(nn.Module): + def __init__(self, d, n_heads=4, dropout=0.0, starting=True): + super().__init__() + assert d % n_heads == 0 + self.h = n_heads + self.dk = d // n_heads + self.starting = starting + + self.ln = nn.LayerNorm(d) + self.to_q = nn.Linear(d, d, bias=False) + self.to_k = nn.Linear(d, d, bias=False) + self.to_v = nn.Linear(d, d, bias=False) + self.to_out = nn.Linear(d, d, bias=False) + self.drop = nn.Dropout(dropout) + + def forward(self, P, mask=None): + """ + P: (B,N,N,d) + mask: (B,N) with 1 for real atoms, 0 for padding + """ + B, N, _, d = P.shape + x = self.ln(P) + + if self.starting: + # fixed i: update row i using keys from (i,k) + q = x.reshape(B * N, N, d) + k = x.reshape(B * N, N, d) + v = x.reshape(B * N, N, d) + else: + # fixed j: update col j using keys from (k,j) + xt = x.transpose(1, 2) + q = xt.reshape(B * N, N, d) + k = xt.reshape(B * N, N, d) + v = xt.reshape(B * N, N, d) + + q = self.to_q(q).view(B * N, N, self.h, self.dk).transpose(1, 2) # (BN,h,N,dk) + k = self.to_k(k).view(B * N, N, self.h, self.dk).transpose(1, 2) + v = self.to_v(v).view(B * N, N, self.h, self.dk).transpose(1, 2) + + attn = (q @ k.transpose(-1, -2)) / (self.dk**0.5) # (BN,h,N,N) + + if mask is not None: + kv_mask = mask.repeat_interleave(N, dim=0) # (BN,N) + attn = attn.masked_fill(kv_mask[:, None, None, :] == 0, float("-inf")) + + w = self.drop(F.softmax(attn, dim=-1)) + out = (w @ v).transpose(1, 2).contiguous().view(B * N, N, d) + out = self.to_out(out).view(B, N, N, d) + + if not self.starting: + out = out.transpose(1, 2) + + return P + out + + +class DistanceTriangleNet(nn.Module): + def __init__( + self, + d=128, + n_heads=4, + n_blocks=2, + rbf_num=32, + r_max=10.0, + dropout=0.0, + in_channels=32, + out_channels=32, + ): + super().__init__() + self.rbf = RBF(num=rbf_num, r_max=r_max, trainable=False) + + self.pair_in = nn.Sequential(nn.Linear(rbf_num, d), nn.SiLU(), nn.Linear(d, d)) + + self.blocks = nn.ModuleList([]) + for _ in range(n_blocks): + self.blocks.append( + nn.ModuleList( + [ + TriangleAttention( + d, n_heads=n_heads, dropout=dropout, starting=True + ), + TriangleAttention( + d, n_heads=n_heads, dropout=dropout, starting=False + ), + nn.LayerNorm(d), + nn.Sequential( + nn.Linear(d, 2 * d), nn.SiLU(), nn.Linear(2 * d, d) + ), + ] + ) + ) + + self.node_ln = nn.LayerNorm(d) + self.in_channels = in_channels + self.out_channels = out_channels + self.node_head = nn.Sequential( + nn.Linear(d, 2 * d), + nn.SiLU(), + nn.Linear(2 * d, self.in_channels * self.out_channels), + ) + + def forward(self, R, k: int = 5, mask=None): + """ + R: (B,N,3) coordinates + mask: (B,N) optional + returns y: (B,N) per-atom values + """ + N = R.shape[0] # noqa: F841 + D = pairwise_dist(R) # (B,N,N) + # Optional: zero-out padded interactions + if mask is not None: + m2 = mask[:, :, None] * mask[:, None, :] # (B,N,N) + D = D * m2 + (1.0 - m2) * D.max().detach() # keep padded far away + + P = self.pair_in(self.rbf(D)) # (B,N,N,d) + + for tri_s, tri_e, ln, ff in self.blocks: + P = tri_s(P, mask=mask) + P = tri_e(P, mask=mask) + P = P + ff(ln(P)) + + # pair -> node (row/col pooling) + row = P.mean(dim=2) # (B,N,d) + col = P.mean(dim=1) # (B,N,d) + H = 0.5 * (row + col) + + y = self.node_head(self.node_ln(H)) # (B,N) + y = rearrange( + y, + "B (k1 k2 k3) (o i)-> B o i k1 k2 k3", + k1=k, + k2=k, + k3=k, + o=self.out_channels, + i=self.in_channels, + ) + if mask is not None: + y = y * mask + return y + + +class GaussianRadialBasis(nn.Module): + def __init__(self, num_gaussians=50, r_min=0.0, r_max=5.0, trainable=False): + super().__init__() + self.num_gaussians = num_gaussians + self.r_min = r_min + self.r_max = r_max + centers = torch.linspace(r_min, r_max, num_gaussians) + spacing = (r_max - r_min) / (num_gaussians - 1) if num_gaussians > 1 else 1.0 + widths = torch.ones(num_gaussians) * spacing + + if trainable: + self.centers = nn.Parameter(centers) + self.widths = nn.Parameter(widths) + else: + self.register_buffer("centers", centers) + self.register_buffer("widths", widths) + + def forward(self, distances): + """ + Expand distances using Gaussian basis functions. + + Args: + distances: Tensor of shape [...] containing interatomic distances + + Returns: + Tensor of shape [..., num_gaussians] with Gaussian features + """ + distances = distances.unsqueeze(-1) + centers = self.centers.view(*([1] * (distances.dim() - 1)), -1) + widths = self.widths.view(*([1] * (distances.dim() - 1)), -1) + + diff = distances - centers + gamma = 1.0 / (2 * widths**2) + return torch.exp(-gamma * diff**2) + + +class CartesianFourierEmbedding(nn.Module): + """ + Fourier features of real displacement vector (cartesian). + Much more meaningful than index or fractional embedding. + """ + + def __init__(self, num_freqs=6, include_radius=True): + super().__init__() + + # smooth low frequencies — not exponential like NeRF + freqs = torch.linspace(0.5, 3.0, num_freqs) + self.register_buffer("freqs", freqs) + + self.include_radius = include_radius + self.out_dim = 2 * num_freqs * 3 + (2 if include_radius else 0) + + def forward(self, cart_coords): + """ + cart_coords: [B, kz, ky, kx, 3] real displacement vectors (Å) + returns: positional features + """ + + # normalize scale for stability + # prevents huge lattice from exploding features + scale = ( + cart_coords.norm(dim=-1, keepdim=True).mean(dim=(1, 2, 3, 4), keepdim=True) + + 1e-6 + ) + v = cart_coords / scale + + # project onto frequencies + # [B,kz,ky,kx,3] -> [B,kz,ky,kx,F,3] + angles = v.unsqueeze(-2) * self.freqs.view(1, 1, 1, 1, -1, 1) + + sin = torch.sin(angles) + cos = torch.cos(angles) + + feat = torch.cat([sin, cos], dim=-2).reshape(*v.shape[:-1], -1) + + if self.include_radius: + r = torch.norm(v, dim=-1, keepdim=True) + feat = torch.cat([feat, r, r**2], dim=-1) + + return feat + + +# class PositionalEmbedding(nn.Module): +# """ +# Learnable positional embedding for kernel positions. +# Similar to positional encodings in Transformers but learnable. +# """ + +# def __init__(self, embed_dim=32, max_kernel_size=7): +# super().__init__() +# self.embed_dim = embed_dim + +# # Learnable embeddings for each coordinate dimension +# # Range: [-max_kernel_size//2, max_kernel_size//2] +# self.z_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.y_embed = nn.Embedding(max_kernel_size, embed_dim) +# self.x_embed = nn.Embedding(max_kernel_size, embed_dim) + +# # Optional: learnable way to combine the three directions +# self.combine = nn.Linear(3 * embed_dim, embed_dim) + +# def forward(self, frac_coords, kernel_size): +# """ +# Args: +# frac_coords: [kz, ky, kx, 3] fractional coordinates (z, y, x offsets) +# kernel_size: (kz, ky, kx) tuple + +# Returns: +# embeddings: [kz, ky, kx, embed_dim] +# """ +# kz, ky, kx = kernel_size + +# # Convert coordinates to indices (shift from [-k//2, k//2] to [0, k]) +# z_idx = (frac_coords[..., 0] + kz // 2).long() +# y_idx = (frac_coords[..., 1] + ky // 2).long() +# x_idx = (frac_coords[..., 2] + kx // 2).long() + +# # Get embeddings for each dimension +# z_emb = self.z_embed(z_idx) # [kz, ky, kx, embed_dim] +# y_emb = self.y_embed(y_idx) # [kz, ky, kx, embed_dim] +# x_emb = self.x_embed(x_idx) # [kz, ky, kx, embed_dim] + +# # Combine (can use addition, concatenation + linear, etc.) +# combined = torch.cat([z_emb, y_emb, x_emb], dim=-1) # [kz, ky, kx, 3*embed_dim] +# return self.combine(combined) # [kz, ky, kx, embed_dim] + + +# class FourierPositionalEmbedding(nn.Module): +# """ +# Fourier features for positional encoding (non-learnable but more expressive). +# Similar to what's used in NeRF. +# """ + +# def __init__(self, embed_dim=32, max_freq=10): +# super().__init__() +# self.embed_dim = embed_dim + +# # Number of frequency bands +# num_freqs = embed_dim // 6 # 6 because we have sin+cos for each of 3 coords + +# # Logarithmically spaced frequencies +# freq_bands = 2.0 ** torch.linspace(0, max_freq, num_freqs) +# self.register_buffer("freq_bands", freq_bands) + +# def forward(self, frac_coords): +# """ +# Args: +# frac_coords: [..., 3] fractional coordinates + +# Returns: +# embeddings: [..., embed_dim] +# """ +# # frac_coords: [..., 3] +# # freq_bands: [num_freqs] + +# coords_expanded = frac_coords.unsqueeze(-2) # [..., 1, 3] +# freqs_expanded = self.freq_bands.unsqueeze(-1) # [num_freqs, 1] + +# # Compute sin and cos for each frequency and coordinate +# angles = coords_expanded * freqs_expanded # [..., num_freqs, 3] + +# sin_features = torch.sin(angles) +# cos_features = torch.cos(angles) + +# # Concatenate all features +# fourier_features = torch.cat( +# [sin_features, cos_features], dim=-2 +# ) # [..., 2*num_freqs, 3] +# return fourier_features.reshape( +# *frac_coords.shape[:-1], -1 +# ) # [..., 6*num_freqs]