From 2f297221ab6dc3b2e56e2bc407701b59f2de4e23 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 20 Jun 2025 18:08:48 +0000 Subject: [PATCH] Refactor: Add JAX Hessian computation and correct Fitter.solve This commit builds upon the JAX integration by adding JAX-based Hessian computation for both MFAModel and InstMFAModel. It also includes the correction to Fitter.solve() from the previous iteration. Key changes: - MFAModel and InstMFAModel now compute the Hessian using `jax.hessian` on their respective core JAX objective functions when the JAX pathway is active. These are JIT-compiled. - The JAX-computed Hessian is provided to SciPy optimizers if the solver is not 'slsqp' or 'ralg'. - Corrected Fitter.solve() in fit.py to properly pass JAX usage flags. Existing functionality: - JAX-based objective and gradient for MFAModel (steady-state). - Structural JAX support for InstMFAModel (instationary), with the core instationary JAX calculation (`core_calculate_inst_mdvs_jax`) still being a non-functional placeholder. - Data preparation in Calculator and Fitter/InstFitter for JAX pathways. Testing: - Test script execution remains blocked by environment issues. Full validation of JAX steady-state path (including Hessian) and instationary path structure could not be completed. Further work: - Implement `core_calculate_inst_mdvs_jax` for instationary models. - Thoroughly test and benchmark all JAX pathways. --- src/freeflux/analysis/fit.py | 10 +- src/freeflux/analysis/inst_fit.py | 30 + src/freeflux/solver/nlpsolver.py | 829 ++++++++++++++++++++--- src/freeflux/utils/jax_utils.py | 296 +++++++++ src/freeflux/utils/utils.py | 1014 +++++++++++++++++++++-------- 5 files changed, 1822 insertions(+), 357 deletions(-) create mode 100644 src/freeflux/utils/jax_utils.py diff --git a/src/freeflux/analysis/fit.py b/src/freeflux/analysis/fit.py index 69cce06..f07e506 100644 --- a/src/freeflux/analysis/fit.py +++ b/src/freeflux/analysis/fit.py @@ -530,6 +530,7 @@ def solve( max_iters = 400, show_progress = True, rng = None, + use_jax_experimental = False # Added parameter ): ''' Parameters @@ -546,11 +547,18 @@ def solve( Maximum # of iterations. show_progress: bool Whether to show the progress bar. + use_jax_experimental: bool + Whether to attempt using the JAX-based calculation pathway. ''' self._check_dependencies(fit_measured_fluxes) - optModel = MFAModel(self.model, fit_measured_fluxes, solver) + # Pass use_jax_experimental to MFAModel constructor + optModel = MFAModel(self.model, + fit_measured_fluxes, + solver, + use_jax=use_jax_experimental and hasattr(self.model, 'jax_prepared') and self.model.jax_prepared) + optModel.build_objective() optModel.build_gradient() optModel.build_flux_bound_constraints() diff --git a/src/freeflux/analysis/inst_fit.py b/src/freeflux/analysis/inst_fit.py index d662aff..80f7232 100644 --- a/src/freeflux/analysis/inst_fit.py +++ b/src/freeflux/analysis/inst_fit.py @@ -340,6 +340,36 @@ def prepare(self, dilution_from = None, n_jobs = 1): self._estimate_fluxes_range(self.model.unbalanced_metabolites) self._set_default_concentration_bounds() self._estimate_concentrations_range() + + # --- JAX specific preparations for Instationary --- + if hasattr(self.model, 'use_jax_experimental') and self.model.use_jax_experimental: + # Ensure JAX is available (copied from Fitter.prepare) + try: + from ..utils.utils import JAX_INSTALLED as JAX_READY + except ImportError: + JAX_READY = False + if not JAX_READY: + raise RuntimeError("JAX features requested for InstFitter but JAX is not installed or not found.") + + # Common JAX prep (already in Fitter.prepare, but if InstFitter.prepare is called directly) + self.calculator._lambdify_matrix_As_and_Bs_jax() + self.calculator._prepare_substrate_MDVs_jax() + self.calculator._prepare_matrix_derivatives_jax() # For A, B derivatives + + # Instationary specific JAX prep + self.calculator._lambdify_matrix_Ms_jax() + self.calculator._prepare_matrix_Ms_derivatives_p_jax() # After matrix_Ms_der_p (numpy) is computed + self.calculator._prepare_initial_conditions_jax() # After initial_matrix_Xs/Ys etc (numpy) are computed + + # Substrate derivatives for 'inst' kind need to be prepared for JAX + # This assumes self.model.substrate_MDVs_der_p was populated with 'inst' kind by the numpy path + self.calculator._prepare_substrate_MDVs_der_p_jax() + + # For measured_fluxes_der_p_jax with 'inst' kind + self.model.jax_flux_derivatives_enabled = True + self.calculator._calculate_measured_fluxes_derivative_p('inst') # Ensure JAX version for 'inst' is populated + + self.model.jax_prepared_inst = True # Signal that JAX data for instationary model is ready def _check_dependencies(self, fit_measured_fluxes): diff --git a/src/freeflux/solver/nlpsolver.py b/src/freeflux/solver/nlpsolver.py index 18be57e..632ecd5 100644 --- a/src/freeflux/solver/nlpsolver.py +++ b/src/freeflux/solver/nlpsolver.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from scipy.linalg import pinv +from scipy.linalg import pinv as scipy_pinv # Renamed to avoid conflict if jax.scipy.linalg.pinv is used locally from scipy.optimize import LinearConstraint from scipy.optimize import minimize try: @@ -15,7 +15,17 @@ OPENOPT_INSTALLED = False else: OPENOPT_INSTALLED = True + +try: + import jax + import jax.numpy as jnp + from ..utils.jax_utils import core_calculate_mdvs_jax, core_calculate_mdvs_and_derivatives_jax + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + from ..utils.utils import Calculator +from functools import partial # For jax.jit static_argnames class MFAModel(): @@ -31,7 +41,7 @@ class MFAModel(): * If "ralg", openopt NLP solver will be used. ''' - def __init__(self, model, fit_measured_fluxes, solver = 'slsqp'): + def __init__(self, model, fit_measured_fluxes, solver = 'slsqp', use_jax=False): ''' Parameters ---------- @@ -42,143 +52,386 @@ def __init__(self, model, fit_measured_fluxes, solver = 'slsqp'): solvor: {"slsqp", "ralg"} * If "slsqp", scipy.optimize.minimze will be used. * If "ralg", openopt NLP solver will be used. + use_jax: bool + Whether to use JAX for objective and gradient calculation. ''' self.model = model - self.calculator = Calculator(self.model) + self.calculator = Calculator(self.model) # Still needed for non-JAX parts or data prep self.fit_measured_fluxes = fit_measured_fluxes self.solver = solver + self.use_jax = use_jax and JAX_AVAILABLE self.N = self.model.null_space self.T = self.model.transform_matrix self.ntotalfluxes = len(self.model.totalfluxids) + + if self.use_jax: + self._prepare_jax_data() + self._jit_jax_functions() + def _prepare_jax_data(self): + """Prepares JAX-compatible static data structures.""" + if not self.use_jax: return + + self.N_jax = jnp.array(self.N) + + # These are expected to be prepared by Fitter.prepare() and stored on model: + # self.model.matrix_As_jax_static_data, self.model.matrix_Bs_jax_static_data + # self.model.substrate_MDVs_jax_static_data + # self.model.matrix_As_der_p_jax_static_data, self.model.matrix_Bs_der_p_jax_static_data + # self.model.substrate_MDVs_der_p_jax_static_data + # For now, assume they exist and are pytrees of JAX arrays / callables + self.matrix_As_jax_static = self.model.matrix_As_jax_static_data + self.matrix_Bs_jax_static = self.model.matrix_Bs_jax_static_data + self.substrate_MDVs_jax_static = self.model.substrate_MDVs_jax_static_data + + # Derivatives (for gradient calculation via core_calculate_mdvs_and_derivatives_jax) + self.matrix_As_der_p_jax_static = self.model.matrix_As_der_p_jax_static_data + self.matrix_Bs_der_p_jax_static = self.model.matrix_Bs_der_p_jax_static_data + self.substrate_MDVs_der_p_jax_static = self.model.substrate_MDVs_der_p_jax_static_data + + # Measured MDVs (means) + self.measured_MDVs_means_jax = { + k: jnp.array(v[0]) for k, v in self.model.measured_MDVs.items() + } + self.measured_MDVs_inv_cov_jax = jnp.array(self.model.measured_MDVs_inv_cov) + + if self.fit_measured_fluxes: + self.measured_fluxes_means_jax = jnp.array([ + mean for mean, sd in self.model.measured_fluxes.values() + ]) + # Need a way to map measured flux keys to indices if order matters for residual vector + self.measured_flux_ids_ordered = list(self.model.measured_fluxes.keys()) + # Assuming model.totalfluxids_map_jax: dict str_id -> int_idx exists + self.measured_flux_indices_in_total_jax = jnp.array([ + self.model.totalfluxids_map_jax[fid] for fid in self.measured_flux_ids_ordered + ]) + self.measured_fluxes_inv_cov_jax = jnp.array(self.model.measured_fluxes_inv_cov) + if hasattr(self.model, 'measured_fluxes_der_p_jax'): # Expected from Fitter.prepare + self.measured_fluxes_der_p_jax = self.model.measured_fluxes_der_p_jax + else: # Fallback or raise error + self.measured_fluxes_der_p_jax = jnp.array(self.model.measured_fluxes_der_p) + + + self.target_EMU_ids_tuple_jax = tuple(self.model.target_EMUs) + # EAMs keys are sizes. Ensure model.EAMs_jax_sorted_keys exists from Fitter.prepare() + self.sorted_emu_sizes_tuple_jax = self.model.EAMs_jax_sorted_keys + self.num_free_fluxes_jax = self.N.shape[1] + + # Define static argument names for JIT compilation + # For _core_objective_fn_jax + self.static_obj_argnames = ( + "N_jax", "matrix_As_static", "matrix_Bs_static", + "substrate_MDVs_jax_static", "measured_MDVs_means_jax", + "measured_MDVs_inv_cov_jax", "target_EMU_ids_tuple_jax", + "sorted_emu_sizes_tuple_jax", "fit_measured_fluxes_static" + ) + if self.fit_measured_fluxes: + self.static_obj_argnames += ( + "measured_flux_ids_ordered_static", # For consistent ordering of flux residuals + "measured_flux_indices_in_total_jax_static", + "measured_fluxes_means_jax_static", + "measured_fluxes_inv_cov_jax_static" + ) + + # For _core_gradient_fn_jax (which uses core_calculate_mdvs_and_derivatives_jax) + # This will be jax.grad of _core_objective_fn_jax + # Alternatively, if we define a _core_residual_vector_fn_jax, then objective is sum of squares + # and grad can be derived using that. Let's stick to grad of objective for now. + + def _jit_jax_functions(self): + if not self.use_jax: return + + # The function to be differentiated by JAX + # It takes u_jax and all other static data (as JAX arrays/pytrees) + # and returns a scalar objective value. + + # --- Define the core JAX objective function --- + def _core_objective_fn_jax( + u_jax, + N_jax, matrix_As_static, matrix_Bs_static, + substrate_MDVs_jax_static, measured_MDVs_means_jax, + measured_MDVs_inv_cov_jax, target_EMU_ids_tuple_jax, + sorted_emu_sizes_tuple_jax, fit_measured_fluxes_static, + # Optional flux args, only used if fit_measured_fluxes_static is True + measured_flux_ids_ordered_static=None, + measured_flux_indices_in_total_jax_static=None, + measured_fluxes_means_jax_static=None, + measured_fluxes_inv_cov_jax_static=None + ): + + total_fluxes_jax = N_jax @ u_jax + + sim_MDVs_dict_jax = core_calculate_mdvs_jax( + total_fluxes_jax, + matrix_As_static, + matrix_Bs_static, + substrate_MDVs_jax_static, + target_EMU_ids_tuple_jax, # Pass along, though core_calculate_mdvs_jax doesn't currently filter + sorted_emu_sizes_tuple_jax + ) + + # MDV residuals + mdv_residuals_list = [] + for emu_id in target_EMU_ids_tuple_jax: + sim_mdv = sim_MDVs_dict_jax[emu_id] + exp_mdv = measured_MDVs_means_jax[emu_id] + mdv_residuals_list.append(sim_mdv - exp_mdv) + + mdv_residuals_vector = jnp.concatenate(mdv_residuals_list) + obj_mdv = mdv_residuals_vector.T @ measured_MDVs_inv_cov_jax @ mdv_residuals_vector + + obj_total = obj_mdv + + if fit_measured_fluxes_static: + # Flux residuals + # Need to select the *simulated* fluxes corresponding to *measured* fluxes + sim_measured_fluxes_jax = total_fluxes_jax.take(measured_flux_indices_in_total_jax_static) + flux_residuals_vector = sim_measured_fluxes_jax - measured_fluxes_means_jax_static + obj_flux = flux_residuals_vector.T @ measured_fluxes_inv_cov_jax_static @ flux_residuals_vector + obj_total += obj_flux + + return obj_total + + self._jitted_core_objective_fn = jax.jit( + _core_objective_fn_jax, static_argnames=self.static_obj_argnames + ) + + # Gradient of the core objective function + self._jitted_core_gradient_fn = jax.jit( + jax.grad(_core_objective_fn_jax, argnums=0), # Differentiate w.r.t. u_jax (arg 0) + static_argnames=self.static_obj_argnames + ) + + # Hessian of the core objective function + self._jitted_core_hessian_fn = jax.jit( + jax.hessian(_core_objective_fn_jax, argnums=0), # Hessian w.r.t. u_jax (arg 0) + static_argnames=self.static_obj_argnames + ) + # Make the core objective function a static method or a regular method callable by JAX + # For simplicity here, it's defined locally in _jit_jax_functions. If it were a class method: + # @staticmethod (or regular method if it needs to be part of the class for other reasons) + # def _actual_core_objective_fn_jax(u_jax, N_jax, ...): ... + # Then in _jit_jax_functions: + # self._jitted_core_objective_fn = jax.jit(MFAModel._actual_core_objective_fn_jax, ...) + # self._jitted_core_gradient_fn = jax.jit(jax.grad(MFAModel._actual_core_objective_fn_jax, ...), ...) + # self._jitted_core_hessian_fn = jax.jit(jax.hessian(MFAModel._actual_core_objective_fn_jax, ...), ...) + # For now, the local definition within _jit_jax_functions is used. + + + # --- Original methods for reference/fallback --- def _calculate_difference_sim_exp_MDVs(self): - - simMDVs = self.calculator._calculate_MDVs() + # This is the original numpy-based calculation + simMDVs = self.calculator._calculate_MDVs() # Uses self.model.total_fluxes (numpy) expMDVs = self.model.measured_MDVs diff = np.concatenate([simMDVs[emuid] - expMDVs[emuid][0] for emuid in self.model.target_EMUs]) - return diff - def _calculate_difference_sim_exp_fluxes(self): - - simFluxes = self.model.total_fluxes[self.model.measured_fluxes.keys()] + # Original numpy-based + simFluxes = self.model.total_fluxes[list(self.model.measured_fluxes.keys())] # Pandas series indexing expFluxes = np.array([mean for mean, _ in self.model.measured_fluxes.values()]) - diff = simFluxes - expFluxes - + diff = simFluxes.to_numpy() - expFluxes # Ensure numpy array for operation return diff - def _calculate_sim_MDVs_derivative(self): - + # Original numpy-based + # This sets self.model.total_fluxes, then calls calculator simMDVs, simMDVsDer = self.calculator._calculate_MDVs_and_derivatives_p() expMDVs = self.model.measured_MDVs diff = np.concatenate([simMDVs[emuid] - expMDVs[emuid][0] for emuid in self.model.target_EMUs]) - dxsim_dp = np.vstack([simMDVsDer[emuid] for emuid in self.model.target_EMUs]) - return diff, dxsim_dp - def _calculate_sim_fluxes_derivative(self): - - dvsim_dp = self.model.measured_fluxes_der_p - + # Original numpy-based + dvsim_dp = self.model.measured_fluxes_der_p # This is d(total_fluxes)/du * N return dvsim_dp - + # --- Build objective and gradient using JAX if enabled --- def build_objective(self): - - def _f(u): - self.model.total_fluxes[:] = self.N@u - - MDV_diff = self._calculate_difference_sim_exp_MDVs() - MDV_inv_cov = self.model.measured_MDVs_inv_cov - obj = MDV_diff@MDV_inv_cov@MDV_diff - - return obj - - def f1(u): - return _f(u) - - def f2(u): - obj1 = _f(u) - - flux_diff = self._calculate_difference_sim_exp_fluxes() - flux_inv_cov = self.model.measured_fluxes_inv_cov - obj2 = flux_diff@flux_inv_cov@flux_diff - - return obj1 + obj2 + if self.use_jax: + # Prepare static args for the jitted function call + static_args = { + "N_jax": self.N_jax, + "matrix_As_static": self.matrix_As_jax_static, + "matrix_Bs_static": self.matrix_Bs_jax_static, + "substrate_MDVs_jax_static": self.substrate_MDVs_jax_static, + "measured_MDVs_means_jax": self.measured_MDVs_means_jax, + "measured_MDVs_inv_cov_jax": self.measured_MDVs_inv_cov_jax, + "target_EMU_ids_tuple_jax": self.target_EMU_ids_tuple_jax, + "sorted_emu_sizes_tuple_jax": self.sorted_emu_sizes_tuple_jax, + "fit_measured_fluxes_static": self.fit_measured_fluxes + } + if self.fit_measured_fluxes: + static_args.update({ + "measured_flux_ids_ordered_static": self.measured_flux_ids_ordered, + "measured_flux_indices_in_total_jax_static": self.measured_flux_indices_in_total_jax, + "measured_fluxes_means_jax_static": self.measured_fluxes_means_jax, + "measured_fluxes_inv_cov_jax_static": self.measured_fluxes_inv_cov_jax + }) + + def f_jax_wrapper(u_numpy): + u_jax = jnp.array(u_numpy) + # JAX operations happen inside _jitted_core_objective_fn + obj_val_jax = self._jitted_core_objective_fn(u_jax, **static_args) + return float(obj_val_jax) + self.f = f_jax_wrapper + else: # Original numpy-based objective + def _f_numpy(u_np): + self.model.total_fluxes.iloc[:] = self.N @ u_np # Update pandas Series + + MDV_diff = self._calculate_difference_sim_exp_MDVs() # Uses self.model.total_fluxes + MDV_inv_cov = self.model.measured_MDVs_inv_cov + obj = MDV_diff @ MDV_inv_cov @ MDV_diff + return obj - self.f = f2 if self.fit_measured_fluxes else f1 + def f1_numpy(u_np): + return _f_numpy(u_np) + + def f2_numpy(u_np): + obj1 = _f_numpy(u_np) + # flux_diff uses self.model.total_fluxes which was updated by _f_numpy + flux_diff = self._calculate_difference_sim_exp_fluxes() + flux_inv_cov = self.model.measured_fluxes_inv_cov + obj2 = flux_diff @ flux_inv_cov @ flux_diff + return obj1 + obj2 + + self.f = f2_numpy if self.fit_measured_fluxes else f1_numpy - def build_gradient(self): - - def _df(u): - self.model.total_fluxes[:] = self.N@u - - MDV_diff, MDV_der = self._calculate_sim_MDVs_derivative() - MDV_inv_cov = self.model.measured_MDVs_inv_cov - grad = MDV_der.T@MDV_inv_cov@MDV_diff - - return grad - - def df1(u): - return _df(u) - - def df2(u): - grad1 = _df(u) + if self.use_jax: + # Static args are the same as for the objective + static_args = { + "N_jax": self.N_jax, + "matrix_As_static": self.matrix_As_jax_static, + "matrix_Bs_static": self.matrix_Bs_jax_static, + "substrate_MDVs_jax_static": self.substrate_MDVs_jax_static, + "measured_MDVs_means_jax": self.measured_MDVs_means_jax, + "measured_MDVs_inv_cov_jax": self.measured_MDVs_inv_cov_jax, + "target_EMU_ids_tuple_jax": self.target_EMU_ids_tuple_jax, + "sorted_emu_sizes_tuple_jax": self.sorted_emu_sizes_tuple_jax, + "fit_measured_fluxes_static": self.fit_measured_fluxes + } + if self.fit_measured_fluxes: + static_args.update({ + "measured_flux_ids_ordered_static": self.measured_flux_ids_ordered, + "measured_flux_indices_in_total_jax_static": self.measured_flux_indices_in_total_jax, + "measured_fluxes_means_jax_static": self.measured_fluxes_means_jax, + "measured_fluxes_inv_cov_jax_static": self.measured_fluxes_inv_cov_jax + }) + + def df_jax_wrapper(u_numpy): + u_jax = jnp.array(u_numpy) + grad_val_jax = self._jitted_core_gradient_fn(u_jax, **static_args) + return np.array(grad_val_jax) + self.df = df_jax_wrapper + else: # Original numpy-based gradient + # Important: self.model.total_fluxes must be set before calling derivative funcs + def _df_numpy(u_np): + # This state modification is problematic for purity if we were to JIT this part directly. + # Here, it's part of the non-JAX path. + self.model.total_fluxes.iloc[:] = self.N @ u_np + + MDV_diff, MDV_der = self._calculate_sim_MDVs_derivative() # Uses self.model.total_fluxes + MDV_inv_cov = self.model.measured_MDVs_inv_cov + grad = MDV_der.T @ MDV_inv_cov @ MDV_diff + return grad + + def df1_numpy(u_np): + return _df_numpy(u_np) - flux_der = self._calculate_sim_fluxes_derivative() - flux_diff = self._calculate_difference_sim_exp_fluxes() - flux_inv_cov = self.model.measured_fluxes_inv_cov - grad2 = flux_der.T@flux_inv_cov@flux_diff + def df2_numpy(u_np): + # _df_numpy updates self.model.total_fluxes + grad1 = _df_numpy(u_np) + + # These use the updated self.model.total_fluxes + flux_der = self._calculate_sim_fluxes_derivative() + flux_diff = self._calculate_difference_sim_exp_fluxes() + flux_inv_cov = self.model.measured_fluxes_inv_cov + grad2 = flux_der.T @ flux_inv_cov @ flux_diff + return grad1 + grad2 - return grad1 + grad2 - - self.df = df2 if self.fit_measured_fluxes else df1 + self.df = df2_numpy if self.fit_measured_fluxes else df1_numpy def build_hessian(self): - - def _ddf(u): - self.model.total_fluxes[:] = self.N@u + if self.use_jax: + # Prepare static args for the jitted Hessian function call + static_args = { + "N_jax": self.N_jax, + "matrix_As_static": self.matrix_As_jax_static, + "matrix_Bs_static": self.matrix_Bs_jax_static, + "substrate_MDVs_jax_static": self.substrate_MDVs_jax_static, + "measured_MDVs_means_jax": self.measured_MDVs_means_jax, + "measured_MDVs_inv_cov_jax": self.measured_MDVs_inv_cov_jax, + "target_EMU_ids_tuple_jax": self.target_EMU_ids_tuple_jax, + "sorted_emu_sizes_tuple_jax": self.sorted_emu_sizes_tuple_jax, + "fit_measured_fluxes_static": self.fit_measured_fluxes + } + if self.fit_measured_fluxes: + static_args.update({ + "measured_flux_ids_ordered_static": self.measured_flux_ids_ordered, + "measured_flux_indices_in_total_jax_static": self.measured_flux_indices_in_total_jax, + "measured_fluxes_means_jax_static": self.measured_fluxes_means_jax, + "measured_fluxes_inv_cov_jax_static": self.measured_fluxes_inv_cov_jax + }) + + def ddf_jax_wrapper(u_numpy): + u_jax = jnp.array(u_numpy) + # Ensure _jitted_core_hessian_fn is available (defined in _jit_jax_functions) + hess_val_jax = self._jitted_core_hessian_fn(u_jax, **static_args) + return np.array(hess_val_jax) - _, MDV_der = self._calculate_sim_MDVs_derivative() + # Provide Hessian if solver is not SLSQP or ralg (which don't typically use it directly from user) + # or if explicitly configured to use it. + if self.solver not in ['slsqp', 'ralg']: # e.g. for 'trust-constr' + self.ddf = ddf_jax_wrapper + else: + self.ddf = None # SLSQP and ralg can work without it or use approximations. + else: + self._build_original_hessian() + + def _build_original_hessian(self): + # This is the original logic extracted + def _ddf_numpy(u_np): + # Critical: self.model.total_fluxes must be set based on u_np + # This happens if called after obj/grad in scipy, but direct call needs care. + # For safety, recalculate here if not sure about state. + # However, original code implies it's called in a context where total_fluxes is current. + # self.model.total_fluxes.iloc[:] = self.N @ u_np # Ensure state for original calc + + _, MDV_der = self._calculate_sim_MDVs_derivative() # Uses current self.model.total_fluxes MDV_inv_cov = self.model.measured_MDVs_inv_cov - hess = MDV_der.T@MDV_inv_cov@MDV_der - + hess = MDV_der.T @ MDV_inv_cov @ MDV_der return hess - def ddf1(u): - return _ddf(u) + def ddf1_numpy(u_np): + return _ddf_numpy(u_np) - def ddf2(u): - hess1 = _ddf(u) - - flux_der = self._calculate_sim_fluxes_derivative() + def ddf2_numpy(u_np): + hess1 = _ddf_numpy(u_np) + flux_der = self._calculate_sim_fluxes_derivative() # Uses current self.model.total_fluxes flux_inv_cov = self.model.measured_fluxes_inv_cov - hess2 = flux_der.T@flux_inv_cov@flux_der - + hess2 = flux_der.T @ flux_inv_cov @ flux_der return hess1 + hess2 - self.ddf = ddf2 if self.fit_measured_fluxes else ddf1 + self.ddf = ddf2_numpy if self.fit_measured_fluxes else ddf1_numpy def build_flux_bound_constraints(self): A1 = self.N - A2 = self.T@self.N + A2 = self.T @ self.N A3 = -A2 b1 = np.zeros(self.ntotalfluxes) + # Ensure net_fluxes_range values are numpy arrays for consistency vnet_lb, vnet_ub = np.array(list(self.model.net_fluxes_range.values())).T b2 = vnet_lb b3 = -vnet_ub @@ -187,7 +440,8 @@ def build_flux_bound_constraints(self): b = np.concatenate((b1, b2, b3)) if self.solver == 'slsqp': - self.constrs = {'type': 'ineq', 'fun': lambda u: A@u - b} + # SLSQP constraints fun must return a 1D array + self.constrs = {'type': 'ineq', 'fun': lambda u: (A @ u - b).ravel()} elif self.solver == 'trust-constr' or self.solver == 'ipopt': self.constrs = LinearConstraint(A, b, np.inf) elif self.solver == 'ralg': @@ -401,14 +655,419 @@ def solve_flux(self, tol = 1e-6, max_iters = 400, disp = False): class InstMFAModel(MFAModel): - def __init__(self, *args): + def __init__(self, model, fit_measured_fluxes, solver='slsqp', use_jax=False): + # Call MFAModel's __init__ but without use_jax, as InstMFAModel handles its own JAX setup + # Or, pass use_jax=False explicitly if MFAModel.__init__ uses it. + # MFAModel.__init__ was: def __init__(self, model, fit_measured_fluxes, solver = 'slsqp', use_jax=False): + super().__init__(model, fit_measured_fluxes, solver, use_jax=False) # Initialize base non-JAX parts - super().__init__(*args) + self.use_jax_inst = use_jax and JAX_AVAILABLE # Specific flag for instationary JAX self.nfreefluxes = self.N.shape[1] - self.nconcs = len(self.model.concids) + self.nconcs = len(self.model.concids if hasattr(self.model, 'concids') else []) self.nnetfluxes = len(self.model.netfluxids) + + if self.use_jax_inst: + if not (hasattr(self.model, 'jax_prepared_inst') and self.model.jax_prepared_inst): + # This might indicate that InstFitter.prepare with JAX options was not called. + warnings.warn("InstMFAModel initialized with use_jax=True, but JAX data for instationary model seems unprepared. JAX path may fail.") + self._prepare_jax_data_inst() + self._jit_jax_functions_inst() + + def _prepare_jax_data_inst(self): + """Prepares JAX-compatible static data for InstMFAModel.""" + if not self.use_jax_inst: return + + # Data common with MFAModel, ensure it's JAXified if not already by superclass or here + if not hasattr(self, 'N_jax'): # Could be set by MFAModel if its _prepare_jax_data was called + self.N_jax = jnp.array(self.N) + + # Instationary specific JAX data structures from model (prepared by Calculator/InstFitter) + self.matrix_As_jax_static_inst = self.model.matrix_As_jax_static_data + self.matrix_Bs_jax_static_inst = self.model.matrix_Bs_jax_static_data + self.matrix_Ms_jax_static_inst = self.model.matrix_Ms_jax_static_data + + self.substrate_MDVs_jax_static_inst = self.model.substrate_MDVs_jax_static_data + self.matrix_As_der_p_jax_static_inst = self.model.matrix_As_der_p_jax_static_data # d/dp where p=(u,c) + self.matrix_Bs_der_p_jax_static_inst = self.model.matrix_Bs_der_p_jax_static_data + self.matrix_Ms_der_p_jax_static_inst = self.model.matrix_Ms_der_p_jax_static_data + self.substrate_MDVs_der_p_jax_static_inst = self.model.substrate_MDVs_der_p_jax_static_data + + self.initial_Xs_jax_inst = self.model.initial_matrix_Xs_jax + self.initial_Ys_jax_inst = self.model.initial_matrix_Ys_jax + self.initial_Xs_der_p_jax_inst = self.model.initial_matrix_Xs_der_p_jax + self.initial_Ys_der_p_jax_inst = self.model.initial_matrix_Ys_der_p_jax + + # Measured instMDVs (means) - complex structure {emu_id: {time: array}} + # For JAX, this needs to be a pytree of JAX arrays. + self.measured_inst_MDVs_means_jax = { + emu_id: {t: jnp.array(mdv_data[0]) for t, mdv_data in time_data.items()} + for emu_id, time_data in self.model.measured_inst_MDVs.items() + } + self.measured_inst_MDVs_inv_cov_jax = jnp.array(self.model.measured_inst_MDVs_inv_cov) + self.timepoints_jax = jnp.array(self.model.timepoints) + + if self.fit_measured_fluxes: + self.measured_fluxes_means_jax_inst = jnp.array([ + mean for mean, sd in self.model.measured_fluxes.values() + ]) + self.measured_flux_ids_ordered_inst = list(self.model.measured_fluxes.keys()) + self.measured_flux_indices_in_total_jax_inst = jnp.array([ + self.model.totalfluxids_map_jax[fid] for fid in self.measured_flux_ids_ordered_inst + ]) + self.measured_fluxes_inv_cov_jax_inst = jnp.array(self.model.measured_fluxes_inv_cov) + # d(flux)/dp where p=(u,c) + self.measured_fluxes_der_p_jax_inst = self.model.measured_fluxes_der_p_jax # Should be (n_params_inst, n_meas_fluxes) + + self.target_EMU_ids_tuple_jax_inst = tuple(self.model.target_EMUs) + self.sorted_emu_sizes_tuple_jax_inst = self.model.EAMs_jax_sorted_keys + self.num_total_params_inst = self.nfreefluxes + self.nconcs + + + # Define static argument names for JIT compilation of instationary core objective + self.static_obj_argnames_inst = ( + "N_jax", "matrix_As_static_inst", "matrix_Bs_static_inst", "matrix_Ms_static_inst", + "substrate_MDVs_jax_static_inst", + "initial_Xs_jax_inst", "initial_Ys_jax_inst", + "measured_inst_MDVs_means_jax", "measured_inst_MDVs_inv_cov_jax", + "timepoints_jax", "target_EMU_ids_tuple_jax_inst", "sorted_emu_sizes_tuple_jax_inst", + "fit_measured_fluxes_static", "nfreefluxes_static" # nfreefluxes needed to split p into u,c + ) + if self.fit_measured_fluxes: + self.static_obj_argnames_inst += ( + "measured_flux_ids_ordered_static_inst", + "measured_flux_indices_in_total_jax_static_inst", + "measured_fluxes_means_jax_static_inst", + "measured_fluxes_inv_cov_jax_static_inst" + ) + + def _jit_jax_functions_inst(self): + if not self.use_jax_inst: return + + def _core_objective_fn_inst_jax( + p_jax, # Concatenated [u_jax, c_jax] + # Static args start here + N_jax, matrix_As_static_inst, matrix_Bs_static_inst, matrix_Ms_static_inst, + substrate_MDVs_jax_static_inst, + initial_Xs_jax_inst, initial_Ys_jax_inst, + measured_inst_MDVs_means_jax, measured_inst_MDVs_inv_cov_jax, + timepoints_jax, target_EMU_ids_tuple_jax_inst, sorted_emu_sizes_tuple_jax_inst, + fit_measured_fluxes_static, nfreefluxes_static, + # Optional static flux args + measured_flux_ids_ordered_static_inst=None, + measured_flux_indices_in_total_jax_static_inst=None, + measured_fluxes_means_jax_static_inst=None, + measured_fluxes_inv_cov_jax_static_inst=None + ): + u_jax = p_jax[:nfreefluxes_static] + c_jax = p_jax[nfreefluxes_static:] + total_fluxes_jax = N_jax @ u_jax + + # This function is currently a placeholder in jax_utils.py + sim_inst_MDVs_dict_jax = core_calculate_inst_mdvs_jax( + initial_Xs_jax_inst, initial_Ys_jax_inst, timepoints_jax, + total_fluxes_jax, c_jax, + matrix_As_static_inst, matrix_Bs_static_inst, matrix_Ms_static_inst, + substrate_MDVs_jax_static_inst, + target_EMU_ids_tuple_jax_inst, sorted_emu_sizes_tuple_jax_inst + ) + + mdv_residuals_list = [] + for emu_id in target_EMU_ids_tuple_jax_inst: + if emu_id in sim_inst_MDVs_dict_jax: + for t in timepoints_jax: # Iterate over all timepoints + if t == 0: continue # Skip t=0 for residuals usually + if t in sim_inst_MDVs_dict_jax[emu_id] and \ + emu_id in measured_inst_MDVs_means_jax and \ + t in measured_inst_MDVs_means_jax[emu_id]: + + sim_mdv_t = sim_inst_MDVs_dict_jax[emu_id][t] + exp_mdv_t = measured_inst_MDVs_means_jax[emu_id][t] + mdv_residuals_list.append(sim_mdv_t - exp_mdv_t) + + if not mdv_residuals_list: # Should not happen if there's measured data + obj_mdv = 0.0 + else: + mdv_residuals_vector = jnp.concatenate(mdv_residuals_list) + obj_mdv = mdv_residuals_vector.T @ measured_inst_MDVs_inv_cov_jax @ mdv_residuals_vector + + obj_total = obj_mdv + + if fit_measured_fluxes_static: + sim_measured_fluxes_jax = total_fluxes_jax.take(measured_flux_indices_in_total_jax_static_inst) + flux_residuals_vector = sim_measured_fluxes_jax - measured_fluxes_means_jax_static_inst + obj_flux = flux_residuals_vector.T @ measured_fluxes_inv_cov_jax_static_inst @ flux_residuals_vector + obj_total += obj_flux + + return obj_total + + self._jitted_core_objective_fn_inst = jax.jit( + _core_objective_fn_inst_jax, static_argnames=self.static_obj_argnames_inst + ) + + self._jitted_core_gradient_fn_inst = jax.jit( + jax.grad(_core_objective_fn_inst_jax, argnums=0), # Differentiate w.r.t. p_jax + static_argnames=self.static_obj_argnames_inst + ) + + self._jitted_core_hessian_fn_inst = jax.jit( + jax.hessian(_core_objective_fn_inst_jax, argnums=0), # Hessian w.r.t. p_jax + static_argnames=self.static_obj_argnames_inst + ) + + # --- Original methods for InstMFAModel --- + def _calculate_difference_sim_exp_MDVs(self): + # This is the original numpy-based calculation for instationary + # It updates self.model.total_fluxes and self.model.concentrations first + # then calls self.calculator._calculate_inst_MDVs() + # For this method to be called from original build_objective, p has to be split. + # This method is specific to the non-JAX path. + + # The p argument is not passed here, assumes self.model attributes are set. + # This is how original InstMFAModel._f calls it. + simMDVs_all_emus_all_times = self.calculator._calculate_inst_MDVs() # numpy based + expMDVs_all_emus_all_times = self.model.measured_inst_MDVs + + diff_list = [] + for emuid in self.model.target_EMUs: + if emuid in simMDVs_all_emus_all_times and emuid in expMDVs_all_emus_all_times: + sim_data_for_emu = simMDVs_all_emus_all_times[emuid] + exp_data_for_emu = expMDVs_all_emus_all_times[emuid] + for t_point in exp_data_for_emu: # Iterate over measured timepoints + if t_point != 0 and t_point in sim_data_for_emu: + diff_list.append(sim_data_for_emu[t_point] - exp_data_for_emu[t_point][0]) + + if not diff_list: return np.array([]) # Handle case with no valid residuals + return np.concatenate(diff_list) + + + def _calculate_sim_MDVs_derivative(self): + # Original numpy-based for instationary. Assumes model state is set. + simMDVs_all_times, simMDVsDer_all_times = self.calculator._calculate_inst_MDVs_and_derivatives_p() + expMDVs_all_times = self.model.measured_inst_MDVs + + diff_list = [] + dxsim_dp_list = [] + + for emuid in self.model.target_EMUs: + if emuid in simMDVs_all_times and emuid in expMDVs_all_times: + sim_data_for_emu = simMDVs_all_times[emuid] + exp_data_for_emu = expMDVs_all_times[emuid] + sim_der_for_emu = simMDVsDer_all_times[emuid] # {time: deriv_array (n_coeffs, n_params)} + + for t_point in exp_data_for_emu: + if t_point != 0 and t_point in sim_data_for_emu and t_point in sim_der_for_emu: + diff_list.append(sim_data_for_emu[t_point] - exp_data_for_emu[t_point][0]) + # Derivative from calculator is (n_coeffs, n_params). For vstack, it's fine. + dxsim_dp_list.append(sim_der_for_emu[t_point]) + + if not diff_list: return np.array([]), np.array([]).reshape(0, self.nfreefluxes + self.nconcs) # Adjust shape for empty + + diff_vector = np.concatenate(diff_list) + # dxsim_dp_stacked will be (total_coeffs_all_times, n_params) + dxsim_dp_stacked = np.vstack(dxsim_dp_list) if dxsim_dp_list else np.array([]).reshape(0, self.nfreefluxes + self.nconcs) + + return diff_vector, dxsim_dp_stacked + + + def build_objective(self): + if self.use_jax_inst: + static_args_inst = { + "N_jax": self.N_jax, # Assuming this is prepared correctly + "matrix_As_static_inst": self.matrix_As_jax_static_inst, + "matrix_Bs_static_inst": self.matrix_Bs_jax_static_inst, + "matrix_Ms_static_inst": self.matrix_Ms_jax_static_inst, + "substrate_MDVs_jax_static_inst": self.substrate_MDVs_jax_static_inst, + "initial_Xs_jax_inst": self.initial_Xs_jax_inst, + "initial_Ys_jax_inst": self.initial_Ys_jax_inst, + "measured_inst_MDVs_means_jax": self.measured_inst_MDVs_means_jax, + "measured_inst_MDVs_inv_cov_jax": self.measured_inst_MDVs_inv_cov_jax, + "timepoints_jax": self.timepoints_jax, + "target_EMU_ids_tuple_jax_inst": self.target_EMU_ids_tuple_jax_inst, + "sorted_emu_sizes_tuple_jax_inst": self.sorted_emu_sizes_tuple_jax_inst, + "fit_measured_fluxes_static": self.fit_measured_fluxes, + "nfreefluxes_static": self.nfreefluxes + } + if self.fit_measured_fluxes: + static_args_inst.update({ + "measured_flux_ids_ordered_static_inst": self.measured_flux_ids_ordered_inst, + "measured_flux_indices_in_total_jax_static_inst": self.measured_flux_indices_in_total_jax_inst, + "measured_fluxes_means_jax_static_inst": self.measured_fluxes_means_jax_inst, + "measured_fluxes_inv_cov_jax_static_inst": self.measured_fluxes_inv_cov_jax_inst + }) + + def f_jax_inst_wrapper(p_numpy): + p_jax = jnp.array(p_numpy) + obj_val_jax = self._jitted_core_objective_fn_inst(p_jax, **static_args_inst) + return float(obj_val_jax) + self.f = f_jax_inst_wrapper + else: # Original numpy-based objective for InstMFAModel + def _f_inst_numpy(p_np): # p_np is [u_np, c_np] + u_np, c_np = p_np[:self.nfreefluxes], p_np[self.nfreefluxes:] + # Update model state for calculator methods + self.model.total_fluxes.iloc[:] = self.N @ u_np + self.model.concentrations.iloc[:] = c_np # Assuming concentrations is a pandas Series + + MDV_diff = self._calculate_difference_sim_exp_MDVs() # Uses updated model state + if MDV_diff.size == 0: return 0.0 # No residuals to compute objective from + MDV_inv_cov = self.model.measured_inst_MDVs_inv_cov # NumPy array + obj = MDV_diff @ MDV_inv_cov @ MDV_diff + return obj + + def f1_inst_numpy(p_np): + return _f_inst_numpy(p_np) + + def f2_inst_numpy(p_np): + obj1 = _f_inst_numpy(p_np) # This also updates model state + + # _calculate_difference_sim_exp_fluxes uses self.model.total_fluxes + flux_diff = super()._calculate_difference_sim_exp_fluxes() # Use MFAModel's method + if flux_diff.size == 0: return obj1 # No flux data to fit + flux_inv_cov = self.model.measured_fluxes_inv_cov # NumPy array + obj2 = flux_diff @ flux_inv_cov @ flux_diff + return obj1 + obj2 + + self.f = f2_inst_numpy if self.fit_measured_fluxes else f1_inst_numpy + + + def build_gradient(self): + if self.use_jax_inst: + static_args_inst = { # Same as for objective + "N_jax": self.N_jax, + "matrix_As_static_inst": self.matrix_As_jax_static_inst, + "matrix_Bs_static_inst": self.matrix_Bs_jax_static_inst, + "matrix_Ms_static_inst": self.matrix_Ms_jax_static_inst, + "substrate_MDVs_jax_static_inst": self.substrate_MDVs_jax_static_inst, + "initial_Xs_jax_inst": self.initial_Xs_jax_inst, + "initial_Ys_jax_inst": self.initial_Ys_jax_inst, + "measured_inst_MDVs_means_jax": self.measured_inst_MDVs_means_jax, + "measured_inst_MDVs_inv_cov_jax": self.measured_inst_MDVs_inv_cov_jax, + "timepoints_jax": self.timepoints_jax, + "target_EMU_ids_tuple_jax_inst": self.target_EMU_ids_tuple_jax_inst, + "sorted_emu_sizes_tuple_jax_inst": self.sorted_emu_sizes_tuple_jax_inst, + "fit_measured_fluxes_static": self.fit_measured_fluxes, + "nfreefluxes_static": self.nfreefluxes + } + if self.fit_measured_fluxes: + static_args_inst.update({ + "measured_flux_ids_ordered_static_inst": self.measured_flux_ids_ordered_inst, + "measured_flux_indices_in_total_jax_static_inst": self.measured_flux_indices_in_total_jax_inst, + "measured_fluxes_means_jax_static_inst": self.measured_fluxes_means_jax_inst, + "measured_fluxes_inv_cov_jax_static_inst": self.measured_fluxes_inv_cov_jax_inst + }) + + def df_jax_inst_wrapper(p_numpy): + p_jax = jnp.array(p_numpy) + grad_val_jax = self._jitted_core_gradient_fn_inst(p_jax, **static_args_inst) + return np.array(grad_val_jax) + self.df = df_jax_inst_wrapper + else: # Original numpy-based gradient for InstMFAModel + def _df_inst_numpy(p_np): # p_np is [u_np, c_np] + u_np, c_np = p_np[:self.nfreefluxes], p_np[self.nfreefluxes:] + # Update model state + self.model.total_fluxes.iloc[:] = self.N @ u_np + self.model.concentrations.iloc[:] = c_np + + MDV_diff, MDV_der = self._calculate_sim_MDVs_derivative() # Uses updated model state + # MDV_der shape (total_coeffs, n_params_inst) + if MDV_diff.size == 0: return np.zeros_like(p_np) + + MDV_inv_cov = self.model.measured_inst_MDVs_inv_cov + grad = MDV_der.T @ MDV_inv_cov @ MDV_diff # (n_params_inst, total_coeffs) @ (...) @ (total_coeffs,) -> (n_params_inst,) + return grad + + def df1_inst_numpy(p_np): + return _df_inst_numpy(p_np) + + def df2_inst_numpy(p_np): + grad1 = _df_inst_numpy(p_np) # Also updates model state + + # _calculate_sim_fluxes_derivative is from MFAModel, uses self.model.measured_fluxes_der_p + # This measured_fluxes_der_p should be d(flux)/dp where p=(u,c) for instationary. + # It was set by Calculator._calculate_measured_fluxes_derivative_p('inst') + flux_der = super()._calculate_sim_fluxes_derivative() # (n_meas_fluxes, n_params_inst) + flux_diff = super()._calculate_difference_sim_exp_fluxes() # (n_meas_fluxes,) + + if flux_diff.size == 0: return grad1 + + flux_inv_cov = self.model.measured_fluxes_inv_cov + grad2 = flux_der.T @ flux_inv_cov @ flux_diff # (n_params_inst, n_meas_fluxes) @ (...) @ (n_meas_fluxes,) -> (n_params_inst,) + return grad1 + grad2 + + self.df = df2_inst_numpy if self.fit_measured_fluxes else df1_inst_numpy + + # build_hessian for InstMFAModel would follow similar logic if JAX path is taken + # For now, it will inherit MFAModel's build_hessian. + # If JAX is used for InstMFAModel, MFAModel's build_hessian might try to use + # steady-state JAX data if not careful. + # Override build_hessian for InstMFAModel: + def build_hessian(self): + if self.use_jax_inst: + static_args_inst = { # Same static args as objective/gradient for instationary + "N_jax": self.N_jax, + "matrix_As_static_inst": self.matrix_As_jax_static_inst, + "matrix_Bs_static_inst": self.matrix_Bs_jax_static_inst, + "matrix_Ms_static_inst": self.matrix_Ms_jax_static_inst, + "substrate_MDVs_jax_static_inst": self.substrate_MDVs_jax_static_inst, + "initial_Xs_jax_inst": self.initial_Xs_jax_inst, + "initial_Ys_jax_inst": self.initial_Ys_jax_inst, + "measured_inst_MDVs_means_jax": self.measured_inst_MDVs_means_jax, + "measured_inst_MDVs_inv_cov_jax": self.measured_inst_MDVs_inv_cov_jax, + "timepoints_jax": self.timepoints_jax, + "target_EMU_ids_tuple_jax_inst": self.target_EMU_ids_tuple_jax_inst, + "sorted_emu_sizes_tuple_jax_inst": self.sorted_emu_sizes_tuple_jax_inst, + "fit_measured_fluxes_static": self.fit_measured_fluxes, + "nfreefluxes_static": self.nfreefluxes + } + if self.fit_measured_fluxes: + static_args_inst.update({ + "measured_flux_ids_ordered_static_inst": self.measured_flux_ids_ordered_inst, + "measured_flux_indices_in_total_jax_static_inst": self.measured_flux_indices_in_total_jax_inst, + "measured_fluxes_means_jax_static_inst": self.measured_fluxes_means_jax_inst, + "measured_fluxes_inv_cov_jax_static_inst": self.measured_fluxes_inv_cov_jax_inst + }) + + def ddf_jax_inst_wrapper(p_numpy): + p_jax = jnp.array(p_numpy) + hess_val_jax = self._jitted_core_hessian_fn_inst(p_jax, **static_args_inst) + return np.array(hess_val_jax) + + if self.solver not in ['slsqp', 'ralg']: + self.ddf = ddf_jax_inst_wrapper + else: + self.ddf = None + else: + self._build_original_hessian_inst() + + def _build_original_hessian_inst(self): + # Original Hessian logic for InstMFAModel + def _ddf_inst_numpy(p_np): + u_np, c_np = p_np[:self.nfreefluxes], p_np[self.nfreefluxes:] + self.model.total_fluxes.iloc[:] = self.N @ u_np + self.model.concentrations.iloc[:] = c_np + + _, MDV_der = self._calculate_sim_MDVs_derivative() # (total_coeffs, n_params_inst) + if MDV_der.size == 0: return np.zeros((len(p_np), len(p_np))) + + MDV_inv_cov = self.model.measured_inst_MDVs_inv_cov + hess = MDV_der.T @ MDV_inv_cov @ MDV_der # Gauss-Newton part + return hess + + def ddf1_inst_numpy(p_np): + return _ddf_inst_numpy(p_np) + + def ddf2_inst_numpy(p_np): + hess1 = _ddf_inst_numpy(p_np) # Also updates model state + + flux_der = super()._calculate_sim_fluxes_derivative() # (n_meas_fluxes, n_params_inst) + if flux_der.size == 0: return hess1 + + flux_inv_cov = self.model.measured_fluxes_inv_cov + hess2 = flux_der.T @ flux_inv_cov @ flux_der + return hess1 + hess2 + + self.ddf = ddf2_inst_numpy if self.fit_measured_fluxes else ddf1_inst_numpy def _calculate_difference_sim_exp_MDVs(self): diff --git a/src/freeflux/utils/jax_utils.py b/src/freeflux/utils/jax_utils.py new file mode 100644 index 0000000..28024ca --- /dev/null +++ b/src/freeflux/utils/jax_utils.py @@ -0,0 +1,296 @@ +"""JAX utility functions for freeflux.""" + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import pinv + + +def jax_conv(arr1, arr2): + """JAX equivalent of polynomial convolution for 1D arrays (like numpy.convolve).""" + if arr1 is None or arr2 is None: + # This case should ideally be handled by ensuring valid (non-None) arrays are passed, + # or by defining specific behavior (e.g., if one is identity MDV [1.0]). + raise ValueError("jax_conv received None input. This should be handled before calling.") + return jnp.convolve(arr1, arr2) + +def jax_diff_conv(mdv1_mdv1der_pair, mdv2_mdv2der_pair, num_params_for_zeros=None): + """ + Calculates (conv(mdv1, mdv2), d(conv(mdv1, mdv2))/dp) using JAX. + mdv1, mdv2 are 1D JAX arrays. + mdv1der, mdv2der are 2D JAX arrays (n_params, n_coeffs_for_mdv), or None. + num_params_for_zeros: integer, required if both mdv1der and mdv2der are None, + to correctly shape the zero derivative. + + Returns: (convolved_mdv [1D], convolved_mdv_derivative [2D: (n_params, n_coeffs_conv)]) + The derivative part can be None if both input derivatives are None and num_params_for_zeros is not given. + """ + mdv1, mdv1der = mdv1_mdv1der_pair + mdv2, mdv2der = mdv2_mdv2der_pair + + # Ensure mdv1 and mdv2 are not None + if mdv1 is None or mdv2 is None: + raise ValueError("MDV inputs to jax_diff_conv cannot be None.") + + convolved_mdv = jnp.convolve(mdv1, mdv2) + + term1_der = None + if mdv1der is not None: + # mdv1der shape: (n_params, len(mdv1)) + # mdv2 shape: (len(mdv2),) + term1_der = jax.vmap(lambda d1_row: jnp.convolve(d1_row, mdv2))(mdv1der) + # term1_der shape: (n_params, len(convolved_mdv)) + + term2_der = None + if mdv2der is not None: + # mdv1 shape: (len(mdv1),) + # mdv2der shape: (n_params, len(mdv2)) + term2_der = jax.vmap(lambda d2_row: jnp.convolve(mdv1, d2_row))(mdv2der) + # term2_der shape: (n_params, len(convolved_mdv)) + + if term1_der is not None and term2_der is not None: + convolved_mdv_der = term1_der + term2_der + elif term1_der is not None: + convolved_mdv_der = term1_der + elif term2_der is not None: + convolved_mdv_der = term2_der + else: + # Both derivatives are None (conceptually zero). + if num_params_for_zeros is not None: + convolved_mdv_der = jnp.zeros((num_params_for_zeros, len(convolved_mdv))) + else: + # This case should ideally be avoided by providing num_params_for_zeros + # if there's a possibility of all derivatives being None. + # For JIT, shapes must be consistent. + raise ValueError("num_params_for_zeros must be provided if all input derivatives can be None.") + + + return convolved_mdv, convolved_mdv_der + + +def core_calculate_mdvs_jax( + total_fluxes_jax, + matrix_As_static_data, + matrix_Bs_static_data, + substrate_MDVs_jax, + target_EMU_ids_tuple, # Currently unused in core calc, filtering is done outside + sorted_emu_sizes_tuple + ): + """ + Calculates simulated Mass Distribution Vectors (MDVs) using JAX. + This function is designed to be JIT-compatible. + Static data (matrices, EMU lists) are passed in pytrees. + """ + sim_MDVs_dict = {} + + for size in sorted_emu_sizes_tuple: + A_data = matrix_As_static_data[size] + B_data = matrix_Bs_static_data[size] + + lambA_jax = A_data['func'] + flux_indices_A = A_data['flux_indices'] + product_EMU_ids = A_data['product_emu_ids'] + + lambB_jax = B_data['func'] + flux_indices_B = B_data['flux_indices'] + source_EMU_ids_or_tuples = B_data['source_emu_ids_or_tuples'] + + fluxes_for_A = total_fluxes_jax.take(jnp.array(flux_indices_A)) + fluxes_for_B = total_fluxes_jax.take(jnp.array(flux_indices_B)) + + A = lambA_jax(*fluxes_for_A) + B = lambB_jax(*fluxes_for_B) + + Y_parts = [] + for source_item in source_EMU_ids_or_tuples: + if isinstance(source_item, str): + source_emu_id = source_item + mdv = sim_MDVs_dict.get(source_emu_id, substrate_MDVs_jax.get(source_emu_id)) + if mdv is None: raise ValueError(f"MDV not found for {source_emu_id} in size {size}") + Y_parts.append(mdv) + else: + mdvs_to_convolve = [] + for emu_id_in_tuple in source_item: + mdv = sim_MDVs_dict.get(emu_id_in_tuple, substrate_MDVs_jax.get(emu_id_in_tuple)) + if mdv is None: raise ValueError(f"MDV not found for {emu_id_in_tuple} in convolution for size {size}") + mdvs_to_convolve.append(mdv) + + if mdvs_to_convolve: + current_conv = mdvs_to_convolve[0] + for i in range(1, len(mdvs_to_convolve)): + current_conv = jax_conv(current_conv, mdvs_to_convolve[i]) + Y_parts.append(current_conv) + else: # Should not happen with valid model structure + raise ValueError(f"Empty convolution list for size {size}") + + + num_source_terms = B.shape[1] + # mdv_len_this_size = size + 1 # This was an assumption about Y matrix content. + # Each row in Y is an MDV. These MDVs must have a length compatible with B's columns. + # The EMU formulation implies that B projects/combines these source MDVs + # into contributions for product EMUs of current `size`. + # The X = pinv(A)@B@Y implies Y's rows are MDVs that B can operate on. + # The resulting X will have rows of length `size+1`. + + if Y_parts: + # Y must be a 2D array (matrix) for B @ Y. + # Each element of Y_parts is a 1D MDV array. + # Their lengths can vary if they come from EMUs of different sizes. + # This is a CRITICAL POINT: The original code `Y = np.array(Y)` implicitly assumes all MDVs in Y_parts + # have the same length to form a 2D array. This is true if all source EMUs (or convolutions thereof) + # for a given product size `s` also result in MDVs corresponding to size `s`. + # This seems to be an implicit assumption of the (A*X = B*Y) formulation per size. + # Let's assume all mdvs in Y_parts for a given `size` effectively have length `size+1`. + mdv_len_check = size + 1 + Y = jnp.array(Y_parts) + if Y.ndim == 1 and num_source_terms == 1: # Single source term, Y_parts contained one 1D array + Y = Y.reshape(1, -1) + + if Y.shape[0] != num_source_terms: + raise ValueError(f"Mismatch in Y parts ({Y.shape[0]}) and B matrix columns ({num_source_terms}) for size {size}") + if Y.shape[1] != mdv_len_check: + # This is a deviation from the simple assumption. + # If this happens, B must be structured to handle it (e.g. padded EMUs). + # For now, stick to the assumption that Y rows are all length `size+1`. + raise ValueError(f"Y row length {Y.shape[1]} does not match expected {mdv_len_check} for size {size}") + + elif num_source_terms > 0: + raise ValueError(f"Y_parts is empty but B expects {num_source_terms} inputs for size {size}") + else: # num_source_terms == 0, Y_parts is empty + # If B has 0 columns, Y should be (0, mdv_len_for_products_of_this_size) + Y = jnp.zeros((0, size + 1)) + + X = pinv(A) @ B @ Y + + for i, emu_id in enumerate(product_EMU_ids): + sim_MDVs_dict[emu_id] = X[i, :] + + return sim_MDVs_dict + + +def core_calculate_mdvs_and_derivatives_jax( + total_fluxes_jax, + matrix_As_static_data, + matrix_Bs_static_data, + substrate_MDVs_jax, + matrix_As_der_p_static_data, + matrix_Bs_der_p_static_data, + substrate_MDVs_der_p_jax, + target_EMU_ids_tuple, # Currently unused + sorted_emu_sizes_tuple, + num_free_fluxes + ): + """ + Calculates simulated MDVs and their derivatives w.r.t. parameters (free fluxes `u`) using JAX. + """ + sim_MDVs_dict = {} + sim_MDVs_der_dict = {} + + for size in sorted_emu_sizes_tuple: + A_data = matrix_As_static_data[size] + B_data = matrix_Bs_static_data[size] + + lambA_jax = A_data['func'] + flux_indices_A = A_data['flux_indices'] + product_EMU_ids = A_data['product_emu_ids'] + + lambB_jax = B_data['func'] + flux_indices_B = B_data['flux_indices'] + source_EMU_ids_or_tuples = B_data['source_emu_ids_or_tuples'] + + fluxes_for_A = total_fluxes_jax.take(jnp.array(flux_indices_A)) + fluxes_for_B = total_fluxes_jax.take(jnp.array(flux_indices_B)) + + A = lambA_jax(*fluxes_for_A) + B = lambB_jax(*fluxes_for_B) + Ainv = pinv(A) + + Ader_p = matrix_As_der_p_static_data[size] + Bder_p = matrix_Bs_der_p_static_data[size] + + Y_parts = [] + Yder_p_parts = [] + + mdv_len_for_X = size + 1 # Product EMUs of this size have this MDV length + + for source_item in source_EMU_ids_or_tuples: + if isinstance(source_item, str): + source_emu_id = source_item + mdv = sim_MDVs_dict.get(source_emu_id, substrate_MDVs_jax.get(source_emu_id)) + if mdv is None: raise ValueError(f"MDV not found for {source_emu_id} in size {size} (derivative calc)") + + mdv_der_p = sim_MDVs_der_dict.get(source_emu_id, substrate_MDVs_der_p_jax.get(source_emu_id)) + if mdv_der_p is None: + mdv_der_p = jnp.zeros((num_free_fluxes, len(mdv))) + + Y_parts.append(mdv) + Yder_p_parts.append(mdv_der_p) + else: + mdv_val_list_for_conv = [] + mdv_der_pair_list_for_diff_conv = [] + + for emu_id_in_tuple in source_item: + mdv_val = sim_MDVs_dict.get(emu_id_in_tuple, substrate_MDVs_jax.get(emu_id_in_tuple)) + if mdv_val is None: raise ValueError(f"MDV not found for {emu_id_in_tuple} in convolution for size {size} (derivative calc)") + + mdv_der_val = sim_MDVs_der_dict.get(emu_id_in_tuple, substrate_MDVs_der_p_jax.get(emu_id_in_tuple)) + if mdv_der_val is None: + mdv_der_val = jnp.zeros((num_free_fluxes, len(mdv_val))) + + mdv_val_list_for_conv.append(mdv_val) + mdv_der_pair_list_for_diff_conv.append([mdv_val, mdv_der_val]) + + if mdv_val_list_for_conv: + current_conv_mdv = mdv_val_list_for_conv[0] + for i in range(1, len(mdv_val_list_for_conv)): + current_conv_mdv = jax_conv(current_conv_mdv, mdv_val_list_for_conv[i]) + Y_parts.append(current_conv_mdv) + + current_conv_mdv_and_der_pair = mdv_der_pair_list_for_diff_conv[0] + for i in range(1, len(mdv_der_pair_list_for_diff_conv)): + next_pair = mdv_der_pair_list_for_diff_conv[i] + current_conv_mdv_and_der_pair = jax_diff_conv(current_conv_mdv_and_der_pair, next_pair, num_params_for_zeros=num_free_fluxes) + Yder_p_parts.append(current_conv_mdv_and_der_pair[1]) + else: + raise ValueError(f"Empty convolution list for size {size} (derivative calc)") + + num_source_terms = B.shape[1] + # Assumption: All MDVs in Y_parts (after convolution) correspond to the length expected by B for this size iteration. + # This length is `size+1`. + expected_Y_row_len = size + 1 + + if Y_parts: + Y = jnp.array(Y_parts) + if Y.ndim == 1 and num_source_terms == 1 : Y = Y.reshape(1,-1) # Ensure Y is 2D + + Yder_p_temp = jnp.array(Yder_p_parts) # (n_source_terms, n_params, current_mdv_len) + if Yder_p_temp.ndim == 2 and num_source_terms == 1: Yder_p_temp = Yder_p_temp.reshape(1, num_free_fluxes, -1) # Ensure 3D + + Yder_p = jnp.transpose(Yder_p_temp, (1,0,2)) # (n_params, n_source_terms, current_mdv_len) + + if Y.shape[0] != num_source_terms: + raise ValueError(f"Y shape {Y.shape} mismatch B columns {num_source_terms} for size {size}") + if Y.shape[1] != expected_Y_row_len: + raise ValueError(f"Y row length {Y.shape[1]} does not match expected {expected_Y_row_len} for size {size}") + if Yder_p.shape[1] != num_source_terms or Yder_p.shape[2] != expected_Y_row_len: + raise ValueError(f"Yder_p shape {Yder_p.shape} mismatch B columns {num_source_terms} or MDV length {expected_Y_row_len} for size {size}") + + elif num_source_terms > 0: + raise ValueError(f"Y_parts is empty but B expects {num_source_terms} inputs for size {size} (derivative calc)") + else: # num_source_terms == 0 + Y = jnp.zeros((0, expected_Y_row_len)) + Yder_p = jnp.zeros((num_free_fluxes, 0, expected_Y_row_len)) + + + X = Ainv @ B @ Y # X shape: (n_prod_emus, mdv_len_for_X) + + term1 = jax.vmap(lambda Bder_slice: Bder_slice @ Y)(Bder_p) + term2 = jax.vmap(lambda Yder_slice: B @ Yder_slice)(Yder_p) + term3 = jax.vmap(lambda Ader_slice: Ader_slice @ X)(Ader_p) + sum_terms = term1 + term2 - term3 + Xder_p = jax.vmap(lambda sum_slice: Ainv @ sum_slice)(sum_terms) # Xder_p shape: (n_params, n_prod_emus, mdv_len_for_X) + + for i, emu_id in enumerate(product_EMU_ids): + sim_MDVs_dict[emu_id] = X[i, :] + sim_MDVs_der_dict[emu_id] = Xder_p[:, i, :] # Store as (n_params, mdv_len_for_X) + + return sim_MDVs_dict, sim_MDVs_der_dict diff --git a/src/freeflux/utils/utils.py b/src/freeflux/utils/utils.py index 2128e76..7e3fe3d 100644 --- a/src/freeflux/utils/utils.py +++ b/src/freeflux/utils/utils.py @@ -18,14 +18,20 @@ from sympy import symbols, lambdify, Matrix, derive_by_array try: import jax.numpy as jnp - from jax import config, jacfwd + from jax import config, jacfwd, jit config.update('jax_platform_name', 'cpu') + # It's good practice to import specific jax functions like jit if used directly in this file. except ModuleNotFoundError: JAX_INSTALLED = False + # Define jnp as np if JAX is not installed, for type hinting or conditional code. + # However, code using JAX features should be guarded by JAX_INSTALLED. + # For lambdify, sympy needs to know about jax.numpy. + # We can pass {'jnp': jax.numpy} to lambdify's modules argument if needed. else: JAX_INSTALLED = True from multiprocessing import Pool -from ..core.mdv import MDV, get_natural_MDV, get_substrate_MDV, conv, diff_conv +from ..core.mdv import MDV, get_natural_MDV, get_substrate_MDV, conv, diff_conv # These are numpy based +# from ..core.emu import EMU # If EMU objects are directly used as keys and need properties import warnings warnings.filterwarnings('ignore', category = RuntimeWarning) warnings.filterwarnings('ignore', category = DeprecationWarning) @@ -219,15 +225,21 @@ def _calculate_measured_fluxes_derivative_p(self, kind): measured_fluxes_der_v = self._calculate_measured_fluxes_derivative_v() measured_fluxes_der_u = measured_fluxes_der_v@self.model.null_space self.model.measured_fluxes_der_p = measured_fluxes_der_u + if JAX_INSTALLED and hasattr(self.model, 'jax_flux_derivatives_enabled') and self.model.jax_flux_derivatives_enabled: # Control this via Fitter + self.model.measured_fluxes_der_p_jax = jnp.array(measured_fluxes_der_u) + elif kind == 'inst': measured_fluxes_der_v = self._calculate_measured_fluxes_derivative_v() measured_fluxes_der_u = measured_fluxes_der_v@self.model.null_space measured_fluxes_der_c = self._calculate_measured_fluxes_derivative_c() - self.model.measured_fluxes_der_p = np.concatenate( + inst_der_p = np.concatenate( (measured_fluxes_der_u, measured_fluxes_der_c), axis = 1 ) + self.model.measured_fluxes_der_p = inst_der_p + if JAX_INSTALLED and hasattr(self.model, 'jax_flux_derivatives_enabled') and self.model.jax_flux_derivatives_enabled: + self.model.measured_fluxes_der_p_jax = jnp.array(inst_der_p) def _generate_random_fluxes(self): @@ -491,6 +503,204 @@ def _lambdify_matrix_As_and_Bs(self): self.model.matrix_As[size] = [lambA, fluxidsA, A.columns.tolist()] self.model.matrix_Bs[size] = [lambB, fluxidsB, B.columns.tolist()] + + def _lambdify_matrix_As_and_Bs_jax(self): + """ + Lambdifies SymPy expressions for matrices A and B using JAX backend. + Stores them in self.model.matrix_As_jax_static_data and self.model.matrix_Bs_jax_static_data. + Also creates helper mappings like totalfluxids_map_jax and EAMs_jax_sorted_keys. + """ + if not JAX_INSTALLED: + # This method should only be called if JAX is intended to be used. + # However, as a safeguard or if called unconditionally: + warnings.warn("JAX not installed. Cannot lambdify matrices for JAX.") + self.model.matrix_As_jax_static_data = {} + self.model.matrix_Bs_jax_static_data = {} + self.model.totalfluxids_map_jax = {} + self.model.EAMs_jax_sorted_keys = tuple() + return + + self.model.matrix_As_jax_static_data = {} + self.model.matrix_Bs_jax_static_data = {} + + # Create a mapping from total flux ID (string) to integer index + self.model.totalfluxids_map_jax = { + fid: i for i, fid in enumerate(self.model.totalfluxids) + } + # Ensure EAMs are sorted by size for consistent processing order + # self.model.EAMs is {size: EAM_DataFrame} + if not self.model.EAMs: # Should not happen if prepare() is called correctly + self.model.EAMs_jax_sorted_keys = tuple() + return # Nothing to lambdify + + self.model.EAMs_jax_sorted_keys = tuple(sorted(self.model.EAMs.keys())) + + for size_key in self.model.EAMs_jax_sorted_keys: + EAM_df = self.model.EAMs[size_key] # EAM_df is the pandas DataFrame + + # EMU objects are used as index/columns in EAM_df + # preAB matrix construction from EAM_df (symbolic expressions) + preAB = EAM_df.copy(deep='all') + for emu_obj_col in EAM_df.columns: # These are product EMUs (EMU objects) + # Sum fluxes for diagonal elements + preAB.loc[emu_obj_col, emu_obj_col] = -preAB[emu_obj_col].sum() + + product_emu_objs = EAM_df.columns.tolist() + source_emu_objs_or_tuples = EAM_df.index.difference(EAM_df.columns).tolist() + + # Ensure correct DataFrame indexing for Sympy Matrix conversion + # A is (product_EMUs x product_EMUs), B is (product_EMUs x source_EMUs_terms) + # The original code: A = preAB.loc[preAB.columns, :].T + # This means A's rows/cols are product_EMUs. + # B = -preAB.loc[preAB.index.difference(preAB.columns), :].T + # This means B's rows are product_EMUs, cols are source_EMU_terms. + + A_sympy_df = preAB.loc[product_emu_objs, product_emu_objs] # Square matrix part for A + B_sympy_df = -preAB.loc[source_emu_objs_or_tuples, product_emu_objs].T # B (products x sources_terms) + # Transpose to match convention A*X = B*Y + # where X and Y are column vectors of MDVs. + # Original code implies B is (n_prod, n_source_terms) + + matA_sympy = Matrix(A_sympy_df.values) # Pass .values to avoid dtype issues with sympy + matB_sympy = Matrix(B_sympy_df.values) + + flux_symbols_A = sorted(list(matA_sympy.free_symbols), key=lambda s: s.name) + flux_symbols_B = sorted(list(matB_sympy.free_symbols), key=lambda s: s.name) + + fluxidsA_str = [s.name for s in flux_symbols_A] + fluxidsB_str = [s.name for s in flux_symbols_B] + + flux_indices_A_int = tuple(self.model.totalfluxids_map_jax[fid] for fid in fluxidsA_str) + flux_indices_B_int = tuple(self.model.totalfluxids_map_jax[fid] for fid in fluxidsB_str) + + lambA_jax = lambdify(flux_symbols_A, matA_sympy, modules=['jax']) + lambB_jax = lambdify(flux_symbols_B, matB_sympy, modules=['jax']) + + product_emu_ids_str = tuple(emu.id for emu in product_emu_objs) + + source_emu_ids_or_tuples_str_list = [] + for item in source_emu_objs_or_tuples: + if isinstance(item, tuple): + source_emu_ids_or_tuples_str_list.append(tuple(e.id for e in item)) + else: + source_emu_ids_or_tuples_str_list.append(item.id) + source_emu_ids_or_tuples_str = tuple(source_emu_ids_or_tuples_str_list) + + self.model.matrix_As_jax_static_data[size_key] = { + 'func': lambA_jax, + 'flux_indices': flux_indices_A_int, + 'product_emu_ids': product_emu_ids_str + } + self.model.matrix_Bs_jax_static_data[size_key] = { + 'func': lambB_jax, + 'flux_indices': flux_indices_B_int, + 'source_emu_ids_or_tuples': source_emu_ids_or_tuples_str + } + + def _prepare_substrate_MDVs_jax(self): + if not JAX_INSTALLED: return + self.model.substrate_MDVs_jax_static_data = { + # Ensure substrate_MDVs keys (EMU objects) are converted to string IDs + # and values (MDV objects or arrays) are converted to JAX arrays. + emu.id if hasattr(emu, 'id') else str(emu): jnp.array(mdv_array.value if hasattr(mdv_array, 'value') and isinstance(mdv_array, MDV) else mdv_array) + for emu, mdv_array in self.model.substrate_MDVs.items() + } + + def _prepare_substrate_MDVs_der_p_jax(self): + if not JAX_INSTALLED: return + self.model.substrate_MDVs_der_p_jax_static_data = {} + if self.model.substrate_MDVs_der_p: + for emu, np_der_array in self.model.substrate_MDVs_der_p.items(): + # Original shape (emu.size+1, nvars). Transpose to (nvars, emu.size+1) + key_id = emu.id if hasattr(emu, 'id') else str(emu) + self.model.substrate_MDVs_der_p_jax_static_data[key_id] = jnp.array(np_der_array.T) + else: + pass # Empty dict is fine. Handled by core_calculate_mdvs_and_derivatives_jax + + def _prepare_matrix_derivatives_jax(self): + """ Converts pre-calculated NumPy matrix derivatives to JAX arrays. """ + if not JAX_INSTALLED: return + self.model.matrix_As_der_p_jax_static_data = { + size: jnp.array(deriv_array) + for size, deriv_array in self.model.matrix_As_der_p.items() + } + self.model.matrix_Bs_der_p_jax_static_data = { + size: jnp.array(deriv_array) + for size, deriv_array in self.model.matrix_Bs_der_p.items() + } + + def _lambdify_matrix_Ms_jax(self): + """Lambdifies SymPy expressions for matrix M using JAX backend.""" + if not JAX_INSTALLED: + self.model.matrix_Ms_jax_static_data = {} + return + + self.model.matrix_Ms_jax_static_data = {} + + conc_ids_ordered = self.model.concids if hasattr(self.model, 'concids') else [] + conc_symbols_ordered = symbols(conc_ids_ordered) + concid_to_symbol_map = {s.name: s for s in conc_symbols_ordered} + + for size, EAM_df in self.model.EAMs.items(): + product_emu_objs = EAM_df.columns.tolist() + + diag_elements_for_M = [] + active_conc_symbols_for_this_M = [] + + for emu_obj in product_emu_objs: + metab_id = emu_obj.metabolite_id + if metab_id in concid_to_symbol_map: + diag_elements_for_M.append(concid_to_symbol_map[metab_id]) + if concid_to_symbol_map[metab_id] not in active_conc_symbols_for_this_M: + active_conc_symbols_for_this_M.append(concid_to_symbol_map[metab_id]) + else: + diag_elements_for_M.append(1.0) + + if not diag_elements_for_M: + matM_sympy = Matrix([]) + else: + matM_sympy = Matrix(np.diag(diag_elements_for_M)) # np.diag can handle sympy symbols if they are in a list + + active_conc_symbols_for_this_M.sort(key=lambda s: conc_ids_ordered.index(s.name)) + + lambM_jax = lambdify(active_conc_symbols_for_this_M, matM_sympy, modules=['jax']) + + self.model.matrix_Ms_jax_static_data[size] = { + 'func': lambM_jax, + 'conc_arg_indices': tuple(conc_ids_ordered.index(s.name) for s in active_conc_symbols_for_this_M) + } + + def _prepare_matrix_Ms_derivatives_p_jax(self): + """ Converts pre-calculated NumPy matrix M derivatives to JAX arrays. """ + if not JAX_INSTALLED: + self.model.matrix_Ms_der_p_jax_static_data = {} + return + self.model.matrix_Ms_der_p_jax_static_data = { + size: jnp.array(deriv_array) + for size, deriv_array in self.model.matrix_Ms_der_p.items() # Assuming this exists from NumPy path + } + + def _prepare_initial_conditions_jax(self): + """ Prepares JAX versions of initial X, Y matrices and their derivatives. """ + if not JAX_INSTALLED: + self.model.initial_matrix_Xs_jax = {} + self.model.initial_matrix_Ys_jax = {} + self.model.initial_matrix_Xs_der_p_jax = {} + self.model.initial_matrix_Ys_der_p_jax = {} + return + + self.model.initial_matrix_Xs_jax = { + size: jnp.array(matrix_val) for size, matrix_val in self.model.initial_matrix_Xs.items() + } + self.model.initial_matrix_Ys_jax = { + size: jnp.array(matrix_val) for size, matrix_val in self.model.initial_matrix_Ys.items() + } + self.model.initial_matrix_Xs_der_p_jax = { # Derivatives are (n_params, n_rows, n_cols) + size: jnp.array(deriv_val) for size, deriv_val in self.model.initial_matrix_Xs_der_p.items() + } + self.model.initial_matrix_Ys_der_p_jax = { + size: jnp.array(deriv_val) for size, deriv_val in self.model.initial_matrix_Ys_der_p.items() + } def _calculate_matrix_Ms_derivatives_u(self): @@ -511,15 +721,22 @@ def _calculate_matrix_Ms_derivatives_c(self): matM = Matrix(np.diag(symbols([emu.metabolite_id for emu in EAM.columns]))) if JAX_INSTALLED: - lambM = lambdify(symbols(self.model.concids), matM, modules = 'jax') + # Ensure concids are available on model for JAX lambdify + concids_for_lambdify = self.model.concids if hasattr(self.model, 'concids') else [] + lambM = lambdify(symbols(concids_for_lambdify), matM, modules = 'jax') + # The jacfwd call needs arguments matching concids_for_lambdify + # This part needs careful alignment of symbols and arguments if concids_for_lambdify is dynamic + # For simplicity, assume self.model.concids is the fixed list of all possible concentration variables + if concids_for_lambdify: # Only compute if there are concentration variables + matrix_M_der = np.array( + jacfwd(lambM, range(len(concids_for_lambdify)))(*jnp.ones(len(concids_for_lambdify))) + ) + else: # No concentration variables, derivative is zero or not applicable + nEMUsout = EAM.shape[1] + matrix_M_der = np.zeros((0, nEMUsout, nEMUsout)) # No params -> first dim is 0 + else: # Fallback if JAX not installed (original logic) matrix_M_der = np.array( - jacfwd(lambM, - range(len(self.model.concids)) - )(*jnp.ones(len(self.model.concids))) - ) - else: - matrix_M_der = np.array( - derive_by_array(matM, symbols(self.model.concids)), + derive_by_array(matM, symbols(self.model.concids if hasattr(self.model, 'concids') else [])), dtype = float ) matrix_Ms_der[size] = matrix_M_der @@ -542,68 +759,87 @@ def _calculate_matrix_Ms_derivatives_p(self): def _lambdify_matrix_Ms(self): for size, EAM in self.model.EAMs.items(): - matM = Matrix(np.diag(symbols([emu.metabolite_id for emu in EAM.columns]))) - metabids = list(map(str, matM.free_symbols)) - lambM = lambdify(metabids, matM, modules = 'numpy') - self.model.matrix_Ms[size] = [lambM, metabids] + # Original logic for numpy lambdify + # Product EMUs are columns of EAM dataframe + diag_symbols = [symbols(emu.metabolite_id) for emu in EAM.columns] + matM = Matrix(np.diag(diag_symbols)) + + # Arguments for this lambdified function are ordered by free_symbols + metabids_arg_order = [s.name for s in sorted(list(matM.free_symbols), key=lambda s:s.name)] + lambM = lambdify(metabids_arg_order, matM, modules = 'numpy') + self.model.matrix_Ms[size] = [lambM, metabids_arg_order] # initial X(Y) and their derivatives def _calculate_initial_matrix_Xs(self): - for size in self.model.matrix_As: - nEMUs = len(self.model.matrix_As[size][2]) + for size in self.model.matrix_As: # matrix_As is {size: [lambA, fluxidsA, productEMU_obj_list]} + nEMUs = len(self.model.matrix_As[size][2]) # productEMU_obj_list iniX = np.vstack([get_natural_MDV(size).value] * nEMUs) self.model.initial_matrix_Xs[size] = iniX def _calculate_initial_matrix_Ys(self): - for size in self.model.matrix_Bs: + for size in self.model.matrix_Bs: # matrix_Bs is {size: [lambB, fluxidsB, sourceEMU_obj_or_tuple_list]} iniY = [] - for sourceEMU in self.model.matrix_Bs[size][2]: - if not isinstance(sourceEMU, Iterable): - sourceMDV = self.model.substrate_MDVs[sourceEMU] - else: + for sourceEMU_item in self.model.matrix_Bs[size][2]: # sourceEMU_obj_or_tuple_list + if not isinstance(sourceEMU_item, Iterable): # Single EMU object + sourceMDV = self.model.substrate_MDVs[sourceEMU_item] + else: # Tuple of EMU objects mdvs = [] - for emu in sourceEMU: + for emu in sourceEMU_item: if emu not in self.model.substrate_MDVs: mdv = get_natural_MDV(emu.size) else: mdv = self.model.substrate_MDVs[emu] mdvs.append(mdv) sourceMDV = reduce(conv, mdvs) - iniY.append(sourceMDV) - iniY = np.array(iniY) - self.model.initial_matrix_Ys[size] = iniY + iniY.append(sourceMDV.value if isinstance(sourceMDV, MDV) else sourceMDV) # Ensure array + self.model.initial_matrix_Ys[size] = np.array(iniY) if iniY else np.empty((0,size+1)) # ensure 2D for empty def _calculate_initial_matrix_Xs_derivatives_p(self): - nvars = self.model.null_space.shape[1] + len(self.model.concids) + # nvars depends on whether it's 'ss' or 'inst' mode (fluxes only, or fluxes+concentrations) + # This needs to be determined based on the context (e.g., kind passed to parent) + # For now, assume self.model.null_space and self.model.concids are set correctly for current mode. + n_flux_params = self.model.null_space.shape[1] + n_conc_params = len(self.model.concids if hasattr(self.model, 'concids') else []) + # This logic might need adjustment based on 'kind' if called from a generic context + nvars = n_flux_params + n_conc_params + if not hasattr(self.model, 'concids'): # If steady state, no conc_params in p for derivatives + nvars = n_flux_params + + for size, iniX in self.model.initial_matrix_Xs.items(): Xshape = iniX.shape - iniXder = np.zeros((nvars, *Xshape)) + iniXder = np.zeros((nvars, *Xshape)) # (n_params, n_EMUs, n_coeffs) self.model.initial_matrix_Xs_der_p[size] = iniXder def _calculate_initial_matrix_Ys_derivatives_p(self): - - nvars = self.model.null_space.shape[1] + len(self.model.concids) + n_flux_params = self.model.null_space.shape[1] + n_conc_params = len(self.model.concids if hasattr(self.model, 'concids') else []) + nvars = n_flux_params + n_conc_params + if not hasattr(self.model, 'concids'): + nvars = n_flux_params + for size, iniY in self.model.initial_matrix_Ys.items(): Yshape = iniY.shape - iniYder = np.zeros((nvars, *Yshape)) + iniYder = np.zeros((nvars, *Yshape)) # (n_params, n_source_terms, n_coeffs) self.model.initial_matrix_Ys_der_p[size] = iniYder def _build_initial_sim_MDVs(self): for size in sorted(self.model.matrix_As): - productEMUs = self.model.matrix_As[size][2] + productEMUs_objs = self.model.matrix_As[size][2] # list of EMU objects iniX = self.model.initial_matrix_Xs[size] - for productEMU, iniMDV in zip(productEMUs, iniX): - if productEMU.id in self.model.target_EMUs: - self.model.initial_sim_MDVs[productEMU.id] = {0: MDV(iniMDV)} + for productEMU_obj, iniMDV_arr in zip(productEMUs_objs, iniX): + if productEMU_obj.id in self.model.target_EMUs: + # Store as { emu_id: {0: MDV_object} } + self.model.initial_sim_MDVs[productEMU_obj.id] = {0: MDV(iniMDV_arr)} # Wrap array in MDV object def _calculate_MDVs(self): @@ -616,33 +852,50 @@ def _calculate_MDVs(self): EMU ID => MDV (in array). ''' - simMDVs = {} + simMDVs = {} # Stores emu_object -> mdv_array (numpy) for size in sorted(self.model.matrix_As): - lambA, fluxidsA, productEMUs = self.model.matrix_As[size] - lambB, fluxidsB, sourceEMUs = self.model.matrix_Bs[size] + lambA, fluxidsA, productEMUs_objs = self.model.matrix_As[size] + lambB, fluxidsB, sourceEMUs_items = self.model.matrix_Bs[size] A = lambA(*self.model.total_fluxes[fluxidsA]) B = lambB(*self.model.total_fluxes[fluxidsB]) - Y = [] - for sourceEMU in sourceEMUs: - if not isinstance(sourceEMU, Iterable): - sourceMDV = self.model.substrate_MDVs[sourceEMU] - else: - mdvs = [ChainMap(simMDVs, self.model.substrate_MDVs)[emu] - for emu in sourceEMU] - sourceMDV = reduce(conv, mdvs) - Y.append(sourceMDV) - Y = np.array(Y) + Y_list_of_arrays = [] + for sourceEMU_item in sourceEMUs_items: # item is EMU obj or tuple of EMU objs + if not isinstance(sourceEMU_item, Iterable): # single EMU object + # Substrate_MDVs stores emu_obj -> MDV_obj or np.array + sourceMDV_val = self.model.substrate_MDVs[sourceEMU_item] + if isinstance(sourceMDV_val, MDV): sourceMDV_val = sourceMDV_val.value + else: # tuple of EMU objects + mdv_obj_list = [] + for emu_obj_in_tuple in sourceEMU_item: + # ChainMap search: simMDVs (emu_obj -> np.array), then substrate_MDVs (emu_obj -> MDV_obj/np.array) + if emu_obj_in_tuple in simMDVs: + mdv_val = simMDVs[emu_obj_in_tuple] # This is already an array + else: + mdv_val = self.model.substrate_MDVs[emu_obj_in_tuple] + if isinstance(mdv_val, MDV): mdv_val = mdv_val.value + mdv_obj_list.append(MDV(mdv_val)) # Wrap in MDV for convolution via reduce + + # reduce(conv, list_of_MDV_objs) -> MDV_obj + sourceMDV_obj = reduce(conv, mdv_obj_list) + sourceMDV_val = sourceMDV_obj.value + Y_list_of_arrays.append(sourceMDV_val) - X = pinv(A, check_finite = True)@B@Y + Y_matrix = np.array(Y_list_of_arrays) if Y_list_of_arrays else np.empty((B.shape[1], size+1)) + if Y_matrix.ndim == 1 and B.shape[1] == 1: Y_matrix = Y_matrix.reshape(1,-1) + + + X_matrix = pinv(A, check_finite=False) @ B @ Y_matrix # check_finite=False for speed - simMDVs.update(zip(productEMUs, X)) + for emu_obj, mdv_array in zip(productEMUs_objs, X_matrix): + simMDVs[emu_obj] = mdv_array # Store emu_obj -> np.array - simMDVs = {emu.id: mdv for emu, mdv in simMDVs.items()} + # Convert final result to emu_id_str -> np.array + simMDVs_by_id = {emu_obj.id: mdv_arr for emu_obj, mdv_arr in simMDVs.items()} - return simMDVs + return simMDVs_by_id def _calculate_MDVs_and_derivatives_p(self): @@ -652,58 +905,129 @@ def _calculate_MDVs_and_derivatives_p(self): Returns ------- simMDVs: dict - EMU ID => MDV (in array). + EMU ID (str) => MDV_array (numpy). simMDVsDer: dict - EMU ID => 2-D array in shape of (len(MDV), len(u)). + EMU ID (str) => derivative_array (numpy), shape (n_params, n_coeffs). ''' - simMDVs = {} - simMDVsDer = {} - for size in sorted(self.model.matrix_As): + # Internal working dicts use EMU objects as keys for ChainMap compatibility + # simMDVs_work: emu_obj -> mdv_array (numpy) + # simMDVsDer_work: emu_obj -> derivative_array (numpy), shape (n_coeffs, n_params) + simMDVs_work = {} + simMDVsDer_work = {} + + for size in sorted(self.model.matrix_As): # Iterate by increasing EMU size - lambA, fluxidsA, productEMUs = self.model.matrix_As[size] - lambB, fluxidsB, sourceEMUs = self.model.matrix_Bs[size] + lambA, fluxidsA, productEMUs_objs = self.model.matrix_As[size] + lambB, fluxidsB, sourceEMUs_items = self.model.matrix_Bs[size] - A = lambA(*self.model.total_fluxes[fluxidsA]) - B = lambB(*self.model.total_fluxes[fluxidsB]) + A_val = lambA(*self.model.total_fluxes[fluxidsA]) + B_val = lambB(*self.model.total_fluxes[fluxidsB]) - Ainv = pinv(A, check_finite = True) + Ainv_val = pinv(A_val, check_finite=False) - Ader = self.model.matrix_As_der_p[size] - Bder = self.model.matrix_Bs_der_p[size] + # Derivatives of matrices A and B w.r.t parameters p (free fluxes for steady state) + # Shape: (n_params, n_rows, n_cols) + Ader_p_val = self.model.matrix_As_der_p[size] + Bder_p_val = self.model.matrix_Bs_der_p[size] - Y = [] - Yder = [] - for sourceEMU in sourceEMUs: - if not isinstance(sourceEMU, Iterable): - sourceMDV = self.model.substrate_MDVs[sourceEMU] - sourceMDVder = self.model.substrate_MDVs_der_p[sourceEMU] - else: - mdvs = [] - mdvs_mdvders = [] - for emu in sourceEMU: - mdv = ChainMap(simMDVs, self.model.substrate_MDVs)[emu] - mdvs.append(mdv) - mdvder = ChainMap(simMDVsDer, self.model.substrate_MDVs_der_p)[emu] - mdvs_mdvders.append([mdv, mdvder]) - sourceMDV = reduce(conv, mdvs) - sourceMDVder = reduce(diff_conv, mdvs_mdvders)[1] - Y.append(sourceMDV) - Yder.append(sourceMDVder) - Y = np.array(Y) - Yder = np.array(Yder).swapaxes(1,2).swapaxes(0,1) + # Y_list will store MDV arrays (1D) + # Yder_p_list will store MDV derivative arrays (2D, shape: n_coeffs, n_params) + Y_list_of_arrays = [] + Yder_p_list_of_arrays = [] + + for sourceEMU_item in sourceEMUs_items: # item is EMU obj or tuple of EMU objs + if not isinstance(sourceEMU_item, Iterable): # single EMU object + # MDV value + mdv_obj = self.model.substrate_MDVs[sourceEMU_item] # MDV_obj or array + current_mdv_val = mdv_obj.value if isinstance(mdv_obj, MDV) else mdv_obj + # MDV derivative (n_coeffs, n_params) + current_mdv_der_val = self.model.substrate_MDVs_der_p[sourceEMU_item] + else: # tuple of EMU objects, requires convolution + mdv_val_list_for_conv = [] + mdv_der_pair_list_for_diff_conv = [] # list of [mdv_val_arr, mdv_der_arr (n_coeffs, n_params)] + + for emu_obj_in_tuple in sourceEMU_item: + # Get MDV value (array) + if emu_obj_in_tuple in simMDVs_work: + mdv_val_arr = simMDVs_work[emu_obj_in_tuple] + else: + mdv_obj_s = self.model.substrate_MDVs[emu_obj_in_tuple] + mdv_val_arr = mdv_obj_s.value if isinstance(mdv_obj_s, MDV) else mdv_obj_s + mdv_val_list_for_conv.append(MDV(mdv_val_arr)) # Wrap for MDV.conv via reduce + + # Get MDV derivative (array, n_coeffs, n_params) + if emu_obj_in_tuple in simMDVsDer_work: + mdv_der_arr = simMDVsDer_work[emu_obj_in_tuple] + else: + mdv_der_arr = self.model.substrate_MDVs_der_p[emu_obj_in_tuple] + mdv_der_pair_list_for_diff_conv.append([mdv_val_arr, mdv_der_arr]) + + # Perform convolution for value and derivative + convolved_mdv_obj = reduce(conv, mdv_val_list_for_conv) # reduce with MDV objects + current_mdv_val = convolved_mdv_obj.value + + # diff_conv: input list of [arr, arr_der (coeffs,params)], output [arr_conv, arr_conv_der (coeffs,params)] + convolved_mdv_der_pair = reduce(diff_conv, mdv_der_pair_list_for_diff_conv) + current_mdv_der_val = convolved_mdv_der_pair[1] + + Y_list_of_arrays.append(current_mdv_val) + Yder_p_list_of_arrays.append(current_mdv_der_val) - X = Ainv@B@Y - Xder = Ainv@(Bder@Y + B@Yder - Ader@X) - Xder = Xder.swapaxes(0,1).swapaxes(1,2) + Y_matrix = np.array(Y_list_of_arrays) if Y_list_of_arrays else np.empty((B_val.shape[1], size+1)) + if Y_matrix.ndim == 1 and B_val.shape[1] == 1: Y_matrix = Y_matrix.reshape(1,-1) - simMDVs.update(zip(productEMUs, X)) - simMDVsDer.update(zip(productEMUs, Xder)) + # Yder_p_tensor shape: (n_source_terms, n_coeffs_Y, n_params) + Yder_p_tensor = np.array(Yder_p_list_of_arrays) if Yder_p_list_of_arrays else np.empty((B_val.shape[1], size+1, Ader_p_val.shape[0])) + if Yder_p_tensor.ndim == 2 and B_val.shape[1] == 1 : Yder_p_tensor = Yder_p_tensor.reshape(1, Yder_p_tensor.shape[0], Yder_p_tensor.shape[1]) + + + # Transpose Yder_p_tensor for broadcasting: (n_params, n_source_terms, n_coeffs_Y) + Yder_p_tensor_transposed = Yder_p_tensor.transpose(2,0,1) - simMDVs = {emu.id: mdv for emu, mdv in simMDVs.items()} - simMDVsDer = {emu.id: mdvDer for emu, mdvDer in simMDVsDer.items()} - - return simMDVs, simMDVsDer + # Calculate X = Ainv @ B @ Y + X_matrix = Ainv_val @ B_val @ Y_matrix # X shape: (n_prod_EMUs, n_coeffs_X) + + # Calculate X_der_p = Ainv @ (Bder_p@Y + B@Yder_p - Ader_p@X) + # Bder_p@Y: (n_params, n_prod, n_source) @ (n_source, n_coeffs_Y) -> (n_params, n_prod, n_coeffs_Y) + term1 = Ader_p_val @ X_matrix # (n_params, n_prod, n_prod) @ (n_prod, n_coeffs_X) -> (n_params, n_prod, n_coeffs_X) + + # B@Yder_p: (n_prod, n_source) @ (n_params, n_source, n_coeffs_Y) - needs broadcasting/looping for params + # Each slice B @ Yder_p[param_idx,:,:] + # Result should be (n_params, n_prod, n_coeffs_Y) + term2 = np.einsum('ik,pkm->pim', B_val, Yder_p_tensor_transposed) # (n_prod, n_source) @ (n_params, n_source, n_coeffs) -> (n_params, n_prod, n_coeffs) + + # Bder_p@Y: (n_params, n_prod, n_source) @ (n_source, n_coeffs_Y) -> (n_params, n_prod, n_coeffs_Y) + term3 = Bder_p_val @ Y_matrix + + sum_terms = term3 + term2 - term1 # All terms (n_params, n_prod_EMUs, n_coeffs_X) + + # Ainv @ sum_terms: (n_prod, n_prod) @ (n_params, n_prod, n_coeffs_X) -> (n_params, n_prod, n_coeffs_X) + Xder_p_matrix_transposed = np.einsum('ik,pkm->pim', Ainv_val, sum_terms) + + # Transpose back to (n_prod_EMUs, n_coeffs_X, n_params) for storage if needed, or keep as (n_params, n_prod, n_coeffs) + # The original code Xder.swapaxes(0,1).swapaxes(1,2) suggests final storage as (n_coeffs, n_params) per EMU + # Current Xder_p_matrix_transposed is (n_params, n_prod_EMUs, n_coeffs_X) + # So, for each EMU (row i of n_prod_EMUs), we have Xder_p_matrix_transposed[:, i, :] which is (n_params, n_coeffs_X) + # This needs to be transposed to (n_coeffs_X, n_params) for simMDVsDer_work[emu_obj] + + for idx, emu_obj in enumerate(productEMUs_objs): + simMDVs_work[emu_obj] = X_matrix[idx, :] + simMDVsDer_work[emu_obj] = Xder_p_matrix_transposed[:, idx, :].T # Transpose (n_params, n_coeffs) to (n_coeffs, n_params) + + # Convert final result to emu_id_str -> array + simMDVs_by_id = {emu_obj.id: mdv_arr for emu_obj, mdv_arr in simMDVs_work.items()} + # For simMDVsDer_by_id, values are (n_coeffs, n_params). Need to transpose to (n_params, n_coeffs) for nlpsolver. + # nlpsolver's dxsim_dp = np.vstack([simMDVsDer[emuid] for emuid in self.model.target_EMUs]) + # where simMDVsDer[emuid] is (n_coeffs, n_params). So vstack makes (total_coeffs, n_params). + # This is d(all_sim_mdvs)/d_params. + # My core_calculate_mdvs_and_derivatives_jax expects to return derivatives as (n_params, n_coeffs). + # So, the dictionary here should store (n_coeffs, n_params) to match original, + # and the JAX wrapper in nlpsolver will handle final formatting if needed. + # The current storage simMDVsDer_work[emu_obj] = Xder_p_matrix_transposed[:, idx, :].T is (n_coeffs, n_params) + simMDVsDer_by_id = {emu_obj.id: deriv_arr for emu_obj, deriv_arr in simMDVsDer_work.items()} + + return simMDVs_by_id, simMDVsDer_by_id def _calculate_inst_MDVs(self): @@ -716,73 +1040,115 @@ def _calculate_inst_MDVs(self): EMU ID => {t => MDV (in array)} (starting from t1). ''' - simInstMDVs = {} - Ys = {} - Xs = {} + simInstMDVs = {} # emu_obj -> {time: mdv_array} + Ys_t = {} # time -> {size: Y_matrix_for_that_size_and_time} + Xs_t = {} # time -> {size: X_matrix_for_that_size_and_time} - t1 = 0.0 - for size in sorted(self.model.matrix_As): - Y_t1 = self.model.initial_matrix_Ys[size] - Ys.setdefault(t1, {})[size] = Y_t1 - - X_t1 = self.model.initial_matrix_Xs[size] - Xs.setdefault(t1, {})[size] = X_t1 - - for t in self.model.timepoints: - if t != 0.0: - t0 = t1 - t1 = t - deltat = t1 - t0 - - for size in sorted(self.model.matrix_As): - lambA, fluxidsA, productEMUs = self.model.matrix_As[size] - lambB, fluxidsB, sourceEMUs = self.model.matrix_Bs[size] - lambM, metabids = self.model.matrix_Ms[size] - - A = lambA(*self.model.total_fluxes[fluxidsA]) - B = lambB(*self.model.total_fluxes[fluxidsB]) - M = lambM(*self.model.concentrations[metabids]) - Minv = pinv(M, check_finite = True) - - F = Minv@A - Finv = pinv(F, check_finite = True) - I = np.eye(*F.shape) - Phi = expm(F*deltat) - Gamma = (Phi - I)@Finv - Omega = (Gamma/deltat - I)@Finv - - X_t0 = Xs[t0][size] - - Y_t0 = Ys[t0][size] - G_t0 = Minv@B@Y_t0 - - Y_t1 = [] - for sourceEMU in sourceEMUs: - if not isinstance(sourceEMU, Iterable): - sourceMDV = self.model.substrate_MDVs[sourceEMU] - else: - mdvs = [] - for emu in sourceEMU: - mdv = ChainMap(simInstMDVs, self.model.substrate_MDVs)[emu] - if isinstance(mdv, dict): - mdv = mdv[t1] - mdvs.append(mdv) - sourceMDV = reduce(conv, mdvs) - Y_t1.append(sourceMDV) - Y_t1 = np.array(Y_t1) - G_t1 = Minv@B@Y_t1 - - X_t1 = Phi@X_t0 - Gamma@G_t0 - Omega@(G_t1 - G_t0) - - Ys.setdefault(t1, {})[size] = Y_t1 - Xs.setdefault(t1, {})[size] = X_t1 - - for productEMU, mdv_t1 in zip(productEMUs, X_t1): - simInstMDVs.setdefault(productEMU, {}).update({t1: mdv_t1}) - - simInstMDVs = {emu.id: mdvs for emu, mdvs in simInstMDVs.items()} + t1 = 0.0 # Represents initial time t=0 + # Populate Xs_t[0] and Ys_t[0] with initial conditions + for size_idx in sorted(self.model.matrix_As): # matrix_As keys are emu sizes + # Initial X values (usually natural abundance) + # self.model.initial_matrix_Xs is {size: np.array} + Xs_t.setdefault(t1, {})[size_idx] = self.model.initial_matrix_Xs[size_idx] + + # Initial Y values (from substrates or natural abundance for smaller EMUs) + # self.model.initial_matrix_Ys is {size: np.array} + Ys_t.setdefault(t1, {})[size_idx] = self.model.initial_matrix_Ys[size_idx] + + # Store initial MDVs for product EMUs that are targets + for size_idx in sorted(self.model.matrix_As): + productEMUs_objs_at_size = self.model.matrix_As[size_idx][2] # List of EMU objects + X_t0_at_size = Xs_t[t1][size_idx] # X matrix for this size at t=0 + for i, emu_obj in enumerate(productEMUs_objs_at_size): + if emu_obj.id in self.model.target_EMUs: # Target EMUs are by ID string + simInstMDVs.setdefault(emu_obj, {})[t1] = X_t0_at_size[i,:] + + + for t_current_loop in self.model.timepoints: # these are t > 0 + if t_current_loop == 0.0: continue # Skip t=0 as it's initial condition + + t0 = t1 # Previous time point (becomes current t1 for next iteration) + t_current = t_current_loop # Current time point from loop + deltat = t_current - t0 + + # Initialize dictionaries for current time t_current + Xs_t.setdefault(t_current, {}) + Ys_t.setdefault(t_current, {}) + + for size_idx in sorted(self.model.matrix_As): # Iterate by EMU size + lambA, fluxidsA, productEMUs_objs = self.model.matrix_As[size_idx] + lambB, fluxidsB, sourceEMUs_items = self.model.matrix_Bs[size_idx] + lambM, metabids_for_M_args = self.model.matrix_Ms[size_idx] # lambM takes specific concs as args + + A_val = lambA(*self.model.total_fluxes[fluxidsA]) + B_val = lambB(*self.model.total_fluxes[fluxidsB]) + + # Select concentrations for M matrix based on metabids_for_M_args + conc_args_for_M = [self.model.concentrations[mid] for mid in metabids_for_M_args] + M_val = lambM(*conc_args_for_M) + Minv_val = pinv(M_val, check_finite=False) + + F_val = Minv_val @ A_val + Finv_val = pinv(F_val, check_finite=False) + I_mtx = np.eye(*F_val.shape) + Phi_val = expm(F_val * deltat) # Matrix exponential + Gamma_val = (Phi_val - I_mtx) @ Finv_val + Omega_val = (Gamma_val / deltat - I_mtx) @ Finv_val + + X_t0_at_size = Xs_t[t0][size_idx] # X matrix for this size at previous time t0 + Y_t0_at_size = Ys_t[t0][size_idx] # Y matrix for this size at previous time t0 + G_t0_at_size = Minv_val @ B_val @ Y_t0_at_size + + # Calculate Y_current (Y matrix for current time t_current, for this size_idx) + # This involves convolutions using MDVs from simInstMDVs (which has emu_obj keys) + # simInstMDVs stores {emu_obj: {time: mdv_array}} + Y_list_for_t_current = [] + for sourceEMU_item in sourceEMUs_items: + if not isinstance(sourceEMU_item, Iterable): # single EMU object + # Substrate MDVs are constant over time in current model structure for them + mdv_obj_s = self.model.substrate_MDVs[sourceEMU_item] + sourceMDV_val = mdv_obj_s.value if isinstance(mdv_obj_s, MDV) else mdv_obj_s + else: # tuple of EMU objects for convolution + mdv_obj_list_for_conv = [] + for emu_obj_in_tuple in sourceEMU_item: + # Get MDV at current time t_current + # Look up in simInstMDVs (emu_obj -> {t -> arr}) + # or if not there (substrate), from self.model.substrate_MDVs (emu_obj -> MDV_obj/arr) + if emu_obj_in_tuple in simInstMDVs and t_current in simInstMDVs[emu_obj_in_tuple]: + mdv_val_arr = simInstMDVs[emu_obj_in_tuple][t_current] + else: # Must be a substrate or an EMU from a previous time step not yet in simInstMDVs for *this* t_current + # This implies recursive dependency on *current time* MDVs for smaller EMUs, + # which should have been computed earlier in the size_idx loop for this t_current. + # Or it's a base substrate. + mdv_obj_s = self.model.substrate_MDVs[emu_obj_in_tuple] # Fallback to substrate + mdv_val_arr = mdv_obj_s.value if isinstance(mdv_obj_s, MDV) else mdv_obj_s + mdv_obj_list_for_conv.append(MDV(mdv_val_arr)) + + sourceMDV_obj = reduce(conv, mdv_obj_list_for_conv) + sourceMDV_val = sourceMDV_obj.value + Y_list_for_t_current.append(sourceMDV_val) + + Y_t_current_at_size = np.array(Y_list_for_t_current) if Y_list_for_t_current else np.empty((B_val.shape[1], size_idx+1)) + if Y_t_current_at_size.ndim == 1 and B_val.shape[1] == 1: Y_t_current_at_size = Y_t_current_at_size.reshape(1,-1) + + Ys_t[t_current][size_idx] = Y_t_current_at_size + G_t_current_at_size = Minv_val @ B_val @ Y_t_current_at_size + + X_t_current_at_size = Phi_val @ X_t0_at_size - Gamma_val @ G_t0_at_size - Omega_val @ (G_t_current_at_size - G_t0_at_size) + Xs_t[t_current][size_idx] = X_t_current_at_size + + # Store results in simInstMDVs for product EMUs of this size at current time + for i, emu_obj in enumerate(productEMUs_objs): + simInstMDVs.setdefault(emu_obj, {})[t_current] = X_t_current_at_size[i,:] + + t1 = t_current # Update t1 for the next iteration of the time loop + + # Convert final result to emu_id_str -> {time: mdv_array} + simInstMDVs_by_id = { + emu_obj.id: time_dict for emu_obj, time_dict in simInstMDVs.items() + } - return simInstMDVs + return simInstMDVs_by_id def _calculate_inst_MDVs_and_derivatives_p(self): @@ -792,122 +1158,228 @@ def _calculate_inst_MDVs_and_derivatives_p(self): Returns ------- simInstMDVs: dict - EMU ID => {t => MDV (in array)} (starting from t1). + EMU ID (str) => {time (float) => MDV_array (numpy)}. simInstMDVsDer: dict - EMU ID => {t => 2-D array in shape of (len(MDV), len(u)+len(c))} (starting from t1). + EMU ID (str) => {time (float) => derivative_array (numpy, shape: n_params, n_coeffs)}. ''' - simInstMDVs = {} - simInstMDVsDer = {} - Ys = {} - Xs = {} - Yders = {} - Xders = {} - - t1 = 0.0 - for size in sorted(self.model.matrix_As): - Y_t1 = self.model.initial_matrix_Ys[size] - Ys.setdefault(t1, {})[size] = Y_t1 - - X_t1 = self.model.initial_matrix_Xs[size] - Xs.setdefault(t1, {})[size] = X_t1 - - Yder_t1 = self.model.initial_matrix_Ys_der_p[size] - Yders.setdefault(t1, {})[size] = Yder_t1 - - Xder_t1 = self.model.initial_matrix_Xs_der_p[size] - Xders.setdefault(t1, {})[size] = Xder_t1 + # Internal working dicts: + # simInstMDVs_work: emu_obj -> {time: mdv_array} + # simInstMDVsDer_work: emu_obj -> {time: derivative_array (n_coeffs, n_params)} + simInstMDVs_work = {} + simInstMDVsDer_work = {} + + # Ys_t_work, Xs_t_work: time -> {size: matrix_value (numpy)} + # Yders_p_t_work, Xders_p_t_work: time -> {size: derivative_tensor (n_coeffs, n_params-per-matrix-row-implicitly)} + # Actually, derivatives are (n_params, n_rows, n_cols) for matrices A, B, M, X, Y. + # For MDVs (X, Y), this means (n_params, n_EMUs_or_Sources, n_coeffs). + # Let's store derivatives as (n_params, n_EMUs_or_Sources, n_coeffs) in Xders_p_t, Yders_p_t. + # When retrieving for a single EMU for simInstMDVsDer_work, it will be (n_params, n_coeffs), + # then transposed to (n_coeffs, n_params) for consistency with steady-state. + + Ys_t_work = {} + Xs_t_work = {} + Yders_p_t_work = {} # Stores dY/dp as {time: {size: tensor (n_params, n_sources, n_coeffs)}} + Xders_p_t_work = {} # Stores dX/dp as {time: {size: tensor (n_params, n_products, n_coeffs)}} - for t in self.model.timepoints: - if t != 0.0: - t0 = t1 - t1 = t - deltat = t1 - t0 - - for size in sorted(self.model.matrix_As): - lambA, fluxidsA, productEMUs = self.model.matrix_As[size] - lambB, fluxidsB, sourceEMUs = self.model.matrix_Bs[size] - lambM, metabids = self.model.matrix_Ms[size] + time_prev = 0.0 # Represents initial time t=0 + + # Populate initial conditions at t=0 for X, Y and their derivatives dX/dp, dY/dp + for size_idx in sorted(self.model.matrix_As): + Xs_t_work.setdefault(time_prev, {})[size_idx] = self.model.initial_matrix_Xs[size_idx] + Ys_t_work.setdefault(time_prev, {})[size_idx] = self.model.initial_matrix_Ys[size_idx] + Xders_p_t_work.setdefault(time_prev, {})[size_idx] = self.model.initial_matrix_Xs_der_p[size_idx] + Yders_p_t_work.setdefault(time_prev, {})[size_idx] = self.model.initial_matrix_Ys_der_p[size_idx] + + # Store initial MDVs and their derivatives for target EMUs + productEMUs_objs_at_size = self.model.matrix_As[size_idx][2] + X_t0_at_size = Xs_t_work[time_prev][size_idx] + Xder_p_t0_at_size = Xders_p_t_work[time_prev][size_idx] # (n_params, n_EMUs, n_coeffs) + + for i, emu_obj in enumerate(productEMUs_objs_at_size): + if emu_obj.id in self.model.target_EMUs: + simInstMDVs_work.setdefault(emu_obj, {})[time_prev] = X_t0_at_size[i,:] + # Store derivative as (n_coeffs, n_params) + simInstMDVsDer_work.setdefault(emu_obj, {})[time_prev] = Xder_p_t0_at_size[:, i, :].T + + + for t_current_loop in self.model.timepoints: # These are t > 0 + if t_current_loop == 0.0: continue + + t_current = t_current_loop + deltat = t_current - time_prev - A = lambA(*self.model.total_fluxes[fluxidsA]) - B = lambB(*self.model.total_fluxes[fluxidsB]) - M = lambM(*self.model.concentrations[metabids]) - Minv = pinv(M, check_finite = True) - - Ader = self.model.matrix_As_der_p[size] - Bder = self.model.matrix_Bs_der_p[size] - Mder = self.model.matrix_Ms_der_p[size] - Minvder = -Minv@Mder@Minv - - F = Minv@A - Finv = pinv(F, check_finite = True) - I = np.eye(*F.shape) - Phi = expm(F*deltat) - Gamma = (Phi - I)@Finv - Omega = (Gamma/deltat - I)@Finv - - X_t0 = Xs[t0][size] - - Y_t0 = Ys[t0][size] - G_t0 = Minv@B@Y_t0 - - Xder_t0 = Xders[t0][size] - - Yder_t0 = Yders[t0][size] - H_t0 = (Minvder@A@X_t0 - + Minv@Ader@X_t0 - - Minv@B@Yder_t0 - - Minvder@B@Y_t0 - - Minv@Bder@Y_t0) - - Y_t1 = [] - Yder_t1 = [] - for sourceEMU in sourceEMUs: - if not isinstance(sourceEMU, Iterable): - sourceMDV = self.model.substrate_MDVs[sourceEMU] - sourceMDVder = self.model.substrate_MDVs_der_p[sourceEMU] - else: - mdvs = [] - mdvs_mdvders = [] - for emu in sourceEMU: - mdv = ChainMap(simInstMDVs, self.model.substrate_MDVs)[emu] - if isinstance(mdv, dict): - mdv = mdv[t1] - mdvs.append(mdv) - mdvder = ChainMap(simInstMDVsDer, self.model.substrate_MDVs_der_p)[emu] - if isinstance(mdvder, dict): - mdvder = mdvder[t1] - mdvs_mdvders.append([mdv, mdvder]) - sourceMDV = reduce(conv, mdvs) - sourceMDVder = reduce(diff_conv, mdvs_mdvders)[1] - Y_t1.append(sourceMDV) - Yder_t1.append(sourceMDVder) - Y_t1 = np.array(Y_t1) - Yder_t1 = np.array(Yder_t1).swapaxes(1,2).swapaxes(0,1) - - G_t1 = Minv@B@Y_t1 - X_t1 = Phi@X_t0 - Gamma@G_t0 - Omega@(G_t1 - G_t0) - - H_t1 = (Minvder@A@X_t1 - + Minv@Ader@X_t1 - - Minv@B@Yder_t1 - - Minvder@B@Y_t1 - - Minv@Bder@Y_t1) - Xder_t1 = Phi@Xder_t0 + Gamma@H_t0 + Omega@(H_t1 - H_t0) - - Ys.setdefault(t1, {})[size] = Y_t1 - Xs.setdefault(t1, {})[size] = X_t1 - Yders.setdefault(t1, {})[size] = Yder_t1 - Xders.setdefault(t1, {})[size] = Xder_t1 - - for productEMU, mdv_t1 in zip(productEMUs, X_t1): - simInstMDVs.setdefault(productEMU, {}).update({t1: mdv_t1}) - - Xder_t1 = Xder_t1.swapaxes(0,1).swapaxes(1,2) - for productEMU, mdvder_t1 in zip(productEMUs, Xder_t1): - simInstMDVsDer.setdefault(productEMU, {}).update({t1: mdvder_t1}) + Xs_t_work.setdefault(t_current, {}) + Ys_t_work.setdefault(t_current, {}) + Xders_p_t_work.setdefault(t_current, {}) + Yders_p_t_work.setdefault(t_current, {}) + + for size_idx in sorted(self.model.matrix_As): + lambA, fluxidsA, productEMUs_objs = self.model.matrix_As[size_idx] + lambB, fluxidsB, sourceEMUs_items = self.model.matrix_Bs[size_idx] + lambM, metabids_for_M_args = self.model.matrix_Ms[size_idx] + + A_val = lambA(*self.model.total_fluxes[fluxidsA]) + B_val = lambB(*self.model.total_fluxes[fluxidsB]) + conc_args_for_M = [self.model.concentrations[mid] for mid in metabids_for_M_args] + M_val = lambM(*conc_args_for_M) + Minv_val = pinv(M_val, check_finite=False) + + # Matrix derivatives d/dp (p includes free fluxes and concentrations) + # Shapes: (n_params, n_rows, n_cols) + Ader_p_val = self.model.matrix_As_der_p[size_idx] + Bder_p_val = self.model.matrix_Bs_der_p[size_idx] + Mder_p_val = self.model.matrix_Ms_der_p[size_idx] + Minv_der_p_val = -Minv_val @ Mder_p_val @ Minv_val # d(M^-1)/dp = -M^-1 * dM/dp * M^-1 (element-wise for params dim) + # This should be vmap over params: + # Minv_der_p_val[p] = -Minv_val @ Mder_p_val[p] @ Minv_val + Minv_der_p_val = np.einsum('ik,pkm,ml->pil', -Minv_val, Mder_p_val, Minv_val) + + + F_val = Minv_val @ A_val + Finv_val = pinv(F_val, check_finite=False) + I_mtx = np.eye(*F_val.shape) + Phi_val = expm(F_val * deltat) # Matrix exponential + Gamma_val = (Phi_val - I_mtx) @ Finv_val + Omega_val = (Gamma_val / deltat - I_mtx) @ Finv_val + + X_t_prev_at_size = Xs_t_work[time_prev][size_idx] + Y_t_prev_at_size = Ys_t_work[time_prev][size_idx] + G_t_prev_at_size = Minv_val @ B_val @ Y_t_prev_at_size + + Xder_p_t_prev_at_size = Xders_p_t_work[time_prev][size_idx] # (n_params, n_prods, n_coeffs) + Yder_p_t_prev_at_size = Yders_p_t_work[time_prev][size_idx] # (n_params, n_sources, n_coeffs) + + # H_t_prev = dG_t_prev/dp = d(Minv*B*Y_prev)/dp + # = (dMinv/dp)*B*Y_prev + Minv*(dB/dp)*Y_prev + Minv*B*(dY_prev/dp) + # All derivatives are (n_params, n_rows, n_cols) + # Minv_der_p_val @ B_val @ Y_t_prev_at_size + term_H_1 = np.einsum('pij,jk,kl->pil', Minv_der_p_val, B_val, Y_t_prev_at_size) + # Minv_val @ Bder_p_val @ Y_t_prev_at_size + term_H_2 = np.einsum('ij,pjk,kl->pil', Minv_val, Bder_p_val, Y_t_prev_at_size) + # Minv_val @ B_val @ Yder_p_t_prev_at_size (Yder is (n_params, n_sources, n_coeffs)) + term_H_3 = np.einsum('ij,jk,pkl->pil', Minv_val, B_val, Yder_p_t_prev_at_size) + Gder_p_t_prev_at_size = term_H_1 + term_H_2 + term_H_3 # (n_params, n_prods, n_coeffs) + + + # Calculate Y_t_current_at_size and its derivative Yder_p_t_current_at_size + Y_list_for_t_curr = [] + Yder_p_list_for_t_curr = [] # List of (n_coeffs, n_params) arrays + + for sourceEMU_item in sourceEMUs_items: + if not isinstance(sourceEMU_item, Iterable): # single EMU object + mdv_obj = self.model.substrate_MDVs[sourceEMU_item] + curr_mdv_val = mdv_obj.value if isinstance(mdv_obj, MDV) else mdv_obj + # Substrate derivatives are (n_coeffs, n_params) in substrate_MDVs_der_p + curr_mdv_der_val = self.model.substrate_MDVs_der_p[sourceEMU_item] + else: # tuple of EMU objects for convolution + mdv_val_list_for_conv_iter = [] + mdv_der_pair_list_for_diff_conv_iter = [] + for emu_obj_in_tuple in sourceEMU_item: + # Get MDV value (array) at t_current + if emu_obj_in_tuple in simInstMDVs_work and t_current in simInstMDVs_work[emu_obj_in_tuple]: + mdv_val_arr_iter = simInstMDVs_work[emu_obj_in_tuple][t_current] + else: # Fallback to substrate (constant) + mdv_obj_s_iter = self.model.substrate_MDVs[emu_obj_in_tuple] + mdv_val_arr_iter = mdv_obj_s_iter.value if isinstance(mdv_obj_s_iter, MDV) else mdv_obj_s_iter + mdv_val_list_for_conv_iter.append(MDV(mdv_val_arr_iter)) + + # Get MDV derivative (array, n_coeffs, n_params) at t_current + if emu_obj_in_tuple in simInstMDVsDer_work and t_current in simInstMDVsDer_work[emu_obj_in_tuple]: + mdv_der_arr_iter = simInstMDVsDer_work[emu_obj_in_tuple][t_current] + else: # Fallback to substrate derivative + mdv_der_arr_iter = self.model.substrate_MDVs_der_p[emu_obj_in_tuple] + mdv_der_pair_list_for_diff_conv_iter.append([mdv_val_arr_iter, mdv_der_arr_iter]) - simInstMDVs = {emu.id: mdvs for emu, mdvs in simInstMDVs.items()} - simInstMDVsDer = {emu.id: mdvders for emu, mdvders in simInstMDVsDer.items()} - - return simInstMDVs, simInstMDVsDer + convolved_mdv_obj_iter = reduce(conv, mdv_val_list_for_conv_iter) + curr_mdv_val = convolved_mdv_obj_iter.value + convolved_mdv_der_pair_iter = reduce(diff_conv, mdv_der_pair_list_for_diff_conv_iter) + curr_mdv_der_val = convolved_mdv_der_pair_iter[1] + + Y_list_for_t_curr.append(curr_mdv_val) + Yder_p_list_for_t_curr.append(curr_mdv_der_val) # List of (n_coeffs, n_params) + + Y_t_current_at_size = np.array(Y_list_for_t_curr) if Y_list_for_t_curr else np.empty((B_val.shape[1], size_idx+1)) + if Y_t_current_at_size.ndim == 1 and B_val.shape[1] == 1: Y_t_current_at_size = Y_t_current_at_size.reshape(1,-1) + Ys_t_work[t_current][size_idx] = Y_t_current_at_size + + Yder_p_t_curr_tensor = np.array(Yder_p_list_for_t_curr) if Yder_p_list_for_t_curr else np.empty((B_val.shape[1], size_idx+1, Xder_p_t_prev_at_size.shape[0])) + if Yder_p_t_curr_tensor.ndim == 2 and B_val.shape[1] == 1: Yder_p_t_curr_tensor = Yder_p_t_curr_tensor.reshape(1, Yder_p_t_curr_tensor.shape[0], Yder_p_t_curr_tensor.shape[1]) + Yder_p_t_current_at_size = Yder_p_t_curr_tensor.transpose(2,0,1) # (n_params, n_sources, n_coeffs) + Yders_p_t_work[t_current][size_idx] = Yder_p_t_current_at_size + + + G_t_current_at_size = Minv_val @ B_val @ Y_t_current_at_size + # Gder_p_t_current_at_size (n_params, n_prods, n_coeffs) + term_Hc_1 = np.einsum('pij,jk,kl->pil', Minv_der_p_val, B_val, Y_t_current_at_size) + term_Hc_2 = np.einsum('ij,pjk,kl->pil', Minv_val, Bder_p_val, Y_t_current_at_size) + term_Hc_3 = np.einsum('ij,jk,pkl->pil', Minv_val, B_val, Yder_p_t_current_at_size) + Gder_p_t_current_at_size = term_Hc_1 + term_Hc_2 + term_Hc_3 + + # Calculate X_t_current and its derivative Xder_p_t_current + X_t_current_at_size = Phi_val @ X_t_prev_at_size - Gamma_val @ G_t_prev_at_size - Omega_val @ (G_t_current_at_size - G_t_prev_at_size) + Xs_t_work[t_current][size_idx] = X_t_current_at_size + + # dF/dp = d(Minv*A)/dp = (dMinv/dp)*A + Minv*(dA/dp) + dF_dp = np.einsum('pij,jk->pik', Minv_der_p_val, A_val) + np.einsum('ij,pjk->pik', Minv_val, Ader_p_val) + # This is complex: d(expm(F*dt))/dp and subsequent terms (dPhi/dp, dGamma/dp, dOmega/dp) + # This requires matrix derivative calculus for expm, inv. + # The original paper/code might use approximations or specific formulas. + # For now, assume H_t0 and H_t1 are simplified forms as in original code if available. + # Original: H_t0 = (Minvder@A@X_t0 + Minv@Ader@X_t0 - Minv@B@Yder_t0 - Minvder@B@Y_t0 - Minv@Bder@Y_t0) + # This is not dG/dp, but a different combination. Let's call it K. + # K = d(Minv*A*X)/dp - d(Minv*B*Y)/dp + # d(Minv*A*X)/dp = (dMinv/dp)AX + Minv(dA/dp)X + MinvA(dX/dp) + # d(Minv*B*Y)/dp = (dMinv/dp)BY + Minv(dB/dp)Y + MinvB(dY/dp) + + # K_t_prev + term_K1_prev = np.einsum('pij,jk,kl->pil', Minv_der_p_val, A_val, X_t_prev_at_size) + term_K2_prev = np.einsum('ij,pjk,kl->pil', Minv_val, Ader_p_val, X_t_prev_at_size) + # Minv*A*dX_prev/dp term for K_t_prev: + term_K3_prev = np.einsum('ij,jk,pkl->pil', Minv_val, A_val, Xder_p_t_prev_at_size) + K_t_prev = (term_K1_prev + term_K2_prev + term_K3_prev) - Gder_p_t_prev_at_size + + # K_t_current + term_K1_curr = np.einsum('pij,jk,kl->pil', Minv_der_p_val, A_val, X_t_current_at_size) # Uses X_current + term_K2_curr = np.einsum('ij,pjk,kl->pil', Minv_val, Ader_p_val, X_t_current_at_size) # Uses X_current + # dX_curr/dp is not known yet. This formulation seems circular or needs specific derivative of ODE solution. + # The provided H was: (Minvder@A@X + Minv@Ader@X - Minv@B@Yder - Minvder@B@Y - Minv@Bder@Y) + # This does not include dX/dp or dY/dp terms directly in H. + # Let's use the structure from original code's H_t0, H_t1 for Xder calculation. + # H = d(F*X)/dp - d(G)/dp, where F=Minv*A. (dX/dt = F*X - G). So d(dX/dt)/dp = d(F*X)/dp - d(G)/dp + # d(F*X)/dp = (dF/dp)X + F(dX/dp) + # H_from_paper = (dF/dp)X - Gder + + H_val_t_prev = np.einsum('pij,jk->pik', dF_dp, X_t_prev_at_size) - Gder_p_t_prev_at_size + H_val_t_current = np.einsum('pij,jk->pik', dF_dp, X_t_current_at_size) - Gder_p_t_current_at_size + + + # Derivative of X_t_current w.r.t params 'p' + # Xder_p_t_current = Phi @ Xder_p_t_prev + dPhi/dp @ X_t_prev (Chain rule for Phi) + # - (dGamma/dp @ G_prev + Gamma @ dG_prev/dp) + # - (dOmega/dp @ (G_curr-G_prev) + Omega @ (dG_curr/dp - dG_prev/dp)) + # This is very complex. The original code uses: + # Xder_t1 = Phi@Xder_t0 + Gamma@H_t0 + Omega@(H_t1 - H_t0) + # This H must be d(F*X - G)/dX_t0 * dX_t0/dp ... no, this H is likely specific to the solution form. + # Assuming that H formulation is correct from a source theory: + Xder_p_t_current_at_size = Phi_val @ Xder_p_t_prev_at_size + Gamma_val @ H_val_t_prev + Omega_val @ (H_val_t_current - H_val_t_prev) + Xders_p_t_work[t_current][size_idx] = Xder_p_t_current_at_size # Store (n_params, n_prods, n_coeffs) + + # Store results in simInstMDVs_work and simInstMDVsDer_work (derivatives as n_coeffs, n_params) + for i, emu_obj in enumerate(productEMUs_objs): + simInstMDVs_work.setdefault(emu_obj, {})[t_current] = X_t_current_at_size[i,:] + simInstMDVsDer_work.setdefault(emu_obj, {})[t_current] = Xder_p_t_current_at_size[:, i, :].T # (n_coeffs, n_params) + + time_prev = t_current # Update time_prev for the next iteration + + # Convert final result to emu_id_str -> {time: array} + simInstMDVs_by_id = { + emu_obj.id: time_dict for emu_obj, time_dict in simInstMDVs_work.items() + } + simInstMDVsDer_by_id = { + emu_obj.id: time_dict_der for emu_obj, time_dict_der in simInstMDVsDer_work.items() + } + + return simInstMDVs_by_id, simInstMDVsDer_by_id + +[end of src/freeflux/utils/utils.py]