forked from matthew-lowery/kernel_neural_operator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathns_pipe.py
More file actions
142 lines (115 loc) · 4.6 KB
/
ns_pipe.py
File metadata and controls
142 lines (115 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import jax
import optax
from jax import numpy as jnp, random as jr
import jax.random as jr
from utils import *
import equinox as eqx
from kernels import *
from models import KNO_NS_PIPE as model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--lr-max', type=float, default=0.001)
parser.add_argument('--lift-dim', type=int, default=128)
parser.add_argument('--depth', type=int, default=7)
parser.add_argument('--test-batch-size', type=int, default=1)
parser.add_argument('--int-kernel', type=str, default='ns_gsm', choices=['g', 'a_g','ns_g', 'gsm', 'ns_gsm'])
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--print-every', type=int, default=1)
parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--grad-clip', type=float, default=0.5)
args = parser.parse_args()
print(args)
DTYPE = jnp.float32
key = jr.PRNGKey(args.seed)
### load data
fp = f'./datasets/ns_pipe.npz'
data = jnp.load(fp)
print(dict(data).keys())
y_grid = data['y_grid'].astype(jnp.float32).reshape(2310,-1,2)
y = data['y'].astype(jnp.float32).reshape(2310,-1)
y_mu, y_std = jnp.mean(y_grid, axis=(0,1), keepdims=True), jnp.std(y_grid, axis=(0,1), keepdims=True)
y_grid = y_grid - y_mu
y_grid = y_grid / y_std
y_grid = y_grid.reshape(-1,129,129,2)
y_h = y_grid[0,0,1,1] - y_grid[0,0,0,1]
x_h = y_grid[0,1,0,0] - y_grid[0,0,0,0]
grid_1d_y = y_grid[:, 0, :, 1]
grid_1d_x = y_grid[:, :, 0, 0]
wx = jnp.zeros((129,))
wx = wx.at[0].set(x_h/2)
wx = wx.at[-1].set(x_h/2)
wx = wx.at[1:-1].set(x_h)
wx = wx.reshape(-1,1)
wy = jnp.zeros((129,))
wy = wy.at[0].set(y_h/2)
wy = wy.at[-1].set(y_h/2)
wy = wy.at[1:-1].set(y_h)
wy = wy.reshape(-1,1)
key,_ = jr.split(key)
q_nodes = y_grid
domain_dims = 2
codomain_dims = 0
ntrain = 1000
ntest = 200
q_train, q_test = q_nodes[: ntrain], q_nodes[-ntest:]
y_train, y_test = y[: ntrain], y[-ntest:]
### data config
num_train_batches = len(q_train) // args.batch_size
num_steps = args.epochs * num_train_batches
## kernel setup
integration_kernel = kernels[args.int_kernel]
integration_kernel = partial(integration_kernel, ndims=1)
### preprocess data
y_normalizer = UnitGaussianNormalizer(y_train)
in_feats = domain_dims + codomain_dims
model = model(integration_kernel,
args.depth,
args.lift_dim,
domain_dims,
in_feats,
129,
key=key)
lr_schedule = cosine_annealing(args.epochs*num_train_batches, peak_value=args.lr_max)
optimizer = optax.chain(optax.clip_by_global_norm(max_norm=args.grad_clip), optax.adam(lr_schedule))
opt_state = optimizer.init(eqx.filter([model], is_trainable))
param_count = sum(x.size for x in jax.tree.leaves(eqx.filter(model, is_trainable)))
print(f'{param_count=}')
@eqx.filter_jit
def train_step(model, opt_state, optimizer, batch, ):
q, y = batch
def loss(model):
y_pred = eqx.filter_vmap(lambda q: model(q, wx, wy))(q)
y_pred = y_pred.reshape(args.batch_size, -1)
y_pred = y_normalizer.decode(y_pred)
l2 = ((y - y_pred)**2).sum(axis=-1).mean()
rel_l2 = (jnp.linalg.norm(y-y_pred, axis=1) / jnp.linalg.norm(y, axis=1)).mean()
return l2, rel_l2
(train_loss,rel_l2), grads = eqx.filter_value_and_grad(loss, has_aux=True)(model)
updates,opt_state = optimizer.update([grads],
opt_state,
eqx.filter([model], is_trainable))
model = eqx.apply_updates(model, updates[0])
return model, opt_state, train_loss, rel_l2
@eqx.filter_jit
def eval(model, batch,):
q, y = batch
def loss(model):
y_pred = jax.lax.map(lambda xs: model(xs, wx,wy),q , batch_size=args.test_batch_size)
y_pred = y_pred.reshape(ntest,-1)
y_pred = y_normalizer.decode(y_pred)
return (jnp.linalg.norm(y-y_pred, axis=1) / jnp.linalg.norm(y, axis=1)).mean()
rel_l2 = loss(model)
return rel_l2
test_l2_best = 100.
for epoch in range(args.epochs):
epoch_key,_ = jr.split(key)
for i in range(num_train_batches):
batch = get_batch(epoch_key, (q_train, y_train), i, args.batch_size)
model, opt_state, train_loss, rel_l2 = train_step(model, opt_state, optimizer, batch)
if (epoch % args.print_every) == 0 or (epoch == args.epochs - 1):
print(f'{epoch=}, train rel_l2: {rel_l2.item()*100:.3f}')
if (epoch % args.eval_every) == 0 or (epoch == args.epochs - 1):
test_rel_l2 = eval(model, (q_test, y_test))
print(f'test rel_l2: {test_rel_l2.item()*100:.3f}')