Skip to content
20 changes: 20 additions & 0 deletions adept/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ def electron_debye_normalization(n0_str, T0_str):
x0 = (v0 / wp0).to("nm")

return PlasmaNormalization(m0=UREG.m_e, q0=UREG.e, n0=n0, T0=T0, L0=x0, v0=v0, tau=tau)


def skin_depth_normalization(n0_str, T0_str):
"""
Returns the VFP-1D normalization.
Unit quantities are:
- c/wp0 (collisionless skin depth)
- Electron thermal velocity
- 1/wp0
"""
n0 = UREG.Quantity(n0_str)
T0 = UREG.Quantity(T0_str)

wp0 = ((n0 * UREG.e**2.0 / (UREG.m_e * UREG.epsilon_0)) ** 0.5).to("rad/s")
tau = 1 / wp0

v0 = ((2.0 * T0 / UREG.m_e) ** 0.5).to("m/s")
x0 = (UREG.c / wp0).to("nm")

return PlasmaNormalization(m0=UREG.m_e, q0=UREG.e, n0=n0, T0=T0, L0=x0, v0=v0, tau=tau)
185 changes: 64 additions & 121 deletions adept/vfp1d/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
from astropy import constants as csts
from astropy import units as u
from astropy.units import Quantity as _Q
from diffrax import ODETerm, SaveAt, SubSaveAt, diffeqsolve
from jax import numpy as jnp

from adept._base_ import ADEPTModule, Stepper
from adept.normalization import UREG, skin_depth_normalization
from adept.utils import filter_scalars
from adept.vfp1d.grid import Grid
from adept.vfp1d.helpers import _initialize_total_distribution_, calc_logLambda
from adept.vfp1d.storage import get_save_quantities, post_process
from adept.vfp1d.vector_field import OSHUN1D
Expand All @@ -14,199 +14,142 @@
class BaseVFP1D(ADEPTModule):
def __init__(self, cfg) -> None:
super().__init__(cfg)
# n0 = critical density for 351nm (3ω Nd:glass)
self.plasma_norm = skin_depth_normalization("9.0663e21/cm^3", cfg["units"]["reference electron temperature"])
self.grid = Grid.from_config(cfg["grid"], self.plasma_norm)

def post_process(self, solver_result: dict, td: str) -> dict:
return post_process(solver_result["solver result"], cfg=self.cfg, td=td, args=self.args)

def write_units(self) -> dict:
ne = u.Quantity(self.cfg["units"]["reference electron density"]).to("1/cm^3")
ni = ne / self.cfg["units"]["Z"]
Te = u.Quantity(self.cfg["units"]["reference electron temperature"]).to("eV")
Ti = u.Quantity(self.cfg["units"]["reference ion temperature"]).to("eV")
norm = self.plasma_norm
Z = self.cfg["units"]["Z"]
# Should we change this to reference electron density? or allow it to be user set?
n0 = u.Quantity("9.0663e21/cm^3")
ion_species = self.cfg["units"]["Ion"]

wp0 = np.sqrt(n0 * csts.e.to("C") ** 2.0 / (csts.m_e * csts.eps0)).to("Hz")
tp0 = (1 / wp0).to("fs")

vth = np.sqrt(2 * Te / csts.m_e).to("m/s") # mean square velocity eq 4-51a in Shkarofsky

x0 = (csts.c / wp0).to("nm")

beta = vth / csts.c
ne = UREG.Quantity(self.cfg["units"]["reference electron density"]).to("1/cc")
Te_eV = norm.T0.to("eV")
logLambda_ei, logLambda_ee = calc_logLambda(
self.cfg, ne, Te_eV, Z, self.cfg["units"]["Ion"], force_ee_equal_ei=True
)

logLambda_ei, logLambda_ee = calc_logLambda(self.cfg, ne, Te, Z, ion_species, force_ee_equal_ei=True)
# Local aliases for quantities used in multiple expressions below
wp0 = (1 / norm.tau).to("rad/s")
vth = norm.v0.to("m/s")
beta = 1.0 / norm.speed_of_light_norm()
# Elementary charge in Gaussian CGS (pint cannot convert SI↔Gaussian charge dimensions)
e_gauss = UREG.Quantity(4.803204712570263e-10, "Fr")
Comment thread
jpbrodrick89 marked this conversation as resolved.

nD_NRL = 1.72e9 * Te.value**1.5 / np.sqrt(ne.value)
nD_Shkarofsky = np.exp(logLambda_ei) * Z / 9
ne_cc = ne.to("1/cc").magnitude
nD_NRL = 1.72e9 * Te_eV.magnitude**1.5 / np.sqrt(ne_cc)

nuei_shk = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / np.exp(logLambda_ei)
# JPB - Maybe comment which page/eq this is from? There are lots of collision times in NRL
# For example, nu_ei on page 32 does include Z^2
nuei_nrl = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / nD_NRL

lambda_mfp_shk = (vth / nuei_shk).to("micron")
lambda_mfp_nrl = (vth / nuei_nrl).to("micron")

nuei_epphaines = (
1
/ (
0.75
* np.sqrt(csts.m_e)
* Te**1.5
/ (np.sqrt(2 * np.pi) * ni * Z**2.0 * csts.e.gauss**4.0 * logLambda_ei)
* UREG.Quantity(1, "electron_mass") ** 0.5
* Te_eV**1.5
/ (np.sqrt(2 * np.pi) * (ne / Z) * Z**2.0 * e_gauss**4.0 * logLambda_ei)
)
).to("Hz")

all_quantities = {
"wp0": wp0,
"n0": n0,
"tp0": tp0,
"n0": norm.n0.to("1/cc"),
"tp0": norm.tau.to("fs"),
"ne": ne,
"vth": vth,
"Te": Te,
"Ti": Ti,
"Te": Te_eV,
"Ti": UREG.Quantity(self.cfg["units"]["reference ion temperature"]).to("eV"),
"logLambda_ei": logLambda_ei,
"logLambda_ee": logLambda_ee,
"beta": beta,
"x0": x0,
"x0": norm.L0.to("nm"),
"nuei_shk": nuei_shk,
"nuei_nrl": nuei_nrl,
"nuei_epphaines": nuei_epphaines,
"nuei_shk_norm": nuei_shk / wp0,
"nuei_nrl_norm": nuei_nrl / wp0,
"nuei_epphaines_norm": nuei_epphaines / wp0,
"lambda_mfp_shk": lambda_mfp_shk,
"lambda_mfp_nrl": lambda_mfp_nrl,
"nuei_shk_norm": (nuei_shk / wp0).to(""),
"nuei_nrl_norm": (nuei_nrl / wp0).to(""),
"nuei_epphaines_norm": (nuei_epphaines / wp0).to(""),
"lambda_mfp_shk": (vth / nuei_shk).to("micron"),
"lambda_mfp_nrl": (vth / nuei_nrl).to("micron"),
"lambda_mfp_epphaines": (vth / nuei_epphaines).to("micron"),
"nD_NRL": nD_NRL,
"nD_Shkarofsky": nD_Shkarofsky,
"nD_Shkarofsky": np.exp(logLambda_ei) * Z / 9,
}

self.cfg["units"]["derived"] = all_quantities
self.cfg["grid"]["beta"] = beta.value
self.cfg["grid"]["beta"] = beta

return {k: str(v) for k, v in all_quantities.items()}

def get_derived_quantities(self):
"""
This function just updates the config with the derived quantities that are only integers or strings.
"""Sync scalar grid values into cfg for logging.

This is run prior to the log params step

:param cfg_grid:
:return:
This is run prior to the log params step.
"""
cfg_grid = self.cfg["grid"]
grid = self.grid

# Default save.*.t.tmin/tmax to grid values (preserves unit strings)
# Default save.*.t.tmin/tmax to computed grid values
for save_type in self.cfg.get("save", {}).keys():
if "t" in self.cfg["save"][save_type]:
t_cfg = self.cfg["save"][save_type]["t"]
t_cfg.setdefault("tmin", cfg_grid.get("tmin", "0ps"))
t_cfg.setdefault("tmax", cfg_grid["tmax"])

cfg_grid["xmax"] = (_Q(cfg_grid["xmax"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value
cfg_grid["xmin"] = (_Q(cfg_grid["xmin"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value
cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]

# sqrt(2 * k * T / m)
cfg_grid["vmax"] = (
8
* np.sqrt((_Q(self.cfg["units"]["reference electron temperature"]) / (csts.m_e * csts.c**2.0)).to("")).value
)

cfg_grid["dv"] = cfg_grid["vmax"] / cfg_grid["nv"]

cfg_grid["tmax"] = (_Q(cfg_grid["tmax"]) / self.cfg["units"]["derived"]["tp0"]).to("").value
cfg_grid["dt"] = (_Q(cfg_grid["dt"]) / self.cfg["units"]["derived"]["tp0"]).to("").value

cfg_grid["nt"] = int(cfg_grid["tmax"] / cfg_grid["dt"]) + 1
# Merge scalar grid values into cfg for logging (arrays come later in get_solver_quantities)
cfg_grid.update(filter_scalars(grid.as_dict()))

if cfg_grid["nt"] > 1e6:
cfg_grid["max_steps"] = int(1e6)
print(r"Only running $10^6$ steps")
else:
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

cfg_grid["tmax"] = cfg_grid["dt"] * cfg_grid["nt"]

print("tmax", cfg_grid["tmax"], "dt", cfg_grid["dt"])
print("xmax", cfg_grid["xmax"], "dx", cfg_grid["dx"])
print("tmax", grid.tmax, "dt", grid.dt)
print("xmax", grid.xmax, "dx", grid.dx)

self.cfg["grid"] = cfg_grid

def get_solver_quantities(self):
"""
This function just updates the config with the derived quantities that are arrays

This is run after the log params step

:param cfg_grid:
:return:
"""
cfg_grid = self.cfg["grid"]

cfg_grid.setdefault("boundary", "periodic")

cfg_grid = {
**cfg_grid,
**{
"x": jnp.linspace(
cfg_grid["xmin"] + cfg_grid["dx"] / 2, cfg_grid["xmax"] - cfg_grid["dx"] / 2, cfg_grid["nx"]
),
"x_edge": jnp.linspace(cfg_grid["xmin"], cfg_grid["xmax"], cfg_grid["nx"] + 1),
"t": jnp.linspace(0, cfg_grid["tmax"], cfg_grid["nt"]),
"v": jnp.linspace(cfg_grid["dv"] / 2, cfg_grid["vmax"] - cfg_grid["dv"] / 2, cfg_grid["nv"]),
"kx": jnp.fft.fftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi,
"kxr": jnp.fft.rfftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi,
},
}

self.cfg["grid"] = cfg_grid
"""Merge all grid values (including arrays) into cfg['grid'] for backward compatibility."""
self.cfg["grid"].update(self.grid.as_dict())

def init_state_and_args(self) -> dict:
"""
This function initializes the state
grid = self.grid
beta = 1.0 / self.plasma_norm.speed_of_light_norm()
f0, f10, ne_prof = _initialize_total_distribution_(self.cfg, grid, beta, self.plasma_norm)

:param cfg:
:return:
"""
nx = self.cfg["grid"]["nx"]
nv = self.cfg["grid"]["nv"]
f0, f10, ne_prof = _initialize_total_distribution_(self.cfg, self.cfg["grid"])
# Scale density to physical ne (f10 is invariant in the big-dt collision limit)
ne = UREG.Quantity(self.cfg["units"]["reference electron density"])
ne_over_n0 = (ne / self.plasma_norm.n0).to("").magnitude
f0 *= ne_over_n0
ne_prof *= ne_over_n0

state = {"f0": f0}
# not currently necessary but kept for completeness
for il in range(1, self.cfg["grid"]["nl"] + 1):
for il in range(1, grid.nl + 1):
for im in range(0, il + 1):
state[f"f{il}{im}"] = jnp.zeros((nx + 1, nv))
state[f"f{il}{im}"] = jnp.zeros((grid.nx + 1, grid.nv))

state["f10"] = f10

for field in ["e", "b"]:
state[field] = jnp.zeros(nx + 1)
state[field] = jnp.zeros(grid.nx + 1)

state["Z"] = jnp.ones(nx)
state["Z"] = jnp.ones(grid.nx)
state["ni"] = ne_prof / self.cfg["units"]["Z"]

self.state = state
self.args = {"drivers": self.cfg["drivers"]}

def init_diffeqsolve(self):
self.cfg = get_save_quantities(self.cfg)
grid = self.grid
self.cfg = get_save_quantities(self.cfg, self.plasma_norm)
self.time_quantities = {
"t0": 0.0,
"t1": self.cfg["grid"]["tmax"],
"max_steps": self.cfg["grid"]["max_steps"],
"save_t0": 0.0,
"save_t1": self.cfg["grid"]["tmax"],
"save_nt": self.cfg["grid"]["tmax"],
"t1": grid.tmax,
"max_steps": grid.max_steps,
}
self.diffeqsolve_quants = dict(
terms=ODETerm(OSHUN1D(self.cfg)),
terms=ODETerm(OSHUN1D(self.cfg, grid=grid)),
solver=Stepper(),
saveat=dict(subs={k: SubSaveAt(ts=v["t"]["ax"], fn=v["func"]) for k, v in self.cfg["save"].items()}),
)
Expand All @@ -217,8 +160,8 @@ def __call__(self, trainable_modules: dict, args: dict):
solver=self.diffeqsolve_quants["solver"],
t0=self.time_quantities["t0"],
t1=self.time_quantities["t1"],
max_steps=self.cfg["grid"]["max_steps"],
dt0=self.cfg["grid"]["dt"],
max_steps=self.grid.max_steps,
dt0=self.grid.dt,
y0=self.state,
args=args,
saveat=SaveAt(**self.diffeqsolve_quants["saveat"]),
Expand Down
Loading
Loading