diff --git a/.gitignore b/.gitignore index e26e1a609..bb2c58d98 100644 --- a/.gitignore +++ b/.gitignore @@ -274,3 +274,6 @@ tests/validation/ CLAUDE.md .claude/ .mcp.json + +# pickle files +*.pkl diff --git a/docs/How-to/usage.md b/docs/How-to/usage.md index f1bd17af4..91937860e 100644 --- a/docs/How-to/usage.md +++ b/docs/How-to/usage.md @@ -172,6 +172,39 @@ Use the CLI to package the results of a grid into a zip file; e.g. for sharing o ```console proteus grid-pack -o output/grid_demo/ ``` +## Postprocessing of grid results + +Results from a PROTEUS grid can be post-processed using the `proteus grid-analyse` command. Running the analysis generates summary CSV files containing all tested input parameters and the final time-step outputs for each simulation, along with a status overview plot showing the distribution of simulation outcomes (e.g., completed, running, error) and empirical cumulative distribution function (ECDF) plots that summarize selected output variables across all completed simulations. (For more details on ECDF plots, see the [Seaborn `ecdfplot` documentation](https://seaborn.pydata.org/generated/seaborn.ecdfplot.html).) + +Before running the command, configure the post-processing options in your grid configuration file:`input/ensembles/example.grid.toml`. The following options can be set: + +- `update_csv` — Generate or update the CSV summary files +- `plot_format` — Output format for plots (e.g., `png`, `pdf`) +- `plot_status` — Enable/disable the status summary plot +- `plot_ecdf` — Enable/disable ECDF plot generation +- `output_variables` — List of variables to include in ECDF plots + +> **Note:** +> The variables specified in `output_variables` must match column names in `runtime_helpfile.csv`. +> The `solidification_time` is the only variable computed during post-processing and is not directly stored in the simulation output. + +--- + +To post-process a grid and generate ECDF plots for further analysis, run the following command: + +``` +proteus grid-analyse --config input/ensembles/example.grid.toml +``` + + +After execution, a `post_processing/` directory is created inside the grid folder with the following structure: + +- `extracted_data/` This directory contains three CSV files: + - `{grid_name}_final_extracted_data_all.csv` which includes every run in the grid + - `{grid_name}_final_extracted_data_completed.csv` which contains only successful runs (used for ECDF plots) + - `{grid_name}_final_extracted_data_running_error.csv` for only failed simulations with status `Running` or `Error`. + +- `grid_plots/` This directory contains a status summary plot and a ECDF plot ## Retrieval scheme (Bayesian optimisation) diff --git a/input/ensembles/example.grid.toml b/input/ensembles/example.grid.toml index 157b82a58..9a3838f79 100644 --- a/input/ensembles/example.grid.toml +++ b/input/ensembles/example.grid.toml @@ -1,11 +1,21 @@ # Config file for running a grid of forward models # Path to output folder where grid will be saved (relative to PROTEUS output folder) -output = "grid_demo/" +output = "scratch/grid_demo/" # Make `output` a symbolic link to this absolute location. To disable: set to empty string. symlink = "" +# Post-processing options +update_csv = true # Whether to update the summary CSV files before plotting +plot_format = "png" # Format for saving plots ("png" or "pdf") +plot_status = true # Generate status summary plot of the grid +plot_ecdf = true # Generate ECDF grid plot for input parameters tested in the grid and output_variables defined below +colormap = "berlin" # Colormap for ECDF plot (e.g. "viridis", "managua") +output_variables = ["solidification_time", "Phi_global", + "T_surf", "P_surf", "atm_kg_per_mol", + "esc_rate_total", "rho_obs"] # List of output variables to include in ECDF plot (name must match variable name in runtime_helpfile.csv) + # Path to base (reference) config file relative to PROTEUS root folder ref_config = "input/demos/dummy.toml" diff --git a/input/planets/toi561b.toml b/input/planets/toi561b.toml new file mode 100644 index 000000000..0bfc2f86a --- /dev/null +++ b/input/planets/toi561b.toml @@ -0,0 +1,465 @@ +# PROTEUS configuration file + +# This is a comprehensive outline of all configuration options. It includes variables +# which have default values, in order to showcase the range of potential options available. +# Variable defaults are defined in `src/proteus/config/*.py` + +# Root tables should be physical, with the exception of "params" +# Software related options should go within the appropriate physical table +# For configuration see https://fwl-proteus.readthedocs.io/en/latest/config.html + +# ---------------------------------------------------- + +# The general structure is: +# [params] parameters for code execution, output files, time-stepping, convergence +# [star] stellar parameters, model selection +# [orbit] planetary orbital parameters +# [struct] planetary structure (mass, radius) +# [atmos_clim] atmosphere climate parameters, model selection +# [atmos_chem] atmosphere chemistry parameters, model selection +# [escape] escape parameters, model selection +# [interior] magma ocean model selection and parameters +# [outgas] outgassing parameters (fO2) and included volatiles +# [delivery] initial volatile inventory, and delivery model selection +# [observe] synthetic observations + +# ---------------------------------------------------- + +version = "2.0" + +# Parameters +[params] + # output files + [params.out] + path = "scratch/toi561b" + logging = "INFO" + plot_mod = 50 # Plotting frequency, 0: wait until completion | n: every n iterations + plot_fmt = "pdf" # Plotting image file format, "png" or "pdf" recommended + write_mod = 50 # Write CSV frequency, 0: wait until completion | n: every n iterations + archive_mod = 100 # Archive frequency, 0: wait until completion | n: every n iterations | none: do not archive + remove_sf = true # Remove SOCRATES spectral file when simulation ends. + + # time-stepping + [params.dt] + minimum = 5e4 # yr, minimum time-step + minimum_rel = 1e-5 # relative minimum time-step [dimensionless] + maximum = 5e8 # yr, maximum time-step # if higher like 1e9, will produce too few snapshots + initial = 1e3 # yr, inital step size + starspec = 1e8 # yr, interval to re-calculate the stellar spectrum + starinst = 1e2 # yr, interval to re-calculate the instellation + method = "adaptive" # proportional | adaptive | maximum + + [params.dt.proportional] + propconst = 52.0 # Proportionality constant + + [params.dt.adaptive] + atol = 0.02 # Step size atol + rtol = 0.10 # Step size rtol + + # Termination criteria + # Set enabled=true/false in each section to enable/disable that termination criterion + [params.stop] + + # Require criteria to be satisfied twice before model will exit? + strict = false + + # required number of iterations + [params.stop.iters] + enabled = false + minimum = 5 + maximum = 5e6 + + # required time constraints + [params.stop.time] + enabled = true + minimum = 1.0e3 # yr, model will certainly run to t > minimum + maximum = 1.1e10 # yr, model will terminate when t > maximum + + # solidification + [params.stop.solid] + enabled = false + phi_crit = 0.005 # non-dim., model will terminate when global melt fraction < phi_crit + + # radiative equilibrium + [params.stop.radeqm] + enabled = false + atol = 0.1 # absolute tolerance [W m-2] + rtol = 1e-3 # relative tolerance + + [params.stop.escape] + enabled = false + p_stop = 1.0 # Stop surface pressure is less than this value + + # disintegration + [params.stop.disint] + enabled = false + + roche_enabled = false + offset_roche = 0 # correction to calculated Roche limit [m] + + spin_enabled = false + offset_spin = 0 # correction to calculated Breakup period [s] + + +# ---------------------------------------------------- +# Star +[star] + + # Physical parameters + mass = 0.806 # from Lacedelli et al., 2022 [M_sun] + age_ini = 0.1 # Gyr, model initialisation/start age + + module = "mors" + + [star.mors] + rot_pcntle = 50.0 # rotation percentile + rot_period = 'none' # rotation period [days] + tracks = "spada" # evolution tracks: spada | baraffe + age_now = 11 # [Gyr] from Lacedelli et al., 2022 + spectrum_source = "solar" # Spectrum source: 'solar' for solar spectra; 'muscles' for MUSCLES spectra; 'phoenix' for synthetic PHOENIX spectrum; see https://proteus-framework.org/proteus/data.html#stellar-spectra + star_name = "sun" # star name, relevant for when spectrum_source = 'solar' (use e.g. 'sun' or 'Sun0.6Ga') or when spectrum_source = 'muscles' (use e.g. 'trappist-1' or 'gj1214'). Not relevent when spectrum_source = 'phoenix'. + star_path = "../FWL_DATA/stellar_spectra/Named/toi561.txt" # optional override star path to custom stellar spectrum, e.g. "$FWL_DATA/stellar_spectra/solar/sun.txt" + + # PHOENIX parameters, only relevant if spectrum_source = "phoenix". Defaults to solar (0.0). + phoenix_FeH = 0.0 # metallicity [Fe/H] + phoenix_alpha = 0.0 # alpha enhancement [α/M] + + # if None, calculated by mors + phoenix_radius = "none" # Stellar radius [R_sun] used for PHOENIX spectrum scaling + phoenix_log_g = "none" # Stellar surface gravity [dex] + phoenix_Teff = "none" # Stellar effective temperature [K] + + [star.dummy] + radius = 0.843 # from Lacedelli et al., 2022 [R_sun] + calculate_radius = false # Calculate star radius using scaling from Teff? + Teff = 5372.0 # from Lacedelli et al., 2022 [K] + +# Orbital system +[orbit] + instellation_method = 'sma' # whether to define orbit using semi major axis ('sma') or instellation flux ('inst') + instellationflux = 1.0 # instellation flux received from the planet in [Earth units] + semimajoraxis = 0.0106 # AU from Patel et al. 2023 + eccentricity = 0.0 # dimensionless from Brinkman et al. 2023 + zenith_angle = 48.19 # degrees + s0_factor = 0.375 # dimensionless + + evolve = false # whether to evolve the SMaxis and eccentricity + module = "lovepy" # module used to calculate tidal heating + + axial_period = "none" # planet's initial day length [hours]; will use orbital period if 'none' + satellite = false # include satellite (moon)? + mass_sat = 7.347e+22 # mass of satellite [kg] + semimajoraxis_sat = 3e8 # initial SMA of satellite's orbit [m] + + [orbit.dummy] + H_tide = 1e-11 # Fixed tidal power density [W kg-1] + Phi_tide = "<0.3" # Tidal heating applied when inequality locally satisfied + Imk2 = 0.0 # Fixed imaginary part of k2 love number, cannot be positive + + [orbit.lovepy] + visc_thresh = 1e9 # Minimum viscosity required for heating [Pa s] + +# Planetary structure - physics table +[struct] + mass_tot = 2.24 # M_earth from Brinkman et al. 2023 + #radius_int = 1.4195 # R_earth from Patel et al. 2023 + corefrac = 0.55 # non-dim., radius fraction 0.20 from Brinkman et al. 2023 + core_density = 10738.33 # Core density [kg m-3] + core_heatcap = 880.0 # Core specific heat capacity [J K-1 kg-1] + + module = "self" # self | zalmoxis + + [struct.zalmoxis] + EOSchoice = "Tabulated:iron/Tdep_silicate" # EOS choices: "Tabulated:iron/silicate", "Tabulated:iron/Tdep_silicate", "Tabulated:water" + coremassfrac = 0.325 # core mass fraction [non-dim.] + mantle_mass_fraction = 0 # mantle mass fraction [non-dim.] + weight_iron_frac = 0.325 # iron fraction in the planet [non-dim.] + temperature_mode = "linear" # Input temperature profile choices: "isothermal", "linear", "prescribed" + surface_temperature = 3500 # Surface temperature [K], required for temperature_mode="isothermal" or "linear" + center_temperature = 6000 # Center temperature [K], required for temperature_mode="linear" + temperature_profile_file = "zalmoxis_ini_input_temp.txt" # filename with a prescribed temperature profile, required for temperature_mode="prescribed" + num_levels = 150 # number of Zalmoxis radius layers + max_iterations_outer = 100 # max. iterations for the outer loop + tolerance_outer = 3e-3 # tolerance for the outer loop + max_iterations_inner = 100 # max. iterations for the inner loop + tolerance_inner = 1e-4 # tolerance for the inner loop + relative_tolerance = 1e-5 # relative tolerance for solve_ivp + absolute_tolerance = 1e-6 # absolute tolerance for solve_ivp + maximum_step = 250000 # maximum integration step size [m] + adaptive_radial_fraction = 0.98 # radial fraction for transition from adaptive integration to fixed-step integration when using "Tabulated:iron/Tdep_silicate" EOS + max_center_pressure_guess = 0.99e12 # maximum central pressure guess based on "Tabulated:iron/Tdep_silicate" EOS limit [Pa] + target_surface_pressure = 101325 # target surface pressure [Pa] + pressure_tolerance = 1e9 # tolerance surface pressure [Pa] + max_iterations_pressure = 200 # max. iterations for the innermost loop + pressure_adjustment_factor = 1.1 # factor for adjusting the pressure in the innermost loop + verbose = false # detailed convergence info and warnings printing? + iteration_profiles_enabled = false # pressure and density profiles for each iteration logging? + +# Atmosphere - physics table +[atmos_clim] + prevent_warming = true # do not allow the planet to heat up + surface_d = 0.01 # m, conductive skin thickness + surface_k = 2.0 # W m-1 K-1, conductive skin thermal conductivity + cloud_enabled = false # enable water cloud radiative effects + cloud_alpha = 0.0 # condensate retention fraction (1 -> fully retained) + surf_state = "fixed" # surface scheme: "mixed_layer" | "fixed" | "skin" + surf_greyalbedo = 0.1 # surface grey albedo + albedo_pl = 0.0 # Enforced Bond albedo (do not use with `rayleigh = true`) from Lacedelli et al. 2022 + rayleigh = false # Enable rayleigh scattering + tmp_minimum = 0.5 # temperature floor on solver + tmp_maximum = 5000.0 # temperature ceiling on solver + + module = "agni" # Which atmosphere module to use + + [atmos_clim.agni] + verbosity = 1 # output verbosity for agni (0:none, 1:info, 2:debug) + p_top = 1.0e-5 # bar, top of atmosphere grid pressure + p_obs = 0.02 # bar, level probed in transmission + spectral_group = "Honeyside" # which gas opacities to include + spectral_bands = "48" # how many spectral bands? + num_levels = 50 # Number of atmospheric grid levels + chemistry = "none" # "none" | "eq" + surf_material = "greybody" # surface material file for scattering + solve_energy = true # solve for energy-conserving atmosphere profile + solution_atol = 0.01 # solver absolute tolerance + solution_rtol = 0.05 # solver relative tolerance + overlap_method = "ee" # gas overlap method + surf_roughness = 1e-3 # characteristic surface roughness [m] + surf_windspeed = 2.0 # characteristic surface wind speed [m/s]. + rainout = true # include volatile condensation/evaporation aloft + latent_heat = true # include latent heat release when `rainout=true`? + oceans = true # form liquid oceans at planet surface? + convection = true # include convective heat transport, with MLT + conduction = true # include conductive heat transport, with Fourier's law + sens_heat = true # include sensible heat flux near surface, with TKE scheme + real_gas = true # use real-gas equations of state + psurf_thresh = 0.1 # bar, surface pressure where we switch to 'transparent' mode + dx_max_ini = 300.0 # initial maximum temperature step [kelvin] allowed by solver + dx_max = 35.0 # maximum temperature step [kelvin] allowed by solver + max_steps = 70 # max steps allowed by solver during each iteration + perturb_all = true # updated entire jacobian each step? + mlt_criterion = "s" # MLT convection stability criterion; (l)edoux or (s)chwarzschild + fastchem_floor = 150.0 # Minimum temperature allowed to be sent to FC + fastchem_maxiter_chem = 60000 # Maximum FC iterations (chemistry) + fastchem_maxiter_solv = 20000 # Maximum FC iterations (internal solver) + fastchem_xtol_chem = 1e-4 # FC solver tolerance (chemistry) + fastchem_xtol_elem = 1e-4 # FC solver tolerance (elemental) + ini_profile = 'isothermal' # Initial guess for temperature profile shape + ls_default = 2 # Default linesearch method (0:none, 1:gs, 2:bt) + fdo = 2 # finite-difference order (options: 2, 4) + + [atmos_clim.janus] + p_top = 1.0e-5 # bar, top of atmosphere grid pressure + p_obs = 1.0e-3 # bar, observed pressure level + spectral_group = "Honeyside" # which gas opacities to include + spectral_bands = "48" # how many spectral bands? + F_atm_bc = 0 # measure outgoing flux at: (0) TOA | (1) Surface + num_levels = 60 # Number of atmospheric grid levels + tropopause = "none" # none | skin | dynamic + overlap_method = "ee" # gas overlap method + + [atmos_clim.dummy] + gamma = 0.7 # atmosphere opacity between 0 and 1 + height_factor = 3.0 # observed height is this times the scale height + +# Volatile escape - physics table +[escape] + + module = "zephyrus" # Which escape module to use + reservoir = "outgas" # Escaping reservoir: "bulk", "outgas", "pxuv". + + [escape.zephyrus] + Pxuv = 1e-2 # Pressure at which XUV radiation become opaque in the planetary atmosphere [bar] + efficiency = 0.1 # Escape efficiency factor + tidal = false # Tidal contribution enabled + + [escape.dummy] + rate = 0.0 # Bulk unfractionated escape rate [kg s-1] + + [escape.boreas] + fractionate = true # Include fractionation in outflow? + efficiency = 0.1 # Escape efficiency factor + sigma_H = 1.89e-18 # H absorption cross-section in XUV [cm2] + sigma_O = 2.00e-18 # O absorption ^ + sigma_C = 2.50e-18 # C absorption ^ + sigma_N = 3.00e-18 # N absorption ^ + sigma_S = 6.00e-18 # S absorption ^ + kappa_H2 = 0.01 # H2 opacity in IR, grey [cm2 g-1] + kappa_H2O = 1.0 # H2O opacity ^ + kappa_O2 = 1.0 # O2 opacity ^ + kappa_CO2 = 1.0 # CO2 opacity ^ + kappa_CO = 1.0 # CO opacity ^ + kappa_CH4 = 1.0 # CH4 opacity ^ + kappa_N2 = 1.0 # N2 opacity ^ + kappa_NH3 = 1.0 # NH3 opacity ^ + kappa_H2S = 1.0 # H2S opacity ^ + kappa_SO2 = 1.0 # SO2 opacity ^ + kappa_S2 = 1.0 # S2 opacity ^ + +# Interior - physics table +[interior] + grain_size = 0.1 # crystal settling grain size [m] + F_initial = 1e5 # Initial heat flux guess [W m-2] + radiogenic_heat = true # enable radiogenic heat production + tidal_heat = false # enable tidal heat production + rheo_phi_loc = 0.4 # Centre of rheological transition + rheo_phi_wid = 0.15 # Width of rheological transition + melting_dir = "Monteux-600" # Name of folder constaining melting curves + lookup_dir = "1TPa-dK09-elec-free/MgSiO3_Wolf_Bower_2018_1TPa" # Name of folder with EOS tables, etc. + + + module = "spider" # Which interior module to use + + [interior.spider] + num_levels = 60 # Number of SPIDER grid levels + mixing_length = 2 # Mixing length parameterization + tolerance = 1.0e-10 # solver tolerance + tolerance_rel = 1.0e-8 # relative solver tolerance + solver_type = "bdf" # SUNDIALS solver method + tsurf_atol = 20.0 # tsurf_poststep_change + tsurf_rtol = 0.01 # tsurf_poststep_change_frac + ini_entropy = 4000.0 # Surface entropy conditions [J K-1 kg-1] + ini_dsdr = -4.698e-6 # Interior entropy gradient [J K-1 kg-1 m-1] + conduction = true # enable conductive heat transfer + convection = true # enable convective heat transfer + gravitational_separation = true # enable gravitational separation + mixing = true # enable mixing + matprop_smooth_width = 1e-2 # melt-fraction window width over which to smooth material properties + + [interior.aragog] + logging = "ERROR" # Aragog log verbosity + num_levels = 220 # Number of Aragog grid levels + tolerance = 1.0e-10 # solver tolerance + initial_condition = 3 # Initial T(p); 1: linear, 2: user defined, 3: adiabat + ini_tmagma = 3200.0 # Initial magma surface temperature [K] + basal_temperature = 7000.0 # CMB temperature when initial boundary = 1 + inner_boundary_condition = 1 # 1 = core cooling model, 2 = prescribed heat flux, 3 = prescribed temperature + inner_boundary_value = 4000 # core temperature [K], if inner_boundary_condition = 3. CMB heat flux [W/m^2], if if inner_boundary_condition = 2 + conduction = true # enable conductive heat transfer + convection = true # enable convective heat transfer + gravitational_separation = false # enable gravitational separation + mixing = false # enable mixing + dilatation = false # enable dilatation source term + mass_coordinates = false # enable mass coordinates + tsurf_poststep_change = 30 # threshold of maximum change on surface temperature + event_triggering = true # enable events triggering to avoid abrupt jumps in surface temperature + bulk_modulus = 260e9 # Adiabatic bulk modulus AW-EOS parameter [Pa]. + + [interior.dummy] + ini_tmagma = 3300.0 # Initial magma surface temperature [K] + tmagma_atol = 30.0 # Max absolute Tsurf change in each step + tmagma_rtol = 0.05 # Max relative Tsurf change in each step + mantle_tliq = 2700.0 # Liquidus temperature + mantle_tsol = 1700.0 # Solidus temperature + mantle_rho = 4550.0 # Mantle density [kg m-3] + mantle_cp = 1792.0 # Mantle heat capacity [J K-1 kg-1] + H_ratio = 0.0 # Radiogenic heating [W/kg] + +# Outgassing - physics table +[outgas] + fO2_shift_IW = 4 # log10(ΔIW), atmosphere/interior boundary oxidation state + + module = "calliope" # Which outgassing module to use + + [outgas.calliope] + include_H2O = true # Include H2O compound + include_CO2 = true # Include CO2 compound + include_N2 = true # Include N2 compound + include_S2 = true # Include S2 compound + include_SO2 = true # Include SO2 compound + include_H2S = true # Include H2S compound + include_NH3 = true # Include NH3 compound + include_H2 = true # Include H2 compound + include_CH4 = true # Include CH4 compound + include_CO = true # Include CO compound + T_floor = 700.0 # Temperature floor applied to outgassing calculation [K]. + rtol = 0.0001 # Relative mass tolerance + xtol = 1e-06 # Absolute mass tolerance + solubility = true # Enable solubility? + + [outgas.atmodeller] + some_parameter = "some_value" + +# Volatile delivery - physics table +[delivery] + + # Radionuclide parameters + radio_tref = 4.55 # Reference age for concentrations [Gyr] + radio_K = 310.0 # ppmw of potassium (all isotopes) + radio_U = 0.031 # ppmw of uranium (all isotopes) + radio_Th = 0.124 # ppmw of thorium (all isotopes) + + # Which initial inventory to use? + initial = 'elements' # "elements" | "volatiles" + + # No module for accretion as of yet + module = "none" + + # Set initial volatile inventory by planetary element abundances + [delivery.elements] + use_metallicity = false # whether or not to specify the elemental abundances in terms of solar metallicity + metallicity = 1000 # metallicity relative to solar metallicity + + H_oceans = 1.0 # Hydrogen inventory in units of equivalent Earth oceans + #H_ppmw = 109.0 # Hydrogen inventory in ppmw relative to mantle mass + # H_kg = 1e20 # Hydrogen inventory in kg + + CH_ratio = 1.0 # C/H mass ratio in mantle/atmosphere system + #C_ppmw = 109.0 # Carbon inventory in ppmw relative to mantle mass + # C_kg = 1e20 # Carbon inventory in kg + + # NH_ratio = 0.018 # N/H mass ratio in mantle/atmosphere system + N_ppmw = 20.1 # Nitrogen inventory in ppmw relative to mantle mass + # N_kg = 1e20 # Nitrogen inventory in kg + + SH_ratio = 2.16 # S/H mass ratio in mantle/atmosphere system + #S_ppmw = 235.0 # Sulfur inventory in ppmw relative to mantle mass + # S_kg = 1e20 # Sulfur inventory in kg + + # Set initial volatile inventory by partial pressures in atmosphere + [delivery.volatiles] + H2O = 0.0 # partial pressure of H2O + CO2 = 0.0 # partial pressure of CO2 + N2 = 0.0 # etc + S2 = 0.0 + SO2 = 0.0 + H2S = 0.0 + NH3 = 0.0 + H2 = 0.0 + CH4 = 0.0 + CO = 0.0 + +# Atmospheric chemistry postprocessing +[atmos_chem] + + module = "vulcan" # Atmospheric chemistry module + when = "manually" # When to run chemistry (manually, offline, online) + + # Physics flags + photo_on = true # Enable photochemistry + Kzz_on = true # Enable eddy diffusion + Kzz_const = "none" # Constant eddy diffusion coefficient (none => use profile) + moldiff_on = true # Enable molecular diffusion in the atmosphere + updraft_const = 0.0 # Set constant updraft velocity + + # Vulcan-specific atmospheric chemistry parameters + [atmos_chem.vulcan] + clip_fl = 1e-20 # Floor on stellar spectrum [erg s-1 cm-2 nm-1] + clip_vmr = 1e-10 # Neglect species with vmr < clip_vmr + make_funs = true # Generate reaction network functions + ini_mix = "profile" # Initial mixing ratios (profile, outgas) + fix_surf = false # Fixed surface mixing ratios + network = "SNCHO" # Class of chemical network to use (CHO, NCHO, SNCHO) + save_frames = true # Plot frames during iterations + yconv_cri = 0.05 # Convergence criterion, value of mixing ratios + slope_cri = 0.0001 # Convergence criterion, rate of change of mixing ratios + +# Calculate simulated observations +[observe] + + # Module with which to calculate the synthetic observables + synthesis = "none" + + [observe.platon] + downsample = 8 # Factor to downsample opacities + clip_vmr = 1e-8 # Minimum VMR for a species to be included diff --git a/pyproject.toml b/pyproject.toml index a1f48f898..c79f7cf85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ "ruff", "sympy", "astropy", + "seaborn", "torch", "botorch", "gpytorch", diff --git a/src/proteus/cli.py b/src/proteus/cli.py index e1c39ccc0..aa0c5559c 100644 --- a/src/proteus/cli.py +++ b/src/proteus/cli.py @@ -554,6 +554,24 @@ def observe(config_path: Path): cli.add_command(offchem) cli.add_command(observe) +# ---------------- +# 'grid_analyse' postprocessing commands +# ---------------- + + +@click.command() +@config_option +def grid_analyse(config_path: Path): + """Generate grid analysis plots and CSV summary files from a grid + config_path : Path to the toml file containing grid analysis configuration + """ + from proteus.grid.post_processing import main + + main(config_path) + + +cli.add_command(grid_analyse) + # ---------------- # GridPROTEUS and BO inference scheme, runners # ---------------- diff --git a/src/proteus/config/_escape.py b/src/proteus/config/_escape.py index 8d33063f8..e55bbfe66 100644 --- a/src/proteus/config/_escape.py +++ b/src/proteus/config/_escape.py @@ -26,7 +26,7 @@ class Zephyrus: Attributes ---------- Pxuv: float - Pressure at which XUV radiation become opaque in the planetary atmosphere [bar] + Pressure at which XUV radiation becomes opaque in the planetary atmosphere (should be above Pxuv > 0 bar) [bar] efficiency: float Escape efficiency factor tidal: bool diff --git a/src/proteus/grid/post_processing.py b/src/proteus/grid/post_processing.py new file mode 100644 index 000000000..6529bc049 --- /dev/null +++ b/src/proteus/grid/post_processing.py @@ -0,0 +1,1007 @@ +from __future__ import annotations + +import tomllib +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import toml +from matplotlib import cm + +from proteus.utils.plot import _preset_labels, _preset_log_scales, _preset_scales + +# --------------------------------------------------------- +# Data loading, extraction, and CSV generation functions +# --------------------------------------------------------- + + +def get_grid_name(grid_path: str | Path) -> str: + """ + Returns the grid name (last part of the path) from the given grid path. + + Parameters + ---------- + grid_path : str or Path + Full path to the grid directory. + + Returns + ------- + grid_name : str + Name of the grid directory. + """ + grid_path = Path(grid_path) + if not grid_path.is_dir(): + raise ValueError(f'{grid_path} is not a valid directory') + return grid_path.name + + +def load_grid_cases(grid_dir: Path): + """ + Load information for each simulation of a PROTEUS grid. + Read 'runtime_helpfile.csv', 'init_coupler.toml' and status + files for each simulation of the grid. + + Parameters + ---------- + grid_dir : Path or str + Path to the grid directory containing the 'case_*' folders + + Returns + ---------- + combined_data : list + List of dictionaries, each containing: + - 'init_parameters' (dict): All input parameters loaded from `init_coupler.toml`. + - 'output_values' (pandas.DataFrame): Data from `runtime_helpfile.csv`. + - 'status' (str): Status string from the `status` file, or 'Unknown' if unavailable. + """ + + combined_data = [] + grid_dir = Path(grid_dir) + + # Collect and sort the case directories + case_dirs = list(grid_dir.glob('case_*')) + case_dirs.sort(key=lambda p: int(p.name.split('_')[1])) + + for case in case_dirs: + runtime_file = case / 'runtime_helpfile.csv' + init_file = case / 'init_coupler.toml' + status_file = case / 'status' + + # Load init parameters + init_params = {} + if init_file.exists(): + try: + init_params = toml.load(init_file) + except Exception as e: + print(f'Error reading init file in {case.name}: {e}') + + # Read runtime_helpfile.csv + df = None + if runtime_file.exists(): + try: + df = pd.read_csv(runtime_file, sep='\t') + except Exception as e: + print(f'WARNING : Error reading runtime_helpfile.csv for {case.name}: {e}') + + # Read status file + status = 'Unknown' + if status_file.exists(): + try: + raw_lines = [ + ln.strip() + for ln in status_file.read_text(encoding='utf-8').splitlines() + if ln.strip() + ] + if len(raw_lines) >= 2: + status = raw_lines[1] + elif raw_lines: + status = raw_lines[0] + else: + status = 'Empty' + except Exception as e: + print(f'WARNING : Error reading status file in {case.name}: {e}') + else: + print(f'WARNING : Missing status file in {case.name}') + + # Combine all info about simulations into a list of dictionaries + combined_data.append( + {'init_parameters': init_params, 'output_values': df, 'status': status} + ) + + # Print summary of statuses + statuses = [c['status'] for c in combined_data] + status_counts = pd.Series(statuses).value_counts().sort_values(ascending=False) + print('-----------------------------------------------------------') + print(f'Total number of simulations: {len(statuses)}') + print('-----------------------------------------------------------') + print('Number of simulations per status:') + for st, count in status_counts.items(): + print(f' - {st:<45} : {count}') + print('-----------------------------------------------------------') + + return combined_data + + +def get_tested_grid_parameters(cases_data: list, grid_dir: str | Path): + """ + Extract tested grid parameters per case using: + - copy.grid.toml to determine which parameters were varied in the grid + - init_parameters already loaded by load_grid_cases for each simulation of the grid + + Parameters + ---------- + cases_data : list + Output of load_grid_cases. + grid_dir : str or Path + Path to the grid directory containing copy.grid.toml. + + Returns + ------- + case_params : dict + Dictionary mapping case index -> {parameter_name: value} + tested_params : dict + Dictionary of tested grid parameters and their grid values (directly from copy.grid.toml) + """ + + grid_dir = Path(grid_dir) + + # 1. Load tested input parameters in the grid + raw_params = toml.load(grid_dir / 'copy.grid.toml') + + # Keep only the parameters and their values + tested_params = {} + + for key, value in raw_params.items(): + if ( + isinstance(value, dict) and 'method' in value + ): # filter to only get tested parameters (those with a "method" key) + method = value['method'] + + if method == 'direct': + tested_params[key] = value['values'] + + elif method == 'linspace': + tested_params[key] = np.linspace(value['start'], value['stop'], value['count']) + + elif method == 'logspace': + tested_params[key] = np.logspace( + np.log10(value['start']), np.log10(value['stop']), value['count'] + ) + + elif method == 'arange': + arr = list(np.arange(value['start'], value['stop'], value['step'])) + # Ensure endpoint is included + if not np.isclose(arr[-1], value['stop']): + arr.append(value['stop']) + tested_params[key] = np.array(arr, dtype=float) + + else: + print(f'⚠️ Unknown method for {key}: {method}') + continue + + grid_param_paths = list(tested_params.keys()) + + # 2.Extract those parameters from loaded cases for each case of the grid + case_params = {} + + if cases_data: + for idx, case in enumerate(cases_data): + params_for_case = {} + init_params = case['init_parameters'] + + for path in grid_param_paths: + keys = path.split('.') + val = init_params + + try: + for k in keys: + val = val[k] + params_for_case[path] = val + except (KeyError, TypeError): + params_for_case[path] = None + + case_params[idx] = params_for_case + + return case_params, tested_params + + +def load_phi_crit(grid_dir: str | Path): + """ + Load the critical melt fraction (phi_crit) from the reference configuration file of the grid. + + Parameters + ---------- + grid_dir : str or Path + Path to the grid directory containing ref_config.toml. + + Returns + ------- + phi_crit : float + The critical melt fraction value loaded from the reference configuration. + + """ + grid_dir = Path(grid_dir) + ref_file = grid_dir / 'ref_config.toml' + + if not ref_file.exists(): + raise FileNotFoundError(f'ref_config.toml not found in {grid_dir}') + + with ref_file.open('r', encoding='utf-8') as f: + ref = toml.load(f) + + try: + phi_crit = ref['params']['stop']['solid']['phi_crit'] + except KeyError: + raise KeyError('phi_crit not found in ref_config.toml') + + return phi_crit + + +def extract_solidification_time(cases_data: list, grid_dir: str | Path): + """ + Extract solidification time for each simulation of the grid for + the condition Phi_global < phi_crit at last time step. + + Parameters + ---------- + cases_data : list + List of dictionaries containing simulation data. + + grid_dir : str or Path + Path to the grid directory containing ref_config.toml. + + Returns + ------- + solidification_times : list + A list containing the solidification times for all cases of the grid. + """ + + # Load phi_crit once + phi_crit = load_phi_crit(grid_dir) + + solidification_times = [] + columns_printed = False + + for i, case in enumerate(cases_data): + df = case['output_values'] + + if df is None: + solidification_times.append(np.nan) + continue + + # Condition for complete solidification + if 'Phi_global' in df.columns and 'Time' in df.columns: + condition = df['Phi_global'] < phi_crit + + if condition.any(): + idx = condition.idxmax() + solidification_times.append(df.loc[idx, 'Time']) + else: + solidification_times.append(np.nan) # if planet is not solidified, append NaN + + else: + if not columns_printed: + print('Warning: Missing Phi_global or Time column.') + print('Columns available:', df.columns.tolist()) + columns_printed = True + solidification_times.append(np.nan) + + return solidification_times + + +def validate_output_variables(df: pd.DataFrame, requested_outputs: list): + """ + Check that requested output variables exist in the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame loaded from runtime_helpfile.csv (or summary CSV). + + requested_outputs : list + List of output variable names from config. + + Returns + ------- + valid_outputs : list + Outputs that exist in the DataFrame. + """ + + available = list(df.columns) + valid_outputs = [] + + for var in requested_outputs: + if var not in available: + # Find index of 'Time' column + try: + idx = available.index('Time') + except ValueError: + idx = 0 # fallback if Time is not in columns + + print(f"WARNING: Output variable '{var}' not found in data.") + print( + f'Available columns include: {available[idx:]}' + ) # show all columns starting from 'Time' + else: + valid_outputs.append(var) + + return valid_outputs + + +def generate_summary_csv( + cases_data: list, + case_params: dict, + grid_dir: str | Path, + grid_name: str, +): + """ + Generate CSV files summarizing simulation cases: + - All cases + - Completed cases only + - Running + Error cases only + """ + + def include_case(status: str, mode: str) -> bool: + status = status.lower() + if mode == 'all': + return True + elif mode == 'completed': + return status.startswith('completed') + elif mode == 'running_error': + return status.startswith('running') or status.startswith('error') + else: + raise ValueError(f'Unknown mode: {mode}') + + # Compute solidification times once + solidification_times = extract_solidification_time(cases_data, grid_dir) + + output_dir = grid_dir / 'post_processing' / 'extracted_data' + output_dir.mkdir(parents=True, exist_ok=True) + + modes = ['all', 'completed', 'running_error'] + + for mode in modes: + summary_rows = [] + + for case_index, case in enumerate(cases_data): + status = case.get('status', '') + + if not include_case(status, mode): + continue + + row = { + 'case_number': case_index, + 'status': status, + } + + # Parameters + row.update(case_params.get(case_index, {})) + + # Output values + df = case.get('output_values') + if df is not None and not df.empty: + row.update(df.iloc[-1].to_dict()) + + # Solidification time + row['solidification_time'] = solidification_times[case_index] + + summary_rows.append(row) + + summary_df = pd.DataFrame(summary_rows) + + output_file = output_dir / f'{grid_name}_final_extracted_data_{mode}.csv' + summary_df.to_csv(output_file, sep='\t', index=False) + + +# --------------------------------------------------------- +# Plotting functions +# --------------------------------------------------------- + + +def get_label(quant): + """ + Get label for a given quantity, using preset labels if available. + If not found in _preset_labels, use the last part of the dot-separated path. + + Parameters + ---------- + quant : str + Quantity for which to get label. + + Returns + ------- + str + Label for the quantity. + """ + if quant in _preset_labels: + return _preset_labels[quant] + else: + # Take only the last part after the last dot + return quant.split('.')[-1] + + +def get_scale(quant): + """ + Get scale factor for a given quantity, using preset scales if available. + Parameters + ---------- + quant : str + Quantity for which to get scale factor. + + Returns + ------- + float + Scale factor for the quantity. + """ + + if quant in _preset_scales: + return _preset_scales[quant] + else: + return 1.0 + + +def get_log_scale(quant): + """ + Get log scale flag for a given quantity, using preset log scales if available. + Parameters + ---------- + quant : str + Quantity for which to get log scale flag. + + Returns + ------- + bool + Log scale flag for the quantity. + """ + + if quant in _preset_log_scales: + return _preset_log_scales[quant] + else: + return False + + +def plot_grid_status(df: pd.DataFrame, cfg: dict, grid_dir: str | Path, grid_name: str): + """ + Plot histogram summary of number of simulation statuses in + the grid using the generated CSV file for all cases. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame loaded from grid_name_final_extracted_data_all.csv. + + cfg : dict + Configuration dictionary containing plotting options. + + grid_dir : Path + Path to the grid directory. + + grid_name : str + Name of the grid. + """ + # Extract plot_format from cfg + plot_format = cfg.get('plot_format') + + if 'status' not in df.columns: + raise ValueError("CSV must contain a 'status' column") + + # Clean and count statuses + statuses = df['status'].astype(str) + status_counts = statuses.value_counts().sort_values(ascending=False) + total_simulations = len(df) + + # Format status labels for better readability + formatted_status_keys = [s.replace(' (', ' \n (') for s in status_counts.index] + palette = sns.color_palette('Accent', len(status_counts)) + palette = dict(zip(formatted_status_keys, palette)) + + # Prepare DataFrame for plotting + plot_df = pd.DataFrame({'Status': formatted_status_keys, 'Count': status_counts.values}) + + # Plot histogram + plt.figure(figsize=(11, 7)) + ax = sns.barplot( + data=plot_df, + x='Status', + y='Count', + hue='Status', + palette=palette, + dodge=False, + edgecolor='black', + ) + + # Remove legend + if ax.legend_: + ax.legend_.remove() + + # Add counts and percentages above bars per status + for i, count in enumerate(status_counts.values): + percentage = 100 * count / total_simulations + offset = 0.01 * status_counts.max() + ax.text( + i, + count + offset, + f'{count} ({percentage:.1f}%)', + ha='center', + va='bottom', + fontsize=14, + ) + + # Add total number of simulations text + ax.text( + 0.97, + 0.94, + f'Total number of simulations : {total_simulations}', + transform=ax.transAxes, + ha='right', + va='top', + fontsize=16, + ) + + # Formatting + ax.grid(alpha=0.2, axis='y') + ax.set_title(f'Simulation status summary for grid {grid_name}', fontsize=16) + ax.set_xlabel('Simulation status', fontsize=16) + ax.set_ylabel('Number of simulations', fontsize=16) + ax.tick_params(axis='x', labelsize=14) + ax.tick_params(axis='y', labelsize=14) + + # Save + output_dir = Path(grid_dir) / 'post_processing' / 'grid_plots' + output_dir.mkdir(parents=True, exist_ok=True) + output_file = output_dir / f'summary_grid_statuses_{grid_name}.{plot_format}' + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + +def flatten_input_parameters(d: dict, parent_key: str = '') -> dict: + """ + Flattens a nested input-parameter dictionary from a TOML configuration + into a flat mapping of dot-separated parameter paths to their plotting + configuration. + + Parameters + ---------- + d : dict + Nested dictionary describing input parameters (from TOML). + parent_key : str, optional + Accumulated parent key for recursive calls. + + Returns + ------- + flat : dict + Dictionary mapping parameter paths (e.g. ``"escape.zephyrus.Pxuv"``) + to their corresponding configuration dictionaries. + """ + + flat = {} + + for k, v in d.items(): + if k == 'colormap': + continue + + new_key = f'{parent_key}.{k}' if parent_key else k + + if isinstance(v, dict) and 'label' in v: + # Leaf parameter block + flat[new_key] = v + elif isinstance(v, dict): + # Recurse deeper + flat.update(flatten_input_parameters(v, new_key)) + + return flat + + +def load_ecdf_plot_settings(cfg, tested_params=None): + """ + Load ECDF plotting settings for both input parameters and output variables + from a configuration dictionary loaded from TOML. + + Parameters + ---------- + cfg : dict + Configuration dictionary loaded from a TOML file. + + tested_params : dict, optional + Dictionary of tested grid parameters and their grid values (directly from copy.grid.toml). + + Returns + ------- + param_settings : dict + Mapping of input-parameter paths to plotting settings. Each value + is a dict containing: + - "label" : str + Label for the parameter (used in colorbar). + - "colormap" : matplotlib colormap + Colormap used to color ECDF curves. + - "log_scale" : bool + Whether to normalize colors on a logarithmic scale. + + output_settings : dict + Mapping of output variable names to plotting settings. Each value + is a dict containing: + - "label" : str + X-axis label for the ECDF plot. + - "log_scale" : bool + Whether to use a logarithmic x-axis. + - "scale" : float + Factor applied to raw output values before plotting. + + plot_format : str + Format for saving plots ("png" or "pdf"). + + """ + + if tested_params is None or len(tested_params) == 0: + raise ValueError('No tested parameters found for ECDF plotting') + + # Optional colormap from config + cmap_name = cfg['colormap'] if 'colormap' in cfg else 'viridis' + default_cmap = getattr(cm, cmap_name, cm.viridis) + + # Build parameter settings from config + param_settings = { + key: { + 'label': get_label(key), + 'colormap': default_cmap, + 'log_scale': get_log_scale(key), + } + for key in tested_params + } + + # Build output settings from config + output_settings = {} + output_list = cfg.get('output_variables', []) + + for key in output_list: + output_settings[key] = { + 'label': get_label(key), + 'log_scale': get_log_scale(key), + 'scale': get_scale(key), + } + + # Extract plot format + plot_format = cfg.get('plot_format') + + return param_settings, output_settings, plot_format + + +def clean_series(s, log_scale): + """Cleans a pandas Series by replacing inf values with NaN and dropping NaN values.""" + s_clean = s.replace([np.inf, -np.inf], np.nan).dropna() + if log_scale: + s_clean = s_clean.loc[lambda x: x > 0] + return s_clean + + +def group_output_by_parameter(df, grid_parameters, outputs): + """ + Groups output values (like P_surf) by one or more grid parameters. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing simulation results including values of the grid parameters and the corresponding extracted outputs. + + grid_parameters : list of str + Column names of the grid parameters to group by (for example, ['escape.zephyrus.efficiency']). + + outputs : list of str + Column names of the outputs to extract (for example, ['P_surf']). + + Returns + ------- + dict + Dictionary where each key is of the form '[output]_per_[parameter]', and each value is a dict {param_value: [output_values]}. + """ + grouped = {} + + for param in grid_parameters: + for output in outputs: + key_name = f'{output}_per_{param}' + value_dict = {} + for param_value in df[param].dropna().unique(): + subset = df[df[param] == param_value] + output_values = clean_series(subset[output], get_log_scale(output)) * get_scale( + output + ) + + value_dict[param_value] = output_values + + grouped[key_name] = value_dict + + return grouped + + +def latex(label: str) -> str: + """ + Wraps a label in dollar signs for LaTeX formatting if it contains a backslash. + """ + return f'${label}$' if '\\' in label else label + + +def ecdf_grid_plot( + tested_params: dict, + grouped_data: dict, + param_settings: dict, + output_settings: dict, + plot_format: str, + grid_dir: str | Path, + grid_name: str, +): + """ + Creates ECDF grid plots where each row corresponds to one input parameter + and each column corresponds to one output. Saves the resulting figure as a {plot_format}. + + Parameters + ---------- + + tested_params : dict + Dictionary of tested grid parameters and their grid values (directly from copy.grid.toml). + + grouped_data : dict + Dictionary where each key is of the form '[output]_per_[parameter]', and each value is a dict {param_value: [output_values]}. + + param_settings : dict + For each input-parameter key, a dict containing: + - "label": label of the colormap for the corresponding input parameter + - "colormap": a matplotlib colormap (e.g. mpl.cm.plasma) + - "log_scale": bool, whether to color-normalize on a log scale + + output_settings : dict + For each output key, a dict containing: + - "label": label of the x-axis for the corresponding output column + - "log_scale": bool, whether to plot the x-axis on log scale + - "scale": float, a factor to multiply raw values by before plotting + + plot_format : str + Format for saving plots ("png" or "pdf"). + + grid_dir : str or Path + Path to the grid directory (used for saving the plot and loading tested parameters). + + grid_name : str + Name of the grid (used for saving the plot). + """ + + # Load tested grid parameters + grid_params = tested_params + + # List of parameter names (rows) and output names (columns) + param_names = list(param_settings.keys()) + out_names = list(output_settings.keys()) + + # Create subplot grid: rows = input parameters, columns = outputs variables + n_rows = len(param_names) + n_cols = len(out_names) + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(4 * n_cols, 2.75 * n_rows), + squeeze=False, + gridspec_kw={'wspace': 0.1, 'hspace': 0.2}, + ) + + # Loop through parameters (rows) and outputs (columns) + for i, param_name in enumerate(param_names): + tested_param = grid_params.get(param_name, []) + if tested_param is None or len(tested_param) == 0: + print(f'Skipping {param_name} — no tested values found in grid_params') + continue + settings = param_settings[param_name] + + # Determine if parameter is numeric or string for coloring + is_numeric = np.issubdtype(np.array(tested_param).dtype, np.number) + if is_numeric: + vmin, vmax = min(tested_param), max(tested_param) + if vmin == vmax: + vmin, vmax = vmin - 1e-9, vmax + 1e-9 + if settings.get('log_scale', False): + norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax) + else: + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + def color_func(v): + return settings['colormap'](norm(v)) + + colorbar_needed = True + else: + unique_vals = sorted(set(tested_param)) + cmap = mpl.colormaps.get_cmap(settings['colormap']).resampled(len(unique_vals)) + color_map = {val: cmap(j) for j, val in enumerate(unique_vals)} + + def color_func(v): + return color_map[v] + + colorbar_needed = False + + for j, output_name in enumerate(out_names): + ax = axes[i][j] + out_settings = output_settings[output_name] + + # Add panel number in upper-left corner + panel_number = i * n_cols + j + 1 + ax.text( + 0.03, + 0.95, + str(panel_number), + transform=ax.transAxes, + fontsize=18, + fontweight='bold', + va='top', + ha='left', + color='black', + bbox=dict( + facecolor='white', edgecolor='silver', boxstyle='round,pad=0.2', alpha=0.8 + ), + ) + + # Plot one ECDF per tested parameter value + for val in tested_param: + data_key = f'{output_name}_per_{param_name}' + # if val not in grouped_data.get(data_key, {}): + # continue + # raw = np.array(grouped_data[data_key][val]) * out_settings["scale"] + data_dict = grouped_data.get(data_key, {}) + if val not in data_dict: + continue + raw = np.array(data_dict[val]) # * out_settings["scale"] + + # Plot ECDF + sns.ecdfplot( + data=raw, + # log_scale=out_settings["log_scale"], + stat='percent', + color=color_func(val), + linewidth=4, + linestyle='-', + ax=ax, + ) + + # Configure x-axis labels, ticks, grids + if i == n_rows - 1: + ax.set_xlabel(latex(out_settings['label']), fontsize=22) + ax.xaxis.set_label_coords(0.5, -0.3) + ax.tick_params(axis='x', labelsize=22) + else: + ax.tick_params(axis='x', labelbottom=False) + + # Configure y-axis (shared label added later) + if j == 0: + ax.set_ylabel('') + ticks = [0.0, 50, 100] + ax.set_yticks(ticks) + ax.tick_params(axis='y', labelsize=22) + else: + ax.set_ylabel('') + ax.set_yticks(ticks) + ax.tick_params(axis='y', labelleft=False) + ax.tick_params( + axis='x', which='minor', direction='in', top=True, bottom=True, length=2 + ) + ax.tick_params( + axis='x', which='major', direction='inout', top=True, bottom=True, length=6 + ) + + ax.grid(alpha=0.4) + + # Configure log scale for x-axis if needed + if out_settings['log_scale']: + ax.set_xscale('log') + + # After plotting all outputs for this parameter (row), add colorbar or legend + if colorbar_needed: # colorbar for numeric parameters + sm = mpl.cm.ScalarMappable(cmap=settings['colormap'], norm=norm) + rightmost_ax = axes[i, -1] # Get the rightmost axis in the current row + cbar = fig.colorbar(sm, ax=rightmost_ax, pad=0.03, aspect=10) + cbar.set_label(latex(settings['label']), fontsize=24) + cbar.ax.yaxis.set_label_coords(6, 0.5) + ticks = sorted(set(tested_param)) + cbar.set_ticks(ticks) + cbar.ax.tick_params(labelsize=22) + else: # legend for string parameters + handles = [ + mpl.lines.Line2D([0], [0], color=color_map[val], lw=4, label=str(val)) + for val in unique_vals + ] + ax.legend(handles=handles, fontsize=24, bbox_to_anchor=(1.01, 1), loc='upper left') + + # Add a single, shared y-axis label + fig.text( + 0.07, + 0.5, + 'Empirical cumulative distribution of grid simulations [%]', + va='center', + rotation='vertical', + fontsize=40, + ) + + # Save figure + output_dir = grid_dir / 'post_processing' / 'grid_plots' + output_dir.mkdir(parents=True, exist_ok=True) + output_file = output_dir / f'ecdf_grid_plot_{grid_name}.{plot_format}' + fig.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close(fig) + + +# --------------------------------------------------------- +# main +# --------------------------------------------------------- + + +def main(grid_analyse_toml_file: str | Path): + # Load configuration from example.grid.toml + with open(grid_analyse_toml_file, 'rb') as f: + cfg = tomllib.load(f) + + # Get grid path and name + grid_path = Path('output/' + cfg['output'] + '/') + print(f'Grid path: {grid_path}') + grid_name = get_grid_name(grid_path) + + # --- Summary CSVs --- + update_csv = cfg.get('update_csv', True) + + summary_dir = grid_path / 'post_processing' / 'extracted_data' + summary_csv_all = summary_dir / f'{grid_name}_final_extracted_data_all.csv' + summary_csv_completed = summary_dir / f'{grid_name}_final_extracted_data_completed.csv' + summary_csv_running_error = ( + summary_dir / f'{grid_name}_final_extracted_data_running_error.csv' + ) + + if update_csv: + # Load grid data + data = load_grid_cases(grid_path) + input_param_grid_per_case, tested_params_grid = get_tested_grid_parameters( + data, grid_path + ) + + # Write CSV + generate_summary_csv(data, input_param_grid_per_case, grid_path, grid_name) + else: + # Check that CSVs exist + for f in [summary_csv_all, summary_csv_completed, summary_csv_running_error]: + if not f.exists(): + raise FileNotFoundError( + f'{f.name} not found in {summary_dir}, ' + 'but update_csv is set to False. Please set update_csv to True to generate it.' + ) + # Only load tested parameters from grid config + _, tested_params_grid = get_tested_grid_parameters([], grid_path) + + # --- Plot grid status --- + if cfg.get('plot_status', True): + all_simulations_data_csv = pd.read_csv(summary_csv_all, sep='\t') + plot_grid_status(all_simulations_data_csv, cfg, grid_path, grid_name) + print('Plot grid status summary is available.') + + # --- ECDF plots --- + if cfg.get('plot_ecdf', True): + completed_simulations_data_csv = pd.read_csv(summary_csv_completed, sep='\t') + columns_output = validate_output_variables( + completed_simulations_data_csv, cfg['output_variables'] + ) + if len(columns_output) == 0: + raise ValueError('No valid output variables found. Check your config file.') + grouped_data = group_output_by_parameter( + completed_simulations_data_csv, + list(tested_params_grid.keys()), + columns_output, + ) + + param_settings_grid, output_settings_grid, plot_format = load_ecdf_plot_settings( + cfg, tested_params_grid + ) + ecdf_grid_plot( + tested_params_grid, + grouped_data, + param_settings_grid, + output_settings_grid, + plot_format, + grid_path, + grid_name, + ) + print('ECDF grid plot is available.') diff --git a/src/proteus/utils/plot.py b/src/proteus/utils/plot.py index 583e7d42c..85400ef48 100644 --- a/src/proteus/utils/plot.py +++ b/src/proteus/utils/plot.py @@ -10,6 +10,7 @@ import numpy as np from proteus.utils.archive import archive_exists +from proteus.utils.constants import M_earth, R_earth from proteus.utils.helper import mol_to_ele log = logging.getLogger('fwl.' + __name__) @@ -86,6 +87,181 @@ 'nacljet': '#ee29f5', } +# Standard label for input and output variables +_preset_labels = { + ## Input parameters (from input.toml files) + # Orbit module + 'orbit.semimajoraxis': 'a [AU]', + 'orbit.eccentricity': 'e', + # Structure module + 'struct.mass_tot': 'M_{\\mathrm{tot}} [M_\\oplus]', + 'struct.radius_int': 'R_{\\mathrm{int}} [R_\\oplus]', + 'struct.corefrac': 'CRF', + # Atmosphere module + 'atmos_clim.module': 'Atmospheric\ntreatment', + # Escape module + 'escape.zephyrus.efficiency': '\\rm \\epsilon', + 'escape.zephyrus.Pxuv': 'P_{\\rm XUV}\\,[bar]', + # Outgassing module + 'outgas.fO2_shift_IW': '\\Delta\\,\\rm IW', + # Delivery module + 'delivery.elements.H_oceans': 'H [Earth oceans]', + 'delivery.elements.H_ppmw': 'H [ppmw]', + 'delivery.elements.H_kg': 'H [kg]', + 'delivery.elements.CH_ratio': 'C/H ratio', + 'delivery.elements.C_ppmw': 'C [ppmw]', + 'delivery.elements.C_kg': 'C [kg]', + 'delivery.elements.NH_ratio': 'N/H ratio', + 'delivery.elements.N_ppmw': 'N [ppmw]', + 'delivery.elements.N_kg': 'N [kg]', + 'delivery.elements.SH_ratio': 'S/H ratio', + 'delivery.elements.S_ppmw': 'S [ppmw]', + 'delivery.elements.S_kg': 'S [kg]', + ## Output variables (from runtime_helpfile.csv) + # Model tracking + 'Time': 'Time [yr]', + 'solidification_time': 'Solidification time [yr]', # computed in post-processing script, not in runtime_helpfile.csv + # Orbital parameters + 'semimajorax': 'a [m]', + 'eccentricity': 'e', + # Planet structure + 'R_int': 'R_{\\mathrm{int}} [R_\\oplus]', + 'M_int': 'M_{\\mathrm{int}} [M_\\oplus]', + 'M_planet': 'M_{\\mathrm{planet}} [M_\\oplus]', + # Temperatures + 'T_surf': 'T_{\\rm surf}\\,[\\mathrm{K}]', + 'T_magma': 'T_{\\rm magma}\\,[\\mathrm{K}]', + 'T_eqm': 'T_{\\rm eqm}\\,[\\mathrm{K}]', + 'T_skin': 'T_{\\rm skin}\\,[\\mathrm{K}]', + # Planet interior properties + 'Phi_global': 'Melt fraction [%]', + # Planet observational properties + 'R_obs': 'R_{\\rm obs}\\,[R_\\oplus]', + 'rho_obs': '\\rho_{\\rm obs}\\,[\\mathrm{g/cm^3}]', + # Atmospheric composition from outgassing + 'M_atm': 'Atmosphere mass [kg]', + 'P_surf': 'P_{\\rm surf}\\,[\\mathrm{bar}]', + 'atm_kg_per_mol': 'MMW [g/mol]', + # Atmospheric escape + 'esc_rate_total': 'Escape rate [g/s]', +} + +_preset_scales = { + ## Input parameters (from input.toml files) + # Orbit module + 'orbit.semimajoraxis': 1.0, + 'orbit.eccentricity': 1.0, + # Structure module + 'struct.mass_tot': 1.0, + 'struct.radius_int': 1.0, + 'struct.corefrac': 1.0, + # Atmosphere module + 'atmos_clim.module': 1.0, + # Escape module + 'escape.zephyrus.efficiency': 1.0, + 'escape.zephyrus.Pxuv': 1.0, + # Outgassing module + 'outgas.fO2_shift_IW': 1.0, + # Delivery module + 'delivery.elements.H_oceans': 1.0, + 'delivery.elements.H_ppmw': 1.0, + 'delivery.elements.H_kg': 1.0, + 'delivery.elements.CH_ratio': 1.0, + 'delivery.elements.C_ppmw': 1.0, + 'delivery.elements.C_kg': 1.0, + 'delivery.elements.NH_ratio': 1.0, + 'delivery.elements.N_ppmw': 1.0, + 'delivery.elements.N_kg': 1.0, + 'delivery.elements.SH_ratio': 1.0, + 'delivery.elements.S_ppmw': 1.0, + 'delivery.elements.S_kg': 1.0, + ## Output variables (from runtime_helpfile.csv) + # Model tracking + 'Time': 1.0, + 'solidification_time': 1.0, # computed in post-processing script, not in runtime_helpfile.csv + # Orbital parameters + 'semimajorax': 1.0, + 'eccentricity': 1.0, + # Planet structure + 'R_int': 1.0 / R_earth, + 'M_int': 1.0 / M_earth, + 'M_planet': 1.0 / M_earth, + # Temperatures + 'T_surf': 1.0, + 'T_magma': 1.0, + 'T_eqm': 1.0, + 'T_skin': 1.0, + # Planet interior properties + 'Phi_global': 100.0, + # Planet observational properties + 'R_obs': 1.0 / R_earth, + 'rho_obs': 0.001, + # Atmospheric composition from outgassing + 'M_atm': 1.0, + 'P_surf': 1.0, + 'atm_kg_per_mol': 1000.0, + # Atmospheric escape + 'esc_rate_total': 1000.0, +} + +_preset_log_scales = { + ## Input parameters (from input.toml files) + # Orbit module + 'orbit.semimajoraxis': False, + 'orbit.eccentricity': False, + # Structure module + 'struct.mass_tot': False, + 'struct.radius_int': False, + 'struct.corefrac': False, + # Atmosphere module + 'atmos_clim.module': False, + # Escape module + 'escape.zephyrus.efficiency': True, + 'escape.zephyrus.Pxuv': True, + # Outgassing module + 'outgas.fO2_shift_IW': False, + # Delivery module + 'delivery.elements.H_oceans': True, + 'delivery.elements.H_ppmw': True, + 'delivery.elements.H_kg': True, + 'delivery.elements.CH_ratio': True, + 'delivery.elements.C_ppmw': True, + 'delivery.elements.C_kg': True, + 'delivery.elements.NH_ratio': True, + 'delivery.elements.N_ppmw': True, + 'delivery.elements.N_kg': True, + 'delivery.elements.SH_ratio': True, + 'delivery.elements.S_ppmw': True, + 'delivery.elements.S_kg': True, + ## Output variables (from runtime_helpfile.csv) + # Model tracking + 'Time': True, + 'solidification_time': True, # computed in post-processing script, not in runtime_helpfile.csv + # Orbital parameters + 'semimajorax': False, + 'eccentricity': False, + # Planet structure + 'R_int': False, + 'M_int': False, + 'M_planet': False, + # Temperatures + 'T_surf': False, + 'T_magma': False, + 'T_eqm': False, + 'T_skin': False, + # Planet interior properties + 'Phi_global': False, + # Planet observational properties + 'R_obs': False, + 'rho_obs': False, + # Atmospheric composition from outgassing + 'M_atm': False, + 'P_surf': True, + 'atm_kg_per_mol': False, + # Atmospheric escape + 'esc_rate_total': True, +} + def _generate_colour(gas: str): """ diff --git a/tests/grid/dummy.grid.toml b/tests/grid/dummy.grid.toml index 171511211..52c24b06a 100644 --- a/tests/grid/dummy.grid.toml +++ b/tests/grid/dummy.grid.toml @@ -6,6 +6,14 @@ output = "dummy_grid" # Make `output` a symbolic link to this absolute location. To disable: set to empty string. symlink = "" +# Post-processing options +update_csv = true # Whether to update the summary CSV files before plotting +plot_format = "png" # Format for saving plots ("png" or "pdf") +plot_status = true # Generate status summary plot of the grid +plot_ecdf = true # Generate ECDF grid plot for input parameters tested in the grid and output_variables defined below +colormap = "viridis" # Colormap for ECDF plot +output_variables = ["solidification_time", "Phi_global", "T_surf", "P_surf", "atm_kg_per_mol", "esc_rate_total", "rho_obs", "H2_bar"] # List of output variables to include in ECDF plot + # Path to base (reference) config file relative to PROTEUS root folder ref_config = "tests/grid/base.toml" diff --git a/tests/grid/test_grid.py b/tests/grid/test_grid.py index 18ac797be..014df7e95 100644 --- a/tests/grid/test_grid.py +++ b/tests/grid/test_grid.py @@ -9,9 +9,11 @@ from proteus.grid.manage import grid_from_config from proteus.grid.pack import pack as gpack +from proteus.grid.post_processing import main as gpostprocess from proteus.grid.summarise import summarise as gsummarise OUT_DIR = PROTEUS_ROOT / 'output' / 'dummy_grid' +GRID_NAME = 'dummy_grid' GRID_CONFIG = PROTEUS_ROOT / 'tests' / 'grid' / 'dummy.grid.toml' BASE_CONFIG = PROTEUS_ROOT / 'tests' / 'grid' / 'base.toml' @@ -74,3 +76,27 @@ def test_grid_pack(grid_run): # check zip exists assert os.path.isfile(OUT_DIR / 'pack.zip') + + +@pytest.mark.integration +def test_grid_post_process(grid_run): + # Test running grid-post-process command + gpostprocess(GRID_CONFIG) + + # check post-processed summary CSV file exists + assert os.path.isfile( + OUT_DIR + / 'post_processing' + / 'extracted_data' + / f'{GRID_NAME}_final_extracted_data_all.csv' + ) + + # check that status summary plot was generated + assert os.path.isfile( + OUT_DIR / 'post_processing' / 'grid_plots' / f'summary_grid_statuses_{GRID_NAME}.png' + ) + + # check that ECDF plot was generated + assert os.path.isfile( + OUT_DIR / 'post_processing' / 'grid_plots' / f'ecdf_grid_plot_{GRID_NAME}.png' + ) diff --git a/tests/grid/test_post_processing.py b/tests/grid/test_post_processing.py new file mode 100644 index 000000000..df1d60d23 --- /dev/null +++ b/tests/grid/test_post_processing.py @@ -0,0 +1,342 @@ +""" +Unit tests for proteus.grid.post_processing helper functions. + +Tests the pure-Python helper functions that require no running grid, +no file I/O, and no compiled binaries. Each test is fast (<100 ms). + +See also: + docs/How-to/test_infrastructure.md + docs/How-to/test_categorization.md + docs/How-to/test_building.md +""" + +from __future__ import annotations + +import pandas as pd +import pytest + +from proteus.grid.post_processing import ( + clean_series, + flatten_input_parameters, + get_grid_name, + get_label, + get_log_scale, + get_scale, + group_output_by_parameter, + latex, + load_ecdf_plot_settings, + validate_output_variables, +) + +# --------------------------------------------------------- +# get_grid_name +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_get_grid_name_returns_directory_name(tmp_path): + """get_grid_name should return the last component of a valid directory path.""" + grid_dir = tmp_path / 'my_test_grid' + grid_dir.mkdir() + assert get_grid_name(grid_dir) == 'my_test_grid' + + +@pytest.mark.unit +def test_get_grid_name_raises_for_nonexistent_path(tmp_path): + """get_grid_name should raise ValueError if the path is not a directory.""" + missing = tmp_path / 'does_not_exist' + with pytest.raises(ValueError, match='not a valid directory'): + get_grid_name(missing) + + +@pytest.mark.unit +def test_get_grid_name_raises_for_file_path(tmp_path): + """get_grid_name should raise ValueError when given a file path instead of a directory.""" + file_path = tmp_path / 'not_a_dir.txt' + file_path.write_text('data') + with pytest.raises(ValueError, match='not a valid directory'): + get_grid_name(file_path) + + +@pytest.mark.unit +def test_get_grid_name_accepts_string(tmp_path): + """get_grid_name should accept a plain string as well as a Path.""" + grid_dir = tmp_path / 'str_grid' + grid_dir.mkdir() + assert get_grid_name(str(grid_dir)) == 'str_grid' + + +# --------------------------------------------------------- +# get_label +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_get_label_known_quantity(): + """Preset quantities should return their human-readable label from _preset_labels.""" + from proteus.utils.plot import _preset_labels + + assert get_label('T_surf') == _preset_labels['T_surf'] + + +@pytest.mark.unit +def test_get_label_unknown_quantity_returns_last_segment(): + """Unknown dotted path should return only the last segment.""" + label = get_label('some.deeply.nested.param') + assert label == 'param' + + +@pytest.mark.unit +def test_get_label_simple_unknown(): + """Unknown non-dotted key should be returned unchanged.""" + label = get_label('totally_unknown_var') + assert label == 'totally_unknown_var' + + +# --------------------------------------------------------- +# get_scale +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_get_scale_known_quantity(): + """Preset quantities should return a non-default (or explicitly 1.0) scale.""" + scale = get_scale('Phi_global') + # Phi_global is stored as fraction but plotted as %, so scale should be 100 + assert scale == pytest.approx(100.0, rel=1e-5) + + +@pytest.mark.unit +def test_get_scale_unknown_returns_one(): + """Unknown quantities should fall back to scale factor of 1.0.""" + assert get_scale('not_a_real_quantity') == pytest.approx(1.0, rel=1e-5) + + +# --------------------------------------------------------- +# get_log_scale +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_get_log_scale_known_log_quantity(): + """escape.zephyrus.efficiency is a known log-scale quantity.""" + from proteus.utils.plot import _preset_log_scales + + assert 'escape.zephyrus.efficiency' in _preset_log_scales + assert get_log_scale('escape.zephyrus.efficiency') is True + + +@pytest.mark.unit +def test_get_log_scale_unknown_returns_false(): + """Unknown quantities default to linear scale (False).""" + assert get_log_scale('some_unknown_output') is False + + +# --------------------------------------------------------- +# latex +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_latex_wraps_backslash_label(): + """Labels containing a backslash should be wrapped in dollar signs.""" + assert latex('\\rm surf') == '$\\rm surf$' + + +@pytest.mark.unit +def test_latex_plain_label_unchanged(): + """Labels without backslash should not be wrapped.""" + assert latex('T_surf') == 'T_surf' + + +@pytest.mark.unit +def test_latex_empty_string(): + """Empty string has no backslash, so it should be returned as-is.""" + assert latex('') == '' + + +# --------------------------------------------------------- +# clean_series +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_clean_series_removes_nan(): + """NaN values should be dropped from the series.""" + s = pd.Series([1.0, float('nan'), 3.0]) + result = clean_series(s, log_scale=False) + assert len(result) == 2 + assert not result.isna().any() + + +@pytest.mark.unit +def test_clean_series_removes_inf(): + """Infinite values should be replaced by NaN and then dropped.""" + s = pd.Series([1.0, float('inf'), -float('inf'), 2.0]) + result = clean_series(s, log_scale=False) + assert len(result) == 2 + + +@pytest.mark.unit +def test_clean_series_log_scale_removes_nonpositive(): + """With log_scale=True, zero and negative values should be dropped.""" + s = pd.Series([-1.0, 0.0, 0.5, 2.0]) + result = clean_series(s, log_scale=True) + assert (result > 0).all() + assert len(result) == 2 + + +@pytest.mark.unit +def test_clean_series_linear_keeps_zeros(): + """With log_scale=False, zero values should be retained.""" + s = pd.Series([0.0, 1.0, 2.0]) + result = clean_series(s, log_scale=False) + assert 0.0 in result.values + assert len(result) == 3 + + +# --------------------------------------------------------- +# validate_output_variables +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_validate_output_variables_all_present(): + """All requested outputs that exist in the DataFrame should be returned.""" + df = pd.DataFrame({'Time': [1], 'T_surf': [300], 'P_surf': [1e5]}) + valid = validate_output_variables(df, ['T_surf', 'P_surf']) + assert valid == ['T_surf', 'P_surf'] + + +@pytest.mark.unit +def test_validate_output_variables_missing_excluded(capsys): + """Missing output variables should be excluded and a warning printed.""" + df = pd.DataFrame({'Time': [1], 'T_surf': [300]}) + valid = validate_output_variables(df, ['T_surf', 'not_a_column']) + assert valid == ['T_surf'] + captured = capsys.readouterr() + assert 'WARNING' in captured.out and 'not_a_column' in captured.out + + +@pytest.mark.unit +def test_validate_output_variables_empty_request(): + """Empty request list should return empty list.""" + df = pd.DataFrame({'Time': [1], 'T_surf': [300]}) + valid = validate_output_variables(df, []) + assert valid == [] + + +# --------------------------------------------------------- +# flatten_input_parameters +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_flatten_input_parameters_simple(): + """Flat dict with leaf blocks should be returned unchanged.""" + d = {'escape': {'efficiency': {'label': 'eff', 'log': True}}} + result = flatten_input_parameters(d) + assert 'escape.efficiency' in result + assert result['escape.efficiency']['label'] == 'eff' + + +@pytest.mark.unit +def test_flatten_input_parameters_skips_colormap(): + """The 'colormap' key should be skipped at any nesting level.""" + d = {'colormap': 'viridis', 'orbit': {'sma': {'label': 'a'}}} + result = flatten_input_parameters(d) + assert 'colormap' not in result + assert 'orbit.sma' in result + + +@pytest.mark.unit +def test_flatten_input_parameters_deeply_nested(): + """Multi-level nesting should produce correct dot-separated keys.""" + d = {'a': {'b': {'c': {'label': 'deep'}}}} + result = flatten_input_parameters(d) + assert 'a.b.c' in result + + +# --------------------------------------------------------- +# load_ecdf_plot_settings +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_load_ecdf_plot_settings_basic(): + """Should return three items: param_settings, output_settings, plot_format.""" + cfg = { + 'colormap': 'viridis', + 'output_variables': ['T_surf', 'P_surf'], + 'plot_format': 'png', + } + tested_params = {'struct.mass_tot': [0.7, 1.0]} + param_settings, output_settings, plot_format = load_ecdf_plot_settings(cfg, tested_params) + + assert 'struct.mass_tot' in param_settings + assert 'T_surf' in output_settings + assert 'P_surf' in output_settings + assert plot_format == 'png' + + +@pytest.mark.unit +def test_load_ecdf_plot_settings_respects_colormap(): + """The colormap from config should be applied to param_settings.""" + import matplotlib.cm as cm + + cfg = { + 'colormap': 'plasma', + 'output_variables': ['T_surf'], + 'plot_format': 'pdf', + } + tested_params = {'orbit.semimajoraxis': [0.1, 1.0]} + param_settings, _, _ = load_ecdf_plot_settings(cfg, tested_params) + # The colormap object stored should correspond to 'plasma' + assert param_settings['orbit.semimajoraxis']['colormap'] is cm.plasma + + +@pytest.mark.unit +def test_load_ecdf_plot_settings_raises_for_empty_params(): + """Should raise ValueError when no tested parameters are provided.""" + cfg = {'colormap': 'viridis', 'output_variables': ['T_surf'], 'plot_format': 'png'} + with pytest.raises(ValueError, match='No tested parameters'): + load_ecdf_plot_settings(cfg, {}) + + +# --------------------------------------------------------- +# group_output_by_parameter +# --------------------------------------------------------- + + +@pytest.mark.unit +def test_group_output_by_parameter_basic(): + """Should group output values correctly by input parameter.""" + df = pd.DataFrame( + { + 'struct.mass_tot': [0.7, 0.7, 1.0, 1.0], + 'T_surf': [500.0, 600.0, 700.0, 800.0], + } + ) + result = group_output_by_parameter(df, ['struct.mass_tot'], ['T_surf']) + key = 'T_surf_per_struct.mass_tot' + assert key in result + assert 0.7 in result[key] + assert 1.0 in result[key] + assert len(result[key][0.7]) == 2 + + +@pytest.mark.unit +def test_group_output_by_parameter_applies_scale(): + """Scale factor from _preset_scales should be applied to output values.""" + # Phi_global has a scale of 100 (fraction → percentage) + df = pd.DataFrame( + { + 'struct.corefrac': [0.35, 0.35], + 'Phi_global': [0.5, 0.8], + } + ) + result = group_output_by_parameter(df, ['struct.corefrac'], ['Phi_global']) + key = 'Phi_global_per_struct.corefrac' + values = list(result[key][0.35]) + # Values should be scaled by 100 + assert pytest.approx(values, rel=1e-5) == [50.0, 80.0]