Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.cache/
docs/reference/*
./examples/MP/experiments
./examples/QM9s
./examples/QM9
*doctrees*
/site

Expand Down
48 changes: 48 additions & 0 deletions src/electrai/configs/MP/config_resnet_lcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Dataset / loader parameters
data:
_target_: src.electrai.dataloader.dataset.RhoRead
root: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/mp_filelist.txt
split_file: null #/scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/split.json
precision: f32
batch_size: 1
train_workers: 8
val_workers: 2
pin_memory: false
val_frac: 0.005
drop_last: false
augmentation: false
random_seed: 42
# downsample_label: 0
# downsample_data: 0

# Model
model:
_target_: src.electrai.model.resnet_LCN.GeneratorResNet
n_residual_blocks: 32
n_channels: 32
kernel_size1: 5
kernel_size2: 5
normalize: True
use_checkpoint: False
use_lattice_conv: true
use_radial_embedding: true
num_gaussians: 500
use_positional_embedding: true
pos_embed_dim: 500

# Training parameters
precision: 32
epochs: 50
lr: 0.01
weight_decay: 0.0
warmup_length: 1
beta1: 0.9
beta2: 0.99

# Weights and biases
wandb_mode: online
entity: PrinceOA
wb_pname: mp-experiment

# checkpoints
ckpt_path: ./checkpoints
6 changes: 3 additions & 3 deletions src/electrai/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs):
else:
raise ValueError("No filename found.")

self.category = Path(datapath).name.split("_")[0] # example: mp_filelist.txt
self.category = Path(datapath).name.split("_")[0]
self.root = Path(datapath).parent
self.member_list = member_list

Expand All @@ -116,7 +116,7 @@ def __len__(self):

def __getitem__(self, index):
index = self.member_list[index]
data, label = utils.load_numpy_rho(
data, label, lattice = utils.load_numpy_rho(
root=self.root,
category=self.category,
index=index,
Expand All @@ -125,4 +125,4 @@ def __getitem__(self, index):
)
data = data.unsqueeze(0)
label = label.unsqueeze(0)
return {"data": data, "label": label, "index": index}
return {"data": data, "label": label, "index": index, "lattice": lattice}
17 changes: 14 additions & 3 deletions src/electrai/dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,33 @@ def load_numpy_rho(
"""
root = Path(root)
if category == "mp":
data, label = load_chgcar(root, index)
data, label, lattice = load_chgcar(root, index)
elif category == "qm9":
data, label = load_npy(root, index)
data = torch.tensor(data, dtype=dtype_map[precision])
label = torch.tensor(label, dtype=dtype_map[precision])
lattice = torch.tensor(lattice, dtype=dtype_map[precision])
# grid_shape = torch.tensor(
# data.shape, dtype=dtype_map[precision], device=lattice.device
# )
# lattice = lattice / grid_shape[:, None]
if augmentation:
data, label = rand_rotate([data, label])
return data, label
# print("shapeeeeeeee", index, data.shape, label.shape)
data = data.permute(2, 1, 0)
label = label.permute(2, 1, 0)
# a, b, c = lattice[0], lattice[1], lattice[2]
# lattice = torch.stack([c, b, a], dim=0) # (z,y,x)
return data, label, lattice


def load_chgcar(root: str | bytes | os.PathLike, index: str):
data = Chgcar.from_file(root / "data" / f"{index}.CHGCAR")
label = Chgcar.from_file(root / "label" / f"{index}.CHGCAR")
lattice = data.structure.lattice.matrix
data = data.data["total"] / data.structure.lattice.volume
label = label.data["total"] / label.structure.lattice.volume
return data, label
return data, label, lattice


def load_npy(root: str | bytes | os.PathLike, index: str):
Expand Down
11 changes: 9 additions & 2 deletions src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def train(args):
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(
project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg)
project=cfg.wb_pname, entity=cfg.entity, config=vars(cfg), name=cfg.run_name
)
else:
wandb_logger = None
Expand All @@ -62,12 +62,19 @@ def train(args):
# -----------------------------
# Trainer
# -----------------------------
local_world_size = int(
os.environ.get("LOCAL_WORLD_SIZE", torch.cuda.device_count())
)
world_size = int(os.environ.get("WORLD_SIZE", local_world_size))
num_nodes = max(1, world_size // local_world_size)
trainer = Trainer(
max_epochs=int(cfg.epochs),
logger=wandb_logger,
callbacks=[checkpoint_cb, lr_monitor],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
devices="auto",
num_nodes=num_nodes,
strategy="ddp",
precision=cfg.precision,
log_every_n_steps=1,
gradient_clip_val=getattr(cfg, "gradient_clip_value", 1.0),
Expand Down
57 changes: 52 additions & 5 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from hydra.utils import instantiate
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities import rank_zero_only
from src.electrai.model.LCN import LatticeConv3d
from src.electrai.model.loss.charge import NormMAE


Expand All @@ -14,8 +16,40 @@ def __init__(self, cfg):
self.model = instantiate(cfg.model)
self.loss_fn = NormMAE()

def forward(self, x):
return self.model(x)
def forward(self, x, lattice_vectors=None):
return self.model(x, lattice_vectors)

@rank_zero_only
def _collect_kernel_stats(self, target_layer: str | None = None) -> None:
# no need for trainer.is_global_zero now; rank_zero_only handles it
for name, module in self.model.named_modules():
if not isinstance(module, LatticeConv3d) or not hasattr(
module, "kernel_stats"
):
continue
if target_layer is not None and name != target_layer:
continue

s = module.kernel_stats
if self.cfg.model["use_lattice_conv"]:
log_dict = {
f"kernels/{name}/alpha": s["alpha"],
f"kernels/{name}/ratio": s["ratio"],
f"kernels/{name}/geo_rms": s["geo_rms"],
f"kernels/{name}/base_rms": s["base_rms"],
}
else:
log_dict = {f"kernels/{name}/base_rms": s["base_rms"]}

# IMPORTANT: let Lightning manage step + syncing
self.log_dict(
log_dict,
on_step=True,
on_epoch=False,
prog_bar=False,
logger=True,
sync_dist=False, # keep False if you're only logging on rank 0
)

def training_step(self, batch):
loss = self._loss_calculation(batch)
Expand All @@ -29,6 +63,9 @@ def training_step(self, batch):
)
return loss

def on_train_batch_end(self, outputs, batch, batch_idx): # noqa: ARG002
self._collect_kernel_stats(target_layer="mid.0.conv1")

def validation_step(self, batch):
loss = self._loss_calculation(batch)
self.log(
Expand All @@ -39,15 +76,16 @@ def validation_step(self, batch):
def _loss_calculation(self, batch):
x = batch["data"]
y = batch["label"]
A = batch["lattice"]
if isinstance(x, list):
losses = []
for x_i, y_i in zip(x, y, strict=True):
pred = self(x_i.unsqueeze(0))
for x_i, y_i, A_i in zip(x, y, strict=True):
pred = self(x_i.unsqueeze(0), A_i.unsqueeze(0))
loss = self.loss_fn(pred, y_i.unsqueeze(0))
losses.append(loss)
loss = torch.stack(losses).mean()
else:
pred = self(x)
pred = self(x, A)
loss = self.loss_fn(pred, y)
return loss

Expand All @@ -56,8 +94,17 @@ def configure_optimizers(self):
self.model.parameters(),
lr=float(self.cfg.lr),
weight_decay=float(self.cfg.weight_decay),
betas=(getattr(self.cfg, "beta1", 0.9), getattr(self.cfg, "beta2", 0.999)),
)

# flat = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
# cos = torch.optim.lr_scheduler.CosineAnnealingLR(
# optimizer, T_max=self.cfg.epochs - 1, eta_min=1e-5
# )

# scheduler = torch.optim.lr_scheduler.SequentialLR(
# optimizer, [flat, cos], milestones=[1]
# )
linsch = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-5,
Expand Down
Loading