Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/freeflux/analysis/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def solve(
max_iters = 400,
show_progress = True,
rng = None,
use_jax_experimental = False # Added parameter
):
'''
Parameters
Expand All @@ -546,11 +547,18 @@ def solve(
Maximum # of iterations.
show_progress: bool
Whether to show the progress bar.
use_jax_experimental: bool
Whether to attempt using the JAX-based calculation pathway.
'''

self._check_dependencies(fit_measured_fluxes)

optModel = MFAModel(self.model, fit_measured_fluxes, solver)
# Pass use_jax_experimental to MFAModel constructor
optModel = MFAModel(self.model,
fit_measured_fluxes,
solver,
use_jax=use_jax_experimental and hasattr(self.model, 'jax_prepared') and self.model.jax_prepared)

optModel.build_objective()
optModel.build_gradient()
optModel.build_flux_bound_constraints()
Expand Down
30 changes: 30 additions & 0 deletions src/freeflux/analysis/inst_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,36 @@ def prepare(self, dilution_from = None, n_jobs = 1):
self._estimate_fluxes_range(self.model.unbalanced_metabolites)
self._set_default_concentration_bounds()
self._estimate_concentrations_range()

# --- JAX specific preparations for Instationary ---
if hasattr(self.model, 'use_jax_experimental') and self.model.use_jax_experimental:
# Ensure JAX is available (copied from Fitter.prepare)
try:
from ..utils.utils import JAX_INSTALLED as JAX_READY
except ImportError:
JAX_READY = False
if not JAX_READY:
raise RuntimeError("JAX features requested for InstFitter but JAX is not installed or not found.")

# Common JAX prep (already in Fitter.prepare, but if InstFitter.prepare is called directly)
self.calculator._lambdify_matrix_As_and_Bs_jax()
self.calculator._prepare_substrate_MDVs_jax()
self.calculator._prepare_matrix_derivatives_jax() # For A, B derivatives

# Instationary specific JAX prep
self.calculator._lambdify_matrix_Ms_jax()
self.calculator._prepare_matrix_Ms_derivatives_p_jax() # After matrix_Ms_der_p (numpy) is computed
self.calculator._prepare_initial_conditions_jax() # After initial_matrix_Xs/Ys etc (numpy) are computed

# Substrate derivatives for 'inst' kind need to be prepared for JAX
# This assumes self.model.substrate_MDVs_der_p was populated with 'inst' kind by the numpy path
self.calculator._prepare_substrate_MDVs_der_p_jax()

# For measured_fluxes_der_p_jax with 'inst' kind
self.model.jax_flux_derivatives_enabled = True
self.calculator._calculate_measured_fluxes_derivative_p('inst') # Ensure JAX version for 'inst' is populated

self.model.jax_prepared_inst = True # Signal that JAX data for instationary model is ready


def _check_dependencies(self, fit_measured_fluxes):
Expand Down
Loading