diff --git a/examples/G-FNO/README.md b/examples/G-FNO/README.md new file mode 100644 index 0000000..0781a78 --- /dev/null +++ b/examples/G-FNO/README.md @@ -0,0 +1,130 @@ +# G-FNO + +## 1. Background + +G-FNO (Group Equivariant Fourier Neural Operators) is a family of operator-learning models for PDE surrogate modeling. This PaddleCFD integration keeps the Paddle version of the main 2D/3D FNO, GCNN, GFNO, Ghybrid, and radialNO variants. + +![G-FNO network](assets/network_visual.png) + +## 2. Code Layout + +- Core model code: `ppcfd/models/g_fno` +- Training and data generation scripts: `examples/G-FNO` + +## 3. Installation + +At the PaddleCFD repository root: + +```bash +python -m pip install -r requirements.txt +python -m pip install -e . +``` + +No extra Python packages beyond PaddleCFD root requirements are needed for the Paddle G-FNO runtime. + +## 4. Import Models From Installed PaddleCFD + +```python +from ppcfd.models.g_fno import FNO2d, GFNO2d + +fno = FNO2d( + num_channels=1, + modes1=12, + modes2=12, + width=20, + initial_step=10, + grid_type="symmetric", +) + +gfno = GFNO2d( + num_channels=1, + modes=12, + width=10, + initial_step=10, + reflection=False, + grid_type="symmetric", +) +``` + +## 5. Data Preparation + +### 5.1 Navier-Stokes with Symmetric Forcing + +From `examples/G-FNO/data_generation/navier_stokes`: + +```bash +python "ns_2d_rt.py" --nu=1e-4 --T=30 --N=1200 --save_path="./data" --ntest=100 --period=4 --device=auto +``` + +### 5.2 Other datasets + +- NS / PDEArena / PDEBench shallow-water datasets still follow the original upstream data sources referenced by the original paper. +- Place datasets under a local `data/` directory and pass absolute or repository-relative paths to `experiments.py`. + +## 6. Training From Examples + +### Supported `--model_type` + +The Paddle version currently supports the model types below. + +| Status | Model types | +| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Supported | `FNO2d`, `FNO2d_aug`, `FNO3d`, `FNO3d_aug`, `GCNN2d_p4`, `GCNN2d_p4m`, `GCNN3d_p4`, `GCNN3d_p4m`, `GFNO2d_p4`, `GFNO2d_p4m`, `GFNO3d_p4`, `GFNO3d_p4m`, `Ghybrid2d_p4`, `Ghybrid2d_p4m`, `radialNO2d_p4`, `radialNO2d_p4m`, `radialNO3d_p4`, `radialNO3d_p4m` | +| Removed in Paddle | `GFNO2d_p4_steer`, `GFNO2d_p4m_steer`, `Unet_Rot2d`, `Unet_Rot_M2d`, `Unet_Rot_3D` | + +The removed Paddle-only model types depended on `e2cnn` or `escnn`, which require Torch at runtime. The four main experiment commands documented below all use `GFNO2d_p4`, which remains supported. + +From `examples/G-FNO`: + +### NS + +```bash +python "experiments.py" --seed=1 --data_path="./data/ns_V1e-4_N10000_T30.mat" \ + --results_path="./results/ns_V1e-4_N10000_T30.mat/GFNO2d_p4" --strategy=teacher_forcing \ + --T=20 --ntrain=1000 --nvalid=100 --ntest=100 --model_type=GFNO2d_p4 --modes=12 --width=10 \ + --batch_size=20 --epochs=100 --suffix=seed1 --txt_suffix="ns_V1e-4_N10000_T30.mat_GFNO2d_p4_seed1" \ + --learning_rate=1e-3 --early_stopping=100 --verbose --super \ + --super_path="./data/ns_data_V1e-4_N20_T50_R256test.mat" --device=auto +``` + +### NS-Sym + +```bash +python "experiments.py" --seed=1 --data_path="./data/ns_V0.0001_N1200_T30_cos4.mat" \ + --results_path="./results/ns_V0.0001_N1200_T30_cos4.mat/GFNO2d_p4" --strategy=teacher_forcing \ + --T=10 --ntrain=1000 --nvalid=100 --ntest=100 --model_type=GFNO2d_p4 --modes=12 --width=10 \ + --batch_size=20 --epochs=100 --suffix=seed1 --txt_suffix="ns_V0.0001_N1200_T30_cos4.mat_GFNO2d_p4_seed1" \ + --learning_rate=1e-3 --early_stopping=100 --verbose --super \ + --super_path="./data/ns_V0.0001_N1200_T30_cos4_super.mat" --device=auto +``` + +### SWE (PDEArena) + +```bash +python "experiments.py" --seed=1 --data_path="./data/ShallowWater2D" \ + --results_path="./results/ShallowWater2D/GFNO2d_p4" --strategy=teacher_forcing \ + --T=9 --ntrain=5600 --nvalid=1120 --ntest=1120 --model_type=GFNO2d_p4 --modes=32 --width=10 \ + --batch_size=20 --epochs=100 --suffix=seed1 --txt_suffix="ShallowWater2D_GFNO2d_p4_seed1" \ + --learning_rate=1e-3 --early_stopping=100 --verbose --time_pad --device=auto +``` + +### SWE-Sym (PDEBench) + +```bash +python "experiments.py" --seed=1 --data_path="./data/2D_rdb_NA_NA.h5" \ + --results_path="./results/2D_rdb_NA_NA.h5/GFNO2d_p4" --strategy=teacher_forcing \ + --T=24 --ntrain=800 --nvalid=100 --ntest=100 --model_type=GFNO2d_p4 --modes=12 --width=10 \ + --batch_size=20 --epochs=100 --suffix=seed1 --txt_suffix="2D_rdb_NA_NA.h5_GFNO2d_p4_seed1" \ + --learning_rate=1e-3 --early_stopping=100 --verbose --super --device=auto +``` + +## 7. Citation + +```latex +@inproceedings{helwig2023group, +author = {Jacob Helwig and Xuan Zhang and Cong Fu and Jerry Kurtin and Stephan Wojtowytsch and Shuiwang Ji}, +title = {Group Equivariant {Fourier} Neural Operators for Partial Differential Equations}, +booktitle = {Proceedings of the 40th International Conference on Machine Learning}, +year = {2023}, +} +``` diff --git a/examples/G-FNO/assets/network_visual.png b/examples/G-FNO/assets/network_visual.png new file mode 100644 index 0000000..9e60427 Binary files /dev/null and b/examples/G-FNO/assets/network_visual.png differ diff --git a/examples/G-FNO/data_generation/navier_stokes/ns_2d_rt.py b/examples/G-FNO/data_generation/navier_stokes/ns_2d_rt.py new file mode 100644 index 0000000..b053fba --- /dev/null +++ b/examples/G-FNO/data_generation/navier_stokes/ns_2d_rt.py @@ -0,0 +1,153 @@ +import os +from pathlib import Path + +import paddle + +""" +This is a modified version of ns_2d.py from https://github.com/zongyi-li/fourier_neural_operator +""" +import argparse +import math +from timeit import default_timer + +import scipy.io +from random_fields import GaussianRF +from tqdm import tqdm + +from ppcfd.models.g_fno.paddle_utils import set_runtime_device + + +def navier_stokes_2d(w0, f, domain_size, visc, T, delta_t=0.0001, record_steps=1): + N = w0.size()[-1] + k_max = math.floor(N / 2.0) + steps = math.ceil(T / delta_t) + w_h = paddle.fft.rfft2(w0) + f_h = paddle.fft.rfft2(f) + if len(f_h.size()) < len(w_h.size()): + f_h = paddle.unsqueeze(f_h, 0) + record_time = math.floor(steps / record_steps) + k_y = paddle.cat( + ( + paddle.arange(start=0, end=k_max, step=1, device=w0.device), + paddle.arange(start=-k_max, end=0, step=1, device=w0.device), + ), + 0, + ).repeat(N, 1) + k_x = k_y.transpose(0, 1) + k_x = k_x[..., : k_max + 1] + k_y = k_y[..., : k_max + 1] + lap = 4 * math.pi**2 * (k_x**2 + k_y**2) / domain_size**2 + lap[0, 0] = 1.0 + dealias = paddle.unsqueeze( + paddle.logical_and( + paddle.abs(k_y) <= 2.0 / 3.0 * k_max, paddle.abs(k_x) <= 2.0 / 3.0 * k_max + ).float(), + 0, + ) + sol = paddle.zeros(*w0.size(), record_steps, device=w0.device) + sol_t = paddle.zeros(record_steps, device=w0.device) + c = 0 + t = 0.0 + for j in range(steps): + psi_h = w_h / lap + q = 2.0 * math.pi / domain_size * k_y * 1.0j * psi_h + q = paddle.fft.irfft2(q, s=(N, N)) + v = -2.0 * math.pi / domain_size * k_x * 1.0j * psi_h + v = paddle.fft.irfft2(v, s=(N, N)) + w_x = 2.0 * math.pi / domain_size * k_x * 1.0j * w_h + w_x = paddle.fft.irfft2(w_x, s=(N, N)) + w_y = 2.0 * math.pi / domain_size * k_y * 1.0j * w_h + w_y = paddle.fft.irfft2(w_y, s=(N, N)) + F_h = paddle.fft.rfft2(q * w_x + v * w_y) + F_h = dealias * F_h + w_h = ( + -delta_t * F_h + delta_t * f_h + (1.0 - 0.5 * delta_t * visc * lap) * w_h + ) / (1.0 + 0.5 * delta_t * visc * lap) + t += delta_t + if (j + 1) % record_time == 0: + w = paddle.fft.irfft2(w_h, s=(N, N)) + sol[..., c] = w + sol_t[c] = t + c += 1 + return sol, sol_t + + +parser = argparse.ArgumentParser() +parser.add_argument("--nu", type=float, required=True) +parser.add_argument("--s", type=int, default=256) +parser.add_argument("--T", type=int, required=True, help="Time horizon") +parser.add_argument("--N", type=int, required=True) +parser.add_argument("--save_path", type=str, required=True) +parser.add_argument("--bsize", type=int, default=20) +parser.add_argument("--suffix", type=str, default=None) +parser.add_argument( + "--ntest", type=int, required=True, help="Number of superresolution examples" +) +parser.add_argument("--period", type=int, required=True, help="Period if sym is true") +parser.add_argument( + "--sym", action="store_true", default=True, help="Use a symmetric forcing term" +) +parser.add_argument("--domain_size", type=float, default=1) +parser.add_argument( + "--device", + type=str, + default="auto", + help="runtime device, e.g. auto, cpu, gpu, xpu, npu, gcu", +) +args = parser.parse_args() +device = set_runtime_device(args.device) +s = args.s +N = args.N +GRF = GaussianRF(2, s, args.domain_size, alpha=2.5, tau=7, device=device) +t = paddle.linspace(0, args.domain_size, s + 1, device=device) +t = t[0:-1] +X, Y = paddle.meshgrid(t, t, indexing="ij") +if args.sym: + f = 0.1 * ( + paddle.cos(args.period * math.pi * X) + paddle.cos(args.period * math.pi * Y) + ) +else: + f = 0.1 * (paddle.sin(2 * math.pi * (X + Y)) + paddle.cos(2 * math.pi * (X + Y))) +record_steps = args.T * 4 +a = paddle.zeros(N, s, s) +u = paddle.zeros(N, s, s, record_steps) +bsize = args.bsize +c = 0 +t0 = default_timer() +for j in tqdm(range(N // bsize)): + w0 = GRF.sample(shape=bsize) + sol, sol_t = navier_stokes_2d( + w0, f, args.domain_size, args.nu, args.T, 0.0001, record_steps + ) + a[c : c + bsize, ...] = w0 + u[c : c + bsize, ...] = sol + c += bsize + t1 = default_timer() + print(j, c, t1 - t0) +a_super = a[-args.ntest :] +u_super = u[-args.ntest :] +space_sub = s // 64 +time_sub = 4 +a = a[..., ::space_sub, ::space_sub] +u = u[..., ::space_sub, ::space_sub, ::time_sub] +if args.sym: + data_name = f"ns_V{args.nu}_N{args.N}_T{args.T}_cos{args.period}{'_' + args.suffix if args.suffix is not None else ''}.mat" +else: + data_name = f"ns_V{args.nu}_N{args.N}_T{args.T}_sin{'_' + args.suffix if args.suffix is not None else ''}.mat" +super_name = data_name[:-4] + "_super.mat" +if not os.path.exists(args.save_path): + os.makedirs(args.save_path) +save_dir = os.path.join(args.save_path, data_name) +super_dir = os.path.join(args.save_path, super_name) +scipy.io.savemat( + save_dir, + mdict={"a": a.cpu().numpy(), "u": u.cpu().numpy(), "t": sol_t.cpu().numpy()}, +) +scipy.io.savemat( + super_dir, + mdict={ + "a": a_super.cpu().numpy(), + "u": u_super.cpu().numpy(), + "t": sol_t.cpu().numpy(), + }, +) diff --git a/examples/G-FNO/data_generation/navier_stokes/random_fields.py b/examples/G-FNO/data_generation/navier_stokes/random_fields.py new file mode 100644 index 0000000..6033a1c --- /dev/null +++ b/examples/G-FNO/data_generation/navier_stokes/random_fields.py @@ -0,0 +1,97 @@ +import paddle + +""" +Source: https://github.com/zongyi-li/fourier_neural_operator +""" +import math +from timeit import default_timer + + +class GaussianRF(object): + def __init__( + self, + dim, + size, + domain_size, + alpha=2, + tau=3, + sigma=None, + boundary="periodic", + device=None, + ): + self.dim = dim + self.device = device + if sigma is None: + sigma = tau ** (0.5 * (2 * alpha - self.dim)) + k_max = size // 2 + if dim == 1: + k = paddle.cat( + ( + paddle.arange(start=0, end=k_max, step=1, device=device), + paddle.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + self.sqrt_eig = ( + size + * math.sqrt(2.0) + * sigma + * (4 * math.pi**2 / domain_size**2 * k**2 + tau**2) + ** (-alpha / 2.0) + ) + self.sqrt_eig[0] = 0.0 + elif dim == 2: + wavenumers = paddle.cat( + ( + paddle.arange(start=0, end=k_max, step=1, device=device), + paddle.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ).repeat(size, 1) + k_x = wavenumers.transpose(0, 1) + k_y = wavenumers + self.sqrt_eig = ( + size**2 + * math.sqrt(2.0) + * sigma + * ( + 4 * math.pi**2 / domain_size**2 * (k_x**2 + k_y**2) + + tau**2 + ) + ** (-alpha / 2.0) + ) + self.sqrt_eig[0, 0] = 0.0 + elif dim == 3: + wavenumers = paddle.cat( + ( + paddle.arange(start=0, end=k_max, step=1, device=device), + paddle.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ).repeat(size, size, 1) + k_x = wavenumers.transpose(1, 2) + k_y = wavenumers + k_z = wavenumers.transpose(0, 2) + self.sqrt_eig = ( + size**3 + * math.sqrt(2.0) + * sigma + * ( + 4 + * math.pi**2 + / domain_size**2 + * (k_x**2 + k_y**2 + k_z**2) + + tau**2 + ) + ** (-alpha / 2.0) + ) + self.sqrt_eig[0, 0, 0] = 0.0 + self.size = [] + for j in range(self.dim): + self.size.append(size) + self.size = tuple(self.size) + + def sample(self, N): + coeff = paddle.randn(N, *self.size, dtype=paddle.complex64, device=self.device) + coeff = self.sqrt_eig * coeff + return paddle.fft.ifftn(coeff, dim=list(range(-1, -self.dim - 1, -1))).real() diff --git a/examples/G-FNO/experiments.py b/examples/G-FNO/experiments.py new file mode 100644 index 0000000..8f46c54 --- /dev/null +++ b/examples/G-FNO/experiments.py @@ -0,0 +1,926 @@ +import os + +import paddle +from ppcfd.models.g_fno import FNO2d, FNO3d +from ppcfd.models.g_fno import GCNN2d, GCNN3d +from ppcfd.models.g_fno import GFNO2d, GFNO3d +from ppcfd.models.g_fno import Ghybrid2d +from ppcfd.models.g_fno import radialNO2d, radialNO3d +from ppcfd.models.g_fno.paddle_utils import _set_num_threads +from ppcfd.models.g_fno.paddle_utils import move_to_device +from ppcfd.models.g_fno.paddle_utils import set_runtime_device + +""" +This is a modified version of fourier_2d_time.py from https://github.com/zongyi-li/fourier_neural_operator +""" +import datetime +import random +import argparse +from timeit import default_timer + +import h5py +import numpy as np +import scipy +import xarray as xr +from tqdm import tqdm +from utils import LpLoss, eq_check_rf, eq_check_rt, pde_data + +_set_num_threads(1) + + +def get_eval_pred(model, x, strategy, T, times): + if strategy == "oneshot": + pred = model(x) + else: + for t in range(T): + t1 = default_timer() + im = model(x) + times.append(default_timer() - t1) + if t == 0: + pred = im + else: + pred = paddle.cat((pred, im), -2) + if strategy == "markov": + x = im + else: + x = paddle.cat((x[..., 1:, :], im), dim=-2) + return pred + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--results_path", type=str, default="./results/tmp", help="path to store results" +) +parser.add_argument( + "--suffix", type=str, default=None, help="suffix to add to the results path" +) +parser.add_argument( + "--txt_suffix", type=str, default=None, help="suffix to add to the results txt" +) +parser.add_argument("--data_path", type=str, required=True, help="path to the data") +parser.add_argument( + "--super_path", type=str, default=None, help="path to the superresolution data" +) +parser.add_argument("--super", action="store_true", help="enable superres testing") +parser.add_argument("--verbose", action="store_true") +parser.add_argument( + "--T", type=int, required=True, help="number of timesteps to predict" +) +parser.add_argument("--ntrain", type=int, required=True, help="training sample size") +parser.add_argument("--nvalid", type=int, required=True, help="valid sample size") +parser.add_argument("--ntest", type=int, required=True, help="test sample size") +parser.add_argument("--nsuper", type=int, default=None) +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--model_type", type=str, required=True) +parser.add_argument("--depth", type=int, default=4) +parser.add_argument("--modes", type=int, default=12) +parser.add_argument("--width", type=int, default=20) +parser.add_argument( + "--Gwidth", + type=int, + default=10, + help="hidden dimension of equivariant layers if model_type=hybrid", +) +parser.add_argument( + "--n_equiv", + type=int, + default=3, + help="number of equivariant layers if model_type=hybrid", +) +parser.add_argument( + "--reflection", + action="store_true", + help="symmetry group p4->p4m for data augmentation", +) +parser.add_argument( + "--grid", type=str, default=None, help="[symmetric, cartesian, None]" +) +parser.add_argument("--epochs", type=int, default=100) +parser.add_argument( + "--early_stopping", + type=int, + default=None, + help="stop if validation error does not improve for successive epochs", +) +parser.add_argument("--batch_size", type=int, default=20) +parser.add_argument("--learning_rate", type=float, default=0.001) +parser.add_argument("--step", action="store_true", help="use step scheduler") +parser.add_argument( + "--gamma", type=float, default=None, help="gamma for step scheduler" +) +parser.add_argument( + "--step_size", type=int, default=None, help="step size for step scheduler" +) +parser.add_argument("--lmbda", type=float, default=0.0001, help="weight decay for adam") +parser.add_argument( + "--strategy", type=str, default="markov", help="markov, recurrent or oneshot" +) +parser.add_argument( + "--time_pad", + action="store_true", + help="pad the time dimension for strategy=oneshot", +) +parser.add_argument( + "--noise_std", + type=float, + default=0.0, + help="amount of noise to inject for strategy=markov", +) +parser.add_argument( + "--device", + type=str, + default="auto", + help="runtime device, e.g. auto, cpu, gpu, xpu, npu, gcu", +) +parser.add_argument( + "--rdb_super_res", + type=int, + default=128, + help="original spatial resolution for PDEBench radial dam break data", +) +parser.add_argument( + "--rdb_downsample", + type=int, + default=4, + help="spatial downsampling factor for PDEBench radial dam break data", +) +args = parser.parse_args() +assert args.model_type in [ + "FNO2d", + "FNO2d_aug", + "FNO3d", + "FNO3d_aug", + "GCNN2d_p4", + "GCNN2d_p4m", + "GCNN3d_p4", + "GCNN3d_p4m", + "GFNO2d_p4", + "GFNO2d_p4m", + "GFNO3d_p4", + "GFNO3d_p4m", + "Ghybrid2d_p4", + "Ghybrid2d_p4m", + "radialNO2d_p4", + "radialNO2d_p4m", + "radialNO3d_p4", + "radialNO3d_p4m", +], f"Invalid model type {args.model_type}" +assert args.strategy in [ + "teacher_forcing", + "markov", + "recurrent", + "oneshot", +], "Invalid training strategy" +runtime_device = set_runtime_device(args.device) +paddle.seed(args.seed) +np.random.seed(args.seed) +random.seed(args.seed) + + +def to_runtime(obj): + return move_to_device(obj, runtime_device) + + +def move_batch(xx, yy): + return to_runtime(xx), to_runtime(yy) + + +data_aug = "aug" in args.model_type +TRAIN_PATH = args.data_path +S = Sx = Sy = 64 +S_super = 4 * S +T_in = 10 +T = args.T +T_super = 4 * T +d = 2 +num_channels = 1 +threeD = args.model_type in [ + "FNO3d", + "FNO3d_aug", + "GCNN3d_p4", + "GCNN3d_p4m", + "GFNO3d_p4", + "GFNO3d_p4m", + "radialNO3d_p4", + "radialNO3d_p4m", +] +extension = TRAIN_PATH.split(".")[-1] +swe = os.path.split(TRAIN_PATH)[-1] == "ShallowWater2D" +rdb = TRAIN_PATH.split(os.path.sep)[-1][:6] == "2D_rdb" +grid_type = "symmetric" +if args.grid: + grid_type = args.grid + assert grid_type in ["symmetric", "cartesian", "None"] +if rdb: + assert T == 24, "T should be 24 for rdb" + assert args.rdb_downsample > 0, "rdb_downsample should be positive" + assert args.rdb_super_res > 0, "rdb_super_res should be positive" + assert ( + args.rdb_super_res % args.rdb_downsample == 0 + ), "rdb_super_res should be divisible by rdb_downsample" + T_in = 1 + S_super = args.rdb_super_res + S = Sx = Sy = S_super // args.rdb_downsample + T_super = 96 +elif swe: + assert not args.super, "Superresolution not supported for pdearena" + assert T == 9, "T should be 9 for swe" + T_in = 2 + Sy, Sx = 96, 192 + num_channels = 2 + grid_type = "cartesian" +spatial_dims = range(1, d + 1) +if args.strategy == "oneshot": + assert threeD, "oneshot strategy only for 3d models" +if threeD: + assert args.strategy == "oneshot", "threeD models use oneshot strategy" +ntrain = args.ntrain +nvalid = args.nvalid +ntest = args.ntest +time_modes = None +time = args.strategy == "oneshot" +if time and not args.time_pad: + time_modes = 5 if swe else 6 +elif time and swe: + time_modes = 8 +modes = args.modes +width = args.width +n_layer = args.depth +batch_size = args.batch_size +epochs = args.epochs +learning_rate = args.learning_rate +scheduler_step = args.step_size +scheduler_gamma = args.gamma +initial_step = 1 if args.strategy == "markov" else T_in +root = args.results_path + f"/{'_'.join(str(datetime.datetime.now()).split())}" +if args.suffix: + root += "_" + args.suffix +os.makedirs(root) +path_model = os.path.join(root, "model.pt") +if args.model_type in ["FNO2d", "FNO2d_aug"]: + model = to_runtime( + FNO2d( + num_channels=num_channels, + initial_step=initial_step, + modes1=modes, + modes2=modes, + width=width, + grid_type=grid_type, + ) + ) +elif args.model_type in ["FNO3d", "FNO3d_aug"]: + modes3 = time_modes if time_modes else modes + model = to_runtime( + FNO3d( + num_channels=num_channels, + initial_step=initial_step, + modes1=modes, + modes2=modes, + modes3=modes3, + width=width, + time=time, + time_pad=args.time_pad, + ) + ) +elif "GCNN2d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + GCNN2d( + num_channels=num_channels, + initial_step=initial_step, + width=width, + reflection=reflection, + ) + ) +elif "GCNN3d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + GCNN3d( + num_channels=num_channels, + initial_step=initial_step, + width=width, + reflection=reflection, + ) + ) +elif "GFNO2d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + GFNO2d( + num_channels=num_channels, + initial_step=initial_step, + modes=modes, + width=width, + reflection=reflection, + grid_type=grid_type, + ) + ) +elif "GFNO3d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + GFNO3d( + num_channels=num_channels, + initial_step=initial_step, + modes=modes, + time_modes=time_modes, + width=width, + reflection=reflection, + grid_type=grid_type, + time_pad=args.time_pad, + ) + ) +elif "Ghybrid2d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + Ghybrid2d( + num_channels=num_channels, + initial_step=initial_step, + modes=modes, + Gwidth=args.Gwidth, + width=width, + reflection=reflection, + n_equiv=args.n_equiv, + ) + ) +elif "radialNO2d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + radialNO2d( + num_channels=num_channels, + initial_step=initial_step, + modes=modes, + width=width, + reflection=reflection, + grid_type=grid_type, + ) + ) +elif "radialNO3d" in args.model_type: + reflection = "p4m" in args.model_type + model = to_runtime( + radialNO3d( + num_channels=num_channels, + initial_step=initial_step, + modes=modes, + time_modes=time_modes, + width=width, + reflection=reflection, + grid_type=grid_type, + time_pad=args.time_pad, + ) + ) +else: + raise NotImplementedError("Model not recognized") +if args.strategy == "oneshot": + x_shape = [batch_size, Sy, Sx, T, initial_step, num_channels] + x_shape_super = [1, S_super, S_super, T_super, initial_step, num_channels] +elif args.strategy == "markov": + x_shape = [batch_size, Sy, Sx, 1, num_channels] + x_shape_super = [1, *((S_super,) * d), 1, num_channels] +else: + x_shape = [batch_size, Sy, Sx, T_in, num_channels] + x_shape_super = [1, *((S_super,) * d), T_in, num_channels] +model.train() +x = to_runtime(paddle.randn(*x_shape)) +if args.strategy == "recurrent": + for _ in range(T): + im = model(x) + x = paddle.cat([x[..., 1:, :], im], dim=-2) +else: + model(x) +eq_check_rt(model, x, spatial_dims) +eq_check_rf(model, x, spatial_dims) +if args.super: + model.eval() + with paddle.no_grad(): + x = to_runtime(paddle.randn(*x_shape_super)) + model(x) +full_data = None +if extension == "mat": + assert num_channels == 1, "num channels should be 1 for .mat data" + assert d == 2, "spatial dim should be 2 for .mat data" + sub = 1 + try: + with h5py.File(TRAIN_PATH, "r") as f: + data = np.array(f["u"]) + data = np.transpose(data, axes=range(len(data.shape) - 1, -1, -1)) + except: + data = scipy.io.loadmat(os.path.expandvars(TRAIN_PATH))["u"].astype(np.float32) + data = data[..., None] +elif rdb: + assert num_channels == 1, "num channels should be 1 for shallow water equations" + assert d == 2, "spatial dim should be 2 for shallow water equations" + with h5py.File(TRAIN_PATH, "r") as f: + data_list = sorted(f.keys()) + data = np.concatenate( + [np.array(f[key]["data"])[None] for key in data_list] + ).transpose(0, 2, 3, 1, 4)[..., :-1, :] + full_data = data[-ntest:] + sampler = paddle.compat.nn.AvgPool2d(kernel_size=args.rdb_downsample) + data = ( + sampler(paddle.tensor(data[..., ::4, 0]).permute(0, 3, 1, 2)) + .permute(0, 2, 3, 1) + .unsqueeze(-1) + .numpy() + ) +elif swe: + assert num_channels == 2, "num channels should be 2 for shallow water equations" + assert ( + ntrain + nvalid + ntest <= 5600 + 1120 + 1120 + ), f"Only {5600 + 1120 + 1120} solutions available" + splits = {"train": ntrain, "valid": nvalid, "test": ntest} + datas = {} + for split, n in splits.items(): + if args.verbose: + print(f"SWE: loading {split}") + path = os.path.join(TRAIN_PATH, f"{split}.zarr") + data = xr.open_zarr(path) + normstat = paddle.load(path=str(os.path.join(TRAIN_PATH, "normstats.pt"))) + sample_rate = 8 + VORT_IND = 0 + PRES_IND = 1 + datas[split] = [] + for idx in tqdm(range(n), disable=not args.verbose): + vort = paddle.tensor(data["vor"][idx].to_numpy()) + vort = (vort - normstat["vor"]["mean"]) / normstat["vor"]["std"] + pres = paddle.tensor(data["pres"][idx].to_numpy()) + pres = (pres - normstat["pres"]["mean"]) / normstat["pres"]["std"] + pres = pres.unsqueeze(1) + pres = pres[4::sample_rate] + vort = vort[4::sample_rate] + pres_vort = paddle.cat([vort, pres], dim=1).permute(2, 3, 0, 1).unsqueeze(0) + datas[split].append(pres_vort) + datas[split] = paddle.cat(datas[split]) + data = paddle.cat([datas["train"], datas["valid"], datas["test"]]) +else: + raise ValueError(f"Extension {extension} not recognized") +assert data.shape[-2] >= T + T_in, "not enough time" +if args.super: + assert not swe, "Superresolution is not supported for the PDE Arena SWE" + assert full_data is not None or args.super_path is not None, "missing super dataset" +if not swe: + data = paddle.from_numpy(data) +assert len(data) >= ntrain + nvalid + ntest, f"not enough data; {len(data)}" +train = data[:ntrain] +assert len(train) == ntrain, "not enough training data" +test = data[-ntest:] +test_rt = test.rot90(axes=list(spatial_dims)[:2]) +test_rf = test.flip(axis=(spatial_dims[0],)) +assert len(test) == ntest, "not enough test data" +valid = data[-(ntest + nvalid) : -ntest] +assert len(valid) == nvalid, "not enough validation data" +if args.verbose: + print(f"{args.model_type}: Train/valid/test data shape: ") + print(train.shape) + print(valid.shape) + print(test.shape) +assert Sx == train.shape[-3], f"Spatial downsampling should give {Sx} grid points" +assert Sy == train.shape[-4], f"Spatial downsampling should give {Sy} grid points" +train_data = pde_data( + train, strategy=args.strategy, T_in=T_in, T_out=T, std=args.noise_std +) +ntrain = len(train_data) +valid_data = pde_data(valid, train=False, strategy=args.strategy, T_in=T_in, T_out=T) +nvalid = len(valid_data) +test_data = pde_data(test, train=False, strategy=args.strategy, T_in=T_in, T_out=T) +test_rt_data = pde_data( + test_rt, train=False, strategy=args.strategy, T_in=T_in, T_out=T +) +test_rf_data = pde_data( + test_rf, train=False, strategy=args.strategy, T_in=T_in, T_out=T +) +ntest = len(test_data) +train_loader = paddle.io.DataLoader( + dataset=train_data, batch_size=batch_size, shuffle=True +) +valid_loader = paddle.io.DataLoader( + dataset=valid_data, batch_size=batch_size, shuffle=False +) +test_loader = paddle.io.DataLoader( + dataset=test_data, batch_size=batch_size, shuffle=False +) +test_rt_loader = paddle.io.DataLoader( + dataset=test_rt_data, batch_size=batch_size, shuffle=False +) +test_rf_loader = paddle.io.DataLoader( + dataset=test_rf_data, batch_size=batch_size, shuffle=False +) +complex_ct = sum(par.size * (1 + par.is_complex()) for par in model.parameters()) +real_ct = sum(par.size for par in model.parameters()) +if args.verbose: + print( + f"{args.model_type}; # Params: complex count {complex_ct}, real count: {real_ct}" + ) +optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=learning_rate, weight_decay=args.lmbda +) +if args.step: + assert args.step_size is not None, "step_size is None" + assert scheduler_gamma is not None, "gamma is None" + tmp_lr = paddle.optimizer.lr.StepDecay( + step_size=args.step_size, + gamma=scheduler_gamma, + learning_rate=optimizer.get_lr(), + ) + optimizer.set_lr_scheduler(tmp_lr) + scheduler = tmp_lr +else: + num_training_steps = epochs * len(train_loader) + tmp_lr = paddle.optimizer.lr.CosineAnnealingDecay( + T_max=num_training_steps, learning_rate=optimizer.get_lr() + ) + optimizer.set_lr_scheduler(tmp_lr) + scheduler = tmp_lr +lploss = LpLoss(size_average=False) +best_valid = float("inf") +x_train, y_train = next(iter(train_loader)) +x, y = move_batch(x_train, y_train) +x_valid, y_valid = next(iter(valid_loader)) +if args.verbose: + print(f"{args.model_type}; Input shape: {x.shape}, Target shape: {y.shape}") +if args.strategy == "oneshot": + assert x_train[0].shape == paddle.Size([Sy, Sx, T, T_in, num_channels]), x_train[ + 0 + ].shape + assert y_train[0].shape == paddle.Size([Sy, Sx, T, num_channels]), y_train[0].shape + assert x_valid[0].shape == paddle.Size([Sy, Sx, T, T_in, num_channels]), x_valid[ + 0 + ].shape + assert y_valid[0].shape == paddle.Size([Sy, Sx, T, num_channels]), y_valid[0].shape +elif args.strategy == "markov": + assert x_train[0].shape == paddle.Size([Sy, Sx, 1, num_channels]), x_train[0].shape + assert y_train[0].shape == paddle.Size([Sy, Sx, num_channels]), y_train[0].shape + assert x_valid[0].shape == paddle.Size([Sy, Sx, 1, num_channels]), x_valid[0].shape + assert y_valid[0].shape == paddle.Size([Sy, Sx, T, num_channels]), y_valid[0].shape +else: + assert x_train[0].shape == paddle.Size([Sy, Sx, T_in, num_channels]), x_train[ + 0 + ].shape + assert x_valid[0].shape == paddle.Size([Sy, Sx, T_in, num_channels]), x_valid[ + 0 + ].shape + assert y_valid[0].shape == paddle.Size([Sy, Sx, T, num_channels]), y_valid[0].shape + if args.strategy == "recurrent": + assert y_train[0].shape == paddle.Size([Sy, Sx, T, num_channels]), y_train[ + 0 + ].shape + else: + assert y_train[0].shape == paddle.Size([Sy, Sx, num_channels]), y_train[0].shape +model.eval() +if args.verbose: + print( + f"{args.model_type} pre-train equivariance checks: Rotations - {eq_check_rt(model, x, spatial_dims)}, Reflections - {eq_check_rf(model, x, spatial_dims)}" + ) +start = default_timer() +if args.verbose: + print("Training...") +train_times = [] +eval_times = [] +for ep in range(epochs): + model.train() + t1 = default_timer() + train_l2 = train_vort_l2 = train_pres_l2 = 0 + for xx, yy in tqdm(train_loader, disable=not args.verbose): + loss = 0 + xx, yy = move_batch(xx, yy) + if data_aug: + for b in range(len(xx)): + for j in range(len(spatial_dims)): + for l in range(j + 1, len(spatial_dims)): + k_rt = random.randint(0, 3) + if k_rt > 0: + if not swe: + dims = [spatial_dims[j] - 1, spatial_dims[l] - 1] + xx[b] = xx[b].rot90(axes=dims, k=k_rt) + yy[b] = yy[b].rot90(axes=dims, k=k_rt) + elif b == 0: + dims = [spatial_dims[j], spatial_dims[l]] + xx = xx.rot90(axes=dims, k=k_rt) + yy = yy.rot90(axes=dims, k=k_rt) + if args.reflection: + k_rf = random.randint(0, 1) + if k_rf == 1: + xx[b] = xx[b].flip(axis=(spatial_dims[j] - 1,)) + yy[b] = yy[b].flip(axis=(spatial_dims[j] - 1,)) + if args.strategy == "recurrent": + for t in range(yy.shape[-2]): + y = yy[..., t, :] + im = model(xx) + loss += lploss( + im.reshape(len(im), -1, num_channels), + y.reshape(len(y), -1, num_channels), + ) + xx = paddle.cat((xx[..., 1:, :], im), dim=-2) + loss /= yy.shape[-2] + else: + im = model(xx) + if args.strategy == "oneshot": + im = im.squeeze(-1) + loss = lploss( + im.reshape(len(im), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ) + train_l2 += loss.item() + if swe: + train_vort_l2 += lploss( + im[..., VORT_IND].reshape(len(im), -1, 1), + yy[..., VORT_IND].reshape(len(yy), -1, 1), + ).item() + train_pres_l2 += lploss( + im[..., PRES_IND].reshape(len(im), -1, 1), + yy[..., PRES_IND].reshape(len(yy), -1, 1), + ).item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if not args.step: + scheduler.step() + if args.step: + scheduler.step() + train_times.append(default_timer() - t1) + valid_l2 = valid_vort_l2 = valid_pres_l2 = 0 + valid_loss_by_channel = None + with paddle.no_grad(): + model.eval() + model(xx) + for xx, yy in valid_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=eval_times + ).view(len(xx), Sy, Sx, T, num_channels) + valid_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + if swe: + valid_vort_l2 += lploss( + pred[..., VORT_IND].reshape(len(pred), -1, 1), + yy[..., VORT_IND].reshape(len(yy), -1, 1), + ).item() + valid_pres_l2 += lploss( + pred[..., PRES_IND].reshape(len(pred), -1, 1), + yy[..., PRES_IND].reshape(len(yy), -1, 1), + ).item() + t2 = default_timer() + if args.verbose: + print( + f"Ep: {ep}, time: {t2 - t1}, train: {train_l2 / ntrain}, valid: {valid_l2 / nvalid}" + ) + if valid_l2 < best_valid: + best_epoch = ep + best_valid = valid_l2 + paddle.save(obj=model.state_dict(), path=path_model) + if args.early_stopping: + if ep - best_epoch > args.early_stopping: + break +stop = default_timer() +train_time = stop - start +train_times = paddle.tensor(train_times).mean().item() +num_eval = len(eval_times) +eval_times = paddle.tensor(eval_times).mean().item() +model.eval() +if args.verbose: + print( + f"{args.model_type} post-train equivariance checks: Rotations - {eq_check_rt(model, xx, spatial_dims)}, Reflections - {eq_check_rf(model, xx, spatial_dims)}" + ) +model.load_state_dict(paddle.load(path=str(path_model))) +model.eval() +test_l2 = test_vort_l2 = test_pres_l2 = 0 +rotations_l2 = 0 +reflections_l2 = 0 +test_rt_l2 = 0 +test_rf_l2 = 0 +test_loss_by_channel = None +with paddle.no_grad(): + for xx, yy in test_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), Sy, Sx, T, num_channels) + test_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + if swe: + test_vort_l2 += lploss( + pred[..., VORT_IND].reshape(len(pred), -1, 1), + yy[..., VORT_IND].reshape(len(yy), -1, 1), + ).item() + test_pres_l2 += lploss( + pred[..., PRES_IND].reshape(len(pred), -1, 1), + yy[..., PRES_IND].reshape(len(yy), -1, 1), + ).item() + rotations_l2 += lploss( + model(xx) + .rot90(axes=list(spatial_dims)[:2]) + .reshape(len(pred), -1, num_channels), + model(xx.rot90(axes=list(spatial_dims)[:2])).reshape( + len(pred), -1, num_channels + ), + ) + reflections_l2 += lploss( + model(xx) + .flip(axis=(spatial_dims[0],)) + .reshape(len(pred), -1, num_channels), + model(xx.flip(axis=(spatial_dims[0],))).reshape( + len(pred), -1, num_channels + ), + ) + for xx, yy in test_rt_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), Sy, Sx, T, num_channels) + test_rt_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + for xx, yy in test_rf_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), Sy, Sx, T, num_channels) + test_rf_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + rotations_l2 = rotations_l2 / ntest + reflections_l2 = reflections_l2 / ntest +( + test_time_l2 +) = test_space_l2 = ntest_super = test_int_space_l2 = test_int_time_l2 = None +if args.super: + if args.super_path and full_data is None: + indent = 1 + try: + with h5py.File(args.super_path, "r") as f: + data = np.array(f["u"]) + data = np.transpose(data, axes=range(len(data.shape) - 1, -1, -1)) + except: + data = scipy.io.loadmat(os.path.expandvars(args.super_path))["u"].astype( + np.float32 + ) + if args.nsuper: + data = data[: args.nsuper] + assert data.shape[1] == S_super, "wrong super space" + assert data.shape[2] == S_super, "wrong super space" + test_a = data[..., 3 : T_in * 4 : 4] + test_space_u = data[..., T_in * 4 : (T + T_in) * 4 : 4] + test_time_u = data[..., T_in * 4 : (T + T_in) * 4] + assert test_time_u.shape[-1] == T_super, "wrong super time" + test_space = paddle.from_numpy( + np.concatenate([test_a, test_space_u], axis=-1) + ).unsqueeze(-1) + test_time = paddle.from_numpy( + np.concatenate([test_a, test_time_u], axis=-1) + ).unsqueeze(-1) + elif full_data is not None: + if args.nsuper: + full_data = full_data[: args.nsuper] + if rdb: + test_space = paddle.from_numpy(full_data[..., ::4, :]) + test_time = np.concatenate( + [full_data[..., ::4, :][..., :1, :], full_data[..., 4:, :]], axis=-2 + ) + test_time = paddle.from_numpy(test_time) + else: + raise ValueError("Missing super data") + test_int_space = test_space.clone() + test_int_time = test_time.clone() + batch_size = 1 + test_space = pde_data( + test_space, train=False, strategy=args.strategy, T_in=T_in, T_out=T + ) + test_int_space = pde_data( + test_int_space, train=False, strategy=args.strategy, T_in=T_in, T_out=T + ) + test_time = pde_data( + test_time, train=False, strategy=args.strategy, T_in=T_in, T_out=T_super + ) + test_int_time = pde_data( + test_int_time, train=False, strategy=args.strategy, T_in=T_in, T_out=T_super + ) + space_loader = paddle.io.DataLoader( + dataset=test_space, batch_size=batch_size, shuffle=False + ) + space_int_loader = paddle.io.DataLoader( + dataset=test_int_space, batch_size=batch_size, shuffle=False + ) + time_loader = paddle.io.DataLoader( + dataset=test_time, batch_size=batch_size, shuffle=False + ) + time_int_loader = paddle.io.DataLoader( + dataset=test_int_time, batch_size=batch_size, shuffle=False + ) + ntest_super = len(space_loader) + test_time_l2 = 0 + test_int_time_l2 = 0 + test_space_l2 = 0 + test_int_space_l2 = 0 + space_permute_inds = [0, 3, 1, 2] + space_unpermute_inds = [0, 2, 3, 1] + space_int_size = [*((S_super,) * d)] + time_permute_inds = [0, 4, 1, 2, 3] + time_unpermute_inds = [0, 2, 3, 4, 1] + time_int_size = [*((S_super,) * d), T_super] + with paddle.no_grad(): + for xx, yy in space_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), *((S_super,) * d), T, num_channels) + test_space_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + for xx, yy in space_int_loader: + if rdb: + xx = to_runtime( + sampler(xx.view(1, S_super, S_super, -1).permute(0, 3, 1, 2)) + .permute(0, 2, 3, 1) + .view((1, *x_shape[1:])) + ) + else: + xx = to_runtime(xx[:, ::4, ::4]) + yy = to_runtime(yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), *((S,) * d), T) + pred = paddle.nn.functional.interpolate( + pred.permute(space_permute_inds), size=space_int_size, mode="bilinear" + ).permute(space_unpermute_inds) + test_int_space_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + for xx, yy in time_loader: + xx, yy = move_batch(xx, yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T_super, times=[] + ).view(len(xx), *((S_super,) * d), T_super, num_channels) + test_time_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + x_new_shape = x_shape + if threeD: + x_new_shape[len(spatial_dims) + 1] = T_super + x_new_shape[0] = 1 + for xx, yy in time_int_loader: + if rdb: + xx = to_runtime( + sampler(xx.view(1, S_super, S_super, -1).permute(0, 3, 1, 2)) + .permute(0, 2, 3, 1) + .view(x_new_shape) + ) + else: + xx = to_runtime(xx[:, ::4, ::4]) + if threeD: + xx = xx[:, :, :, ::4] + yy = to_runtime(yy) + pred = get_eval_pred( + model=model, x=xx, strategy=args.strategy, T=T, times=[] + ).view(len(xx), *((S,) * d), T, num_channels) + pred = paddle.nn.functional.interpolate( + pred.permute(time_permute_inds), size=time_int_size, mode="trilinear" + ).permute(time_unpermute_inds) + test_int_time_l2 += lploss( + pred.reshape(len(pred), -1, num_channels), + yy.reshape(len(yy), -1, num_channels), + ).item() + test_space_l2 = test_space_l2 / ntest_super + test_int_space_l2 = test_int_space_l2 / ntest_super + test_time_l2 = test_time_l2 / ntest_super + test_int_time_l2 = test_int_time_l2 / ntest_super +print( + f"""{args.model_type} done training; +Test: {test_l2 / ntest}, Rotations: {rotations_l2}, Reflections: {reflections_l2}, Super Space Test: {test_space_l2}, Super Time Test: {test_time_l2}""" +) +summary = f"""Args: {str(args)} +Parameters: {complex_ct} +Train time: {train_time} +Mean epoch time: {train_times} +Mean inference time: {eval_times} +Num inferences: {num_eval} +Train: {train_l2 / ntrain} +Valid: {valid_l2 / nvalid} +Test: {test_l2 / ntest} +Rotation Test: {test_rt_l2 / ntest} +Reflection Test: {test_rf_l2 / ntest} +Super Space Test: {test_space_l2} +Super Space Interpolation Test: {test_int_space_l2} +Super S: {S_super} +Super Time Test: {test_time_l2} +Super Time Interpolation Test: {test_int_time_l2} +Super T: {T_super} +Best Valid: {best_valid / nvalid} +Best epoch: {best_epoch + 1} +Test Rotation Equivariance loss: {rotations_l2} +Test Reflection Equivariance loss: {reflections_l2} +Epochs trained: {ep}""" +if swe: + summary += f"\nVorticity Test: {test_vort_l2 / ntest}\nPressure Test: {test_pres_l2 / ntest}" +txt = "results" +if args.txt_suffix: + txt += f"_{args.txt_suffix}" +txt += ".txt" +with open(os.path.join(root, txt), "w") as f: + f.write(summary) diff --git a/examples/G-FNO/utils.py b/examples/G-FNO/utils.py new file mode 100644 index 0000000..a20f14b --- /dev/null +++ b/examples/G-FNO/utils.py @@ -0,0 +1,129 @@ +import paddle + + +class pde_data(paddle.io.Dataset): + def __init__(self, data, T_in, T_out=None, train=True, strategy="markov", std=0.0): + self.markov = strategy == "markov" + self.teacher_forcing = strategy == "teacher_forcing" + self.one_shot = strategy == "oneshot" + self.data = ( + data[..., : T_in + T_out] if self.one_shot else data[..., : T_in + T_out, :] + ) + self.nt = T_in + T_out + self.T_in = T_in + self.T_out = T_out + self.num_hist = 1 if self.markov else self.T_in + self.train = train + self.noise_std = std + + def __len__(self): + if self.train: + if self.markov: + return len(self.data) * (self.nt - 1) + if self.teacher_forcing: + return len(self.data) * (self.nt - self.T_in) + return len(self.data) + + def __getitem__(self, idx): + if not self.train or not (self.markov or self.teacher_forcing): + pde = self.data[idx] + if self.one_shot: + x = pde[..., : self.T_in, :] + x = x.unsqueeze(-3).repeat([1, 1, self.T_out, 1, 1]) + y = pde[..., self.T_in : self.T_in + self.T_out, :] + else: + x = pde[..., self.T_in - self.num_hist : self.T_in, :] + y = pde[..., self.T_in : self.T_in + self.T_out, :] + return x, y + pde_idx = idx // (self.nt - self.num_hist) + t_idx = idx % (self.nt - self.num_hist) + self.num_hist + pde = self.data[pde_idx] + x = pde[..., t_idx - self.num_hist : t_idx, :] + y = pde[..., t_idx, :] + if self.noise_std > 0: + x += paddle.randn(*x.shape, device=x.device) * self.noise_std + return x, y + + +class LpLoss(object): + def __init__(self, d=2, p=2, size_average=True, reduction=True): + super(LpLoss, self).__init__() + assert d > 0 and p > 0 + self.d = d + self.p = p + self.reduction = reduction + self.size_average = size_average + + def abs(self, x, y): + num_examples = x.size()[0] + h = 1.0 / (x.size()[1] - 1.0) + all_norms = h ** (self.d / self.p) * paddle.norm( + x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1 + ) + if self.reduction: + if self.size_average: + return paddle.mean(all_norms) + return paddle.sum(all_norms) + return all_norms + + def rel(self, x, y): + num_examples = x.size()[0] + assert x.shape == y.shape and len(x.shape) == 3, "wrong shape" + diff_norms = paddle.norm(x - y, self.p, 1) + y_norms = paddle.norm(y, self.p, 1) + if self.reduction: + loss = (diff_norms / y_norms).mean(-1) + if self.size_average: + return paddle.mean(loss) + return paddle.sum(loss) + return diff_norms / y_norms + + def __call__(self, x, y): + return self.rel(x, y) + + +def eq_check_rt(model, x, spatial_dims): + model.eval() + diffs = [] + with paddle.no_grad(): + out = model(x) + out[out == 0] = float("nan") + for j in range(len(spatial_dims)): + for l in range(j + 1, len(spatial_dims)): + dims = [spatial_dims[j], spatial_dims[l]] + diffs.append( + [ + ( + ( + ( + out.rot90(k=k, axes=dims) + - model(x.rot90(k=k, axes=dims)) + ) + / out.rot90(k=k, axes=dims) + ) + .abs() + .nanmean() + .item() + * 100 + ) + for k in range(1, 4) + ] + ) + return paddle.tensor(diffs).mean().item() + + +def eq_check_rf(model, x, spatial_dims): + model.eval() + diffs = [] + with paddle.no_grad(): + out = model(x) + out[out == 0] = float("nan") + for j in spatial_dims: + diffs.append( + ((out.flip(axis=(j,)) - model(x.flip(axis=(j,)))) / out.flip(axis=(j,))) + .abs() + .nanmean() + .item() + * 100 + ) + return paddle.tensor(diffs).mean().item() diff --git a/ppcfd/models/__init__.py b/ppcfd/models/__init__.py index 13b523a..e096517 100755 --- a/ppcfd/models/__init__.py +++ b/ppcfd/models/__init__.py @@ -72,3 +72,11 @@ __all__.append("symbolic_gn") except ImportError: pass # Optional dependency + +# G-FNO +try: + from ppcfd.models import g_fno + + __all__.append("g_fno") +except ImportError: + pass # Optional dependency diff --git a/ppcfd/models/g_fno/FNO.py b/ppcfd/models/g_fno/FNO.py new file mode 100644 index 0000000..27991d9 --- /dev/null +++ b/ppcfd/models/g_fno/FNO.py @@ -0,0 +1,390 @@ +import paddle + +from .grid import grid +from .paddle_utils import move_to_device + + +class UnitGaussianNormalizer(object): + def __init__(self, x, eps=1e-05): + super(UnitGaussianNormalizer, self).__init__() + self.mean = paddle.mean(x, 0) + self.std = paddle.std(x=x, axis=0) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x + + def decode(self, x, sample_idx=None): + if sample_idx is None: + std = self.std + self.eps + mean = self.mean + else: + if len(self.mean.shape) == len(sample_idx[0].shape): + std = self.std[sample_idx] + self.eps + mean = self.mean[sample_idx] + if len(self.mean.shape) > len(sample_idx[0].shape): + std = self.std[:, sample_idx] + self.eps + mean = self.mean[:, sample_idx] + x = x * std + mean + return x + + def to(self, device): + self.mean = move_to_device(self.mean, device) + self.std = move_to_device(self.std, device) + return self + + def cpu(self): + return self.to("cpu") + + +class SpectralConv2d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes1, modes2): + super(SpectralConv2d, self).__init__() + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes1 + self.modes2 = modes2 + self.scale = 1 / (in_channels * out_channels) + self.weights1 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + dtype=paddle.complex64, + ) + ) + self.weights2 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + dtype=paddle.complex64, + ) + ) + + def compl_mul2d(self, input, weights): + return paddle.einsum("bixy,ioxy->boxy", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + x_ft = paddle.fft.rfft2(x) + out_ft = paddle.zeros( + batchsize, + self.out_channels, + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d( + x_ft[:, :, : self.modes1, : self.modes2], self.weights1 + ) + out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d( + x_ft[:, :, -self.modes1 :, : self.modes2], self.weights2 + ) + x = paddle.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + return x + + +class MLP2d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, mid_channels): + super(MLP2d, self).__init__() + self.mlp1 = paddle.nn.Conv2d(in_channels, mid_channels, 1) + self.mlp2 = paddle.nn.Conv2d(mid_channels, out_channels, 1) + + def forward(self, x): + x = self.mlp1(x) + x = paddle.nn.functional.gelu(x) + x = self.mlp2(x) + return x + + +class FNO2d(paddle.nn.Module): + def __init__(self, num_channels, modes1, modes2, width, initial_step, grid_type): + super(FNO2d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.modes1 = modes1 + self.modes2 = modes2 + self.width = width + self.padding = 8 + self.grid = grid(twoD=True, grid_type=grid_type) + self.p = paddle.compat.nn.Linear( + initial_step * num_channels + self.grid.grid_dim, self.width + ) + self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) + self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) + self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) + self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) + self.mlp0 = MLP2d(self.width, self.width, self.width) + self.mlp1 = MLP2d(self.width, self.width, self.width) + self.mlp2 = MLP2d(self.width, self.width, self.width) + self.mlp3 = MLP2d(self.width, self.width, self.width) + self.w0 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w1 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w2 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w3 = paddle.nn.Conv2d(self.width, self.width, 1) + self.norm = paddle.nn.InstanceNorm2D(num_features=self.width) + self.q = MLP2d(self.width, num_channels, self.width * 4) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) + x = self.grid(x) + x = self.p(x) + x = x.permute(0, 3, 1, 2) + x1 = self.norm(self.conv0(self.norm(x))) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv1(self.norm(x))) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv2(self.norm(x))) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv3(self.norm(x))) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + x = self.q(x) + x = x.permute(0, 2, 3, 1) + return x.unsqueeze(-2) + + def get_grid(self, shape): + batchsize, size_x, size_y = shape[0], shape[1], shape[2] + gridx = paddle.linspace(0, 1, size_x) + gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) + gridy = paddle.linspace(0, 1, size_y) + gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) + return paddle.cat((gridx, gridy), dim=-1) + + +class SpectralConv3d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes1, modes2, modes3): + super(SpectralConv3d, self).__init__() + """ + 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes1 + self.modes2 = modes2 + self.modes3 = modes3 + self.scale = 1 / (in_channels * out_channels) + self.weights1 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.modes3, + dtype=paddle.complex64, + ) + ) + self.weights2 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.modes3, + dtype=paddle.complex64, + ) + ) + self.weights3 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.modes3, + dtype=paddle.complex64, + ) + ) + self.weights4 = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.modes3, + dtype=paddle.complex64, + ) + ) + + def compl_mul3d(self, input, weights): + return paddle.einsum("bixyz,ioxyz->boxyz", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + x_ft = paddle.fft.rfftn(x, dim=[-3, -2, -1]) + out_ft = paddle.zeros( + batchsize, + self.out_channels, + x.size(-3), + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = self.compl_mul3d( + x_ft[:, :, : self.modes1, : self.modes2, : self.modes3], self.weights1 + ) + out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = self.compl_mul3d( + x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3], self.weights2 + ) + out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = self.compl_mul3d( + x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3], self.weights3 + ) + out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = self.compl_mul3d( + x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3], self.weights4 + ) + x = paddle.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) + return x + + +class MLP3d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, mid_channels): + super(MLP3d, self).__init__() + self.mlp1 = paddle.nn.Conv3d(in_channels, mid_channels, 1) + self.mlp2 = paddle.nn.Conv3d(mid_channels, out_channels, 1) + + def forward(self, x): + x = self.mlp1(x) + x = paddle.nn.functional.gelu(x) + x = self.mlp2(x) + return x + + +class FNO3d(paddle.nn.Module): + def __init__( + self, + num_channels, + modes1, + modes2, + modes3, + width, + initial_step, + time, + time_pad=False, + ): + super(FNO3d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. + input shape: (batchsize, x=64, y=64, t=40, c=13) + output: the solution of the next 40 timesteps + output shape: (batchsize, x=64, y=64, t=40, c=1) + """ + self.modes1 = modes1 + self.modes2 = modes2 + self.modes3 = modes3 + self.width = width + self.time = time + self.time_pad = time_pad + self.padding = 6 + self.p = paddle.compat.nn.Linear(initial_step * num_channels + 3, self.width) + self.conv0 = SpectralConv3d( + self.width, self.width, self.modes1, self.modes2, self.modes3 + ) + self.conv1 = SpectralConv3d( + self.width, self.width, self.modes1, self.modes2, self.modes3 + ) + self.conv2 = SpectralConv3d( + self.width, self.width, self.modes1, self.modes2, self.modes3 + ) + self.conv3 = SpectralConv3d( + self.width, self.width, self.modes1, self.modes2, self.modes3 + ) + self.mlp0 = MLP3d(self.width, self.width, self.width) + self.mlp1 = MLP3d(self.width, self.width, self.width) + self.mlp2 = MLP3d(self.width, self.width, self.width) + self.mlp3 = MLP3d(self.width, self.width, self.width) + self.w0 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w1 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w2 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w3 = paddle.nn.Conv3d(self.width, self.width, 1) + self.q = MLP3d(self.width, num_channels, self.width * 4) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3], -1) + grid = self.get_grid(x.shape).to(x.device) + x = paddle.cat((x, grid), dim=-1) + x = self.p(x) + x = x.permute(0, 4, 1, 2, 3) + if self.time and self.time_pad: + x = paddle.compat.nn.functional.pad(x, [0, self.padding]) + x1 = self.conv0(x) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv1(x) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv2(x) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv3(x) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + if self.time and self.time_pad: + x = x[..., : -self.padding] + x = self.q(x) + x = x.permute(0, 2, 3, 4, 1) + if not self.time: + x = x.unsqueeze(-2) + return x + + def get_grid(self, shape): + batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] + gridx = paddle.linspace(0, 1, size_x) + gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat( + [batchsize, 1, size_y, size_z, 1] + ) + gridy = paddle.linspace(0, 1, size_y) + gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat( + [batchsize, size_x, 1, size_z, 1] + ) + gridz = paddle.linspace(0, 1, size_z) + gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat( + [batchsize, size_x, size_y, 1, 1] + ) + return paddle.cat((gridx, gridy, gridz), dim=-1) diff --git a/ppcfd/models/g_fno/GCNN.py b/ppcfd/models/g_fno/GCNN.py new file mode 100644 index 0000000..70c3ff8 --- /dev/null +++ b/ppcfd/models/g_fno/GCNN.py @@ -0,0 +1,335 @@ +from functools import partial + +import paddle + +from .GFNO import GConv2d, GConv3d, GMLP2d, GMLP3d, GNorm + + +class GCNN2d(paddle.nn.Module): + def __init__(self, num_channels, width, initial_step, reflection): + super(GCNN2d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.kernel_size = 3 + assert self.kernel_size % 2 == 1, "Kernel size should be odd" + self.padding = (self.kernel_size - 1) // 2 + self.pad = partial( + paddle.compat.nn.functional.pad, pad=[self.padding] * 4, mode="circular" + ) + self.width = width + self.p = GConv2d( + in_channels=num_channels * initial_step + 1, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + first_layer=True, + ) + self.conv0 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + reflection=reflection, + ) + self.conv1 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + reflection=reflection, + ) + self.conv2 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + reflection=reflection, + ) + self.conv3 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + reflection=reflection, + ) + self.mlp0 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp1 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp2 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp3 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.w0 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w1 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w2 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w3 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.norm = GNorm(self.width, group_size=4 * (1 + reflection)) + self.q = GMLP2d( + in_channels=self.width, + out_channels=num_channels, + mid_channels=self.width * 4, + reflection=reflection, + last_layer=True, + ) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) + grid = self.get_grid(x.shape).to(x.device) + x = paddle.cat((x, grid), dim=-1) + x = x.permute(0, 3, 1, 2) + x = self.p(x) + x1 = self.norm(self.conv0(self.pad(self.norm(x)))) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv1(self.pad(self.norm(x)))) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv2(self.pad(self.norm(x)))) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv3(self.pad(self.norm(x)))) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + x = self.q(x) + x = x.permute(0, 2, 3, 1) + return x.unsqueeze(-2) + + def get_grid(self, shape): + batchsize, size_x, size_y = shape[0], shape[1], shape[2] + gridx = ( + paddle.linspace(0, 1, size_x) + .reshape(1, size_x, 1, 1) + .repeat([batchsize, 1, size_y, 1]) + ) + gridy = ( + paddle.linspace(0, 1, size_y) + .reshape(1, 1, size_y, 1) + .repeat([batchsize, size_x, 1, 1]) + ) + midpt = 0.5 + gridx = (gridx - midpt) ** 2 + gridy = (gridy - midpt) ** 2 + return gridx + gridy + + +class GCNN3d(paddle.nn.Module): + def __init__(self, num_channels, width, initial_step, reflection): + super(GCNN3d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. + input shape: (batchsize, x=64, y=64, t=40, c=13) + output: the solution of the next 40 timesteps + output shape: (batchsize, x=64, y=64, t=40, c=1) + """ + self.kernel_size = 3 + assert self.kernel_size % 2 == 1, "Kernel size should be odd" + self.padding = (self.kernel_size - 1) // 2 + self.pad = partial( + paddle.compat.nn.functional.pad, + pad=[0, 0] + [self.padding] * 4, + mode="circular", + ) + self.pad0 = partial(paddle.compat.nn.functional.pad, pad=[self.padding] * 2) + self.width = width + self.p = GConv3d( + in_channels=num_channels * initial_step + 2, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + first_layer=True, + ) + self.conv0 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + kernel_size_T=self.kernel_size, + reflection=reflection, + ) + self.conv1 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + kernel_size_T=self.kernel_size, + reflection=reflection, + ) + self.conv2 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + kernel_size_T=self.kernel_size, + reflection=reflection, + ) + self.conv3 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=self.kernel_size, + kernel_size_T=self.kernel_size, + reflection=reflection, + ) + self.mlp0 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp1 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp2 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp3 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.w0 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w1 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w2 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w3 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.q = GMLP3d( + in_channels=self.width, + out_channels=num_channels, + mid_channels=self.width * 4, + reflection=reflection, + last_layer=True, + ) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3], -1) + grid = self.get_grid(x.shape).to(x.device) + x = paddle.cat((x, grid), dim=-1) + x = x.permute(0, 4, 1, 2, 3) + x = self.p(x) + x1 = self.conv0(self.pad0(self.pad(x))) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv1(self.pad0(self.pad(x))) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv2(self.pad0(self.pad(x))) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv3(self.pad0(self.pad(x))) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + x = self.q(x) + x = x.permute(0, 2, 3, 4, 1) + return x.unsqueeze(-2) + + def get_grid(self, shape): + batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] + gridx = ( + paddle.linspace(0, 1, size_x) + .reshape(1, size_x, 1, 1, 1) + .repeat([batchsize, 1, size_y, size_z, 1]) + ) + gridy = ( + paddle.linspace(0, 1, size_y) + .reshape(1, 1, size_y, 1, 1) + .repeat([batchsize, size_x, 1, size_z, 1]) + ) + gridz = ( + paddle.linspace(0, 1, size_z) + .reshape(1, 1, 1, size_z, 1) + .repeat([batchsize, size_x, size_y, 1, 1]) + ) + midpt = 0.5 + gridx = (gridx - midpt) ** 2 + gridy = (gridy - midpt) ** 2 + return paddle.cat((gridx + gridy, gridz), dim=-1) diff --git a/ppcfd/models/g_fno/GFNO.py b/ppcfd/models/g_fno/GFNO.py new file mode 100644 index 0000000..7914712 --- /dev/null +++ b/ppcfd/models/g_fno/GFNO.py @@ -0,0 +1,932 @@ +import math + +import paddle + +from .grid import grid + + +class GConv2d(paddle.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + bias=True, + first_layer=False, + last_layer=False, + spectral=False, + Hermitian=False, + reflection=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.reflection = reflection + self.rt_group_size = 4 + self.group_size = self.rt_group_size * (1 + reflection) + assert kernel_size % 2 == 1, "kernel size must be odd" + dtype = paddle.complex64 if spectral else paddle.float32 + self.kernel_size_Y = kernel_size + self.kernel_size_X = kernel_size // 2 + 1 if Hermitian else kernel_size + self.Hermitian = Hermitian + if first_layer or last_layer: + self.W = paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.kernel_size_Y, + self.kernel_size_X, + dtype=dtype, + ) + ) + elif self.Hermitian: + self.W = paddle.nn.ParameterDict( + parameters={ + "y0_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_X - 1, + 1, + 2, + dtype=paddle.float32, + ) + ), + "yposx_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_Y, + self.kernel_size_X - 1, + 2, + dtype=paddle.float32, + ) + ), + "00_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + 1, + 1, + dtype=paddle.float32, + ) + ), + } + ) + else: + self.W = paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_Y, + self.kernel_size_X, + dtype=dtype, + ) + ) + self.first_layer = first_layer + self.last_layer = last_layer + self.B = ( + paddle.nn.Parameter(paddle.empty(1, out_channels, 1, 1)) if bias else None + ) + self.eval_build = True + self.reset_parameters() + self.get_weight() + + def reset_parameters(self): + if self.Hermitian: + for key in self.W: + paddle.nn.init.kaiming_uniform_(self.W[key], a=math.sqrt(5)) + else: + paddle.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) + if self.B is not None: + paddle.nn.init.kaiming_uniform_(self.B, a=math.sqrt(5)) + + def get_weight(self): + if self.training: + self.eval_build = True + elif self.eval_build: + self.eval_build = False + else: + return + if self.Hermitian: + y0_modes = paddle.as_complex(self.W["y0_modes"]) + yposx_modes = paddle.as_complex(self.W["yposx_modes"]) + self.weights = paddle.cat( + [ + y0_modes, + self.W["00_modes"].astype(paddle.complex64), + y0_modes.flip(axis=(-2,)).conj(), + ], + dim=-2, + ) + self.weights = paddle.cat([self.weights, yposx_modes], dim=-1) + self.weights = paddle.cat( + [self.weights[..., 1:].conj().rot90(k=2, axes=[-2, -1]), self.weights], + dim=-1, + ) + else: + self.weights = self.W[:] + if self.first_layer or self.last_layer: + self.weights = self.weights.repeat(1, self.group_size, 1, 1, 1) + for k in range(1, self.rt_group_size): + self.weights[:, k] = self.weights[:, k].rot90(k=k, axes=[-2, -1]) + if self.reflection: + self.weights[:, self.rt_group_size :] = self.weights[ + :, : self.rt_group_size + ].flip(axis=[-2]) + if self.first_layer: + self.weights = self.weights.view( + -1, self.in_channels, self.kernel_size_Y, self.kernel_size_Y + ) + if self.B is not None: + self.bias = self.B.repeat_interleave(repeats=self.group_size, dim=1) + else: + self.weights = self.weights.transpose(2, 1).reshape( + self.out_channels, -1, self.kernel_size_Y, self.kernel_size_Y + ) + self.bias = self.B + else: + self.weights = self.weights.repeat(1, self.group_size, 1, 1, 1, 1) + for k in range(1, self.rt_group_size): + self.weights[:, k] = self.weights[:, k - 1].rot90(axes=[-2, -1]) + if self.reflection: + self.weights[:, k] = paddle.cat( + [ + self.weights[:, k, :, self.rt_group_size - 1].unsqueeze(2), + self.weights[:, k, :, : self.rt_group_size - 1], + self.weights[:, k, :, self.rt_group_size + 1 :], + self.weights[:, k, :, self.rt_group_size].unsqueeze(2), + ], + dim=2, + ) + else: + self.weights[:, k] = paddle.cat( + [ + self.weights[:, k, :, -1].unsqueeze(2), + self.weights[:, k, :, :-1], + ], + dim=2, + ) + if self.reflection: + self.weights[:, self.rt_group_size :] = paddle.cat( + [ + self.weights[:, : self.rt_group_size, :, self.rt_group_size :], + self.weights[:, : self.rt_group_size, :, : self.rt_group_size], + ], + dim=3, + ).flip(axis=[-2]) + self.weights = self.weights.view( + self.out_channels * self.group_size, + self.in_channels * self.group_size, + self.kernel_size_Y, + self.kernel_size_Y, + ) + if self.B is not None: + self.bias = self.B.repeat_interleave(repeats=self.group_size, dim=1) + if self.Hermitian: + self.weights = self.weights[..., -self.kernel_size_X :] + + def forward(self, x): + self.get_weight() + x = paddle.nn.functional.conv2d(input=x, weight=self.weights) + if self.B is not None: + x = x + self.bias + return x + + +class GSpectralConv2d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes, reflection=False): + super(GSpectralConv2d, self).__init__() + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.modes = modes + self.conv = GConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * modes - 1, + reflection=reflection, + bias=False, + spectral=True, + Hermitian=True, + ) + self.get_weight() + + def get_weight(self): + self.conv.get_weight() + self.weights = self.conv.weights.transpose(0, 1) + + def compl_mul2d(self, input, weights): + return paddle.einsum("bixy,ioxy->boxy", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + freq0_y = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-2])) == 0) + .nonzero() + .item() + ) + self.get_weight() + x_ft = paddle.fft.fftshift(paddle.fft.rfft2(x), dim=-2) + x_ft = x_ft[..., freq0_y - self.modes + 1 : freq0_y + self.modes, : self.modes] + out_ft = paddle.zeros( + batchsize, + self.weights.shape[0], + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[ + ..., freq0_y - self.modes + 1 : freq0_y + self.modes, : self.modes + ] = self.compl_mul2d(x_ft, self.weights) + x = paddle.fft.irfft2( + paddle.fft.ifftshift(out_ft, dim=-2), s=(x.size(-2), x.size(-1)) + ) + return x + + +class GMLP2d(paddle.nn.Module): + def __init__( + self, + in_channels, + out_channels, + mid_channels, + reflection=False, + last_layer=False, + ): + super(GMLP2d, self).__init__() + self.mlp1 = GConv2d( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + reflection=reflection, + ) + self.mlp2 = GConv2d( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + reflection=reflection, + last_layer=last_layer, + ) + + def forward(self, x): + x = self.mlp1(x) + x = paddle.nn.functional.gelu(x) + x = self.mlp2(x) + return x + + +class GNorm(paddle.nn.Module): + def __init__(self, width, group_size): + super().__init__() + self.group_size = group_size + self.norm = paddle.nn.InstanceNorm3D(num_features=width) + + def forward(self, x): + x = x.view(x.shape[0], -1, self.group_size, x.shape[-2], x.shape[-1]) + x = self.norm(x) + x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1]) + return x + + +class GFNO2d(paddle.nn.Module): + def __init__(self, num_channels, modes, width, initial_step, reflection, grid_type): + super(GFNO2d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.modes = modes + self.width = width + self.grid = grid(twoD=True, grid_type=grid_type) + self.p = GConv2d( + in_channels=num_channels * initial_step + self.grid.grid_dim, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + first_layer=True, + ) + self.conv0 = GSpectralConv2d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + reflection=reflection, + ) + self.conv1 = GSpectralConv2d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + reflection=reflection, + ) + self.conv2 = GSpectralConv2d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + reflection=reflection, + ) + self.conv3 = GSpectralConv2d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + reflection=reflection, + ) + self.mlp0 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp1 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp2 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp3 = GMLP2d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.w0 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w1 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w2 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.w3 = GConv2d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + reflection=reflection, + ) + self.norm = GNorm(self.width, group_size=4 * (1 + reflection)) + self.q = GMLP2d( + in_channels=self.width, + out_channels=num_channels, + mid_channels=self.width * 4, + reflection=reflection, + last_layer=True, + ) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) + x = self.grid(x) + x = x.permute(0, 3, 1, 2) + x = self.p(x) + x1 = self.norm(self.conv0(self.norm(x))) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv1(self.norm(x))) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv2(self.norm(x))) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.norm(self.conv3(self.norm(x))) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + x = self.q(x) + x = x.permute(0, 2, 3, 1) + return x.unsqueeze(-2) + + +class GConv3d(paddle.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + kernel_size_T, + bias=True, + first_layer=False, + last_layer=False, + spectral=False, + Hermitian=False, + reflection=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.reflection = reflection + self.rt_group_size = 4 + self.group_size = self.rt_group_size * (1 + reflection) + assert kernel_size % 2 == 1, "kernel size must be odd" + dtype = paddle.complex64 if spectral else paddle.float32 + self.kernel_size_Y = kernel_size + self.kernel_size_X = kernel_size // 2 + 1 if Hermitian else kernel_size + self.kernel_size_T_full = kernel_size_T + self.kernel_size_T = kernel_size_T // 2 + 1 if Hermitian else kernel_size_T + self.Hermitian = Hermitian + if first_layer or last_layer: + self.W = paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.kernel_size_Y, + self.kernel_size_X, + self.kernel_size_T, + dtype=dtype, + ) + ) + elif self.Hermitian: + self.W = paddle.nn.ParameterDict( + parameters={ + "y00_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_X - 1, + 1, + 1, + dtype=paddle.complex64, + ) + ), + "yposx0_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_Y, + self.kernel_size_X - 1, + 1, + dtype=paddle.complex64, + ) + ), + "000_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, 1, in_channels, self.group_size, 1, 1, 1 + ) + ), + "yxpost_modes": paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_Y, + self.kernel_size_Y, + self.kernel_size_T - 1, + dtype=paddle.complex64, + ) + ), + } + ) + else: + self.W = paddle.nn.Parameter( + paddle.empty( + out_channels, + 1, + in_channels, + self.group_size, + self.kernel_size_Y, + self.kernel_size_X, + self.kernel_size_T, + dtype=dtype, + ) + ) + self.first_layer = first_layer + self.last_layer = last_layer + self.B = ( + paddle.nn.Parameter(paddle.empty(1, out_channels, 1, 1, 1)) + if bias + else None + ) + self.eval_build = True + self.reset_parameters() + self.get_weight() + + def reset_parameters(self): + if self.Hermitian: + for key in self.W: + paddle.nn.init.kaiming_uniform_(self.W[key], a=math.sqrt(5)) + else: + paddle.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) + if self.B is not None: + paddle.nn.init.kaiming_uniform_(self.B, a=math.sqrt(5)) + + def get_weight(self): + if self.training: + self.eval_build = True + elif self.eval_build: + self.eval_build = False + else: + return + if self.Hermitian: + self.weights = paddle.cat( + [ + self.W["y00_modes"].conj().flip(axis=(-3,)), + self.W["000_modes"].astype(paddle.complex64), + self.W["y00_modes"], + ], + dim=-3, + ) + self.weights = paddle.cat( + [ + self.W["yposx0_modes"].conj().rot90(k=2, axes=[-3, -2]), + self.weights, + self.W["yposx0_modes"], + ], + dim=-2, + ) + self.weights = paddle.cat( + [ + self.W["yxpost_modes"] + .conj() + .rot90(k=2, axes=[-3, -2]) + .flip(axis=(-1,)), + self.weights, + self.W["yxpost_modes"], + ], + dim=-1, + ) + else: + self.weights = self.W[:] + if self.first_layer or self.last_layer: + self.weights = self.weights.repeat(1, self.group_size, 1, 1, 1, 1) + for k in range(1, self.rt_group_size): + self.weights[:, k] = self.weights[:, k].rot90(k=k, axes=[-3, -2]) + if self.reflection: + self.weights[:, self.rt_group_size :] = self.weights[ + :, : self.rt_group_size + ].flip(axis=[-3]) + if self.first_layer: + self.weights = self.weights.view( + -1, + self.in_channels, + self.kernel_size_Y, + self.kernel_size_Y, + self.kernel_size_T, + ) + if self.B is not None: + self.bias = self.B.repeat_interleave(repeats=self.group_size, dim=1) + else: + self.weights = self.weights.transpose(2, 1).reshape( + self.out_channels, + -1, + self.kernel_size_Y, + self.kernel_size_Y, + self.kernel_size_T, + ) + self.bias = self.B + else: + self.weights = self.weights.repeat(1, self.group_size, 1, 1, 1, 1, 1) + for k in range(1, self.rt_group_size): + self.weights[:, k] = self.weights[:, k - 1].rot90(axes=[-3, -2]) + if self.reflection: + self.weights[:, k] = paddle.cat( + [ + self.weights[:, k, :, self.rt_group_size - 1].unsqueeze(2), + self.weights[:, k, :, : self.rt_group_size - 1], + self.weights[:, k, :, self.rt_group_size + 1 :], + self.weights[:, k, :, self.rt_group_size].unsqueeze(2), + ], + dim=2, + ) + else: + self.weights[:, k] = paddle.cat( + [ + self.weights[:, k, :, -1].unsqueeze(2), + self.weights[:, k, :, :-1], + ], + dim=2, + ) + if self.reflection: + self.weights[:, self.rt_group_size :] = paddle.cat( + [ + self.weights[:, : self.rt_group_size, :, self.rt_group_size :], + self.weights[:, : self.rt_group_size, :, : self.rt_group_size], + ], + dim=3, + ).flip(axis=[-3]) + self.weights = self.weights.view( + self.out_channels * self.group_size, + self.in_channels * self.group_size, + self.kernel_size_Y, + self.kernel_size_Y, + self.kernel_size_T_full, + ) + if self.B is not None: + self.bias = self.B.repeat_interleave(repeats=self.group_size, dim=1) + if self.Hermitian: + self.weights = self.weights[..., -self.kernel_size_T :] + + def forward(self, x): + self.get_weight() + x = paddle.nn.functional.conv3d(input=x, weight=self.weights) + if self.B is not None: + x = x + self.bias + return x + + +class GSpectralConv3d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes, time_modes, reflection): + super(GSpectralConv3d, self).__init__() + """ + 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.modes = modes + self.time_modes = time_modes + self.conv = GConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * modes - 1, + kernel_size_T=2 * time_modes - 1, + reflection=reflection, + bias=False, + spectral=True, + Hermitian=True, + ) + self.get_weight() + + def get_weight(self): + self.conv.get_weight() + self.weights = self.conv.weights.transpose(0, 1) + + def compl_mul3d(self, input, weights): + return paddle.einsum("bixyz,ioxyz->boxyz", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + freq0_x = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-2])) == 0) + .nonzero() + .item() + ) + freq0_y = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-3])) == 0) + .nonzero() + .item() + ) + self.get_weight() + x_ft = paddle.fft.fftshift(paddle.fft.rfftn(x, dim=[-3, -2, -1]), dim=[-3, -2]) + x_ft = x_ft[ + ..., + freq0_y - self.modes + 1 : freq0_y + self.modes, + freq0_x - self.modes + 1 : freq0_x + self.modes, + : self.time_modes, + ] + out_ft = paddle.zeros( + batchsize, + self.weights.shape[0], + x.size(-3), + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[ + ..., + freq0_y - self.modes + 1 : freq0_y + self.modes, + freq0_x - self.modes + 1 : freq0_x + self.modes, + : self.time_modes, + ] = self.compl_mul3d(x_ft, self.weights) + x = paddle.fft.irfftn( + paddle.fft.ifftshift(out_ft, dim=[-3, -2]), + s=(x.size(-3), x.size(-2), x.size(-1)), + ) + return x + + +class GMLP3d(paddle.nn.Module): + def __init__( + self, + in_channels, + out_channels, + mid_channels, + reflection=False, + last_layer=False, + ): + super(GMLP3d, self).__init__() + self.mlp1 = GConv3d( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.mlp2 = GConv3d( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + last_layer=last_layer, + ) + + def forward(self, x): + x = self.mlp1(x) + x = paddle.nn.functional.gelu(x) + x = self.mlp2(x) + return x + + +class GFNO3d(paddle.nn.Module): + def __init__( + self, + num_channels, + modes, + time_modes, + width, + initial_step, + reflection, + grid_type, + time_pad=False, + ): + super(GFNO3d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. + input shape: (batchsize, x=64, y=64, t=40, c=13) + output: the solution of the next 40 timesteps + output shape: (batchsize, x=64, y=64, t=40, c=1) + """ + self.modes = modes + self.time_modes = time_modes + self.width = width + self.time_pad = time_pad + self.padding = 6 + self.grid = grid(twoD=False, grid_type=grid_type) + self.p = GConv3d( + in_channels=num_channels * initial_step + self.grid.grid_dim, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + first_layer=True, + ) + self.conv0 = GSpectralConv3d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + time_modes=self.time_modes, + reflection=reflection, + ) + self.conv1 = GSpectralConv3d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + time_modes=self.time_modes, + reflection=reflection, + ) + self.conv2 = GSpectralConv3d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + time_modes=self.time_modes, + reflection=reflection, + ) + self.conv3 = GSpectralConv3d( + in_channels=self.width, + out_channels=self.width, + modes=self.modes, + time_modes=self.time_modes, + reflection=reflection, + ) + self.mlp0 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp1 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp2 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.mlp3 = GMLP3d( + in_channels=self.width, + out_channels=self.width, + mid_channels=self.width, + reflection=reflection, + ) + self.w0 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w1 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w2 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.w3 = GConv3d( + in_channels=self.width, + out_channels=self.width, + kernel_size=1, + kernel_size_T=1, + reflection=reflection, + ) + self.q = GMLP3d( + in_channels=self.width, + out_channels=num_channels, + mid_channels=self.width * 4, + reflection=reflection, + last_layer=True, + ) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3], -1) + x = self.grid(x) + x = x.permute(0, 4, 1, 2, 3) + x = self.p(x) + if self.time_pad: + x = paddle.compat.nn.functional.pad(x, [0, self.padding]) + x1 = self.conv0(x) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv1(x) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv2(x) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = paddle.nn.functional.gelu(x) + x1 = self.conv3(x) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + if self.time_pad: + x = x[..., : -self.padding] + x = self.q(x) + x = x.permute(0, 2, 3, 4, 1) + return x.unsqueeze(-2) diff --git a/ppcfd/models/g_fno/Ghybrid.py b/ppcfd/models/g_fno/Ghybrid.py new file mode 100644 index 0000000..c546f34 --- /dev/null +++ b/ppcfd/models/g_fno/Ghybrid.py @@ -0,0 +1,113 @@ +import paddle +import paddle.nn.functional as F + +from .FNO import MLP2d, SpectralConv2d +from .GFNO import GConv2d, GMLP2d, GNorm, GSpectralConv2d + + +class Ghybrid2d(paddle.nn.Module): + def __init__( + self, num_channels, modes, Gwidth, width, initial_step, reflection, n_equiv + ): + super(Ghybrid2d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.modes = modes + self.Gwidth = Gwidth + self.width = width + self.n_equiv = n_equiv + self.rt_group_size = 4 + self.group_size = self.rt_group_size * (1 + reflection) + self.p = GConv2d( + in_channels=num_channels * initial_step + 1, + out_channels=self.Gwidth, + kernel_size=1, + reflection=reflection, + first_layer=True, + ) + assert n_equiv in [1, 2, 3], "Number of equivariant layers should be 1, 2, or 3" + self.spectral_convs = paddle.nn.ModuleList() + self.mlps = paddle.nn.ModuleList() + self.convs = paddle.nn.ModuleList() + for layer in range(n_equiv): + self.spectral_convs.append( + GSpectralConv2d( + in_channels=self.Gwidth, + out_channels=self.Gwidth, + modes=self.modes, + reflection=reflection, + ) + ) + self.mlps.append( + GMLP2d( + in_channels=self.Gwidth, + out_channels=self.Gwidth, + mid_channels=self.Gwidth, + reflection=reflection, + ) + ) + self.convs.append( + GConv2d( + in_channels=self.Gwidth, + out_channels=self.Gwidth, + kernel_size=1, + reflection=reflection, + ) + ) + for layer in range(4 - n_equiv): + in_width = self.Gwidth * self.group_size if layer == 0 else self.width + self.spectral_convs.append( + SpectralConv2d(in_width, self.width, self.modes, self.modes) + ) + self.mlps.append(MLP2d(self.width, self.width, self.width)) + self.convs.append(paddle.nn.Conv2d(in_width, self.width, 1)) + self.Gnorm = GNorm(self.Gwidth, group_size=4 * (1 + reflection)) + self.norm = paddle.nn.InstanceNorm2D(num_features=self.width) + self.q = MLP2d(self.width, num_channels, self.width * 4) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) + grid = self.get_grid(x.shape).to(x.device) + x = paddle.cat((x, grid), dim=-1) + x = x.permute(0, 3, 1, 2) + x = self.p(x) + norm = self.Gnorm + for layer in range(4): + x1 = norm(self.spectral_convs[layer](norm(x))) + x1 = self.mlps[layer](x1) + x2 = self.convs[layer](x) + x = x1 + x2 + if layer < 3: + x = F.gelu(x) + if layer == self.n_equiv - 1: + norm = self.norm + x = self.q(x) + x = x.permute(0, 2, 3, 1) + return x.unsqueeze(-2) + + def get_grid(self, shape): + batchsize, size_x, size_y = shape[0], shape[1], shape[2] + gridx = ( + paddle.linspace(0, 1, size_x) + .reshape(1, size_x, 1, 1) + .repeat([batchsize, 1, size_y, 1]) + ) + gridy = ( + paddle.linspace(0, 1, size_y) + .reshape(1, 1, size_y, 1) + .repeat([batchsize, size_x, 1, 1]) + ) + midpt = 0.5 + gridx = (gridx - midpt) ** 2 + gridy = (gridy - midpt) ** 2 + return gridx + gridy diff --git a/ppcfd/models/g_fno/__init__.py b/ppcfd/models/g_fno/__init__.py new file mode 100644 index 0000000..32bb86f --- /dev/null +++ b/ppcfd/models/g_fno/__init__.py @@ -0,0 +1,28 @@ +"""G-FNO model package for PaddleCFD.""" + +from .FNO import FNO2d +from .FNO import FNO3d +from .GCNN import GCNN2d +from .GCNN import GCNN3d +from .GFNO import GFNO2d +from .GFNO import GFNO3d +from .Ghybrid import Ghybrid2d +from .paddle_utils import resolve_runtime_device +from .paddle_utils import set_runtime_device +from .radialNO import radialNO2d +from .radialNO import radialNO3d + + +__all__ = [ + "FNO2d", + "FNO3d", + "GCNN2d", + "GCNN3d", + "GFNO2d", + "GFNO3d", + "Ghybrid2d", + "radialNO2d", + "radialNO3d", + "resolve_runtime_device", + "set_runtime_device", +] diff --git a/ppcfd/models/g_fno/grid.py b/ppcfd/models/g_fno/grid.py new file mode 100644 index 0000000..0468af4 --- /dev/null +++ b/ppcfd/models/g_fno/grid.py @@ -0,0 +1,73 @@ +import paddle + + +class grid(paddle.nn.Module): + def __init__(self, twoD, grid_type): + super(grid, self).__init__() + assert grid_type in ["cartesian", "symmetric", "None"], "Invalid grid type" + self.symmetric = grid_type == "symmetric" + self.include_grid = grid_type != "None" + self.grid_dim = (1 + (not self.symmetric) + (not twoD)) * self.include_grid + if self.include_grid: + if twoD: + self.get_grid = self.twoD_grid + else: + self.get_grid = self.threeD_grid + else: + self.get_grid = paddle.nn.Identity() + + def forward(self, x): + return self.get_grid(x) + + def twoD_grid(self, x): + shape = x.shape + batchsize, size_x, size_y = shape[0], shape[1], shape[2] + gridx = ( + paddle.linspace(0, 1, size_x) + .reshape(1, size_x, 1, 1) + .repeat([batchsize, 1, size_y, 1]) + ) + gridy = ( + paddle.linspace(0, 1, size_y) + .reshape(1, 1, size_y, 1) + .repeat([batchsize, size_x, 1, 1]) + ) + if not self.symmetric: + grid = paddle.cat((gridx, gridy), dim=-1) + else: + midx = 0.5 + midy = (size_y - 1) / (2 * (size_x - 1)) + gridx = (gridx - midx) ** 2 + gridy = (gridy - midy) ** 2 + grid = gridx + gridy + grid = grid.to(x.device) + return paddle.cat((x, grid), dim=-1) + + def threeD_grid(self, x): + shape = x.shape + batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] + gridx = ( + paddle.linspace(0, 1, size_x) + .reshape(1, size_x, 1, 1, 1) + .repeat([batchsize, 1, size_y, size_z, 1]) + ) + gridy = ( + paddle.linspace(0, 1, size_y) + .reshape(1, 1, size_y, 1, 1) + .repeat([batchsize, size_x, 1, size_z, 1]) + ) + gridz = ( + paddle.linspace(0, 1, size_z) + .reshape(1, 1, 1, size_z, 1) + .repeat([batchsize, size_x, size_y, 1, 1]) + ) + if not self.symmetric: + grid = paddle.cat((gridx, gridy, gridz), dim=-1) + else: + midx = 0.5 + midy = (size_y - 1) / (2 * (size_x - 1)) + gridx = (gridx - midx) ** 2 + gridy = (gridy - midy) ** 2 + grid = paddle.cat((gridx + gridy, gridz), dim=-1) + grid = grid.to(x.device) + return paddle.cat((x, grid), dim=-1) diff --git a/ppcfd/models/g_fno/paddle_utils.py b/ppcfd/models/g_fno/paddle_utils.py new file mode 100644 index 0000000..9e93e39 --- /dev/null +++ b/ppcfd/models/g_fno/paddle_utils.py @@ -0,0 +1,90 @@ +import os + +import paddle + +############################## 相关utils函数,如下 ############################## +############################ PaConvert 自动生成的代码 ########################### + +def _set_num_threads(int): + os.environ['CPU_NUM'] = str(int) + +def _Tensor_split(self, split_size, dim=0): + if isinstance(split_size, int): + return paddle.split(self, self.shape[dim] // split_size, dim) + else: + return paddle.split(self, split_size, dim) + +setattr(paddle.Tensor, "split", _Tensor_split) + +def device2int(device): + if isinstance(device, str): + device = device.replace('cuda', 'gpu') + device = device.replace('gpu:', '') + return int(device) + + +def normalize_device_spec(device): + if device is None: + return None + if not isinstance(device, str): + raise TypeError(f"device must be a string or None, got {type(device).__name__}") + device = device.strip() + if not device: + raise ValueError("device must not be empty") + if device.startswith("cuda"): + return device.replace("cuda", "gpu", 1) + return device + + +def _iter_auto_device_candidates(): + current_device = normalize_device_spec(paddle.device.get_device()) + if current_device != "cpu": + yield current_device + + get_custom_device_types = getattr(paddle.device, "get_all_custom_device_type", None) + if callable(get_custom_device_types): + for custom_device in get_custom_device_types() or []: + yield normalize_device_spec(custom_device) + + yield "gpu" + yield "xpu" + yield "cpu" + + +def resolve_runtime_device(device="auto"): + normalized_device = normalize_device_spec(device or "auto") + if normalized_device != "auto": + return normalized_device + + seen = set() + last_error = None + for candidate in _iter_auto_device_candidates(): + if candidate in seen: + continue + seen.add(candidate) + try: + paddle.set_device(candidate) + return candidate + except ValueError as err: + last_error = err + + if last_error is not None: + raise last_error + return "cpu" + + +def set_runtime_device(device="auto"): + resolved_device = resolve_runtime_device(device) + if normalize_device_spec(device or "auto") != "auto": + paddle.set_device(resolved_device) + return resolved_device + + +def move_to_device(obj, device): + normalized_device = normalize_device_spec(device) + if normalized_device is None: + return obj + if not hasattr(obj, "to"): + raise TypeError(f"object of type {type(obj).__name__} cannot be moved to a device") + return obj.to(device=normalized_device) +############################## 相关utils函数,如上 ############################## diff --git a/ppcfd/models/g_fno/radialNO.py b/ppcfd/models/g_fno/radialNO.py new file mode 100644 index 0000000..91efb9c --- /dev/null +++ b/ppcfd/models/g_fno/radialNO.py @@ -0,0 +1,395 @@ +import paddle + +from .FNO import MLP2d, MLP3d +from .grid import grid + + +class radialSpectralConv2d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes, reflection): + super(radialSpectralConv2d, self).__init__() + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.reflection = reflection + self.in_channels = in_channels + self.out_channels = out_channels + self.modes = modes + self.scale = 1 / (in_channels * out_channels) + self.dtype = paddle.float32 + if reflection: + self.inds_lower = paddle.tril_indices( + row=self.modes + 1, col=self.modes + 1 + ) + self.W = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.inds_lower.shape[1], + dtype=self.dtype, + ) + ) + else: + self.W_LC = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, out_channels, self.modes + 1, 1, dtype=self.dtype + ) + ) + self.W_LR = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, out_channels, self.modes, self.modes, dtype=self.dtype + ) + ) + self.eval_build = True + self.get_weight() + + def get_weight(self): + if self.training: + self.eval_build = True + elif self.eval_build: + self.eval_build = False + else: + return + if self.reflection: + W_LR = paddle.zeros( + self.in_channels, + self.out_channels, + self.modes + 1, + self.modes + 1, + dtype=self.dtype, + ).to(self.W.device) + W_LR[..., self.inds_lower[0], self.inds_lower[1]] = self.W + W_LR.transpose(-1, -2)[..., self.inds_lower[0], self.inds_lower[1]] = self.W + self.weights = paddle.cat( + [W_LR[..., 1:, :].flip(axis=-2), W_LR], dim=-2 + ).cfloat() + else: + W_LR = paddle.cat([self.W_LC[:, :, 1:], self.W_LR], dim=-1) + W_UR = paddle.cat( + [self.W_LC.flip(axis=-2), W_LR.rot90(axes=[-2, -1])], dim=-1 + ) + self.weights = paddle.cat([W_UR, W_LR], dim=-2).cfloat() + + def compl_mul2d(self, input, weights): + return paddle.einsum("bixy,ioxy->boxy", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + freq0_y = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-2])) == 0) + .nonzero() + .item() + ) + self.get_weight() + x_ft = paddle.fft.fftshift(paddle.fft.rfft2(x), dim=-2) + x_ft = x_ft[ + ..., freq0_y - self.modes : freq0_y + self.modes + 1, : self.modes + 1 + ] + out_ft = paddle.zeros( + batchsize, + self.out_channels, + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[ + ..., freq0_y - self.modes : freq0_y + self.modes + 1, : self.modes + 1 + ] = self.compl_mul2d(x_ft, self.weights) + x = paddle.fft.irfft2( + paddle.fft.ifftshift(out_ft, dim=-2), s=(x.size(-2), x.size(-1)) + ) + return x + + +class radialNO2d(paddle.nn.Module): + def __init__(self, num_channels, modes, width, initial_step, reflection, grid_type): + super(radialNO2d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.act = paddle.nn.ReLU() + self.norm = paddle.nn.InstanceNorm2D(num_features=width) + self.modes = modes + self.width = width + self.grid = grid(twoD=True, grid_type=grid_type) + self.p = paddle.compat.nn.Linear( + initial_step * num_channels + self.grid.grid_dim, self.width + ) + self.conv0 = radialSpectralConv2d( + self.width, self.width, self.modes, reflection + ) + self.conv1 = radialSpectralConv2d( + self.width, self.width, self.modes, reflection + ) + self.conv2 = radialSpectralConv2d( + self.width, self.width, self.modes, reflection + ) + self.conv3 = radialSpectralConv2d( + self.width, self.width, self.modes, reflection + ) + self.mlp0 = MLP2d(self.width, self.width, self.width) + self.mlp1 = MLP2d(self.width, self.width, self.width) + self.mlp2 = MLP2d(self.width, self.width, self.width) + self.mlp3 = MLP2d(self.width, self.width, self.width) + self.w0 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w1 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w2 = paddle.nn.Conv2d(self.width, self.width, 1) + self.w3 = paddle.nn.Conv2d(self.width, self.width, 1) + self.q = MLP2d(self.width, num_channels, self.width * 4) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) + x = self.grid(x) + x = self.p(x) + x = x.permute(0, 3, 1, 2) + x1 = self.norm(self.conv0(self.norm(x))) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = self.act(x) + x1 = self.norm(self.conv1(self.norm(x))) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = self.act(x) + x1 = self.norm(self.conv2(self.norm(x))) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = self.act(x) + x1 = self.norm(self.conv3(self.norm(x))) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + x = self.q(x) + x = x.permute(0, 2, 3, 1) + return x.unsqueeze(-2) + + +class radialSpectralConv3d(paddle.nn.Module): + def __init__(self, in_channels, out_channels, modes, time_modes, reflection): + super(radialSpectralConv3d, self).__init__() + """ + 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.reflection = reflection + self.in_channels = in_channels + self.out_channels = out_channels + self.modes = modes + self.time_modes = time_modes + self.scale = 1 / (in_channels * out_channels) + self.dtype = paddle.float32 + if reflection: + self.inds_lower = paddle.tril_indices(row=self.modes, col=self.modes) + self.W = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.inds_lower.shape[1], + self.time_modes, + dtype=self.dtype, + ) + ) + else: + self.W_LC = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes, + 1, + self.time_modes, + dtype=self.dtype, + ) + ) + self.W_LR = paddle.nn.Parameter( + self.scale + * paddle.rand( + in_channels, + out_channels, + self.modes - 1, + self.modes - 1, + self.time_modes, + dtype=self.dtype, + ) + ) + self.eval_build = True + self.get_weight() + + def get_weight(self): + if self.training: + self.eval_build = True + elif self.eval_build: + self.eval_build = False + else: + return + if self.reflection: + W_LR = paddle.zeros( + self.in_channels, + self.out_channels, + self.modes, + self.modes, + self.time_modes, + dtype=self.dtype, + ).to(self.W.device) + W_LR[..., self.inds_lower[0], self.inds_lower[1], :] = self.W + W_LR.transpose(-2, -3)[ + ..., self.inds_lower[0], self.inds_lower[1], : + ] = self.W + W_R = paddle.cat([W_LR[..., 1:, :, :].flip(axis=-3), W_LR], dim=-3) + else: + W_LR = paddle.cat([self.W_LC[:, :, 1:], self.W_LR], dim=-2) + W_UR = paddle.cat( + [self.W_LC.flip(axis=-3), W_LR.rot90(axes=[-3, -2])], dim=-2 + ) + W_R = paddle.cat([W_UR, W_LR], dim=-3) + self.weights = paddle.cat( + [W_R[..., 1:, :].flip(axis=[-3, -2]), W_R], dim=-2 + ).cfloat() + + def compl_mul3d(self, input, weights): + return paddle.einsum("bixyz,ioxyz->boxyz", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + freq0_x = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-2])) == 0) + .nonzero() + .item() + ) + freq0_y = ( + (paddle.fft.fftshift(paddle.fft.fftfreq(n=x.shape[-3])) == 0) + .nonzero() + .item() + ) + self.get_weight() + x_ft = paddle.fft.fftshift(paddle.fft.rfftn(x, dim=[-3, -2, -1]), dim=[-3, -2]) + x_ft = x_ft[ + ..., + freq0_y - self.modes + 1 : freq0_y + self.modes, + freq0_x - self.modes + 1 : freq0_x + self.modes, + : self.time_modes, + ] + out_ft = paddle.zeros( + batchsize, + self.out_channels, + x.size(-3), + x.size(-2), + x.size(-1) // 2 + 1, + dtype=paddle.complex64, + device=x.device, + ) + out_ft[ + ..., + freq0_y - self.modes + 1 : freq0_y + self.modes, + freq0_x - self.modes + 1 : freq0_x + self.modes, + : self.time_modes, + ] = self.compl_mul3d(x_ft, self.weights) + x = paddle.fft.irfftn( + paddle.fft.ifftshift(out_ft, dim=[-3, -2]), + s=(x.size(-3), x.size(-2), x.size(-1)), + ) + return x + + +class radialNO3d(paddle.nn.Module): + def __init__( + self, + num_channels, + modes, + time_modes, + width, + initial_step, + reflection, + grid_type, + time_pad=False, + ): + super(radialNO3d, self).__init__() + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the previous 10 timesteps + 3 locations (u(t-10, x, y, z), ..., u(t-1, x, y, z), x, y, z) + input shape: (batchsize, x=64, y=64, c=12) + output: the solution of the next timestep + output shape: (batchsize, x=64, y=64, c=1) + """ + self.act = paddle.nn.ReLU() + self.modes = modes + self.time_modes = time_modes + self.width = width + self.time_pad = time_pad + self.padding = 6 + self.grid = grid(twoD=False, grid_type=grid_type) + self.p = paddle.compat.nn.Linear( + initial_step * num_channels + self.grid.grid_dim, self.width + ) + self.conv0 = radialSpectralConv3d( + self.width, self.width, self.modes, self.time_modes, reflection + ) + self.conv1 = radialSpectralConv3d( + self.width, self.width, self.modes, self.time_modes, reflection + ) + self.conv2 = radialSpectralConv3d( + self.width, self.width, self.modes, self.time_modes, reflection + ) + self.conv3 = radialSpectralConv3d( + self.width, self.width, self.modes, self.time_modes, reflection + ) + self.mlp0 = MLP3d(self.width, self.width, self.width) + self.mlp1 = MLP3d(self.width, self.width, self.width) + self.mlp2 = MLP3d(self.width, self.width, self.width) + self.mlp3 = MLP3d(self.width, self.width, self.width) + self.w0 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w1 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w2 = paddle.nn.Conv3d(self.width, self.width, 1) + self.w3 = paddle.nn.Conv3d(self.width, self.width, 1) + self.q = MLP3d(self.width, num_channels, self.width * 4) + + def forward(self, x): + x = x.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3], -1) + x = self.grid(x) + x = self.p(x) + x = x.permute(0, 4, 1, 2, 3) + if self.time_pad: + x = paddle.compat.nn.functional.pad(x, [0, self.padding]) + x1 = self.conv0(x) + x1 = self.mlp0(x1) + x2 = self.w0(x) + x = x1 + x2 + x = self.act(x) + x1 = self.conv1(x) + x1 = self.mlp1(x1) + x2 = self.w1(x) + x = x1 + x2 + x = self.act(x) + x1 = self.conv2(x) + x1 = self.mlp2(x1) + x2 = self.w2(x) + x = x1 + x2 + x = self.act(x) + x1 = self.conv3(x) + x1 = self.mlp3(x1) + x2 = self.w3(x) + x = x1 + x2 + if self.time_pad: + x = x[..., : -self.padding] + x = self.q(x) + x = x.permute(0, 2, 3, 4, 1) + return x