Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 103 additions & 45 deletions drug_pinn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
# SPDX-License-Identifier: MIT
#
# Copyright (c) 2025 Zara Dar
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following restrictions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

from typing import Tuple, Dict, List, Any
from muon import SingleDeviceMuonWithAuxAdam

# -------------------------------
# Problem setup: dC/dt = -k C
Expand All @@ -14,8 +36,20 @@
np.random.seed(42)


def get_device() -> torch.device:
"""Returns the best available device (CUDA, MPS, or CPU)."""
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


class DrugPINN(nn.Module):
def __init__(self, hidden_layers=(32, 32, 32)):
"""
Physics-Informed Neural Network to solve the drug concentration decay problem.
"""
def __init__(self, hidden_layers: Tuple[int, ...] = (32, 32, 32)):
super().__init__()
layers = []
in_features = 1
Expand All @@ -28,70 +62,96 @@ def __init__(self, hidden_layers=(32, 32, 32)):
self.apply(self._init_weights)

@staticmethod
def _init_weights(m):
def _init_weights(m: nn.Module):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)

def forward(self, t):
def forward(self, t: torch.Tensor) -> torch.Tensor:
return self.net(t)


def true_solution(t, C0, k):
def true_solution(t: np.ndarray, C0: float, k: float) -> np.ndarray:
"""Computes the analytical solution C(t) = C0 * exp(-k * t)."""
return C0 * np.exp(-k * t)


def train_drug_pinn(
# Number of optimization steps
epochs=4000,
# Number of collocation points per epoch
M=128,
# C0: initial plasma concentration in mg/L (10 mg/L is a typical
# therapeutic level for several small‑molecule drugs)
C0=10.0,
# k: elimination rate constant in 1/hour. 0.1 1/h corresponds to a
# half‑life t_half ≈ ln(2)/k ≈ 6.9 hours, which is common for many
# medications.
k=0.1,
lam_bc=10.0,
# Time horizon in hours
t_max=24.0,
lr=1e-3,
):
device = torch.device("cpu") # use GPU if you have one :-)
epochs: int = 4000,
M: int = 128,
C0: float = 10.0,
k: float = 0.1,
lam_bc: float = 10.0,
t_max: float = 24.0,
lr: float = 1e-3,
) -> Tuple[DrugPINN, np.ndarray, Dict[str, Any]]:
"""
Trains the PINN model using the Muon optimizer for hidden weights and Adam for others.

Args:
epochs: Number of optimization steps.
M: Number of collocation points per epoch.
C0: Initial plasma concentration (mg/L).
k: Elimination rate constant (1/hour).
lam_bc: Weight for boundary condition loss.
t_max: Time horizon in hours.
lr: Learning rate for the Adam component (auxiliary). Muon uses its own default (0.02).

Returns:
The trained model, loss history array, and configuration dictionary.
"""
device = get_device()
print(f"Using device: {device}")

model = DrugPINN().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Split parameters for Muon (matrices) and Adam (vectors/scalars)
# Muon is for 2D parameters (weights of Linear layers)
# Adam is for 1D parameters (biases) or others
muon_params = []
adam_params = []

for name, param in model.named_parameters():
if param.ndim >= 2:
muon_params.append(param)
else:
adam_params.append(param)

# Configure the MuonWithAuxAdam optimizer
# We use the provided lr for the Adam part to match original semantics as close as possible for non-matrix params.
# We use Muon's default 0.02 for the matrix params as it's the standard for that optimizer.
param_groups = [
dict(params=muon_params, use_muon=True, lr=0.02, weight_decay=0.0),
dict(params=adam_params, use_muon=False, lr=lr, weight_decay=0.0)
]

optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

loss_history = []

# Fixed BC point at t = 0
t_bc = torch.zeros((1, 1), dtype=torch.float32, device=device)
C0_tensor = torch.full((1, 1), float(C0), dtype=torch.float32, device=device)

# Normalize time for the network input: t_norm in [0, 1]
# We must adjust the derivative: dC/dt = (dC/dt_norm) * (dt_norm/dt)
# dt_norm/dt = 1/t_max

for epoch in range(epochs):
optimizer.zero_grad()

# Sample t in domain [0, t_max]
t_colloc = torch.rand((M, 1), dtype=torch.float32, device=device) * t_max
t_colloc.requires_grad_(True)

# Normalize input for the model
# Normalize input for the model: t_norm in [0, 1]
t_norm = t_colloc / t_max
C_pred = model(t_norm)

# Compute gradient w.r.t. t_colloc (using chain rule implicitly via autograd if we used t_colloc in graph)
# However, it's cleaner to compute grad w.r.t t_norm and scale manually
# Compute gradient w.r.t. t_norm
dC_dt_norm = torch.autograd.grad(
C_pred,
t_norm,
grad_outputs=torch.ones_like(C_pred),
create_graph=True,
)[0]

# dC/dt = (dC/dt_norm) * (dt_norm/dt) = dC_dt_norm * (1/t_max)
dCdt = dC_dt_norm * (1.0 / t_max)

# Physics residual: dC/dt + k C = 0
Expand Down Expand Up @@ -122,7 +182,8 @@ def train_drug_pinn(
)


def plot_loss(loss_history, filename="drug_pinn_loss.png"):
def plot_loss(loss_history: np.ndarray, filename: str = "drug_pinn_loss.png") -> None:
"""Plots the training loss history."""
plt.rcParams.update({"font.size": 18})

fig, ax = plt.subplots(figsize=(8, 5))
Expand All @@ -142,29 +203,28 @@ def plot_loss(loss_history, filename="drug_pinn_loss.png"):


def create_animation(
model,
config,
filename="drug_pinn_concentration.gif",
duration_seconds=5.0,
fps=25,
):
model: nn.Module,
config: Dict[str, Any],
filename: str = "drug_pinn_concentration.gif",
duration_seconds: float = 5.0,
fps: int = 25,
) -> None:
"""Creates an animation of the drug concentration prediction vs ground truth."""
C0 = config["C0"]
k = config["k"]
t_max = config["t_max"]

device = torch.device("cpu")
device = next(model.parameters()).device
model.eval()

# Time grid for evaluation
t_np = np.linspace(0.0, t_max, 200)
# Normalize time for inference!
# Normalize time for inference
t_norm = torch.tensor(t_np.reshape(-1, 1) / t_max, dtype=torch.float32, device=device)

with torch.no_grad():
C_pred_np = model(t_norm).cpu().numpy().flatten()

C_true_np = true_solution(t_np, C0, k)

# Ground truth scatter points (static)
t_gt = np.linspace(0.0, t_max, 25)
C_gt = true_solution(t_gt, C0, k)
Expand Down Expand Up @@ -238,5 +298,3 @@ def main():

if __name__ == "__main__":
main()


133 changes: 133 additions & 0 deletions muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-License-Identifier: MIT
#
# Copyright (c) 2025 Erkin Alp Güney
#
# Based on code copyrighted by Keller Jordan (original Muon implementation).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following restrictions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import torch.optim as optim

def zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return update


def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
return buf1c / (buf2c.sqrt() + eps)


class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())

@torch.no_grad()
def step(self, closure=None):

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if group["use_muon"]:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])

return loss