diff --git a/adept/normalization.py b/adept/normalization.py index 59ae8f69..80548003 100644 --- a/adept/normalization.py +++ b/adept/normalization.py @@ -88,3 +88,23 @@ def electron_debye_normalization(n0_str, T0_str): x0 = (v0 / wp0).to("nm") return PlasmaNormalization(m0=UREG.m_e, q0=UREG.e, n0=n0, T0=T0, L0=x0, v0=v0, tau=tau) + + +def skin_depth_normalization(n0_str, T0_str): + """ + Returns the VFP-1D normalization. + Unit quantities are: + - c/wp0 (collisionless skin depth) + - Electron thermal velocity + - 1/wp0 + """ + n0 = UREG.Quantity(n0_str) + T0 = UREG.Quantity(T0_str) + + wp0 = ((n0 * UREG.e**2.0 / (UREG.m_e * UREG.epsilon_0)) ** 0.5).to("rad/s") + tau = 1 / wp0 + + v0 = ((2.0 * T0 / UREG.m_e) ** 0.5).to("m/s") + x0 = (UREG.c / wp0).to("nm") + + return PlasmaNormalization(m0=UREG.m_e, q0=UREG.e, n0=n0, T0=T0, L0=x0, v0=v0, tau=tau) diff --git a/adept/vfp1d/base.py b/adept/vfp1d/base.py index 38dfb549..57bc1b38 100644 --- a/adept/vfp1d/base.py +++ b/adept/vfp1d/base.py @@ -1,11 +1,11 @@ import numpy as np -from astropy import constants as csts -from astropy import units as u -from astropy.units import Quantity as _Q from diffrax import ODETerm, SaveAt, SubSaveAt, diffeqsolve from jax import numpy as jnp from adept._base_ import ADEPTModule, Stepper +from adept.normalization import UREG, skin_depth_normalization +from adept.utils import filter_scalars +from adept.vfp1d.grid import Grid from adept.vfp1d.helpers import _initialize_total_distribution_, calc_logLambda from adept.vfp1d.storage import get_save_quantities, post_process from adept.vfp1d.vector_field import OSHUN1D @@ -14,199 +14,142 @@ class BaseVFP1D(ADEPTModule): def __init__(self, cfg) -> None: super().__init__(cfg) + # n0 = critical density for 351nm (3ω Nd:glass) + self.plasma_norm = skin_depth_normalization("9.0663e21/cm^3", cfg["units"]["reference electron temperature"]) + self.grid = Grid.from_config(cfg["grid"], self.plasma_norm) def post_process(self, solver_result: dict, td: str) -> dict: return post_process(solver_result["solver result"], cfg=self.cfg, td=td, args=self.args) def write_units(self) -> dict: - ne = u.Quantity(self.cfg["units"]["reference electron density"]).to("1/cm^3") - ni = ne / self.cfg["units"]["Z"] - Te = u.Quantity(self.cfg["units"]["reference electron temperature"]).to("eV") - Ti = u.Quantity(self.cfg["units"]["reference ion temperature"]).to("eV") + norm = self.plasma_norm Z = self.cfg["units"]["Z"] - # Should we change this to reference electron density? or allow it to be user set? - n0 = u.Quantity("9.0663e21/cm^3") - ion_species = self.cfg["units"]["Ion"] - - wp0 = np.sqrt(n0 * csts.e.to("C") ** 2.0 / (csts.m_e * csts.eps0)).to("Hz") - tp0 = (1 / wp0).to("fs") - - vth = np.sqrt(2 * Te / csts.m_e).to("m/s") # mean square velocity eq 4-51a in Shkarofsky - - x0 = (csts.c / wp0).to("nm") - - beta = vth / csts.c + ne = UREG.Quantity(self.cfg["units"]["reference electron density"]).to("1/cc") + Te_eV = norm.T0.to("eV") + logLambda_ei, logLambda_ee = calc_logLambda( + self.cfg, ne, Te_eV, Z, self.cfg["units"]["Ion"], force_ee_equal_ei=True + ) - logLambda_ei, logLambda_ee = calc_logLambda(self.cfg, ne, Te, Z, ion_species, force_ee_equal_ei=True) + # Local aliases for quantities used in multiple expressions below + wp0 = (1 / norm.tau).to("rad/s") + vth = norm.v0.to("m/s") + beta = 1.0 / norm.speed_of_light_norm() + # Elementary charge in Gaussian CGS (pint cannot convert SI↔Gaussian charge dimensions) + e_gauss = UREG.Quantity(4.803204712570263e-10, "Fr") - nD_NRL = 1.72e9 * Te.value**1.5 / np.sqrt(ne.value) - nD_Shkarofsky = np.exp(logLambda_ei) * Z / 9 + ne_cc = ne.to("1/cc").magnitude + nD_NRL = 1.72e9 * Te_eV.magnitude**1.5 / np.sqrt(ne_cc) nuei_shk = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / np.exp(logLambda_ei) # JPB - Maybe comment which page/eq this is from? There are lots of collision times in NRL # For example, nu_ei on page 32 does include Z^2 nuei_nrl = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / nD_NRL - lambda_mfp_shk = (vth / nuei_shk).to("micron") - lambda_mfp_nrl = (vth / nuei_nrl).to("micron") - nuei_epphaines = ( 1 / ( 0.75 - * np.sqrt(csts.m_e) - * Te**1.5 - / (np.sqrt(2 * np.pi) * ni * Z**2.0 * csts.e.gauss**4.0 * logLambda_ei) + * UREG.Quantity(1, "electron_mass") ** 0.5 + * Te_eV**1.5 + / (np.sqrt(2 * np.pi) * (ne / Z) * Z**2.0 * e_gauss**4.0 * logLambda_ei) ) ).to("Hz") all_quantities = { "wp0": wp0, - "n0": n0, - "tp0": tp0, + "n0": norm.n0.to("1/cc"), + "tp0": norm.tau.to("fs"), "ne": ne, "vth": vth, - "Te": Te, - "Ti": Ti, + "Te": Te_eV, + "Ti": UREG.Quantity(self.cfg["units"]["reference ion temperature"]).to("eV"), "logLambda_ei": logLambda_ei, "logLambda_ee": logLambda_ee, "beta": beta, - "x0": x0, + "x0": norm.L0.to("nm"), "nuei_shk": nuei_shk, "nuei_nrl": nuei_nrl, "nuei_epphaines": nuei_epphaines, - "nuei_shk_norm": nuei_shk / wp0, - "nuei_nrl_norm": nuei_nrl / wp0, - "nuei_epphaines_norm": nuei_epphaines / wp0, - "lambda_mfp_shk": lambda_mfp_shk, - "lambda_mfp_nrl": lambda_mfp_nrl, + "nuei_shk_norm": (nuei_shk / wp0).to(""), + "nuei_nrl_norm": (nuei_nrl / wp0).to(""), + "nuei_epphaines_norm": (nuei_epphaines / wp0).to(""), + "lambda_mfp_shk": (vth / nuei_shk).to("micron"), + "lambda_mfp_nrl": (vth / nuei_nrl).to("micron"), "lambda_mfp_epphaines": (vth / nuei_epphaines).to("micron"), "nD_NRL": nD_NRL, - "nD_Shkarofsky": nD_Shkarofsky, + "nD_Shkarofsky": np.exp(logLambda_ei) * Z / 9, } self.cfg["units"]["derived"] = all_quantities - self.cfg["grid"]["beta"] = beta.value + self.cfg["grid"]["beta"] = beta return {k: str(v) for k, v in all_quantities.items()} def get_derived_quantities(self): - """ - This function just updates the config with the derived quantities that are only integers or strings. + """Sync scalar grid values into cfg for logging. - This is run prior to the log params step - - :param cfg_grid: - :return: + This is run prior to the log params step. """ cfg_grid = self.cfg["grid"] + grid = self.grid - # Default save.*.t.tmin/tmax to grid values (preserves unit strings) + # Default save.*.t.tmin/tmax to computed grid values for save_type in self.cfg.get("save", {}).keys(): if "t" in self.cfg["save"][save_type]: t_cfg = self.cfg["save"][save_type]["t"] t_cfg.setdefault("tmin", cfg_grid.get("tmin", "0ps")) t_cfg.setdefault("tmax", cfg_grid["tmax"]) - cfg_grid["xmax"] = (_Q(cfg_grid["xmax"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value - cfg_grid["xmin"] = (_Q(cfg_grid["xmin"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value - cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"] - - # sqrt(2 * k * T / m) - cfg_grid["vmax"] = ( - 8 - * np.sqrt((_Q(self.cfg["units"]["reference electron temperature"]) / (csts.m_e * csts.c**2.0)).to("")).value - ) - - cfg_grid["dv"] = cfg_grid["vmax"] / cfg_grid["nv"] - - cfg_grid["tmax"] = (_Q(cfg_grid["tmax"]) / self.cfg["units"]["derived"]["tp0"]).to("").value - cfg_grid["dt"] = (_Q(cfg_grid["dt"]) / self.cfg["units"]["derived"]["tp0"]).to("").value - - cfg_grid["nt"] = int(cfg_grid["tmax"] / cfg_grid["dt"]) + 1 + # Merge scalar grid values into cfg for logging (arrays come later in get_solver_quantities) + cfg_grid.update(filter_scalars(grid.as_dict())) - if cfg_grid["nt"] > 1e6: - cfg_grid["max_steps"] = int(1e6) - print(r"Only running $10^6$ steps") - else: - cfg_grid["max_steps"] = cfg_grid["nt"] + 4 - - cfg_grid["tmax"] = cfg_grid["dt"] * cfg_grid["nt"] - - print("tmax", cfg_grid["tmax"], "dt", cfg_grid["dt"]) - print("xmax", cfg_grid["xmax"], "dx", cfg_grid["dx"]) + print("tmax", grid.tmax, "dt", grid.dt) + print("xmax", grid.xmax, "dx", grid.dx) self.cfg["grid"] = cfg_grid def get_solver_quantities(self): - """ - This function just updates the config with the derived quantities that are arrays - - This is run after the log params step - - :param cfg_grid: - :return: - """ - cfg_grid = self.cfg["grid"] - - cfg_grid.setdefault("boundary", "periodic") - - cfg_grid = { - **cfg_grid, - **{ - "x": jnp.linspace( - cfg_grid["xmin"] + cfg_grid["dx"] / 2, cfg_grid["xmax"] - cfg_grid["dx"] / 2, cfg_grid["nx"] - ), - "x_edge": jnp.linspace(cfg_grid["xmin"], cfg_grid["xmax"], cfg_grid["nx"] + 1), - "t": jnp.linspace(0, cfg_grid["tmax"], cfg_grid["nt"]), - "v": jnp.linspace(cfg_grid["dv"] / 2, cfg_grid["vmax"] - cfg_grid["dv"] / 2, cfg_grid["nv"]), - "kx": jnp.fft.fftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi, - "kxr": jnp.fft.rfftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi, - }, - } - - self.cfg["grid"] = cfg_grid + """Merge all grid values (including arrays) into cfg['grid'] for backward compatibility.""" + self.cfg["grid"].update(self.grid.as_dict()) def init_state_and_args(self) -> dict: - """ - This function initializes the state + grid = self.grid + beta = 1.0 / self.plasma_norm.speed_of_light_norm() + f0, f10, ne_prof = _initialize_total_distribution_(self.cfg, grid, beta, self.plasma_norm) - :param cfg: - :return: - """ - nx = self.cfg["grid"]["nx"] - nv = self.cfg["grid"]["nv"] - f0, f10, ne_prof = _initialize_total_distribution_(self.cfg, self.cfg["grid"]) + # Scale density to physical ne (f10 is invariant in the big-dt collision limit) + ne = UREG.Quantity(self.cfg["units"]["reference electron density"]) + ne_over_n0 = (ne / self.plasma_norm.n0).to("").magnitude + f0 *= ne_over_n0 + ne_prof *= ne_over_n0 state = {"f0": f0} # not currently necessary but kept for completeness - for il in range(1, self.cfg["grid"]["nl"] + 1): + for il in range(1, grid.nl + 1): for im in range(0, il + 1): - state[f"f{il}{im}"] = jnp.zeros((nx + 1, nv)) + state[f"f{il}{im}"] = jnp.zeros((grid.nx + 1, grid.nv)) state["f10"] = f10 for field in ["e", "b"]: - state[field] = jnp.zeros(nx + 1) + state[field] = jnp.zeros(grid.nx + 1) - state["Z"] = jnp.ones(nx) + state["Z"] = jnp.ones(grid.nx) state["ni"] = ne_prof / self.cfg["units"]["Z"] self.state = state self.args = {"drivers": self.cfg["drivers"]} def init_diffeqsolve(self): - self.cfg = get_save_quantities(self.cfg) + grid = self.grid + self.cfg = get_save_quantities(self.cfg, self.plasma_norm) self.time_quantities = { "t0": 0.0, - "t1": self.cfg["grid"]["tmax"], - "max_steps": self.cfg["grid"]["max_steps"], - "save_t0": 0.0, - "save_t1": self.cfg["grid"]["tmax"], - "save_nt": self.cfg["grid"]["tmax"], + "t1": grid.tmax, + "max_steps": grid.max_steps, } self.diffeqsolve_quants = dict( - terms=ODETerm(OSHUN1D(self.cfg)), + terms=ODETerm(OSHUN1D(self.cfg, grid=grid)), solver=Stepper(), saveat=dict(subs={k: SubSaveAt(ts=v["t"]["ax"], fn=v["func"]) for k, v in self.cfg["save"].items()}), ) @@ -217,8 +160,8 @@ def __call__(self, trainable_modules: dict, args: dict): solver=self.diffeqsolve_quants["solver"], t0=self.time_quantities["t0"], t1=self.time_quantities["t1"], - max_steps=self.cfg["grid"]["max_steps"], - dt0=self.cfg["grid"]["dt"], + max_steps=self.grid.max_steps, + dt0=self.grid.dt, y0=self.state, args=args, saveat=SaveAt(**self.diffeqsolve_quants["saveat"]), diff --git a/adept/vfp1d/fokker_planck.py b/adept/vfp1d/fokker_planck.py index 59f189f7..6de7624d 100644 --- a/adept/vfp1d/fokker_planck.py +++ b/adept/vfp1d/fokker_planck.py @@ -25,6 +25,7 @@ CentralDifferencing, ChangCooper, ) +from adept.vfp1d.grid import Grid class FastVFP(AbstractMaxwellianPreservingModel): @@ -201,10 +202,7 @@ class F0Collisions(eqx.Module): The model and scheme are configurable via config["terms"]["fokker_planck"]["f00"]. """ - v: Array - v_edge: Array - dv: float - nv: int + grid: Grid nuee_coeff: float model: AbstractKernelBasedModel scheme: AbstractDriftDiffusionDifferencingScheme @@ -212,26 +210,16 @@ class F0Collisions(eqx.Module): _sc_rtol: float _sc_atol: float - def __init__(self, cfg: dict): + def __init__(self, cfg: dict, grid): """ - Initialize F0Collisions from config. - - Config should have: - - grid.v, grid.dv, grid.nv - - terms.fokker_planck.f00.model (e.g., "CoulombianKernel") - - terms.fokker_planck.f00.scheme (e.g., "central" or "chang_cooper") - - terms.fokker_planck.self_consistent_beta (optional): dict with enabled, max_steps, rtol, atol - - For collision frequency, provide ONE of: - - terms.fokker_planck.nuee_coeff: Direct override (for testing with dimensionless units) - - units.derived.n0, units.derived.logLambda_ee: Physical units (production) + Initialize F0Collisions from config and grid. - The nuee_coeff override allows tests to use nu=1.0 without dealing with physical units. + Args: + cfg: Full config dict. Reads terms.fokker_planck for collision settings + and units.derived for physical constants. + grid: Grid object providing velocity grid quantities (v, dv, nv, v_edge). """ - self.v = cfg["grid"]["v"] - self.dv = cfg["grid"]["dv"] - self.nv = cfg["grid"]["nv"] - self.v_edge = 0.5 * (self.v[1:] + self.v[:-1]) + self.grid = grid # Collision coefficient: allow direct override for testing fp_cfg = cfg["terms"]["fokker_planck"] @@ -245,7 +233,7 @@ def __init__(self, cfg: dict): # nuee0 = 4π n0 r_e^2 c logΛ_ee normalised to plasma frequency ω_p0 = √(4πn0 r_e)) # => nuee0/ω_p0 = r_e ω_p0 logΛ_ee / c = k_p0 r_e logΛ_ee, where k_p0 = ω_p/c r_e = 2.8179403205e-13 # Classical electron radius in cm (CODATA 2022 value) - kp0re = r_e * np.sqrt(4 * np.pi * cfg["units"]["derived"]["n0"].to("1/cm^3").value * r_e) + kp0re = r_e * np.sqrt(4 * np.pi * cfg["units"]["derived"]["n0"].to("1/cc").magnitude * r_e) self.nuee_coeff = kp0re * cfg["units"]["derived"]["logLambda_ee"] # Create model and scheme from config @@ -254,8 +242,8 @@ def __init__(self, cfg: dict): model_name = f00_cfg.get("model", "CoulombianKernel") scheme_name = f00_cfg.get("scheme", "central") - self.model = _get_model(model_name, self.v, self.dv) - self.scheme = _get_scheme(scheme_name, self.dv) + self.model = _get_model(model_name, grid.v, grid.dv) + self.scheme = _get_scheme(scheme_name, grid.dv) # Self-consistent beta config (defaults to disabled) fp_cfg = cfg["terms"]["fokker_planck"] @@ -279,8 +267,8 @@ def _solve_one_vslice_(self, nu: float, f0: Array, dt: float) -> Array: beta = _find_self_consistent_beta_single( f0, - self.v, - self.dv, + self.grid.v, + self.grid.dv, spherical=True, # spherical=True for positive-only grid rtol=self._sc_rtol, atol=self._sc_atol, @@ -294,10 +282,10 @@ def _solve_one_vslice_(self, nu: float, f0: Array, dt: float) -> Array: # Compute C_edge using the general formula: C = 2*beta*D*v # This ensures Chang-Cooper achieves Maxwellian equilibrium - C_edge = 2.0 * beta * D * self.v_edge + C_edge = 2.0 * beta * D * self.grid.v_edge # For spherical geometry, nu ~ 1/v² to account for the Jacobian - nu_arr = self.nuee_coeff / self.v**2 + nu_arr = self.nuee_coeff / self.grid.v**2 op = self.scheme.get_operator(C_edge=C_edge, D=D, nu=nu_arr, dt=dt) return lx.linear_solve(op, f0, solver=lx.AutoLinearSolver(well_posed=True)).value @@ -324,34 +312,33 @@ class FLMCollisions: operator are ignored and a contribution along the diagonal is scaled by a factor depending on Z. """ - def __init__(self, cfg: dict): - self.v = cfg["grid"]["v"] - self.dv = cfg["grid"]["dv"] + def __init__(self, cfg: dict, grid): + self.grid = grid + self.Z = cfg["units"]["Z"] r_e = 2.8179402894e-13 - kp = np.sqrt(4 * np.pi * cfg["units"]["derived"]["n0"].to("1/cm^3").value * r_e) + kp = np.sqrt(4 * np.pi * cfg["units"]["derived"]["n0"].to("1/cc").magnitude * r_e) kpre = r_e * kp self.nuee_coeff = kpre * cfg["units"]["derived"]["logLambda_ee"] self.nuei_coeff = ( kpre * self.Z**2.0 * cfg["units"]["derived"]["logLambda_ei"] ) # will be multiplied by ni = ne / Z - - self.nl = cfg["grid"]["nl"] self.ee = cfg["terms"]["fokker_planck"]["flm"]["ee"] self.Z_nuei_scaling = (cfg["units"]["Z"] + 4.2) / (cfg["units"]["Z"] + 0.24) + nl = grid.nl self.a1, self.a2, self.b1, self.b2, self.b3, self.b4 = ( - np.zeros(self.nl + 1), - np.zeros(self.nl + 1), - np.zeros(self.nl + 1), - np.zeros(self.nl + 1), - np.zeros(self.nl + 1), - np.zeros(self.nl + 1), + np.zeros(nl + 1), + np.zeros(nl + 1), + np.zeros(nl + 1), + np.zeros(nl + 1), + np.zeros(nl + 1), + np.zeros(nl + 1), ) - for il in range(1, self.nl + 1): + for il in range(1, nl + 1): self.a1[il] = (il + 1) * (il + 2) / (2 * il + 1) / (2 * il + 3) self.a2[il] = -(il - 1) * il / (2 * il + 1) / (2 * il - 1) self.b1[il] = (-il * (il + 1) / 2 - (il + 1)) / (2 * il + 1) / (2 * il + 3) @@ -370,7 +357,13 @@ def calc_ros_i(self, flm: Array, power: int) -> Array: :return: the Rosenbluth integral """ - return 4 * jnp.pi * self.v**-power * jnp.cumsum(self.v[None, :] ** (2.0 + power) * flm, axis=1) * self.dv + return ( + 4 + * jnp.pi + * self.grid.v**-power + * jnp.cumsum(self.grid.v[None, :] ** (2.0 + power) * flm, axis=1) + * self.grid.dv + ) def calc_ros_j(self, flm: Array, power: int) -> Array: r""" @@ -382,9 +375,9 @@ def calc_ros_j(self, flm: Array, power: int) -> Array: return ( 4 * jnp.pi - * self.v[None, :] ** -power - * jnp.cumsum((self.v[None, :] ** (2.0 + power) * flm)[:, ::-1], axis=1)[:, ::-1] - * self.dv + * self.grid.v[None, :] ** -power + * jnp.cumsum((self.grid.v[None, :] ** (2.0 + power) * flm)[:, ::-1], axis=1)[:, ::-1] + * self.grid.dv ) def get_ee_offdiagonal_contrib(self, t, y: Array, args: dict) -> Array: @@ -424,14 +417,17 @@ def get_ee_diagonal_contrib(self, f0: Array) -> Array: diag_term1 = 8 * jnp.pi * f0 - lower_d2dv2 = (i2 + jm1) / (3.0 * self.v[None, :]) / self.dv**2.0 - diag_d2dv2 = (i2 + jm1) / (3.0 * self.v[None, :]) / self.dv**2.0 - upper_d2dv2 = (i2 + jm1) / (3.0 * self.v[None, :]) / self.dv**2.0 + v = self.grid.v[None, :] + dv = self.grid.dv + + lower_d2dv2 = (i2 + jm1) / (3.0 * v) / dv**2.0 + diag_d2dv2 = (i2 + jm1) / (3.0 * v) / dv**2.0 + upper_d2dv2 = (i2 + jm1) / (3.0 * v) / dv**2.0 - diag_angular = -(-i2 + 2 * jm1 + 3 * i0) / (3.0 * self.v[None, :] ** 3.0) + diag_angular = -(-i2 + 2 * jm1 + 3 * i0) / (3.0 * v**3.0) - lower_ddv = (-i2 + 2 * jm1 + 3 * i0) / (3.0 * self.v[None, :] ** 2.0) / 2 / self.dv - upper_ddv = (-i2 + 2 * jm1 + 3 * i0) / (3.0 * self.v[None, :] ** 2.0) / 2 / self.dv + lower_ddv = (-i2 + 2 * jm1 + 3 * i0) / (3.0 * v**2.0) / 2 / dv + upper_ddv = (-i2 + 2 * jm1 + 3 * i0) / (3.0 * v**2.0) / 2 / dv # adding spatial differencing coefficients here # 1 -2 1 for d2dv2 @@ -463,18 +459,18 @@ def __call__(self, Z, ni, f0, f10, dt, include_ee_offdiag_explicitly=True): 2. The ee collision operator is ignored and the Z* scaling is used instead """ - for il in range(1, self.nl + 1): - ei_diag = -il * (il + 1) / 2.0 * (Z[:, None] ** 2.0) * ni[:, None] / self.v[None, :] ** 3.0 + v = self.grid.v[None, :] + dv = self.grid.dv + for il in range(1, self.grid.nl + 1): + ei_diag = -il * (il + 1) / 2.0 * (Z[:, None] ** 2.0) * ni[:, None] / v**3.0 if self.ee: ee_diag, ee_lower, ee_upper = self.get_ee_diagonal_contrib(f0) pad_f0 = jnp.concatenate([f0[:, 1::-1], f0], axis=1) # - d2dv2 = ( - 0.5 / self.v[None, :] * jnp.gradient(jnp.gradient(pad_f0, self.dv, axis=1), self.dv, axis=1)[:, 2:] - ) + d2dv2 = 0.5 / v * jnp.gradient(jnp.gradient(pad_f0, dv, axis=1), dv, axis=1)[:, 2:] - ddv = self.v[None, :] ** -2.0 * jnp.gradient(pad_f0, self.dv, axis=1)[:, 2:] + ddv = v**-2.0 * jnp.gradient(pad_f0, dv, axis=1)[:, 2:] diag = 1 - dt * (self.nuei_coeff * ei_diag + self.nuee_coeff * ee_diag) lower = -dt * self.nuee_coeff * ee_lower diff --git a/adept/vfp1d/grid.py b/adept/vfp1d/grid.py new file mode 100644 index 00000000..82f1ddfc --- /dev/null +++ b/adept/vfp1d/grid.py @@ -0,0 +1,136 @@ +"""Configuration-space and velocity-space grid for VFP-1D simulations.""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING + +import equinox as eqx +import jax.numpy as jnp +import numpy as np + +if TYPE_CHECKING: + from adept.normalization import PlasmaNormalization + + +class Grid(eqx.Module): + """Spatial, velocity, and temporal grid for VFP-1D. + + Simpler than the Vlasov-1D grid: no FFT / Fourier-dual arrays are needed + because VFP-1D does not use spectral solvers. + + All derived quantities (dx, dv, nt, arrays, etc.) are computed in __init__ + from the minimal set of input parameters. + """ + + # Spatial grid + xmin: float + xmax: float + nx: int + dx: float + x: jnp.ndarray # cell centres (nx,) + x_edge: jnp.ndarray # cell edges (nx+1,) + + # Velocity grid (positive-only, 0 to vmax) + nv: int + vmax: float + dv: float + v: jnp.ndarray # cell centres (nv,) + v_edge: jnp.ndarray # cell edges (nv-1,) + + # Temporal grid + tmin: float + tmax: float # actual tmax (aligned to dt) + dt: float + nt: int + max_steps: int + t: jnp.ndarray + + # Physics / mode parameters stored on the grid for convenience + nl: int # number of Legendre harmonics + boundary: str # "periodic" or "reflective" + + def __init__( + self, + *, + xmin: float, + xmax: float, + nx: int, + tmin: float, + tmax: float, + dt: float, + nv: int, + vmax: float, + nl: int, + boundary: str = "periodic", + ): + # -- Spatial ---------------------------------------------------------- + self.xmin = xmin + self.xmax = xmax + self.nx = nx + self.dx = xmax / nx + + self.x = jnp.linspace(xmin + self.dx / 2, xmax - self.dx / 2, nx) + self.x_edge = jnp.linspace(xmin, xmax, nx + 1) + + # -- Velocity --------------------------------------------------------- + self.nv = nv + self.vmax = vmax + self.dv = vmax / nv + + self.v = jnp.linspace(self.dv / 2, vmax - self.dv / 2, nv) + self.v_edge = 0.5 * (self.v[1:] + self.v[:-1]) + + # -- Temporal --------------------------------------------------------- + self.tmin = tmin + self.dt = dt + self.nt = int(tmax / dt) + 1 + + max_steps = 1e6 + if self.nt > max_steps: + print(f"Requested {self.nt} steps, only running {int(max_steps)} steps") + self.max_steps = int(max_steps) + else: + self.max_steps = self.nt + 4 + + self.tmax = self.dt * self.nt + self.t = jnp.linspace(0, self.tmax, self.nt) + + # -- Physics ---------------------------------------------------------- + self.nl = nl + self.boundary = boundary + + @staticmethod + def from_config(cfg_grid: dict, norm: PlasmaNormalization) -> Grid: + """Construct Grid from config dict and plasma normalization. + + Args: + cfg_grid: The ``cfg["grid"]`` dict (raw, with unit strings). + norm: Plasma normalization (provides L0, tau, T0 for unit conversion). + """ + from adept.normalization import normalize + + xmax = normalize(cfg_grid["xmax"], norm, dim="x") + xmin = normalize(cfg_grid["xmin"], norm, dim="x") + tmax = normalize(cfg_grid["tmax"], norm, dim="t") + dt = normalize(cfg_grid["dt"], norm, dim="t") + + beta = 1.0 / norm.speed_of_light_norm() + vmax = 8.0 * beta / np.sqrt(2.0) + + return Grid( + xmin=xmin, + xmax=xmax, + nx=cfg_grid["nx"], + tmin=0.0, + tmax=tmax, + dt=dt, + nv=cfg_grid["nv"], + vmax=vmax, + nl=cfg_grid["nl"], + boundary=cfg_grid.get("boundary", "periodic"), + ) + + def as_dict(self) -> dict: + """Return all grid fields as a plain dict (delegates to ``dataclasses.asdict``).""" + return asdict(self) diff --git a/adept/vfp1d/helpers.py b/adept/vfp1d/helpers.py index beff225b..c0b355ac 100644 --- a/adept/vfp1d/helpers.py +++ b/adept/vfp1d/helpers.py @@ -3,12 +3,12 @@ import numpy as np -from astropy.units import Quantity as _Q from jax import Array from jax import numpy as jnp from scipy.special import gamma from adept._base_ import get_envelope +from adept.normalization import PlasmaNormalization, normalize # ideally this should be passed as as an argument and not re-initialised from adept.vfp1d.vector_field import OSHUN1D @@ -57,15 +57,15 @@ def calc_logLambda( """ if isinstance(cfg["units"]["logLambda"], str): if cfg["units"]["logLambda"].casefold() == "nrl": - log_ne = np.log(ne.to("1/cm^3").value) - log_Te = np.log(Te.to("eV").value) + log_ne = np.log(ne.to("1/cc").magnitude) + log_Te = np.log(Te.to("eV").magnitude) log_Z = np.log(Z) logLambda_ee = max( 2.0, 23.5 - 0.5 * log_ne + 1.25 * log_Te - np.sqrt(1e-5 + 0.0625 * (log_Te - 2.0) ** 2.0) ) - if Te.to("eV").value > 10 * Z**2.0: + if Te.to("eV").magnitude > 10 * Z**2.0: logLambda_ei = max(2.0, 24.0 - 0.5 * log_ne + log_Te) else: logLambda_ei = max(2.0, 23.0 - 0.5 * log_ne + 1.5 * log_Te - log_Z) @@ -117,20 +117,24 @@ def _initialize_distribution_( return f, vax -def _initialize_total_distribution_(cfg: dict, cfg_grid: dict) -> tuple[Array, Array, Array]: +def _initialize_total_distribution_( + cfg: dict, grid, beta: float, norm: PlasmaNormalization +) -> tuple[Array, Array, Array]: """ This function initializes the distribution function as a sum of the individual species :param cfg: Dict - :param cfg_grid: Dict + :param grid: Grid object + :param beta: vth/c (dimensionless thermal velocity) + :param norm: Plasma normalization (used for unit conversion of spatial profiles) :return: distribution function, density profile (nx, nv), (nx,) """ params = cfg["density"] - prof_total = {"n": np.zeros([cfg_grid["nx"]]), "T": np.zeros([cfg_grid["nx"]])} + prof_total = {"n": np.zeros([grid.nx]), "T": np.zeros([grid.nx])} - f0 = np.zeros([cfg_grid["nx"], cfg_grid["nv"]]) - f10 = np.zeros([cfg_grid["nx"] + 1, cfg_grid["nv"]]) + f0 = np.zeros([grid.nx, grid.nv]) + f10 = np.zeros([grid.nx + 1, grid.nv]) species_found = False for name, species_params in cfg["density"].items(): if name.startswith("species-"): @@ -145,14 +149,13 @@ def _initialize_total_distribution_(cfg: dict, cfg_grid: dict) -> tuple[Array, A profs[k] = species_params[k]["baseline"] * np.ones_like(prof_total[k]) elif species_params[k]["basis"] == "tanh": - center = (_Q(species_params[k]["center"]) / cfg["units"]["derived"]["x0"]).to("").value - width = (_Q(species_params[k]["width"]) / cfg["units"]["derived"]["x0"]).to("").value - rise = (_Q(species_params[k]["rise"]) / cfg["units"]["derived"]["x0"]).to("").value + center = normalize(species_params[k]["center"], norm, dim="x") + width = normalize(species_params[k]["width"], norm, dim="x") + rise = normalize(species_params[k]["rise"], norm, dim="x") left = center - width * 0.5 right = center + width * 0.5 - # rise = species_params[k]["rise"] - prof = get_envelope(rise, rise, left, right, cfg_grid["x"]) + prof = get_envelope(rise, rise, left, right, grid.x) if species_params[k]["bump_or_trough"] == "trough": prof = 1 - prof @@ -161,51 +164,45 @@ def _initialize_total_distribution_(cfg: dict, cfg_grid: dict) -> tuple[Array, A elif species_params[k]["basis"] == "sine": baseline = species_params[k]["baseline"] amp = species_params[k]["amplitude"] - ll = (_Q(species_params[k]["wavelength"]) / cfg["units"]["derived"]["x0"]).to("").value + ll = normalize(species_params[k]["wavelength"], norm, dim="x") - profs[k] = baseline * (1.0 + amp * jnp.sin(2 * jnp.pi / ll * cfg_grid["x"])) + profs[k] = baseline * (1.0 + amp * jnp.sin(2 * jnp.pi / ll * grid.x)) elif species_params[k]["basis"] == "cosine": baseline = species_params[k]["baseline"] amp = species_params[k]["amplitude"] - ll = (_Q(species_params[k]["wavelength"]) / cfg["units"]["derived"]["x0"]).to("").value + ll = normalize(species_params[k]["wavelength"], norm, dim="x") - profs[k] = baseline * (1.0 + amp * jnp.cos(2 * jnp.pi / ll * cfg_grid["x"])) + profs[k] = baseline * (1.0 + amp * jnp.cos(2 * jnp.pi / ll * grid.x)) else: raise NotImplementedError - profs["n"] *= (cfg["units"]["derived"]["ne"] / cfg["units"]["derived"]["n0"]).value - prof_total["n"] += profs["n"] # Distribution function temp_f0, _ = _initialize_distribution_( - nv=int(cfg_grid["nv"]), + nv=grid.nv, m=m, - vth=cfg_grid["beta"], - vmax=cfg_grid["vmax"], + vth=beta, + vmax=grid.vmax, n_prof=profs["n"], T_prof=profs["T"], ) f0 += temp_f0 # initialize f1 by taking a big time step while keeping f0 fixed (essentially sets electron inertia to 0) - # I don't like having to reinitialise oshun to get helper functions, - # either we pass as an argument or refactor # TODO: add switch to opt in/out - oshun = OSHUN1D(cfg) + oshun = OSHUN1D(cfg, grid=grid) big_dt = 1e12 ni = prof_total["n"] / cfg["units"]["Z"] - # f10 lives at cell edges (nx+1, nv), use ddx_c2e to get derivative at edges - nx = cfg["grid"]["nx"] - Z_edge = oshun.interp_c2e(jnp.ones(nx)) + Z_edge = oshun.interp_c2e(jnp.ones(grid.nx)) ni_edge = oshun.interp_c2e(jnp.array(ni)) f0_at_edges = oshun.interp_c2e(jnp.array(f0)) - f10_star = -big_dt * oshun.v[None, :] * oshun.ddx_c2e(jnp.array(f0)) + f10_star = -big_dt * grid.v[None, :] * oshun.ddx_c2e(jnp.array(f0)) f10_from_adv = oshun.ei( Z=Z_edge, ni=ni_edge, diff --git a/adept/vfp1d/storage.py b/adept/vfp1d/storage.py index ef781b52..50c74b36 100644 --- a/adept/vfp1d/storage.py +++ b/adept/vfp1d/storage.py @@ -3,11 +3,12 @@ import numpy as np import xarray as xr -from astropy.units import Quantity as _Q from diffrax import Solution from jax import numpy as jnp from matplotlib import pyplot as plt +from adept.normalization import PlasmaNormalization, normalize + def calc_EH(this_Z: int, this_wt: float) -> float: """ @@ -96,8 +97,8 @@ def store_fields(cfg: dict, binary_dir: str, fields: dict, this_t: np.ndarray, p :return: """ - xax = cfg["units"]["derived"]["x0"].to("micron").value * cfg["grid"]["x"] - tax = this_t * cfg["units"]["derived"]["tp0"].to("ps").value + xax = cfg["units"]["derived"]["x0"].to("micron").magnitude * cfg["grid"]["x"] + tax = this_t * cfg["units"]["derived"]["tp0"].to("ps").magnitude if any(x in ["x", "kx"] for x in cfg["save"][prefix].keys()): crds = set(cfg["save"][prefix].keys()) - {"t", "func"} @@ -146,7 +147,7 @@ def calc_kappa(cfg: dict, T: xr.DataArray, q: xr.DataArray, n: xr.DataArray) -> / n.data / T.data / np.gradient(T.data, cfg["grid"]["dx"], axis=1) - * (cfg["units"]["derived"]["nuei_epphaines"] / cfg["units"]["derived"]["wp0"]).to("").value + * (cfg["units"]["derived"]["nuei_epphaines"] / cfg["units"]["derived"]["wp0"]).to("").magnitude ) return xr.DataArray(kappa, coords=(("t (ps)", T.coords["t (ps)"].data), ("x (um)", T.coords["x (um)"].data))) @@ -181,10 +182,10 @@ def store_f(cfg: dict, this_t: dict, td: str, ys: dict) -> xr.Dataset: :param ys: :return: """ - x0_um = cfg["units"]["derived"]["x0"].to("micron").value + x0_um = cfg["units"]["derived"]["x0"].to("micron").magnitude xax_center = x0_um * cfg["grid"]["x"] xax_edge = x0_um * cfg["grid"]["x_edge"] - tax = this_t["electron"] * cfg["units"]["derived"]["tp0"].to("ps").value + tax = this_t["electron"] * cfg["units"]["derived"]["tp0"].to("ps").magnitude das = {} for dist in ys["electron"].keys(): @@ -254,7 +255,7 @@ def post_process(soln: Solution, cfg: dict, td: str, args: dict | None = None) - plt.close() elif k.startswith("default"): - tax = soln.ts["default"] * cfg["units"]["derived"]["tp0"].to("ps").value + tax = soln.ts["default"] * cfg["units"]["derived"]["tp0"].to("ps").magnitude scalars_xr = xr.Dataset( {k: xr.DataArray(v, coords=(("t (ps)", tax),)) for k, v in soln.ys["default"].items()} ) @@ -354,17 +355,18 @@ def dist_save_func(t, y, args): return dist_save_func -def get_save_quantities(cfg: dict) -> dict: +def get_save_quantities(cfg: dict, norm: PlasmaNormalization) -> dict: """ This function updates the config with the quantities required for the diagnostics and saving routines :param cfg: + :param norm: Plasma normalization :return: The updated config """ for k in cfg["save"].keys(): # this can be fields or electron or scalar? - tmin = (_Q(cfg["save"][k]["t"]["tmin"]) / cfg["units"]["derived"]["tp0"]).to("").value - tmax = (_Q(cfg["save"][k]["t"]["tmax"]) / cfg["units"]["derived"]["tp0"]).to("").value + tmin = normalize(cfg["save"][k]["t"]["tmin"], norm, dim="t") + tmax = normalize(cfg["save"][k]["t"]["tmax"], norm, dim="t") cfg["save"][k]["t"]["ax"] = jnp.linspace(tmin, tmax, cfg["save"][k]["t"]["nt"]) if k.startswith("fields"): diff --git a/adept/vfp1d/vector_field.py b/adept/vfp1d/vector_field.py index 538916ba..3a062721 100644 --- a/adept/vfp1d/vector_field.py +++ b/adept/vfp1d/vector_field.py @@ -13,20 +13,14 @@ class OSHUN1D: """ - def __init__(self, cfg: dict): - self.cfg = cfg - self.v = cfg["grid"]["v"] - self.dv = cfg["grid"]["dv"] - - self.dx = cfg["grid"]["dx"] - self.dt = cfg["grid"]["dt"] - self.nx = cfg["grid"]["nx"] - self.boundary = cfg["grid"].get("boundary", "periodic") + def __init__(self, cfg: dict, grid): + self.grid = grid + self.e_solver = cfg["terms"]["e_solver"] self.ampere_coeff = 1e-6 - self.lb = F0Collisions(cfg) - self.ei = FLMCollisions(cfg) + self.lb = F0Collisions(cfg, grid) + self.ei = FLMCollisions(cfg, grid) self.large_eps = 1e-6 self.eps = 1e-12 @@ -42,7 +36,7 @@ def ddv(self, f: Array) -> Array: Array: df/dv """ temp = jnp.concatenate([f[:, :1], f], axis=1) - return jnp.gradient(temp, self.dv, axis=1)[:, 1:] + return jnp.gradient(temp, self.grid.dv, axis=1)[:, 1:] def ddv_f1(self, f: Array) -> Array: """ @@ -55,21 +49,21 @@ def ddv_f1(self, f: Array) -> Array: Array: df1/dv """ temp = jnp.concatenate([-f[:, :1], f], axis=1) - return jnp.gradient(temp, self.dv, axis=1)[:, 1:] + return jnp.gradient(temp, self.grid.dv, axis=1)[:, 1:] def ddx_c2e(self, f: Array) -> Array: """d/dx of center field, evaluated at edges. (nx, ...) -> (nx+1, ...)""" - if self.boundary == "periodic": + if self.grid.boundary == "periodic": left_ghost = f[-1:] right_ghost = f[:1] else: # reflective left_ghost = f[:1] right_ghost = f[-1:] - return jnp.diff(f, axis=0, prepend=left_ghost, append=right_ghost) / self.dx + return jnp.diff(f, axis=0, prepend=left_ghost, append=right_ghost) / self.grid.dx def ddx_e2c(self, f: Array) -> Array: """d/dx of edge field, evaluated at centers. (nx+1, ...) -> (nx, ...)""" - return jnp.diff(f, axis=0) / self.dx + return jnp.diff(f, axis=0) / self.grid.dx def interp_e2c(self, f: Array) -> Array: """Average edge values to centers. (nx+1, ...) -> (nx, ...)""" @@ -77,7 +71,7 @@ def interp_e2c(self, f: Array) -> Array: def interp_c2e(self, f: Array) -> Array: """Average center values to edges. (nx, ...) -> (nx+1, ...)""" - if self.boundary == "periodic": + if self.grid.boundary == "periodic": left_ghost = f[-1:] right_ghost = f[:1] else: # reflective @@ -97,7 +91,7 @@ def calc_j(self, f1: Array) -> Array: Array: j(x) """ - return -4 * jnp.pi / 3.0 * jnp.sum(f1 * self.v[None, :] ** 3.0, axis=1) * self.dv + return -4 * jnp.pi / 3.0 * jnp.sum(f1 * self.grid.v[None, :] ** 3.0, axis=1) * self.grid.dv def implicit_e_solve(self, Z: Array, ni: Array, f0: Array, f10: Array, e: Array) -> Array: """ @@ -124,7 +118,7 @@ def implicit_e_solve(self, Z: Array, ni: Array, f0: Array, f10: Array, e: Array) f0_at_edges = self.interp_c2e(f0) # calculate j without any e field - f10_after_coll = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=f10, dt=self.dt) + f10_after_coll = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=f10, dt=self.grid.dt) j0 = self.calc_j(f10_after_coll) # get perturbation @@ -132,14 +126,14 @@ def implicit_e_solve(self, Z: Array, ni: Array, f0: Array, f10: Array, e: Array) # calculate effect of dex _, f10_after_dex = self.push_edfdv(f0, f10, de) - f10_after_dex = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=f10_after_dex, dt=self.dt) + f10_after_dex = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=f10_after_dex, dt=self.grid.dt) jx_dx = self.calc_j(f10_after_dex) # directly solve for ex new_e = -j0 * de / (jx_dx - j0) # For reflective BCs, E=0 at boundary edges (j0=jx_dx=0 there causes 0/0) - if self.boundary == "reflective": + if self.grid.boundary == "reflective": new_e = new_e.at[0].set(0.0).at[-1].set(0.0) return new_e @@ -153,14 +147,15 @@ def nonlinear_implicit_e_f0_f1_operator(self, y, args): new_f0, new_f1, new_e = y["f0"], y["f1"], y["e"] old_f0, old_f1, old_e = args["f0"], args["f1"], args["e"] - res_f0 = (new_f0 - old_f0) / self.dt - new_e[:, None] / 3 * (self.ddv_f1(new_f1) + 2 / self.v * new_f1) + v = self.grid.v + res_f0 = (new_f0 - old_f0) / self.grid.dt - new_e[:, None] / 3 * (self.ddv_f1(new_f1) + 2 / v * new_f1) # C_f1 = self.step_f10_coll(f1) res_f1 = ( - (new_f1 - old_f1) / self.dt - new_e[:, None] * self.ddv(new_f0) + 1e-4 * new_f1 / self.v[None, :] ** 3.0 + (new_f1 - old_f1) / self.grid.dt - new_e[:, None] * self.ddv(new_f0) + 1e-4 * new_f1 / v[None, :] ** 3.0 ) - new_j = -4 * jnp.pi / 3.0 * jnp.sum(new_f1 * self.v[None, :] ** 3.0, axis=1) * self.dv - res_e = (new_e - old_e) / self.dt + new_j + new_j = -4 * jnp.pi / 3.0 * jnp.sum(new_f1 * v[None, :] ** 3.0, axis=1) * self.grid.dv + res_e = (new_e - old_e) / self.grid.dt + new_j # return {"f0": res_f0, "f1": res_f1, "e": res_e} @@ -231,12 +226,12 @@ def _edfdv_(self, t: float, y: dict, args: dict) -> dict: e_at_centers = self.interp_e2c(e_field) # (nx+1,) -> (nx,) g00 = self.ddv(f0_at_edges) - h10_c = 2.0 / self.v * f10_at_centers + self.ddv_f1(f10_at_centers) + h10_c = 2.0 / self.grid.v * f10_at_centers + self.ddv_f1(f10_at_centers) df0dt_e = e_at_centers[:, None] / 3.0 * h10_c df10dt_e = e_field[:, None] * g00 - if self.boundary == "reflective": + if self.grid.boundary == "reflective": df10dt_e = df10dt_e.at[0].set(0.0).at[-1].set(0.0) return {"f0": df0dt_e, "f10": df10dt_e} @@ -257,8 +252,8 @@ def push_edfdv(self, f0, f10, e): diffrax.ODETerm(self._edfdv_), solver=diffrax.Tsit5(), t0=0.0, - t1=self.dt, - dt0=self.dt, + t1=self.grid.dt, + dt0=self.grid.dt, y0={"f0": f0, "f10": f10}, args={"e": e}, ) @@ -283,10 +278,10 @@ def _vdfdx_(self, t: float, y: dict, args: dict) -> dict: f0 = y["f0"] f10 = y["f10"] - df0dt_sa = -self.v[None, :] / 3.0 * self.ddx_e2c(f10) # (nx+1,nv) -> (nx,nv) - df10dt_sa = -self.v[None, :] * self.ddx_c2e(f0) # (nx,nv) -> (nx+1,nv) + df0dt_sa = -self.grid.v[None, :] / 3.0 * self.ddx_e2c(f10) # (nx+1,nv) -> (nx,nv) + df10dt_sa = -self.grid.v[None, :] * self.ddx_c2e(f0) # (nx,nv) -> (nx+1,nv) - if self.boundary == "reflective": + if self.grid.boundary == "reflective": df10dt_sa = df10dt_sa.at[0].set(0.0).at[-1].set(0.0) return {"f0": df0dt_sa, "f10": df10dt_sa} @@ -307,8 +302,8 @@ def push_vdfdx(self, f0: Array, f10: Array) -> Array: diffrax.ODETerm(self._vdfdx_), solver=diffrax.Tsit5(), t0=0.0, - t1=self.dt, - dt0=self.dt, + t1=self.grid.dt, + dt0=self.grid.dt, y0={"f0": f0, "f10": f10}, ) return result.ys["f0"][-1], result.ys["f10"][-1] @@ -336,7 +331,7 @@ def __call__(self, t, y, args) -> dict: # explicit push for v df/dx f0_star, f10_star = self.push_vdfdx(f0, f10) # implicit solve f00 coll - f0_star = self.lb(None, f0_star, self.dt) + f0_star = self.lb(None, f0_star, self.grid.dt) # Interpolate center quantities to edges for FLM collision operator Z_edge = self.interp_c2e(Z) @@ -350,24 +345,24 @@ def __call__(self, t, y, args) -> dict: new_f0, new_f10 = self.push_edfdv(f0_star, f10_star, new_e) # solve f10 coll f0_at_edges = self.interp_c2e(f0_star) - new_f10 = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=new_f10, dt=self.dt) + new_f10 = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=new_f10, dt=self.grid.dt) elif self.e_solver == "edfdv-ampere-implicit": # implicit E, f0, f1 using a nonlinear iterative inversion new_f0, new_f10, new_e = self.implicit_e_f0_f1_solve(f0=f0_star, f1=f10_star, e=y["e"]) elif self.e_solver == "ampere": - new_e = y["e"] + self.dt * self.ampere_coeff * self.calc_j(f10_star) + new_e = y["e"] + self.grid.dt * self.ampere_coeff * self.calc_j(f10_star) # push e new_f0, new_f10 = self.push_edfdv(f0_star, f10_star, new_e) # solve f10 coll f0_at_edges = self.interp_c2e(new_f0) - new_f10 = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=new_f10, dt=self.dt) + new_f10 = self.ei(Z=Z_edge, ni=ni_edge, f0=f0_at_edges, f10=new_f10, dt=self.grid.dt) else: raise NotImplementedError # Enforce reflective BCs on f10 - if self.boundary == "reflective": + if self.grid.boundary == "reflective": new_f10 = new_f10.at[0].set(0.0).at[-1].set(0.0) return {"f0": new_f0, "f10": new_f10, "f11": y["f11"], "e": new_e, "b": y["b"], "Z": y["Z"], "ni": y["ni"]} diff --git a/tests/test_vfp1d/test_fp_relaxation.py b/tests/test_vfp1d/test_fp_relaxation.py index 48ef1a72..95ed037f 100644 --- a/tests/test_vfp1d/test_fp_relaxation.py +++ b/tests/test_vfp1d/test_fp_relaxation.py @@ -20,6 +20,7 @@ from jax import Array from adept.vfp1d.fokker_planck import F0Collisions +from adept.vfp1d.grid import Grid # ============================================================================= # Test configuration @@ -158,11 +159,23 @@ def make_vector_field( sc_iterations: int, ) -> eqx.Module: """Create an F0Collisions vector field for the given model/scheme combo.""" - # Build config cfg = self._make_config(grid, model_name, scheme_name, nu, sc_iterations) - # Create production class and adapter - collisions = F0Collisions(cfg) + # Build a vfp1d Grid with dummy spatial/temporal values (F0Collisions only uses velocity fields) + vfp_grid = Grid( + xmin=0.0, + xmax=1.0, + nx=1, + tmin=0.0, + tmax=1.0, + dt=1.0, + nv=grid.nv, + vmax=float(grid.vmax), + nl=1, + boundary="periodic", + ) + + collisions = F0Collisions(cfg, vfp_grid) return F0CollisionsVectorField(collisions=collisions, dt=dt) def _make_config( @@ -179,11 +192,6 @@ def _make_config( "CentralDifferencing": "central", } return { - "grid": { - "v": np.asarray(grid.v), - "dv": float(grid.dv), - "nv": grid.nv, - }, "terms": { "fokker_planck": { # Direct nuee_coeff override for dimensionless testing