Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 394 additions & 0 deletions other/experiments/mace/test_mace_browser.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
<!DOCTYPE html>
<html>

<head>
<meta charset="utf-8">
<title>MACE + ASE in Pyodide (WASM)</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}

body {
font-family: 'SF Mono', 'Fira Code', monospace;
background: #0d1117;
color: #c9d1d9;
padding: 2rem;
}

h1 {
color: #58a6ff;
margin-bottom: 1rem;
font-size: 1.5rem;
}

#log {
background: #161b22;
border: 1px solid #30363d;
border-radius: 8px;
padding: 1rem;
white-space: pre-wrap;
font-size: 0.85rem;
line-height: 1.6;
max-height: 85vh;
overflow-y: auto;
}

.info {
color: #8b949e;
}

.success {
color: #3fb950;
}

.error {
color: #f85149;
}

.highlight {
color: #d2a8ff;
}

.step {
color: #58a6ff;
font-weight: bold;
}
</style>
</head>

<body>
<h1>⚛️ MACE + ASE in Pyodide (WASM)</h1>
<div id="log"></div>

<script>
const logEl = document.getElementById('log');
function log(msg, cls = 'info') {
const span = document.createElement('span');
span.className = cls;
span.textContent = msg + '\n';
logEl.appendChild(span);
logEl.scrollTop = logEl.scrollHeight;
}

async function main() {
log('=== MACE + ASE Pyodide Test ===', 'step');
log('');

// Step 1: Load Pyodide
log('[1/7] Loading Pyodide runtime...', 'step');
const pyodide = await loadPyodide({
indexURL: "https://cdn.jsdelivr.net/pyodide/v0.24.1/full/"
});
log(' ✓ Pyodide loaded', 'success');
log('');

// Step 2: Install base packages
log('[2/7] Installing base packages (micropip, sympy, scipy)...', 'step');
await pyodide.loadPackage('micropip');
await pyodide.loadPackage(['sympy', 'scipy', 'typing-extensions', 'matplotlib', 'pandas', 'numpy', 'ssl']);
log(' ✓ Base packages installed', 'success');
log('');

// Step 3: Install torch wheel
log('[3/7] Installing PyTorch wheel...', 'step');
// If opened from file://, use localhost server instead
const baseUrl = window.location.protocol === 'file:'
? 'http://localhost:8800'
: window.location.origin;
const wheelUrl = baseUrl + '/dist/torch-2.1.0a0-cp311-cp311-emscripten_3_1_45_wasm32.whl?v=' + Date.now();
log(` URL: ${wheelUrl}`, 'info');

try {
const result = await pyodide.runPythonAsync(`
import micropip
await micropip.install("${wheelUrl}", deps=False)
import torch
f"torch {torch.__version__} loaded"
`);
log(` ✓ ${result}`, 'success');
} catch (e) {
log(` ✗ torch install failed: ${e.message}`, 'error');
log('', 'info');
log('=== Test aborted ===', 'step');
return;
}
log('');

// Step 4: Install MACE dependencies
log('[4/7] Installing MACE dependencies (e3nn, opt_einsum, etc.)...', 'step');
try {
const result = await pyodide.runPythonAsync(`
import micropip

# Core deps
await micropip.install("opt_einsum")
await micropip.install("opt_einsum_fx", deps=False)
await micropip.install("e3nn==0.4.4", deps=False)

# MACE training deps
await micropip.install("prettytable")
await micropip.install("torch_ema", deps=False)
await micropip.install("lightning-utilities", deps=False)
await micropip.install("torchmetrics", deps=False)

# ASE
await micropip.install("ase", deps=False)

# MACE itself
await micropip.install("mace-torch", deps=False)

# Verify imports
import e3nn
import ase
import mace
f"e3nn {e3nn.__version__}, ase {ase.__version__}, mace installed"
`);
log(` ✓ ${result}`, 'success');
} catch (e) {
log(` ✗ MACE deps failed: ${e.message}`, 'error');
log('', 'info');
log('=== Test aborted ===', 'step');
return;
}
log('');

// Step 5: Import MACE modules
log('[5/7] Importing MACE modules...', 'step');
try {
const result = await pyodide.runPythonAsync(`
import torch
# Patch torch.compiler.is_compiling for newer MACE versions
if not hasattr(torch, 'compiler'):
import types
torch.compiler = types.ModuleType('torch.compiler')
if not hasattr(torch.compiler, 'is_compiling'):
torch.compiler.is_compiling = lambda: False
# Patch Tensor.numpy since WASM build lacks NumPy integration
import numpy as _np
def _tensor_numpy(self):
return _np.array(self.detach().tolist())
torch.Tensor.numpy = _tensor_numpy
from mace.tools.torch_geometric.data import Data
from mace.tools.torch_geometric.batch import Batch
from mace.tools.torch_geometric.dataloader import DataLoader
from mace import modules as mace_modules
from mace.tools import utils as mace_utils
"MACE modules imported successfully"
`);
log(` ✓ ${result}`, 'success');
} catch (e) {
log(` ✗ MACE import failed: ${e.message}`, 'error');
log('', 'info');
log('=== Test aborted ===', 'step');
return;
}
log('');

// Step 6: Create MACE model and run inference
log('[6/7] Creating MACE model & running inference on H2O...', 'step');
try {
const result = await pyodide.runPythonAsync(`
import torch
import numpy as np
from e3nn import o3
from mace.modules import MACE
from mace.tools.torch_geometric.data import Data
import mace.modules as mace_modules

# Create a minimal MACE model
model_config = {
"r_max": 5.0,
"num_bessel": 8,
"num_polynomial_cutoff": 5,
"max_ell": 2,
"interaction_cls_first": mace_modules.interaction_classes["RealAgnosticResidualInteractionBlock"],
"interaction_cls": mace_modules.interaction_classes["RealAgnosticResidualInteractionBlock"],
"num_interactions": 1,
"num_elements": 4,
"hidden_irreps": o3.Irreps("16x0e + 16x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": torch.nn.functional.silu,
"atomic_energies": np.array([-1.0, -3.0, -5.0, -7.0]),
"avg_num_neighbors": 5.0,
"atomic_numbers": [1, 6, 7, 8],
"correlation": 2,
"radial_type": "bessel",
}

model = MACE(**model_config)
model.double()
model.eval()
n_params = sum(p.numel() for p in model.parameters())

# Test inference on a water molecule (H2O)
positions = torch.tensor([
[0.0000, 0.0000, 0.1173], # O
[0.0000, 0.7572, -0.4692], # H
[0.0000, -0.7572, -0.4692], # H
], dtype=torch.float64)
positions.requires_grad_(True)

node_attrs = torch.zeros(3, 4, dtype=torch.float64)
node_attrs[0, 3] = 1.0 # O
node_attrs[1, 0] = 1.0 # H
node_attrs[2, 0] = 1.0 # H

edge_index = torch.tensor([[0,0,1,1,2,2],[1,2,0,2,0,1]], dtype=torch.long)
shifts = torch.zeros(6, 3, dtype=torch.float64)
unit_shifts = torch.zeros(6, 3, dtype=torch.float64)

data = Data(
positions=positions,
node_attrs=node_attrs,
edge_index=edge_index,
shifts=shifts,
unit_shifts=unit_shifts,
cell=torch.zeros(3, 3, dtype=torch.float64),
batch=torch.zeros(3, dtype=torch.long),
ptr=torch.tensor([0, 3], dtype=torch.long),
)

output = model(data.to_dict(), training=False)

energy = output["energy"].item()
forces = output["forces"].detach().numpy()
result_lines = []
result_lines.append(f"MACE model: {n_params:,} parameters")
result_lines.append(f"H2O Energy: {energy:.6f} eV")
result_lines.append(f"H2O Forces (eV/A):")
for i, (elem, f) in enumerate(zip(['O','H','H'], forces)):
result_lines.append(f" {elem}: [{f[0]:+.6f}, {f[1]:+.6f}, {f[2]:+.6f}]")
result_lines.append(f"Force sum: [{forces.sum(axis=0)[0]:.2e}, {forces.sum(axis=0)[1]:.2e}, {forces.sum(axis=0)[2]:.2e}] (should be ~0)")
"\\n".join(result_lines)
`);
log(` ✓ MACE inference succeeded!`, 'success');
log(result, 'highlight');
} catch (e) {
log(` ✗ MACE model failed: ${e.message}`, 'error');
log('', 'info');
log('=== Test aborted ===', 'step');
return;
}
log('');

// Step 7: ASE relaxation with MACE
log('[7/7] Running ASE BFGS relaxation with MACE...', 'step');
try {
const result = await pyodide.runPythonAsync(`
import torch
import numpy as np
from ase import Atoms
from ase.optimize import BFGS
from ase.calculators.calculator import Calculator, all_changes
from mace.tools.torch_geometric.data import Data

class MACECalculatorSimple(Calculator):
implemented_properties = ['energy', 'forces']

def __init__(self, model, atomic_numbers_map, **kwargs):
super().__init__(**kwargs)
self.model = model
self.model.eval()
self.z_to_idx = {z: i for i, z in enumerate(atomic_numbers_map)}
self.num_types = len(atomic_numbers_map)

def calculate(self, atoms=None, properties=['energy', 'forces'], system_changes=all_changes):
super().calculate(atoms, properties, system_changes)

positions = torch.tensor(self.atoms.positions, dtype=torch.float64)
n = len(self.atoms)

node_attrs = torch.zeros(n, self.num_types, dtype=torch.float64)
for i, z in enumerate(self.atoms.numbers):
node_attrs[i, self.z_to_idx[z]] = 1.0

edges_src, edges_dst = [], []
for i in range(n):
for j in range(n):
if i != j:
edges_src.append(i)
edges_dst.append(j)
edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
n_edges = edge_index.shape[1]

data = Data(
positions=positions,
node_attrs=node_attrs,
edge_index=edge_index,
shifts=torch.zeros(n_edges, 3, dtype=torch.float64),
unit_shifts=torch.zeros(n_edges, 3, dtype=torch.float64),
cell=torch.zeros(3, 3, dtype=torch.float64),
batch=torch.zeros(n, dtype=torch.long),
ptr=torch.tensor([0, n], dtype=torch.long),
)

positions.requires_grad_(True)
data.positions = positions

output = self.model(data.to_dict(), training=False)

self.results['energy'] = output['energy'].item()
self.results['forces'] = output['forces'].detach().numpy()

# Create the calculator
calc = MACECalculatorSimple(model, atomic_numbers_map=[1, 6, 7, 8])

# Create a distorted water molecule
atoms = Atoms('OH2', positions=[
[0.0, 0.0, 0.0],
[0.0, 0.9, -0.5],
[0.0, -0.6, -0.6],
])
atoms.calc = calc

result_lines = []
e0 = atoms.get_potential_energy()
f0 = atoms.get_forces()
result_lines.append(f"Initial: E={e0:.6f} eV, max|F|={np.max(np.abs(f0)):.6f} eV/A")

# Run BFGS relaxation
opt = BFGS(atoms, logfile=None)
opt.run(fmax=0.1, steps=20)

e1 = atoms.get_potential_energy()
f1 = atoms.get_forces()
result_lines.append(f"After {opt.nsteps} BFGS steps:")
result_lines.append(f" Energy: {e1:.6f} eV (delta: {e1-e0:.6f} eV)")
result_lines.append(f" Max |F|: {np.max(np.abs(f1)):.6f} eV/A")
result_lines.append(f" Final positions:")
for i, (elem, pos) in enumerate(zip(['O','H','H'], atoms.positions)):
result_lines.append(f" {elem}: [{pos[0]:+.4f}, {pos[1]:+.4f}, {pos[2]:+.4f}]")
result_lines.append("")
result_lines.append("🎉 ASE + MACE relaxation in Pyodide WASM — SUCCESS! 🎉")
"\\n".join(result_lines)
`);
log(` ✓ ASE relaxation completed!`, 'success');
log(result, 'highlight');
} catch (e) {
log(` ✗ ASE relaxation failed: ${e.message}`, 'error');
}

log('');
log('=== Test complete ===', 'step');
}

// Load Pyodide script then run
const script = document.createElement('script');
script.src = 'https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js';
script.onload = () => {
main().catch(err => {
log(`Fatal error: ${err.message}`, 'error');
console.error(err);
log('', 'info');
log('=== Test complete ===', 'step');
});
};
script.onerror = () => log('Failed to load Pyodide JS', 'error');
document.head.appendChild(script);
</script>
</body>

</html>
Loading