diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index fe396968f..94ba3c626 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,7 @@ # see https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners # default owners = active maintainers -* @Doresic @PaulJonasJost @vwiela +* @Doresic @PaulJonasJost @MoR1chter # Examples /doc/example/censored_data.ipynb @Doresic @@ -35,22 +35,22 @@ /pypesto/optimize/ @PaulJonasJost /pypesto/petab/ @dweindl @FFroehlich /pypesto/predict/ @dilpath -/pypesto/problem/ @PaulJonasJost @vwiela +/pypesto/problem/ @PaulJonasJost @MoR1chter /pypesto/profile/ @PaulJonasJost @Doresic /pypesto/result/ @PaulJonasJost -/pypesto/sample/ @dilpath @arrjon +/pypesto/sample/ @dilpath @arrjon @vwiela /pypesto/select/ @dilpath /pypesto/startpoint/ @PaulJonasJost /pypesto/store/ @PaulJonasJost -/pypesto/visualize/ @stephanmg +/pypesto/visualize/ @Doresic # Tests -/test/base/ @PaulJonasJost @vwiela +/test/base/ @PaulJonasJost @MoR1chter /test/doc/ @PaulJonasJost /test/hierarchical/ @dweindl @Doresic /test/julia/ @PaulJonasJost @vwiela /test/optimize/ @PaulJonasJost /test/petab/ @dweindl @FFroehlich /test/profile/ @PaulJonasJost @Doresic -/test/sample/ @dilpath @arrjon +/test/sample/ @dilpath @arrjon @vwiela /test/select/ @dilpath diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c670d978..cfe18393f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,7 @@ jobs: CXX: clang++ - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -91,7 +91,7 @@ jobs: run: ulimit -n 65536 65536 && tox -e base - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -163,7 +163,7 @@ jobs: CXX: clang++ - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -174,11 +174,6 @@ jobs: matrix: python-version: ['3.11', '3.13'] - # needed to allow julia-actions/cache to delete old caches that it has created - permissions: - actions: write - contents: read - steps: - name: Check out repository uses: actions/checkout@v6 @@ -196,25 +191,15 @@ jobs: .tox/ key: "${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-ci-${{ github.job }}" - - name: Install julia - uses: julia-actions/setup-julia@v2 - with: - version: 1.11 - - name: Install dependencies run: .github/workflows/install_deps.sh - - name: Install PEtabJL dependencies - run: > - julia -e 'using Pkg; Pkg.add("PEtab"); - Pkg.add("OrdinaryDiffEq"); Pkg.add("Sundials")' - - name: Run tests timeout-minutes: 25 run: tox -e julia - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -250,7 +235,7 @@ jobs: run: tox -e optimize - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -286,7 +271,7 @@ jobs: run: tox -e hierarchical - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml @@ -322,7 +307,7 @@ jobs: run: tox -e select - name: Coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index f69264f8b..c80d70aaf 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -11,6 +11,13 @@ jobs: matrix: python-version: ['3.12'] + environment: + name: PyPI + url: https://pypi.org/p/pypesto + + permissions: + id-token: write + steps: - name: Check out repository uses: actions/checkout@v6 @@ -25,15 +32,14 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + python -m pip install -U build - - name: Build and publish - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} - run: | - python setup.py sdist bdist_wheel - twine upload dist/* + - name: Build distribution + run: + python -m build + + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/doc/example/getting_started.ipynb b/doc/example/getting_started.ipynb index 2e21f2c7e..c363a0a62 100644 --- a/doc/example/getting_started.ipynb +++ b/doc/example/getting_started.ipynb @@ -92,7 +92,9 @@ "name": "#%% md\n" } }, - "source": "Define lower and upper parameter bounds and create an optimization problem." + "source": [ + "Define lower and upper parameter bounds and create an optimization problem." + ] }, { "cell_type": "code", @@ -632,7 +634,7 @@ "outputs": [], "source": [ "# adapt x_labels.\n", - "x_labels = [f\"Log10({name})\" for name in problem.x_names]\n", + "x_labels = [f\"Log({name})\" for name in problem.x_names]\n", "\n", "ax = visualize.profiles(result, x_labels=x_labels, show_bounds=True)\n", "# visualize optimal parameter values\n", @@ -650,6 +652,22 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The 2D profile visualization shows joint parameter paths across all profile pairs. Diagonal plots show 1D profiles (likelihood ratio vs. parameter value). Off-diagonal plots show how each non-profiled parameter moves while another is profiled, with color indicating the likelihood ratio.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = pypesto.visualize.visualize_2d_profile(result)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -675,7 +693,7 @@ " result, confidence_level=0.95, show_bounds=True\n", ")\n", "\n", - "ax.set_xlabel(\"Log10(Parameter value)\");" + "ax.set_xlabel(\"Log(Parameter value)\");" ] }, { @@ -695,7 +713,7 @@ " result, confidence_levels=[0.99, 0.95, 0.9]\n", ")\n", "\n", - "ax.set_xlabel(\"Log10(Parameter value)\");" + "ax.set_xlabel(\"Log(Parameter value)\");" ] }, { diff --git a/pypesto/hierarchical/inner_calculator_collector.py b/pypesto/hierarchical/inner_calculator_collector.py index 8570cbb94..fbd10cdca 100644 --- a/pypesto/hierarchical/inner_calculator_collector.py +++ b/pypesto/hierarchical/inner_calculator_collector.py @@ -475,7 +475,7 @@ def __call__( "Cannot use least squares solver with" "parameter dependent sigma! Support can be " "enabled via " - "amici_model.setAddSigmaResiduals()." + "amici_model.set_add_sigma_residuals()." ) self._known_least_squares_safe = True # don't check this again diff --git a/pypesto/hierarchical/ordinal/solver.py b/pypesto/hierarchical/ordinal/solver.py index c3d4a5397..30109591a 100644 --- a/pypesto/hierarchical/ordinal/solver.py +++ b/pypesto/hierarchical/ordinal/solver.py @@ -236,7 +236,7 @@ def calculate_gradients( par_sim_ids: Ids of outer simulation parameters, includes fixed parameters. par_edata_indices: - Indices of parameters from `amici_model.getParameterIds()` that are needed for + Indices of parameters from `amici_model.get_free_parameter_ids()` that are needed for sensitivity calculation. Comes from `edata.plist` for each condition. snllh: A zero-initialized vector of the same length as ``par_opt_ids`` to store the diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index 42ce9269c..51c9c02da 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -139,7 +139,7 @@ def __init__( derivatives. amici_reporting: Determines which quantities will be computed by AMICI, - see ``amici.Solver.setReturnDataReportingMode``. Set to ``None`` + see ``amici.Solver.set_return_data_reporting_mode``. Set to ``None`` to compute only the minimum required information. """ import amici diff --git a/pypesto/objective/amici/amici_calculator.py b/pypesto/objective/amici/amici_calculator.py index f1ab22c3f..efdc51996 100644 --- a/pypesto/objective/amici/amici_calculator.py +++ b/pypesto/objective/amici/amici_calculator.py @@ -137,7 +137,7 @@ def __call__( "Cannot use least squares solver with" "parameter dependent sigma! Support can be " "enabled via " - "amici_model.setAddSigmaResiduals()." + "amici_model.set_add_sigma_residuals()." ) self._known_least_squares_safe = True # don't check this again diff --git a/pypesto/objective/julia/base.py b/pypesto/objective/julia/base.py index 6e6fd75d9..ac461cd16 100644 --- a/pypesto/objective/julia/base.py +++ b/pypesto/objective/julia/base.py @@ -1,7 +1,14 @@ -"""Interface to Julia via pyjulia.""" +"""Interface to Julia via juliacall.""" from collections.abc import Callable +# Import juliacall early to avoid conflicts with other libraries (especially numpy) +# See: https://juliapy.github.io/PythonCall.jl/dev/faq/ +try: + from juliacall import Main as jl # noqa: F401 +except ImportError: + jl = None + import numpy as np from ..function import Objective @@ -27,10 +34,10 @@ def _read_source(module_name: str, source_file: str) -> None: module_name: Julia module name. source_file: Qualified Julia source file. """ - from julia import Main + from juliacall import Main as jl - if not hasattr(Main, module_name): - Main.include(source_file) + if not hasattr(jl, module_name): + jl.include(source_file) class JuliaObjective(Objective): @@ -40,45 +47,18 @@ class JuliaObjective(Objective): It expects the corresponding Julia objects to be defined in a `source_file` within a `module`. - We use the PyJulia package to access Julia from inside Python. - It can be installed via `pip install pypesto[julia]`, however requires - additional Julia dependencies to be installed via: + We use the juliacall package (part of PythonCall.jl) to access Julia + from inside Python. It can be installed via `pip install pypesto[julia]`. - >>> python -c "import julia; julia.install()" + juliacall automatically manages the Julia installation and configuration, + so no additional setup steps are required beyond pip installation. For further information, see - https://pyjulia.readthedocs.io/en/latest/installation.html. - - There are some known problems, e.g. with statically linked Python - interpreters, see - https://pyjulia.readthedocs.io/en/latest/troubleshooting.html - for details. - Possible solutions are to pass ``compiled_modules=False`` to the Julia - constructor early in your code: - - >>> from julia.api import Julia - >>> jl = Julia(compiled_modules=False) - - This however slows down loading and using Julia packages, especially for - large ones. - An alternative is to use the ``python-jl`` command shipped with PyJulia: - - >>> python-jl MY_SCRIPT.py - - This basically launches a Python interpreter inside Julia. - When using Jupyter notebooks, this wrapper can be installed as an - additional kernel via: - - >>> python -m ipykernel install --name python-jl [--prefix=/path/to/python/env] - - And changing the first argument in - ``/path/to/python/env/share/jupyter/kernels/python-jl/kernel.json`` - to ``python-jl``. + https://juliapy.github.io/PythonCall.jl/stable/juliacall/ - Model simulations are eagerly converted to Python objects - (specifically, `numpy.ndarray` and `pandas.DataFrame`). - This can introduce overhead and could be avoided by an alternative - lazy implementation. + Model simulations are efficiently handled with minimal overhead. + By default, juliacall wraps mutable objects instead of copying them, + providing better performance than PyJulia. Parameters ---------- @@ -103,10 +83,10 @@ def __init__( ): # lazy imports try: - from julia import Main # noqa: F401 + from juliacall import Main as jl # noqa: F401 except ImportError: raise ImportError( - "Install PyJulia, e.g. via `pip install pypesto[julia]`, " + "Install juliacall, e.g. via `pip install pypesto[julia]`, " "and see the class documentation", ) from None @@ -133,10 +113,10 @@ def get(self, name: str, as_array: bool = False) -> Callable | None: Use this function to access any variable from the Julia module. """ - from julia import Main + from juliacall import Main as jl if name is not None: - ret = getattr(getattr(Main, self.module), name, None) + ret = getattr(getattr(jl, self.module), name, None) if as_array: ret = _as_array(ret) return ret diff --git a/pypesto/objective/julia/petabJl.py b/pypesto/objective/julia/petabJl.py index f9ddefb22..5ffebbb64 100644 --- a/pypesto/objective/julia/petabJl.py +++ b/pypesto/objective/julia/petabJl.py @@ -3,6 +3,13 @@ import logging import os +# Import juliacall early to avoid conflicts with other libraries (especially numpy) +# See: https://juliapy.github.io/PythonCall.jl/dev/faq/ +try: + from juliacall import Main as jl # noqa: F401 +except ImportError: + jl = None + import numpy as np from .base import JuliaObjective, _read_source @@ -37,12 +44,14 @@ def __init__( """Initialize objective.""" # lazy imports try: - from julia import Main, Pkg # noqa: F401 + from juliacall import Main as jl # noqa: F401 - Pkg.activate(".") + # Load Pkg into Julia session + jl.seval("using Pkg") + jl.Pkg.activate(".") except ImportError: raise ImportError( - "Install PyJulia, e.g. via `pip install pypesto[julia]`, " + "Install juliacall, e.g. via `pip install pypesto[julia]`, " "and see the class documentation", ) from None @@ -87,15 +96,14 @@ def __setstate__(self, state): setattr(self, key, value) # lazy imports try: - from julia import ( - Main, # noqa: F401 - Pkg, - ) + from juliacall import Main as jl # noqa: F401 - Pkg.activate(".") + # Load Pkg into Julia session + jl.seval("using Pkg") + jl.Pkg.activate(".") except ImportError: raise ImportError( - "Install PyJulia, e.g. via `pip install pypesto[julia]`, " + "Install juliacall, e.g. via `pip install pypesto[julia]`, " "and see the class documentation", ) from None # Include module if not already included @@ -138,19 +146,20 @@ def precompile_model(self, force_compile: bool = False): return None # lazy imports try: - from julia import Main # noqa: F401 + from juliacall import Main as jl # noqa: F401 except ImportError: raise ImportError( - "Install PyJulia, e.g. via `pip install pypesto[julia]`, " + "Install juliacall, e.g. via `pip install pypesto[julia]`, " "and see the class documentation", ) from None # setting up a local project, where the precompilation will be done in - from julia import Pkg - Pkg.activate(".") + # Load Pkg into Julia session + jl.seval("using Pkg") + jl.Pkg.activate(".") # create a Project f"{self.module}_pre". try: - Pkg.generate(f"{directory}/{self.module}_pre") + jl.Pkg.generate(f"{directory}/{self.module}_pre") except Exception: logger.info("Module is already generated. Skipping generate...") # Adjust the precompilation file @@ -169,16 +178,16 @@ def precompile_model(self, force_compile: bool = False): os.rename("dummy_temp_file.jl", self.source_file) try: - Pkg.develop(path=f"{directory}/{self.module}_pre") + jl.Pkg.develop(path=f"{directory}/{self.module}_pre") except Exception: logger.info("Module is already developed. Skipping develop...") - Pkg.activate(f"{directory}/{self.module}_pre/") + jl.Pkg.activate(f"{directory}/{self.module}_pre/") # add dependencies - Pkg.add("PrecompileTools") - Pkg.add("OrdinaryDiffEq") - Pkg.add("PEtab") - Pkg.add("Sundials") - Pkg.precompile() + jl.Pkg.add("PrecompileTools") + jl.Pkg.add("OrdinaryDiffEq") + jl.Pkg.add("PEtab") + jl.Pkg.add("Sundials") + jl.Pkg.precompile() def write_precompilation_module(module, source_file_orig): diff --git a/pypesto/objective/julia/petab_jl_importer.py b/pypesto/objective/julia/petab_jl_importer.py index 1a81ef450..7f17c5582 100644 --- a/pypesto/objective/julia/petab_jl_importer.py +++ b/pypesto/objective/julia/petab_jl_importer.py @@ -6,6 +6,13 @@ import os.path from collections.abc import Iterable +# Import juliacall early to avoid conflicts with other libraries (especially numpy) +# See: https://juliapy.github.io/PythonCall.jl/dev/faq/ +try: + from juliacall import Main as jl # noqa: F401 +except ImportError: + jl = None + import numpy as np from pypesto.objective.julia import PEtabJlObjective @@ -116,10 +123,10 @@ def create_objective( """ # lazy imports try: - from julia import Main # noqa: F401 + from juliacall import Main as jl # noqa: F401 except ImportError: raise ImportError( - "Install PyJulia, e.g. via `pip install pypesto[julia]`, " + "Install juliacall, e.g. via `pip install pypesto[julia]`, " "and see the class documentation", ) from None if self.source_file is None: diff --git a/pypesto/optimize/optimizer.py b/pypesto/optimize/optimizer.py index b07e5db35..6b66ad2f8 100644 --- a/pypesto/optimize/optimizer.py +++ b/pypesto/optimize/optimizer.py @@ -1004,9 +1004,13 @@ def minimize( check_finite_bounds(lb, ub) - xopt, fopt = pyswarm.pso( + result = pyswarm.pso( problem.objective.get_fval, lb, ub, **self.options ) + if hasattr(result, "x") and hasattr(result, "fun"): + xopt, fopt = result.x, result.fun + else: + xopt, fopt = result optimizer_result = OptimizerResult( x=np.array(xopt), fval=fopt, optimizer=str(self) @@ -1700,9 +1704,15 @@ def __init__( Optimizer options. See :meth:`fides.minimize.Optimizer.minimize` and :class:`fides.constants.Options` for details. hessian_update: - Hessian update strategy. If this is ``None``, a hybrid approximation - that switches from the ``problem.objective`` provided Hessian ( - approximation) to a BFGS approximation will be used. + Hessian update strategy. Defaults to a BFGS approximation if + ``problem.objective`` does not provide a Hessian. Otherwise, it is + assumed that the ``problem.objective`` Hessian is actually the + Fisher information matrix (FIM), and hence a Hessian approximation + strategy is the default, which uses the FIM initially but switches + to BFGS during later iterations. + If your ``problem.objective`` Hessian is actually the Hessian, + then use ``None`` to have Fides use the ``problem.objective`` + Hessian for all iterations. """ super().__init__() @@ -1762,15 +1772,23 @@ def minimize( if self.hessian_update == "default": if not problem.objective.has_hess: - warnings.warn( + logger.debug( "Fides is using BFGS as hessian approximation, " "as the problem does not provide a Hessian. " - "Specify a Hessian to use a more efficient " - "hybrid approximation scheme.", + "Specify a Hessian (or Fisher information matrix, to use " + "a more efficient hybrid approximation scheme. See the " + "docstring for `hessian_update` in the class constructor " + "for more details.", stacklevel=1, ) _hessian_update = fides.BFGS() else: + logger.debug( + "A hybrid Hessian approximation strategy will be " + "employed. See the docstring for `hessian_update` in " + "the class constructor for more details.", + stacklevel=1, + ) _hessian_update = fides.HybridFixed() else: _hessian_update = self.hessian_update diff --git a/pypesto/profile/options.py b/pypesto/profile/options.py index bcc3b805e..11bcfb87a 100644 --- a/pypesto/profile/options.py +++ b/pypesto/profile/options.py @@ -1,20 +1,42 @@ +import warnings from typing import Union +#: Deprecated ``ProfileOptions`` step-size names mapped to their replacements. +#: Kept for backward compatibility, both as constructor arguments and as +#: attributes. +_DEPRECATED_STEP_SIZE_NAMES = { + "default_step_size": "default_step_size_absolute", + "min_step_size": "min_step_size_absolute", + "max_step_size": "max_step_size_absolute", +} + class ProfileOptions(dict): """ Options for optimization based profiling. + Step sizes can be configured as absolute values or relative fractions of + the parameter span. A family is disabled by setting its default step size + to `0`. For each profiled parameter, pyPESTO uses either the full absolute + family or the full relative family, whichever has the larger default step + size, i.e. whichever of `default_step_size_absolute` and + `default_step_size_relative * (ub - lb)` is larger. + Attributes ---------- - default_step_size: - Default step size of the profiling routine along the profile path - (adaptive step lengths algorithms will only use this as a first guess - and then refine the update). - min_step_size: - Lower bound for the step size in adaptive methods. - max_step_size: - Upper bound for the step size in adaptive methods. + default_step_size_absolute: + Default absolute profile step size. Set to `0` to disable. + default_step_size_relative: + Default relative profile step size, as fraction of `ub - lb`. Set to + `0` to disable. + min_step_size_absolute: + Minimum absolute step size in adaptive methods. + min_step_size_relative: + Minimum relative step size, as fraction of `ub - lb`. + max_step_size_absolute: + Maximum absolute step size in adaptive methods. + max_step_size_relative: + Maximum relative step size, as fraction of `ub - lb`. step_size_factor: Adaptive methods recompute the likelihood at the predicted point and try to find a good step length by a sort of line search algorithm. @@ -38,13 +60,22 @@ class ProfileOptions(dict): whole_path: Whether to profile the whole bounds or only till we get below the ratio. + step_size_precheck_mode: + Controls the step-size precheck, which estimates how many profile + steps the resolved step sizes imply and reports suspiciously small + steps. One of ``"off"`` (disable the precheck), ``"warn"`` (only ever + emit a warning), or ``"raise"`` (raise an error only for extreme, + worst-case estimates, and warn otherwise). """ def __init__( self, - default_step_size: float = 0.01, - min_step_size: float = 0.001, - max_step_size: float = 0.1, + default_step_size_absolute: float = 0.02, + default_step_size_relative: float = 0.0025, + min_step_size_absolute: float = 0.01, + min_step_size_relative: float = 0.00125, + max_step_size_absolute: float = 0.2, + max_step_size_relative: float = 0.025, step_size_factor: float = 1.25, delta_ratio_max: float = 0.1, ratio_min: float = 0.145, @@ -52,12 +83,46 @@ def __init__( reg_order: int = 4, adaptive_target_scaling_factor: float = 1.5, whole_path: bool = False, + step_size_precheck_mode: str = "warn", + default_step_size: float | None = None, + min_step_size: float | None = None, + max_step_size: float | None = None, ): super().__init__() - self.default_step_size = default_step_size - self.min_step_size = min_step_size - self.max_step_size = max_step_size + # Backward compatibility: the absolute step-size arguments were + # renamed. If an old name is passed, it overrides the new one. + if default_step_size is not None: + warnings.warn( + "`default_step_size` is deprecated. Use " + "`default_step_size_absolute` instead.", + DeprecationWarning, + stacklevel=2, + ) + default_step_size_absolute = default_step_size + if min_step_size is not None: + warnings.warn( + "`min_step_size` is deprecated. Use " + "`min_step_size_absolute` instead.", + DeprecationWarning, + stacklevel=2, + ) + min_step_size_absolute = min_step_size + if max_step_size is not None: + warnings.warn( + "`max_step_size` is deprecated. Use " + "`max_step_size_absolute` instead.", + DeprecationWarning, + stacklevel=2, + ) + max_step_size_absolute = max_step_size + + self.default_step_size_absolute = default_step_size_absolute + self.default_step_size_relative = default_step_size_relative + self.min_step_size_absolute = min_step_size_absolute + self.min_step_size_relative = min_step_size_relative + self.max_step_size_absolute = max_step_size_absolute + self.max_step_size_relative = max_step_size_relative self.ratio_min = ratio_min self.step_size_factor = step_size_factor self.delta_ratio_max = delta_ratio_max @@ -65,11 +130,20 @@ def __init__( self.reg_order = reg_order self.adaptive_target_scaling_factor = adaptive_target_scaling_factor self.whole_path = whole_path + self.step_size_precheck_mode = step_size_precheck_mode self.validate() def __getattr__(self, key): """Allow usage of keys like attributes.""" + if key in _DEPRECATED_STEP_SIZE_NAMES: + new_key = _DEPRECATED_STEP_SIZE_NAMES[key] + warnings.warn( + f"`{key}` is deprecated. Use `{new_key}` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self[new_key] try: return self[key] except KeyError: @@ -99,18 +173,45 @@ def validate(self): Raises ``ValueError`` if current settings aren't valid. """ - if self.min_step_size <= 0: - raise ValueError("min_step_size must be > 0.") - if self.max_step_size <= 0: - raise ValueError("max_step_size must be > 0.") - if self.min_step_size > self.max_step_size: - raise ValueError("min_step_size must be <= max_step_size.") - if self.default_step_size <= 0: - raise ValueError("default_step_size must be > 0.") - if self.default_step_size > self.max_step_size: - raise ValueError("default_step_size must be <= max_step_size.") - if self.default_step_size < self.min_step_size: - raise ValueError("default_step_size must be >= min_step_size.") + + def validate_step_size_family(family: str) -> bool: + default_step_size = self[f"default_step_size_{family}"] + min_step_size = self[f"min_step_size_{family}"] + max_step_size = self[f"max_step_size_{family}"] + + if default_step_size < 0: + raise ValueError(f"default_step_size_{family} must be >= 0.") + if default_step_size == 0: + return False + if min_step_size <= 0: + raise ValueError(f"min_step_size_{family} must be > 0.") + if max_step_size <= 0: + raise ValueError(f"max_step_size_{family} must be > 0.") + if min_step_size > default_step_size: + raise ValueError( + f"min_step_size_{family} must be <= " + f"default_step_size_{family}." + ) + if default_step_size > max_step_size: + raise ValueError( + f"default_step_size_{family} must be <= " + f"max_step_size_{family}." + ) + return True + + absolute_enabled = validate_step_size_family("absolute") + relative_enabled = validate_step_size_family("relative") + if not absolute_enabled and not relative_enabled: + raise ValueError( + "At least one step-size family must be enabled by setting " + "default_step_size_absolute > 0 or " + "default_step_size_relative > 0." + ) if self.adaptive_target_scaling_factor < 1: raise ValueError("adaptive_target_scaling_factor must be > 1.") + if self.step_size_precheck_mode not in {"off", "warn", "raise"}: + raise ValueError( + "step_size_precheck_mode must be one of " + "{'off', 'warn', 'raise'}." + ) diff --git a/pypesto/profile/profile.py b/pypesto/profile/profile.py index 58082cb31..9b7934aca 100644 --- a/pypesto/profile/profile.py +++ b/pypesto/profile/profile.py @@ -12,7 +12,11 @@ from .options import ProfileOptions from .profile_next_guess import next_guess from .task import ProfilerTask -from .util import initialize_profile +from .util import ( + _format_profile_step_size_resolution_summary, + initialize_profile, + resolve_profile_step_sizes_for_parameters, +) logger = logging.getLogger(__name__) @@ -94,6 +98,13 @@ def parameter_profile( profile_options = ProfileOptions.create_instance(profile_options) profile_options.validate() + # Resolve the step sizes once, up front + resolved_steps_by_par = resolve_profile_step_sizes_for_parameters( + problem=problem, + parameter_indices=problem.x_free_indices, + options=profile_options, + ) + # Create a function handle that will be called later to get the next point. # This function will be used to generate the initial points of optimization # steps in profiling in `walk_along_profile.py` @@ -119,6 +130,7 @@ def create_next_guess( current_profile_, problem_, global_opt_, + resolved_steps_by_par, min_step_increase_factor_, max_step_reduce_factor_, ) @@ -156,6 +168,14 @@ def create_next_guess( i_par=i_par, profile_list=profile_list, ) + resolved_steps = resolved_steps_by_par[i_par] + logger.debug( + _format_profile_step_size_resolution_summary( + problem=problem, + i_par=i_par, + resolved_steps=resolved_steps, + ) + ) # create two tasks for each parameter: in descending and ascending direction for par_direction in [-1, 1]: @@ -165,6 +185,7 @@ def create_next_guess( optimizer=optimizer, options=profile_options, create_next_guess=create_next_guess, + resolved_steps_by_par=resolved_steps_by_par, global_opt=global_opt, i_par=i_par, par_direction=par_direction, diff --git a/pypesto/profile/profile_next_guess.py b/pypesto/profile/profile_next_guess.py index cbeff3b3a..d09458abf 100644 --- a/pypesto/profile/profile_next_guess.py +++ b/pypesto/profile/profile_next_guess.py @@ -7,6 +7,7 @@ from ..problem import Problem from ..result import ProfilerResult from .options import ProfileOptions +from .util import ResolvedProfileStepSizeMap, ResolvedProfileStepSizes logger = logging.getLogger(__name__) @@ -27,6 +28,7 @@ def next_guess( current_profile: ProfilerResult, problem: Problem, global_opt: float, + resolved_steps_by_par: ResolvedProfileStepSizeMap, min_step_increase_factor: float = 1.0, max_step_reduce_factor: float = 1.0, ) -> np.ndarray: @@ -59,6 +61,8 @@ def next_guess( The problem to be solved. global_opt: Log-posterior value of the global optimum. + resolved_steps_by_par: + Pre-resolved profile step sizes. min_step_increase_factor: Factor to increase the minimal step size bound. Used only in :func:`adaptive_step`. @@ -72,7 +76,12 @@ def next_guess( """ if update_type == "fixed_step": next_initial_guess = fixed_step( - x, par_index, par_direction, profile_options, problem + x, + par_index, + par_direction, + profile_options, + problem, + resolved_steps_by_par, ) elif update_type == "adaptive_step_order_0": order = 0 @@ -93,6 +102,7 @@ def next_guess( current_profile, problem, global_opt, + resolved_steps_by_par, order, min_step_increase_factor, max_step_reduce_factor, @@ -113,11 +123,12 @@ def fixed_step( par_direction: Literal[1, -1], options: ProfileOptions, problem: Problem, + resolved_steps_by_par: ResolvedProfileStepSizeMap, ) -> np.ndarray: """Most simple method to create the next guess. - Computes the next point based on the fixed step size given by - :attr:`pypesto.profile.ProfileOptions.default_step_size`. + Computes the next point based on the resolved default step size for the + profiled parameter. Parameters ---------- @@ -131,13 +142,16 @@ def fixed_step( Various options applied to the profile optimization. problem: The problem to be solved. + resolved_steps_by_par: + Pre-resolved profile step sizes. Returns ------- The updated parameter vector, of size `dim_full`. """ + resolved_steps = resolved_steps_by_par[par_index] delta_x = np.zeros(len(x)) - delta_x[par_index] = par_direction * options.default_step_size + delta_x[par_index] = par_direction * resolved_steps.default_step_size # check whether the next point is maybe outside the bounds # and correct it @@ -158,6 +172,7 @@ def adaptive_step( current_profile: ProfilerResult, problem: Problem, global_opt: float, + resolved_steps_by_par: ResolvedProfileStepSizeMap, order: int = 1, min_step_increase_factor: float = 1.0, max_step_reduce_factor: float = 1.0, @@ -183,6 +198,8 @@ def adaptive_step( The problem to be solved. global_opt: Log-posterior value of the global optimum. + resolved_steps_by_par: + Pre-resolved profile step sizes. order: Specifies the precise algorithm for extrapolation. Available options are: @@ -201,11 +218,15 @@ def adaptive_step( ------- The updated parameter vector, of size `dim_full`. """ + resolved_steps = resolved_steps_by_par[par_index] + trust_region_max_step = np.zeros(len(x)) + for i_step_par, i_resolved_steps in resolved_steps_by_par.items(): + trust_region_max_step[i_step_par] = i_resolved_steps.max_step_size # restrict step proposal to minimum and maximum step size def clip_to_minmax(step_size_proposal): - min_step_size = options.min_step_size * min_step_increase_factor - max_step_size = options.max_step_size * max_step_reduce_factor + min_step_size = resolved_steps.min_step_size * min_step_increase_factor + max_step_size = resolved_steps.max_step_size * max_step_reduce_factor return np.clip(step_size_proposal, min_step_size, max_step_size) # restrict step proposal to bounds @@ -230,13 +251,16 @@ def clip_to_bounds(step_proposal): current_profile, problem, options, + resolved_steps, ) # check whether we must make a minimum step anyway, since we're close to # the next bound min_delta_x = ( x[par_index] - + par_direction * options.min_step_size * min_step_increase_factor + + par_direction + * resolved_steps.min_step_size + * min_step_increase_factor ) if par_direction == -1 and (min_delta_x < problem.lb_full[par_index]): @@ -275,7 +299,9 @@ def par_extrapol(step_length): # Define a trust region for the step size in all directions # to avoid overshooting x_step = np.clip( - x_step, x - options.max_step_size, x + options.max_step_size + x_step, + x - trust_region_max_step, + x + trust_region_max_step, ) return clip_to_bounds(x_step) @@ -287,8 +313,8 @@ def par_extrapol(step_length): # to avoid overshooting step_in_x = np.clip( step_length * delta_x_dir, - -options.max_step_size, - options.max_step_size, + -trust_region_max_step, + trust_region_max_step, ) x_stepped = x + step_in_x return clip_to_bounds(x_stepped) @@ -339,6 +365,8 @@ def par_extrapol(step_length): par_index, problem, options, + resolved_steps.min_step_size, + resolved_steps.max_step_size, min_step_increase_factor, max_step_reduce_factor, ) @@ -353,7 +381,8 @@ def handle_profile_history( current_profile: ProfilerResult, problem: Problem, options: ProfileOptions, -) -> tuple[float, np.array, list[float], float]: + resolved_steps: ResolvedProfileStepSizes, +) -> tuple[float, np.ndarray, list[float], float, float]: """Compute the very first step direction update guesses. Check whether enough steps have been taken for applying regression, @@ -386,7 +415,7 @@ def handle_profile_history( current_profile.x_path[par_index, -2], ): # try to use the default step size - step_size_guess = options.default_step_size + step_size_guess = resolved_steps.default_step_size delta_obj_value = 0.0 last_delta_fval = 0.0 @@ -398,10 +427,11 @@ def handle_profile_history( ) # Bound the step size by default values step_size_guess = min( - last_delta_x_par_index, options.default_step_size + last_delta_x_par_index, + resolved_steps.default_step_size, ) # Step size cannot be smaller than the minimum step size - step_size_guess = max(step_size_guess, options.min_step_size) + step_size_guess = max(step_size_guess, resolved_steps.min_step_size) delta_obj_value = current_profile.fval_path[-1] - global_opt last_delta_fval = ( @@ -491,6 +521,8 @@ def do_line_search( par_index: int, problem: Problem, options: ProfileOptions, + effective_min_step_size: float, + effective_max_step_size: float, min_step_increase_factor: float, max_step_reduce_factor: float, ) -> np.ndarray: @@ -556,18 +588,12 @@ def do_line_search( step_size_guess = clip_to_minmax(step_size_guess * adapt_factor) next_x = clip_to_bounds(par_extrapol(step_size_guess)) - # Check if we hit the bounds - if ( - direction == "decrease" - and step_size_guess - == options.min_step_size * min_step_increase_factor - ): + # Check if the step-size clipping hit the adaptive bounds. + min_step_bound = effective_min_step_size * min_step_increase_factor + max_step_bound = effective_max_step_size * max_step_reduce_factor + if direction == "decrease" and step_size_guess <= min_step_bound: return next_x - if ( - direction == "increase" - and step_size_guess - == options.max_step_size * max_step_reduce_factor - ): + if direction == "increase" and step_size_guess >= max_step_bound: return next_x # compute new objective value diff --git a/pypesto/profile/task.py b/pypesto/profile/task.py index 523db01d4..021e8eb7b 100644 --- a/pypesto/profile/task.py +++ b/pypesto/profile/task.py @@ -8,6 +8,7 @@ from ..problem import Problem from ..result import ProfilerResult from .options import ProfileOptions +from .util import ResolvedProfileStepSizeMap, precheck_profile_step_size from .walk_along_profile import walk_along_profile logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def __init__( global_opt: float, optimizer: "pypesto.optimize.Optimizer", create_next_guess: Callable, + resolved_steps_by_par: ResolvedProfileStepSizeMap, par_direction: Literal[-1, 1], ): """ @@ -44,6 +46,8 @@ def __init__( Various options applied to the profile optimization. create_next_guess: Handle of the method which creates the next profile point proposal + resolved_steps_by_par: + Pre-resolved profile step sizes. i_par: index for the current parameter par_direction: @@ -57,6 +61,7 @@ def __init__( self.current_profile = current_profile self.global_opt = global_opt self.create_next_guess = create_next_guess + self.resolved_steps_by_par = resolved_steps_by_par self.i_par = i_par self.options = options self.par_direction = par_direction @@ -70,6 +75,15 @@ def execute(self) -> dict[str, Any]: # flip profile self.current_profile.flip_profile() + precheck_profile_step_size( + current_profile=self.current_profile, + problem=self.problem, + i_par=self.i_par, + par_direction=self.par_direction, + options=self.options, + resolved_steps=self.resolved_steps_by_par[self.i_par], + ) + # compute the current profile self.current_profile = walk_along_profile( current_profile=self.current_profile, @@ -78,6 +92,7 @@ def execute(self) -> dict[str, Any]: optimizer=self.optimizer, options=self.options, create_next_guess=self.create_next_guess, + resolved_steps_by_par=self.resolved_steps_by_par, global_opt=self.global_opt, i_par=self.i_par, ) diff --git a/pypesto/profile/util.py b/pypesto/profile/util.py index 3ea7a0d00..4a4303f84 100644 --- a/pypesto/profile/util.py +++ b/pypesto/profile/util.py @@ -1,7 +1,9 @@ """Utility function for profile module.""" +import warnings from collections.abc import Iterable -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal import numpy as np import scipy.stats @@ -9,6 +11,42 @@ from ..C import GRAD from ..problem import Problem from ..result import ProfileResult, ProfilerResult, Result +from .options import ProfileOptions + +PROFILE_STEP_PRECHECK_NOMINAL_WARN_THRESHOLD = 500 +PROFILE_STEP_PRECHECK_DENSE_WARN_THRESHOLD = 1000 + + +@dataclass(frozen=True) +class ResolvedProfileStepSizes: + """ + Effective step sizes for one profiled parameter. + + The minimum, default, and maximum values always come from the same + step-size family. + + Attributes + ---------- + mode: + Selected step-size family, either `"absolute"` or `"relative"`. + default_step_size: + Resolved default step size. + min_step_size: + Resolved minimum step size. + max_step_size: + Resolved maximum step size. + span: + Parameter span `ub - lb` on the optimization scale. + """ + + mode: Literal["absolute", "relative"] + default_step_size: float + min_step_size: float + max_step_size: float + span: float + + +ResolvedProfileStepSizeMap = dict[int, ResolvedProfileStepSizes] def chi2_quantile_to_ratio(alpha: float = 0.95, df: int = 1): @@ -80,6 +118,189 @@ def calculate_approximate_ci( return lb, ub +def validate_profile_parameter_bounds(problem: Problem, i_par: int) -> float: + """Validate finite profile bounds for one parameter and return its span.""" + lb = float(problem.lb_full[i_par]) + ub = float(problem.ub_full[i_par]) + if not np.isfinite(lb) or not np.isfinite(ub): + raise ValueError( + "Profiling requires finite lower and upper bounds for parameter " + f"'{problem.x_names[i_par]}' (index={i_par})." + ) + span = ub - lb + if span <= 0: + raise ValueError( + "Profiling requires an upper bound greater than the lower bound " + f"for parameter '{problem.x_names[i_par]}' (index={i_par})." + ) + return span + + +def resolve_profile_step_sizes( + problem: Problem, + i_par: int, + options: ProfileOptions, +) -> ResolvedProfileStepSizes: + """ + Resolve effective profile step sizes for one parameter. + + Relative step sizes are scaled by the parameter span `ub - lb`. If the + resolved relative default is at least as large as the absolute default, + the full relative family is used. Otherwise the full absolute family is + used. + + Parameters + ---------- + problem: + The parameter estimation problem containing bounds and scales. + i_par: + Index of the profiled parameter in full dimension. + options: + Profile options containing absolute and relative step-size settings. + + Returns + ------- + resolved_steps: + Resolved step sizes and selection metadata. + """ + # Bounds are required here because relative steps are defined from the + # finite parameter span on the optimization scale. + span = validate_profile_parameter_bounds(problem, i_par) + + if options.default_step_size_relative > 0: + relative_default_step_size = options.default_step_size_relative * span + relative_min_step_size = options.min_step_size_relative * span + relative_max_step_size = options.max_step_size_relative * span + + # Select one complete step-size family based on the default step size. + if ( + options.default_step_size_absolute == 0 + or relative_default_step_size >= options.default_step_size_absolute + ): + return ResolvedProfileStepSizes( + mode="relative", + default_step_size=relative_default_step_size, + min_step_size=relative_min_step_size, + max_step_size=relative_max_step_size, + span=span, + ) + + return ResolvedProfileStepSizes( + mode="absolute", + default_step_size=options.default_step_size_absolute, + min_step_size=options.min_step_size_absolute, + max_step_size=options.max_step_size_absolute, + span=span, + ) + + +def resolve_profile_step_sizes_for_parameters( + problem: Problem, + parameter_indices: Iterable[int], + options: ProfileOptions, +) -> ResolvedProfileStepSizeMap: + """Resolve effective profile step sizes for multiple parameters.""" + return { + i_par: resolve_profile_step_sizes(problem, i_par, options) + for i_par in parameter_indices + } + + +def _format_profile_step_size_resolution_summary( + problem: Problem, + i_par: int, + resolved_steps: ResolvedProfileStepSizes, +) -> str: + """Create a one-line summary of the resolved step-size family.""" + scale = str(problem.x_scales[i_par]).lower() + parameter_name = problem.x_names[i_par] + + return ( + "Resolved profile step sizes for " + f"{parameter_name} (index={i_par}): " + f"family={resolved_steps.mode}, " + f"scale={scale}, span={resolved_steps.span}, " + f"min={resolved_steps.min_step_size}, " + f"default={resolved_steps.default_step_size}, " + f"max={resolved_steps.max_step_size}." + ) + + +def precheck_profile_step_size( + current_profile: ProfilerResult, + problem: Problem, + i_par: int, + par_direction: int, + options: ProfileOptions, + resolved_steps: ResolvedProfileStepSizes, +) -> None: + """ + Warn or raise if the resolved step sizes imply many profile steps. + + Two estimates are formed: a nominal one from the default step size and a + worst-case one from the minimum step size. In ``"raise"`` mode, an error + is raised only when the worst-case estimate is excessive; a merely large + nominal estimate only triggers a warning, so valid runs are not broken. + + Parameters + ---------- + current_profile: + Current profile path. + problem: + The parameter estimation problem. + i_par: + Index of the profiled parameter in full dimension. + par_direction: + Profiling direction, either `-1` for descending or `1` for ascending. + options: + Profile options. + resolved_steps: + Pre-resolved step sizes for the profiled parameter. + """ + if options.step_size_precheck_mode == "off": + return + + # Estimate how much of the bounded parameter range is left in this + # profiling direction. + x0 = float(current_profile.x_path[i_par, -1]) + if par_direction == -1: + available_span = x0 - float(problem.lb_full[i_par]) + elif par_direction == 1: + available_span = float(problem.ub_full[i_par]) - x0 + else: + raise ValueError("par_direction must be either -1 or 1.") + + if not np.isfinite(available_span) or available_span <= 0: + return + + # Use the resolved default and minimum steps as nominal and dense estimates. + nominal_count = available_span / resolved_steps.default_step_size + dense_count = available_span / resolved_steps.min_step_size + + nominal_warn = nominal_count > PROFILE_STEP_PRECHECK_NOMINAL_WARN_THRESHOLD + dense_warn = dense_count > PROFILE_STEP_PRECHECK_DENSE_WARN_THRESHOLD + if not nominal_warn and not dense_warn: + return + + parameter_name = problem.x_names[i_par] + message = ( + f"Profiling parameter '{parameter_name}' may require many steps " + f"({nominal_count:.1f} with the default step size, " + f"up to {dense_count:.1f} with the minimum step size). " + "Consider increasing the profile step sizes." + ) + if not options.whole_path: + message += ( + " This is a bound-based upper estimate; profiling may stop " + "earlier at the likelihood-ratio threshold." + ) + + if dense_warn and options.step_size_precheck_mode == "raise": + raise ValueError(message) + + warnings.warn(message, UserWarning, stacklevel=2) + + def initialize_profile( problem: Problem, result: Result, diff --git a/pypesto/profile/walk_along_profile.py b/pypesto/profile/walk_along_profile.py index 5cb1d25b3..be85b0523 100644 --- a/pypesto/profile/walk_along_profile.py +++ b/pypesto/profile/walk_along_profile.py @@ -9,6 +9,7 @@ from ..problem import Problem from ..result import OptimizerResult, ProfilerResult from .options import ProfileOptions +from .util import ResolvedProfileStepSizeMap logger = logging.getLogger(__name__) @@ -20,6 +21,7 @@ def walk_along_profile( optimizer: Optimizer, options: ProfileOptions, create_next_guess: Callable, + resolved_steps_by_par: ResolvedProfileStepSizeMap, global_opt: float, i_par: int, max_tries: int = 10, @@ -47,6 +49,8 @@ def walk_along_profile( Various options applied to the profile optimization. create_next_guess: Handle of the method which creates the next profile point proposal + resolved_steps_by_par: + Pre-resolved profile step sizes. i_par: index for the current parameter max_tries: @@ -60,6 +64,8 @@ def walk_along_profile( if par_direction not in (-1, 1): raise AssertionError("par_direction must be -1 or 1") + resolved_steps = resolved_steps_by_par[i_par] + # while loop for profiling (will be exited by break command) while True: # get current position on the profile path @@ -86,8 +92,8 @@ def walk_along_profile( while not optimization_successful: # Check max_step_size is not reduced below min_step_size if ( - options.max_step_size * max_step_reduce_factor - < options.min_step_size + resolved_steps.max_step_size * max_step_reduce_factor + < resolved_steps.min_step_size ): logger.warning( "Max step size reduced below min step size. " @@ -134,7 +140,8 @@ def walk_along_profile( max_step_reduce_factor *= 0.5 logger.warning( f"Optimization at {problem.x_names[i_par]}={x_next[i_par]} failed. " - f"Reducing max_step_size to {options.max_step_size * max_step_reduce_factor}." + "Reducing max_step_size to " + f"{resolved_steps.max_step_size * max_step_reduce_factor}." ) else: # if too many parameters are fixed, there is nothing to do ... @@ -167,8 +174,8 @@ def walk_along_profile( while not optimization_successful: # Check min_step_size is not increased above max_step_size if ( - options.min_step_size * min_step_increase_factor - > options.max_step_size + resolved_steps.min_step_size * min_step_increase_factor + > resolved_steps.max_step_size ): logger.warning( "Min step size increased above max step size. " @@ -208,7 +215,8 @@ def walk_along_profile( min_step_increase_factor *= 1.25 logger.warning( f"Optimization at {problem.x_names[i_par]}={x_next[i_par]} failed. " - f"Increasing min_step_size to {options.min_step_size * min_step_increase_factor}." + "Increasing min_step_size to " + f"{resolved_steps.min_step_size * min_step_increase_factor}." ) if not optimization_successful: diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 3b1de1933..70042c905 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -4,6 +4,7 @@ import importlib import logging +from typing import Any import numpy as np @@ -15,10 +16,9 @@ logger = logging.getLogger(__name__) -# Lazy import of pymc, arviz, and pytensor +# Lazy import of pymc and pytensor # Check availability once at module load time _HAS_PYMC = importlib.util.find_spec("pymc") is not None -_HAS_ARVIZ = importlib.util.find_spec("arviz") is not None if _HAS_PYMC: import pymc @@ -30,15 +30,20 @@ pt = None _PT_OP_BASE = object -if _HAS_ARVIZ: - import arviz as az -else: - az = None - # implementation based on: # https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html +# TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0 +# and simplify this helper to `data.posterior.to_dataset()`. +def _get_posterior_dataset(data: Any) -> Any: + """Return posterior as an xarray Dataset across ArviZ versions.""" + posterior = data.posterior + if hasattr(posterior, "to_array"): + return posterior + return posterior.to_dataset() + + class PymcObjectiveOp(_PT_OP_BASE): """PyTensor wrapper around a (non-normalized) log-probability function.""" @@ -151,7 +156,7 @@ def __init__( self.problem: Problem | None = None self.x0: np.ndarray | None = None self.trace: pymc.backends.Text | None = None - self.data: az.InferenceData | None = None + self.data: Any | None = None @classmethod def translate_options(cls, options): @@ -251,10 +256,10 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" + posterior = _get_posterior_dataset(self.data) + # dimensions - n_par, n_chain, n_iter = np.asarray( - self.data.posterior.to_array() - ).shape + n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape n_par -= 1 # remove log-posterior # parameters @@ -263,10 +268,10 @@ def get_samples(self) -> McmcPtResult: if len(par_ids) != n_par: raise AssertionError("Mismatch of parameter dimension") for i_par, par_id in enumerate(par_ids): - trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id]) + trace_x[:, :, i_par] = np.asarray(posterior[par_id]) # function values - trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"]) + trace_neglogpost = -np.asarray(posterior["loggyposty"]) if ( trace_x.shape[0] != trace_neglogpost.shape[0] diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index d8b4cf711..f31adcb61 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -9,7 +9,11 @@ from ..objective import FD from ..result import McmcPtResult -from ..sample.pymc import PymcObjectiveOp, PymcSampler +from ..sample.pymc import ( + PymcObjectiveOp, + PymcSampler, + _get_posterior_dataset, +) from ..sample.sampler import SamplerImportError logger = logging.getLogger(__name__) @@ -89,7 +93,7 @@ def fit( } # create model context - with pymc.Model(): + with pymc.Model() as model: # parameter bounds as uniform prior _k = [ pymc.Uniform(x_name, lower=lb, upper=ub) @@ -119,6 +123,7 @@ def fit( ) self.data = data + self.model = model def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: """ @@ -130,16 +135,19 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: Number of samples to be computed. """ # get InferenceData object - pymc_data = self.data.sample(n_samples) + with self.model: + pymc_data = self.data.sample(n_samples) + posterior = _get_posterior_dataset(pymc_data) + x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( - [pymc_data.posterior[name].values for name in x_names_free] + [posterior[name].values for name in x_names_free] ).T return McmcPtResult( trace_x=post_samples[np.newaxis, :], - trace_neglogpost=pymc_data.posterior.loggyposty.values, + trace_neglogpost=posterior.loggyposty.values, trace_neglogprior=np.full( - pymc_data.posterior.loggyposty.values.shape, np.nan + posterior.loggyposty.values.shape, np.nan ), betas=np.array([1.0] * post_samples.shape[0]), burn_in=0, diff --git a/pypesto/visualize/__init__.py b/pypesto/visualize/__init__.py index 7318cf9c9..0362aeb83 100644 --- a/pypesto/visualize/__init__.py +++ b/pypesto/visualize/__init__.py @@ -44,7 +44,13 @@ parameters_lowlevel, ) from .profile_cis import profile_cis, profile_nested_cis -from .profiles import profile_lowlevel, profiles, profiles_lowlevel +from .profiles import ( + profile_lowlevel, + profile_lowlevel_2d, + profiles, + profiles_lowlevel, + visualize_2d_profile, +) from .reference_points import ReferencePoint, create_references from .sampling import ( sampling_1d_marginals, diff --git a/pypesto/visualize/dimension_reduction.py b/pypesto/visualize/dimension_reduction.py index d94bae435..7d8bbb335 100644 --- a/pypesto/visualize/dimension_reduction.py +++ b/pypesto/visualize/dimension_reduction.py @@ -3,10 +3,11 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -import matplotlib.pyplot as plt +import matplotlib.axes import numpy as np from ..C import COLOR +from .misc import get_ax, get_axes_array, hide_unused_axes if TYPE_CHECKING: try: @@ -19,7 +20,7 @@ def projection_scatter_umap( umap_coordinates: np.ndarray, components: Sequence[int] = (0, 1), **kwargs -): +) -> matplotlib.axes.Axes | np.ndarray: """ Plot a scatter plots for UMAP coordinates. @@ -38,8 +39,8 @@ def projection_scatter_umap( Returns ------- axs: - Either one axes object, or a dictionary of plot axes (depending on the - number of coordinates passed) + Either a single matplotlib Axes (2 components) or a 2-D NumPy array + of Axes (more than 2 components). """ n_components = len(components) if n_components == 2: @@ -71,8 +72,9 @@ def projection_scatter_umap_original( umap_object: UmapTypeObject, color_by: Sequence[float] = None, components: Sequence[int] = (0, 1), + ax: matplotlib.axes.Axes | None = None, **kwargs, -): +) -> matplotlib.axes.Axes: """ See `projection_scatter_umap` for more documentation. @@ -88,10 +90,12 @@ def projection_scatter_umap_original( A sequence/list of floats, which specify the color in the colormap components: Components to be plotted (corresponds to columns of umap_coordinates) + ax: + Axes object to use. Returns ------- - ax: matplotlib.Axes + ax: matplotlib.axes.Axes The plot axes. """ import umap.plot @@ -100,12 +104,20 @@ def projection_scatter_umap_original( umap_object.embedding_ = umap_object.embedding_[:, components] # use umap's original plotting routine to visualize - umap.plot.points(umap_object, values=color_by, theme="viridis", **kwargs) + if ax is not None: + kwargs["ax"] = ax + + return umap.plot.points( + umap_object, + values=color_by, + theme="viridis", + **kwargs, + ) def projection_scatter_pca( pca_coordinates: np.ndarray, components: Sequence[int] = (0, 1), **kwargs -): +) -> matplotlib.axes.Axes | np.ndarray: """ Plot a scatter plot for PCA coordinates. @@ -123,8 +135,8 @@ def projection_scatter_pca( Returns ------- axs: - Either one axes object, or a dictionary of plot axes (depending on the - number of coordinates passed) + Either a single matplotlib Axes (2 components) or a 2-D NumPy array + of Axes (more than 2 components). """ n_components = len(components) if n_components == 2: @@ -154,8 +166,12 @@ def projection_scatter_pca( def ensemble_crosstab_scatter_lowlevel( - dataset: np.ndarray, component_labels: Sequence[str] = None, **kwargs -): + dataset: np.ndarray, + component_labels: Sequence[str] = None, + axes: np.ndarray | None = None, + size: tuple[float, float] | None = None, + **kwargs, +) -> np.ndarray: """ Plot cross-classification table of scatter plots for different coordinates. @@ -171,17 +187,24 @@ def ensemble_crosstab_scatter_lowlevel( Returns ------- - axs: - A dictionary of plot axes. + axes: + 2-D NumPy array containing one matplotlib Axes per panel. """ # We got more than two components. Create a cross-classification table n_components = dataset.shape[1] - axs = _create_crosstab_axes(n_components) + if component_labels is None: + component_labels = [ + f"component {i_component + 1}" + for i_component in range(n_components) + ] - # wo don't even try to plot this into an existing axes object. - # Overplotting a multi-axes figure is asking for trouble... - if "ax" in kwargs.keys(): - del kwargs["ax"] + if "ax" in kwargs: + if axes is None: + axes = kwargs.pop("ax") + else: + del kwargs["ax"] + + axes = _create_crosstab_axes(n_components, axes=axes, size=size) for x_comp in range(0, n_components - 1): for y_comp in range(x_comp + 1, n_components): @@ -201,17 +224,16 @@ def ensemble_crosstab_scatter_lowlevel( tmp_dataset, x_label=x_label, y_label=y_label, - ax=axs[(x_comp, y_comp)], + ax=axes[y_comp - 1, x_comp], **kwargs, ) - # return dict of axes - return axs + return axes def ensemble_scatter_lowlevel( dataset, - ax: plt.Axes | None = None, - size: tuple[float] | None = (12, 6), + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = (12, 6), x_label: str = "component 1", y_label: str = "component 2", color_by: Sequence[float] = None, @@ -220,7 +242,7 @@ def ensemble_scatter_lowlevel( marker_type: str = ".", scatter_size: float = 0.5, invert_scatter_order: bool = False, -): +) -> matplotlib.axes.Axes: """ Create a scatter plot. @@ -253,15 +275,10 @@ def ensemble_scatter_lowlevel( Returns ------- - ax: matplotlib.Axes + ax: matplotlib.axes.Axes The plot axes. """ - # first get the data to check identifiability - # axes - if ax is None: - fig, ax = plt.subplots() - fig.set_size_inches(*size) - plt.sca(ax) + ax = get_ax(ax, size) if color_by is None: color_by = np.array([1.0] * dataset.shape[0]) @@ -270,7 +287,7 @@ def ensemble_scatter_lowlevel( if invert_scatter_order: ordering = -1 - plt.scatter( + ax.scatter( dataset[::ordering, 0], dataset[::ordering, 1], c=color_by, @@ -281,17 +298,19 @@ def ensemble_scatter_lowlevel( # beautify ax.set_facecolor(background_color) - plt.xlabel(x_label) - plt.ylabel(y_label) - plt.xticks([]) - plt.yticks([]) - - plt.tight_layout() + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_xticks([]) + ax.set_yticks([]) return ax -def _create_crosstab_axes(n_comp: int): +def _create_crosstab_axes( + n_comp: int, + axes: np.ndarray | None = None, + size: tuple[float, float] | None = None, +) -> np.ndarray: """ Create a figure with cross-classification table of axes. @@ -302,15 +321,22 @@ def _create_crosstab_axes(n_comp: int): Returns ------- - axs: - A dictionary of plot axes. + axes: + A 2-D NumPy array of plot axes. """ - axs = {} - - # run over x- and y-coordinate - for x_comp in range(0, n_comp - 1): - for y_comp in range(x_comp + 1, n_comp): - i_ax = (y_comp - 1) * (n_comp - 1) + x_comp + 1 - axs[(x_comp, y_comp)] = plt.subplot(n_comp - 1, n_comp - 1, i_ax) - - return axs + n_grid = n_comp - 1 + if size is None and axes is None: + size = (3.0 * n_grid, 3.0 * n_grid) + + axes = get_axes_array(axes=axes, nrows=n_grid, ncols=n_grid, size=size) + used_indices = [ + (y_comp - 1) * n_grid + x_comp + for x_comp in range(0, n_comp - 1) + for y_comp in range(x_comp + 1, n_comp) + ] + axes = hide_unused_axes( + axes=axes, + used_indices=used_indices, + clear=True, + ) + return axes diff --git a/pypesto/visualize/ensemble.py b/pypesto/visualize/ensemble.py index 4d6bf25f5..1f08b8f49 100644 --- a/pypesto/visualize/ensemble.py +++ b/pypesto/visualize/ensemble.py @@ -1,4 +1,4 @@ -import matplotlib.pyplot as plt +import matplotlib.axes import numpy as np import pandas as pd from matplotlib.collections import PatchCollection @@ -6,13 +6,14 @@ from ..C import COLOR_HIT_BOTH_BOUNDS, COLOR_HIT_NO_BOUNDS, COLOR_HIT_ONE_BOUND from ..ensemble import Ensemble +from .misc import get_ax def ensemble_identifiability( ensemble: Ensemble, - ax: plt.Axes | None = None, - size: tuple[float] | None = (12, 6), -): + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = (12, 6), +) -> matplotlib.axes.Axes: """ Visualize identifiablity of parameter ensemble. @@ -33,7 +34,7 @@ def ensemble_identifiability( Returns ------- - ax: matplotlib.Axes + ax: matplotlib.axes.Axes The plot axes. """ # first get the data to check identifiability @@ -55,9 +56,9 @@ def ensemble_identifiability_lowlevel( lb_hit: np.ndarray, ub_hit: np.ndarray, both_hit: np.ndarray, - ax: plt.Axes | None = None, - size: tuple[float] | None = (16, 10), -): + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = (16, 10), +) -> matplotlib.axes.Axes: """ Low-level identifiablity routine. @@ -89,7 +90,7 @@ def ensemble_identifiability_lowlevel( Returns ------- - ax: matplotlib.Axes + ax: matplotlib.axes.Axes The plot axes. """ # define some short hands for later plotting @@ -113,11 +114,7 @@ def ensemble_identifiability_lowlevel( patches_none_hit, ) = _create_patches(none_hit, lb_hit, ub_hit, both_hit) - # axes - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) # create axes object and add patch collections if patches_both_hit: @@ -193,16 +190,16 @@ def ensemble_identifiability_lowlevel( ax.text(-0.03, 0.0, "lower\nbound", ha="right", va="center") ax.plot([-0.02, 1.03], [0, 0], "k:", linewidth=1.5) ax.plot([-0.02, 1.03], [1, 1], "k:", linewidth=1.5) - plt.xticks([]) - plt.yticks([]) + ax.set_xticks([]) + ax.set_yticks([]) # plot frame ax.plot([0, 0], vert, "k-", linewidth=1.5) ax.plot([1, 1], vert, "k-", linewidth=1.5) # beautify axes - plt.xlim((-0.15, 1.1)) - plt.ylim((-0.78, 1.15)) + ax.set_xlim((-0.15, 1.1)) + ax.set_ylim((-0.78, 1.15)) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.spines["bottom"].set_visible(False) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index a196a0755..b74f2b55f 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -4,8 +4,9 @@ import warnings from collections.abc import Iterable from numbers import Number -from typing import TYPE_CHECKING +import matplotlib.axes +import matplotlib.pyplot as plt import numpy as np from ..C import ( @@ -27,9 +28,6 @@ from ..util import assign_clusters, delete_nan_inf from .clust_color import assign_colors_for_list -if TYPE_CHECKING: - from matplotlib.pyplot import Axes - logger = logging.getLogger(__name__) @@ -158,9 +156,9 @@ def process_offset_y( def process_y_limits( - ax: Axes, + ax: matplotlib.axes.Axes, y_limits: None | Iterable[float] | np.ndarray, -) -> Axes: +) -> matplotlib.axes.Axes: """ Apply user specified limits of y-axis. @@ -420,3 +418,267 @@ def process_parameter_indices( f"{ALL}, {FREE_ONLY} or a list of indices." ) return list(parameter_indices) + + +def make_grid_shape(n_panels: int) -> tuple[int, int]: + """ + Return a near-square ``(nrows, ncols)`` grid for ``n_panels`` subplots. + + Parameters + ---------- + n_panels: + Number of panels to arrange. + + Returns + ------- + nrows, ncols: + Smallest grid with ``nrows * ncols >= n_panels`` and aspect ratio + close to square. + """ + if n_panels < 1: + raise ValueError("n_panels must be at least 1.") + nrows = int(np.ceil(np.sqrt(n_panels))) + ncols = int(np.ceil(n_panels / nrows)) + return nrows, ncols + + +def get_ax( + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = None, +) -> matplotlib.axes.Axes: + """ + Return an Axes, creating one of size ``size`` if ``ax`` is None. + + Parameters + ---------- + ax: + Existing matplotlib Axes. If provided, returned unchanged. + size: + Figure size ``(width, height)`` in inches; only used when ``ax`` is + None. If None, matplotlib's default figure size is used. + + Returns + ------- + ax: + A matplotlib Axes. + """ + if ax is not None: + return ax + _, ax = plt.subplots(figsize=size, layout="constrained") + return ax + + +def get_axes_array( + axes: matplotlib.axes.Axes | np.ndarray | None = None, + nrows: int = 1, + ncols: int = 1, + size: tuple[float, float] | None = None, +) -> np.ndarray: + """ + Return a 2-D array of Axes, creating one if ``axes`` is None. + + Parameters + ---------- + axes: + Existing matplotlib Axes grid. If provided, it is normalized to a + 2-D object array and validated against ``(nrows, ncols)``. + nrows, ncols: + Expected grid shape. + size: + Figure size ``(width, height)`` in inches; only used when ``axes`` + is None. + + Returns + ------- + axes: + A 2-D NumPy object array containing matplotlib Axes. + """ + if axes is None: + _, axes = plt.subplots( + nrows, + ncols, + squeeze=False, + figsize=size, + layout="constrained", + ) + return axes + + axes_array = np.asarray(axes, dtype=object) + if axes_array.ndim == 0: + axes_array = axes_array.reshape(1, 1) + elif axes_array.ndim == 1: + if nrows == 1: + axes_array = axes_array.reshape(1, ncols) + elif ncols == 1: + axes_array = axes_array.reshape(nrows, 1) + else: + raise ValueError(f"Pass `axes` with shape ({nrows}, {ncols}).") + + if axes_array.shape != (nrows, ncols): + raise ValueError(f"Pass `axes` with shape ({nrows}, {ncols}).") + + return axes_array + + +def hide_unused_axes( + axes: np.ndarray, + n_used: int | None = None, + used_indices: Iterable[int] | None = None, + clear: bool = False, +) -> np.ndarray: + """ + Hide unused axes in a 2-D grid and ensure used axes are visible. + + Parameters + ---------- + axes: + 2-D NumPy array containing matplotlib Axes. + n_used: + Number of leading axes in ``axes.flat`` to keep visible. + used_indices: + Flat indices of the axes that should remain visible. + clear: + Whether to clear every axis before toggling visibility. + + Returns + ------- + axes: + The same 2-D NumPy array with updated visibility. + """ + axes_array = np.asarray(axes, dtype=object) + if axes_array.ndim != 2: + raise ValueError("Pass `axes` as a 2-D NumPy array.") + + if (n_used is None) == (used_indices is None): + raise ValueError("Pass exactly one of `n_used` or `used_indices`.") + + if used_indices is None: + if n_used is None or not 0 <= n_used <= axes_array.size: + raise ValueError( + f"`n_used` must be between 0 and {axes_array.size}." + ) + visible_indices = set(range(n_used)) + else: + visible_indices = set(used_indices) + invalid_indices = [ + index + for index in visible_indices + if index < 0 or index >= axes_array.size + ] + if invalid_indices: + raise ValueError( + "Pass `used_indices` within the flattened axes range " + f"[0, {axes_array.size - 1}]." + ) + + for index, ax in enumerate(axes_array.flat): + if clear: + ax.clear() + ax.set_visible(index in visible_indices) + + return axes_array + + +def plot_diagonal_marginal( + ax: matplotlib.axes.Axes, + values: np.ndarray, + diag_kind: str = "kde", + color: str = "C0", +) -> None: + """ + Plot a 1-D marginal on a diagonal scatter-matrix panel. + + Parameters + ---------- + ax: + Axes to draw into. + values: + One-dimensional sample values. + diag_kind: + Marginal visualization mode: ``"kde"`` or ``"hist"``. + color: + Base matplotlib color for the marginal. + """ + from scipy.stats import gaussian_kde + + values = np.asarray(values) + if values.size == 0: + return + data_range = values.max() - values.min() + if data_range == 0: + data_range = max(abs(float(values.mean())) * 0.1, 0.1) + x_pad = data_range * 0.25 + x_grid = np.linspace(values.min() - x_pad, values.max() + x_pad, 300) + + if diag_kind == "kde" and len(values) > 1: + try: + kde = gaussian_kde(values) + y_grid = kde(x_grid) + ax.fill_between(x_grid, y_grid, alpha=0.35, color=color) + ax.plot(x_grid, y_grid, color=color, lw=1.5) + ax.set_ylabel("Density") + return + except np.linalg.LinAlgError: + pass + + ax.hist(values, bins="auto", color=color, alpha=0.6) + ax.set_ylabel("Count") + + +#: Sentinel meaning "this kwarg was not passed at all." +#: Use as the default for deprecated kwargs so that an explicit +#: ``f(old_kwarg=None)`` can be detected and warned about. +_UNSET = object() + + +def process_deprecated_kwarg( + canonical_name: str, + canonical_value, + deprecated_name: str, + deprecated_value=_UNSET, + stacklevel: int = 3, +): + """ + Resolve a kwarg that has been renamed. + + The deprecated kwarg must use :data:`_UNSET` as its default in the + calling function so that an explicit ``f(old_kwarg=None)`` is correctly + detected and warned about. + + Returns the canonical value if the deprecated kwarg was not passed, + the deprecated value (with a ``DeprecationWarning``) if only the old + name was used, or raises ``ValueError`` if both are given. + + Parameters + ---------- + canonical_name: + Name of the canonical (new) kwarg, used in messages. + canonical_value: + Value passed under the canonical name (or ``None``). + deprecated_name: + Name of the deprecated (old) kwarg, used in messages. + deprecated_value: + Value passed under the deprecated name; defaults to :data:`_UNSET`. + stacklevel: + Forwarded to :func:`warnings.warn`. Default 3 attributes the + warning to the caller of the public function that invoked this + helper. + + Returns + ------- + value: + The resolved value, or ``None`` if neither was given. + """ + if deprecated_value is _UNSET: + return canonical_value + if canonical_value is not None: + raise ValueError( + f"Pass either `{canonical_name}` or the deprecated " + f"`{deprecated_name}`, not both." + ) + warnings.warn( + f"`{deprecated_name}` is deprecated; use `{canonical_name}` instead.", + DeprecationWarning, + stacklevel=stacklevel, + ) + return deprecated_value diff --git a/pypesto/visualize/model_fit.py b/pypesto/visualize/model_fit.py index d10076970..8fa249a4f 100644 --- a/pypesto/visualize/model_fit.py +++ b/pypesto/visualize/model_fit.py @@ -301,7 +301,7 @@ def _get_simulation_rdatas( """ # add timepoints as needed if simulation_timepoints is None: - end_time = max(problem.objective.edatas[0].getTimepoints()) + end_time = max(problem.objective.edatas[0].get_timepoints()) simulation_timepoints = np.linspace(start=0, stop=end_time, num=1000) # get optimization result @@ -338,10 +338,10 @@ def _get_simulation_rdatas( # disable sensitivities to improve computation time amici_solver = copy.deepcopy(problem.objective.amici_solver) - amici_solver.setSensitivityOrder(asd.SensitivityOrder.none) + amici_solver.set_sensitivity_order(asd.SensitivityOrder.none) for j in range(len(edatas)): - edatas[j].setTimepoints(simulation_timepoints) + edatas[j].set_timepoints(simulation_timepoints) fill_in_parameters( edatas=edatas, @@ -413,7 +413,7 @@ def _time_trajectory_model_with_states( ] if state_names is not None: state_indices_by_name = [ - model.getStateNames().index(state_name) + model.get_state_names().index(state_name) for state_name in state_names ] state_indices = list(set(state_indices_by_id + state_indices_by_name)) diff --git a/pypesto/visualize/observable_mapping.py b/pypesto/visualize/observable_mapping.py index 5fced3a4d..7da5a9354 100644 --- a/pypesto/visualize/observable_mapping.py +++ b/pypesto/visualize/observable_mapping.py @@ -21,6 +21,7 @@ ) from ..problem import HierarchicalProblem, Problem from ..result import Result +from .misc import get_axes_array, hide_unused_axes, make_grid_shape try: import amici.sim.sundials as asd @@ -41,13 +42,63 @@ pass +def _prepare_observable_mapping_axes( + axes: matplotlib.axes.Axes | np.ndarray | None, + n_panels: int, + **kwargs, +) -> np.ndarray: + """Return a cleared axes grid for observable-mapping plots.""" + n_rows, n_cols = make_grid_shape(n_panels) + + if axes is None: + kwargs.setdefault("layout", "constrained") + _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) + else: + axes = get_axes_array(axes=axes, nrows=n_rows, ncols=n_cols) + + return hide_unused_axes(axes=axes, n_used=n_panels, clear=True) + + +def _plot_observable_mapping_measurements( + ax: matplotlib.axes.Axes, + simulation: np.ndarray, + measurements: np.ndarray, +) -> None: + """Plot measurement points with the shared visualization styling.""" + ax.scatter( + simulation, + measurements, + color="C0", + s=40, + alpha=0.9, + linewidths=0.6, + edgecolors="white", + label="Measurements", + zorder=3, + ) + + +def _finalize_observable_mapping_axes( + ax: matplotlib.axes.Axes, + title: str, +) -> None: + """Apply consistent styling to observable-mapping panels.""" + ax.set_title(title) + ax.set_xlabel("Model output") + ax.set_ylabel("Measurements") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.margins(x=0.05, y=0.08) + ax.legend(frameon=False, loc="best") + + def visualize_estimated_observable_mapping( pypesto_result: Result, pypesto_problem: HierarchicalProblem, start_index: int = 0, - axes: plt.Axes | None = None, + axes: matplotlib.axes.Axes | np.ndarray | None = None, **kwargs, -): +) -> np.ndarray | None: """Visualize the estimated observable mapping for relative and semi-quantitative observables. Visualizes the estimated linear mapping for relative observables and the non-linear @@ -62,7 +113,7 @@ def visualize_estimated_observable_mapping( start_index: The observable mapping from this start's optimized vector will be plotted. axes: - The axes to plot the estimated observable mapping on. + Optional axes grid to draw into. kwargs: Additional arguments to passed to ``matplotlib.pyplot.subplots`` (e.g. `figsize= ...`). @@ -70,7 +121,8 @@ def visualize_estimated_observable_mapping( Returns ------- axes: - The matplotlib axes. + A 2-D NumPy array of matplotlib Axes, or ``None`` if the required + simulation fails. """ # Check if the pyPESTO problem is hierarchical. @@ -114,22 +166,12 @@ def visualize_estimated_observable_mapping( ] rel_and_semiquant_obs_indices.sort() - # If axes are given, check if they are of the correct length. - if ( - axes is not None - and len(axes) != n_relative_observables + n_semiquant_observables - ): - raise ValueError( - "The number of axes must be equal to the number of relative and semi-quantitative observables." - ) - - # If axes are not given, create them. - if axes is None: - n_axes = n_relative_observables + n_semiquant_observables - n_rows = int(np.ceil(np.sqrt(n_axes))) - n_cols = int(np.ceil(n_axes / n_rows)) - _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) - axes = axes.flatten() + n_axes = n_relative_observables + n_semiquant_observables + axes = _prepare_observable_mapping_axes( + axes=axes, + n_panels=n_axes, + **kwargs, + ) # Plot the estimated observable mapping for relative observables. if n_relative_observables > 0: @@ -143,19 +185,15 @@ def visualize_estimated_observable_mapping( # Plot the estimated spline approximations for semi-quantitative observables. if n_semiquant_observables > 0: - axes = plot_splines_from_pypesto_result( + spline_axes = plot_splines_from_pypesto_result( pypesto_result=pypesto_result, start_index=start_index, axes=axes, rel_and_semiquant_obs_indices=rel_and_semiquant_obs_indices, ) - - # Remove any axes that were not used. - for ax in axes[n_relative_observables + n_semiquant_observables :]: - ax.remove() - - # Increase the distance between the subplots. - plt.tight_layout() + if spline_axes is None: + return None + axes = spline_axes return axes @@ -164,10 +202,10 @@ def plot_linear_observable_mappings_from_pypesto_result( pypesto_result: Result, pypesto_problem: HierarchicalProblem, start_index=0, - axes: plt.Axes | None = None, + axes: np.ndarray | None = None, rel_and_semiquant_obs_indices: list[int] | None = None, **kwargs, -): +) -> np.ndarray: """Plot the linear observable mappings from a pyPESTO result. Parameters @@ -179,18 +217,19 @@ def plot_linear_observable_mappings_from_pypesto_result( start_index: The observable mapping from this start's optimized vector will be plotted. axes: - The axes to plot the linear observable mappings on. + Optional 2-D NumPy array of Axes to draw into. Required when + ``rel_and_semiquant_obs_indices`` is set. rel_and_semiquant_obs_indices: - The indices of the relative and semi-quantitative observables in the - amici model. Important if both relative and semi-quantitative observables - will be plotted on the same axes. + Sorted indices of the relative and semi-quantitative observables in + the AMICI model. Each observable is plotted on the subplot at the + corresponding position in ``axes.flat``. **kwargs: Additional arguments to pass to the ``matplotlib.pyplot.subplots`` function. Returns ------- axes: - The matplotlib axes. + A 2-D NumPy array of matplotlib Axes. """ # Check the calculator is the InnerCalculatorCollector. if not isinstance( @@ -218,7 +257,7 @@ def plot_linear_observable_mappings_from_pypesto_result( inner_problem: RelativeInnerProblem = relative_calculator.inner_problem # Get the relative observable ids and indices. - relative_observable_ids = pypesto_problem.relative_observable_ids + relative_observable_ids = pypesto_problem.relative_observable_ids or [] relative_observable_indices = [ amici_model.get_observable_ids().index(observable_id) for observable_id in relative_observable_ids @@ -226,28 +265,31 @@ def plot_linear_observable_mappings_from_pypesto_result( # Get the number of relative observables. n_relative_observables = len(relative_observable_ids) - - # Check if the axes are given. - if axes is not None and len(axes) <= max(relative_observable_indices): + if n_relative_observables == 0: raise ValueError( - "The number of axes must be larger than the largest observable index." + "The problem does not contain any relative observables." ) - # If axes are not given, create them. - if axes is None: - if n_relative_observables == 1: - # Make figure with only one plot - _, ax = plt.subplots(1, 1, **kwargs) - axes = [ax] - else: - # Choose number of rows and columns to be used for the subplots - n_rows = int(np.ceil(np.sqrt(n_relative_observables))) - n_cols = int(np.ceil(n_relative_observables / n_rows)) + if rel_and_semiquant_obs_indices is None: + axes = _prepare_observable_mapping_axes( + axes=axes, + n_panels=n_relative_observables, + **kwargs, + ) + else: + if axes is None: + raise ValueError( + "Pass `axes` when `rel_and_semiquant_obs_indices` is set." + ) + if not isinstance(axes, np.ndarray) or axes.ndim != 2: + raise ValueError("`axes` must be a 2-D NumPy array.") + if axes.size < len(rel_and_semiquant_obs_indices): + raise ValueError( + "The number of axes must be at least equal to the number " + "of relative and semi-quantitative observables." + ) - # Make as many subplots as there are relative observables - _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) - # Flatten the axes array - axes = axes.flatten() + flat_axes = axes.flat ################################################################# # Simulate the model with the parameters from the pypesto result. @@ -328,7 +370,7 @@ def plot_linear_observable_mappings_from_pypesto_result( else: ax_index = relative_observable_indices.index(observable_index) - ax = axes[ax_index] + ax = flat_axes[ax_index] # Get the inner parameters for the current observable. inner_parameters = inner_problem.get_xs_for_obs_idx(observable_index) @@ -366,32 +408,35 @@ def plot_linear_observable_mappings_from_pypesto_result( expdata=sim, mask=observable_data_mask ) - ax.plot(simulation, measurements, "bs", label="Measurements") + _plot_observable_mapping_measurements(ax, simulation, measurements) # Plot the linear mapping. + sorted_simulation = np.sort(simulation) ax.plot( - np.sort(simulation), - scaling_factor_value * np.sort(simulation) + offset_value, + sorted_simulation, + scaling_factor_value * sorted_simulation + offset_value, linestyle="-", - color="orange", + color="C1", + linewidth=1.8, label="Linear mapping", + zorder=2, ) - ax.legend() - ax.set_title(f"Observable {observable_id}") - ax.set_xlabel("Model output") - ax.set_ylabel("Measurements") - - if rel_and_semiquant_obs_indices is None: - for ax in axes[n_relative_observables:]: - ax.remove() + _finalize_observable_mapping_axes( + ax, + title=f"Observable {observable_id}", + ) return axes def plot_splines_from_pypesto_result( - pypesto_result: Result, start_index=0, **kwargs -): + pypesto_result: Result, + start_index=0, + axes: matplotlib.axes.Axes | np.ndarray | None = None, + rel_and_semiquant_obs_indices: list[int] | None = None, + **kwargs, +) -> np.ndarray | None: """Plot the estimated spline approximations from a pypesto result. Parameters @@ -400,13 +445,20 @@ def plot_splines_from_pypesto_result( The pypesto result. start_index: The observable mapping from this start's optimized vector will be plotted. + axes: + Optional axes grid to draw into. + rel_and_semiquant_obs_indices: + The indices of the relative and semi-quantitative observables in the + amici model. Important if both relative and semi-quantitative observables + will be plotted on the same axes. kwargs: Additional arguments to pass to the plotting function. Returns ------- axes: - The matplotlib axes. + A 2-D NumPy array of matplotlib Axes, or ``None`` if the required + simulation fails. """ # Check that the problem contains an objective. if pypesto_result.problem.objective is None: @@ -521,6 +573,8 @@ def plot_splines_from_pypesto_result( inner_results, sim, observable_ids, + axes=axes, + rel_and_semiquant_obs_indices=rel_and_semiquant_obs_indices, **kwargs, ) @@ -531,10 +585,10 @@ def plot_splines_from_inner_result( results: list[dict], sim: list[np.ndarray], observable_ids=None, - axes: plt.Axes | None = None, + axes: np.ndarray | None = None, rel_and_semiquant_obs_indices: list[int] | None = None, **kwargs, -): +) -> np.ndarray: """Plot the estimated spline approximations from inner results. Parameters @@ -550,18 +604,19 @@ def plot_splines_from_inner_result( observable_ids: The ids of the observables. axes: - The axes to plot the estimated spline approximations on. + Optional 2-D NumPy array of Axes to draw into. Required when + ``rel_and_semiquant_obs_indices`` is set. rel_and_semiquant_obs_indices: - The indices of the relative and semi-quantitative observables in the - amici model. Important if both relative and semi-quantitative observables - will be plotted on the same axes. + Sorted indices of the relative and semi-quantitative observables in + the AMICI model. Each observable is plotted on the subplot at the + corresponding position in ``axes.flat``. kwargs: Additional arguments to pass to the plotting function. Returns ------- axes: - The matplotlib axes. + A 2-D NumPy array of matplotlib Axes. """ if len(results) != len(inner_problem.groups): @@ -573,27 +628,26 @@ def plot_splines_from_inner_result( n_groups = len(inner_problem.groups) semiquant_groups = list(inner_problem.groups.keys()) - # Check if the axes are given - if axes is not None and len(axes) < max(semiquant_groups): - raise ValueError( - "The number of axes must be equal to or larger than the largest group index." + if rel_and_semiquant_obs_indices is None: + axes = _prepare_observable_mapping_axes( + axes=axes, + n_panels=n_groups, + **kwargs, ) + else: + if axes is None: + raise ValueError( + "Pass `axes` when `rel_and_semiquant_obs_indices` is set." + ) + if not isinstance(axes, np.ndarray) or axes.ndim != 2: + raise ValueError("`axes` must be a 2-D NumPy array.") + if axes.size < len(rel_and_semiquant_obs_indices): + raise ValueError( + "The number of axes must be at least equal to the number " + "of relative and semi-quantitative observables." + ) - if axes is None: - if n_groups == 1: - # Make figure with only one plot - _, ax = plt.subplots(1, 1, **kwargs) - - axes = [ax] - else: - # Choose number of rows and columns to be used for the subplots - n_rows = int(np.ceil(np.sqrt(n_groups))) - n_cols = int(np.ceil(n_groups / n_rows)) - - # Make as many subplots as there are groups - _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) - # Flatten the axes array - axes = axes.flatten() + flat_axes = axes.flat # for each result and group, plot the inner solution for result, group in zip(results, inner_problem.groups, strict=True): @@ -625,18 +679,30 @@ def plot_splines_from_inner_result( K=len(simulation), ) - axes[ax_index].plot( - simulation, measurements, "bs", label="Measurements" + _plot_observable_mapping_measurements( + flat_axes[ax_index], + simulation, + measurements, ) - axes[ax_index].plot( - spline_bases, spline_knots, "g.", label="Spline knots" + flat_axes[ax_index].scatter( + spline_bases, + spline_knots, + color="C2", + s=30, + alpha=0.9, + linewidths=0.4, + edgecolors="white", + label="Spline knots", + zorder=4, ) - axes[ax_index].plot( + flat_axes[ax_index].plot( spline_bases, spline_knots, linestyle="-", - color="g", + color="C2", + linewidth=1.8, label="Spline function", + zorder=2, ) if inner_solver.options[REGULARIZE_SPLINE]: alpha_opt, beta_opt = _calculate_optimal_regularization( @@ -644,26 +710,25 @@ def plot_splines_from_inner_result( N=len(spline_knots), c=spline_bases, ) - axes[ax_index].plot( + flat_axes[ax_index].plot( spline_bases, alpha_opt * spline_bases + beta_opt, linestyle="--", - color="orange", + color="C1", + linewidth=1.5, label="Regularization line", + zorder=1, ) - axes[ax_index].legend() if observable_ids is not None: - axes[ax_index].set_title(f"Observable {observable_ids[group - 1]}") + title = f"Observable {observable_ids[group - 1]}" else: - axes[ax_index].set_title(f"Group {group}") - - axes[ax_index].set_xlabel("Model output") - axes[ax_index].set_ylabel("Measurements") + title = f"Group {group}" - if rel_and_semiquant_obs_indices is None: - for ax in axes[len(results) :]: - ax.remove() + _finalize_observable_mapping_axes( + flat_axes[ax_index], + title=title, + ) return axes @@ -721,7 +786,7 @@ def _add_spline_mapped_simulations_to_model_fit( result: Result | Sequence[Result], pypesto_problem: Problem, start_index: int = 0, - axes: plt.Axes | None = None, + axes: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes | None: """Visualize the spline optimized model fit. diff --git a/pypesto/visualize/optimization_stats.py b/pypesto/visualize/optimization_stats.py index 7171dfced..5669627f8 100644 --- a/pypesto/visualize/optimization_stats.py +++ b/pypesto/visualize/optimization_stats.py @@ -1,7 +1,6 @@ from collections.abc import Iterable, Sequence import matplotlib.axes -import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import is_color_like @@ -10,7 +9,14 @@ from ..C import COLOR from ..result import Result from .clust_color import assign_colors, assign_colors_for_list -from .misc import process_result_list, process_start_indices +from .misc import ( + get_ax, + get_axes_array, + hide_unused_axes, + make_grid_shape, + process_result_list, + process_start_indices, +) def optimization_run_properties_one_plot( @@ -21,6 +27,7 @@ def optimization_run_properties_one_plot( colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, plot_type: str = "line", + ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """ Plot stats for allproperties specified in properties_to_plot on one plot. @@ -29,6 +36,8 @@ def optimization_run_properties_one_plot( ---------- results: Optimization result obtained by 'optimize.py' or list of those + ax: + Axes object to use. properties_to_plot: Optimization run properties that should be plotted size: @@ -99,9 +108,7 @@ def optimization_run_properties_one_plot( "optimization properties to plot" ) - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) for idx, prop_name in enumerate(properties_to_plot): optimization_run_property_per_multistart( @@ -128,7 +135,8 @@ def optimization_run_properties_per_multistart( colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, plot_type: str = "line", -) -> dict[str, plt.Subplot]: + axes: np.ndarray | None = None, +) -> np.ndarray: """ One plot per optimization property in properties_to_plot. @@ -155,8 +163,8 @@ def optimization_run_properties_per_multistart( Returns ------- - ax: - The plot axes. + axes: + 2-D NumPy array containing one matplotlib Axes per panel. Examples -------- @@ -195,27 +203,26 @@ def optimization_run_properties_per_multistart( "n_sres", ] + if plot_type not in {"line", "hist"}: + raise ValueError( + "`optimization_run_properties_per_multistart` supports only " + "`plot_type='line'` or `plot_type='hist'`." + ) + num_subplot = len(properties_to_plot) - # compute, how many rows and columns we need for the subplots - num_row = int(np.round(np.sqrt(num_subplot))) - num_col = int(np.ceil(num_subplot / num_row)) - fig, axes = plt.subplots(num_row, num_col, squeeze=False) - fig.set_size_inches(*size) - - for ax in axes.flat[num_subplot:]: - ax.remove() - axes = dict(zip(range(num_subplot), axes.flat, strict=True)) + num_row, num_col = make_grid_shape(num_subplot) + axes = get_axes_array(axes=axes, nrows=num_row, ncols=num_col, size=size) + axes = hide_unused_axes(axes=axes, n_used=num_subplot, clear=True) for idx, prop_name in enumerate(properties_to_plot): - ax = axes[idx] optimization_run_property_per_multistart( results, prop_name, - ax, - size, - start_indices, - colors, - legends, - plot_type, + axes=axes.flat[idx], + size=size, + start_indices=start_indices, + colors=colors, + legends=legends, + plot_type=plot_type, ) return axes @@ -223,13 +230,13 @@ def optimization_run_properties_per_multistart( def optimization_run_property_per_multistart( results: Result | Sequence[Result], opt_run_property: str, - axes: matplotlib.axes.Axes | None = None, + axes: matplotlib.axes.Axes | np.ndarray | None = None, size: tuple[float, float] = (18.5, 10.5), start_indices: int | Iterable[int] | None = None, colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, plot_type: str = "line", -) -> matplotlib.axes.Axes: +) -> np.ndarray: """ Plot stats for an optimization run property specified by opt_run_property. @@ -265,7 +272,7 @@ def optimization_run_property_per_multistart( Returns ------- axes: - The plot axes. + 2-D NumPy array containing one matplotlib Axes per panel. """ supported_properties = { "time": "Wall-clock time (seconds)", @@ -286,48 +293,51 @@ def optimization_run_property_per_multistart( # parse input (results, colors, legends) = process_result_list(results, colors, legends) - # axes - if axes is None: - ncols = 2 if plot_type == "both" else 1 - fig, axes = plt.subplots(1, ncols) - fig.set_size_inches(*size) + ncols = 2 if plot_type == "both" else 1 + axes = get_axes_array(axes=axes, nrows=1, ncols=ncols, size=size) + fig = axes.flat[0].figure + for ax in axes.flat: + ax.clear() + ax.set_visible(True) + + if plot_type == "both": fig.suptitle( f"{supported_properties[opt_run_property]} per optimizer run" ) else: - axes.set_title( + axes[0, 0].set_title( f"{supported_properties[opt_run_property]} per optimizer run" ) # loop over results for j, result in enumerate(results): if plot_type == "both": - axes[0] = stats_lowlevel( + stats_lowlevel( result, opt_run_property, supported_properties[opt_run_property], - axes[0], + axes[0, 0], start_indices, colors[j], legends[j], ) - axes[1] = stats_lowlevel( + stats_lowlevel( result, opt_run_property, supported_properties[opt_run_property], - axes[1], + axes[0, 1], start_indices, colors[j], legends[j], plot_type="hist", ) else: - axes = stats_lowlevel( + stats_lowlevel( result, opt_run_property, supported_properties[opt_run_property], - axes, + axes[0, 0], start_indices, colors[j], legends[j], @@ -336,10 +346,10 @@ def optimization_run_property_per_multistart( if sum(legend is not None for legend in legends) > 0: if plot_type == "both": - for ax in axes: + for ax in axes.flat: ax.legend() else: - axes.legend() + axes[0, 0].legend() return axes diff --git a/pypesto/visualize/optimizer_convergence.py b/pypesto/visualize/optimizer_convergence.py index 447d48eac..60954f27e 100644 --- a/pypesto/visualize/optimizer_convergence.py +++ b/pypesto/visualize/optimizer_convergence.py @@ -1,17 +1,18 @@ -import matplotlib.pyplot as plt +import matplotlib.axes import numpy as np import pandas as pd from ..result import Result +from .misc import get_ax def optimizer_convergence( result: Result, - ax: plt.Axes | None = None, + ax: matplotlib.axes.Axes | None = None, xscale: str = "symlog", yscale: str = "log", - size: tuple[float] = (18.5, 10.5), -) -> plt.Axes: + size: tuple[float, float] = (18.5, 10.5), +) -> matplotlib.axes.Axes: """ Visualize to help spotting convergence issues. @@ -45,8 +46,7 @@ def optimizer_convergence( """ import seaborn as sns - if ax is None: - ax = plt.subplots(figsize=size)[1] + ax = get_ax(ax, size) fvals = result.optimize_result.fval grad_norms = [ diff --git a/pypesto/visualize/optimizer_history.py b/pypesto/visualize/optimizer_history.py index b6f076e75..6fc624010 100644 --- a/pypesto/visualize/optimizer_history.py +++ b/pypesto/visualize/optimizer_history.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Iterable -import matplotlib.pyplot as plt +import matplotlib.axes import numpy as np from matplotlib.ticker import MaxNLocator @@ -16,7 +16,12 @@ from ..history import HistoryBase from ..result import Result from .clust_color import assign_colors -from .misc import process_offset_y, process_result_list, process_y_limits +from .misc import ( + get_ax, + process_offset_y, + process_result_list, + process_y_limits, +) from .reference_points import ReferencePoint, create_references logger = logging.getLogger(__name__) @@ -24,8 +29,8 @@ def optimizer_history( results: Result | list[Result], - ax: plt.Axes | None = None, - size: tuple = (18.5, 10.5), + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] = (18.5, 10.5), trace_x: str = TRACE_X_STEPS, trace_y: str = TRACE_Y_FVAL, scale_y: str = "log10", @@ -39,7 +44,7 @@ def optimizer_history( | list[dict] | None = None, legends: str | list[str] | None = None, -) -> plt.Axes: +) -> matplotlib.axes.Axes: """ Plot history of optimizer. @@ -129,12 +134,12 @@ def optimizer_history_lowlevel( vals: list[np.ndarray], scale_y: str = "log10", colors: COLOR | list[COLOR] | np.ndarray | None = None, - ax: plt.Axes | None = None, - size: tuple = (18.5, 10.5), + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] = (18.5, 10.5), x_label: str = "Optimizer steps", y_label: str = "Objective value", legend_text: str | None = None, -) -> plt.Axes: +) -> matplotlib.axes.Axes: """ Plot optimizer history using list of numpy arrays. @@ -163,11 +168,7 @@ def optimizer_history_lowlevel( ax: The plot axes. """ - # axes - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) # parse input fvals = [] @@ -411,13 +412,13 @@ def get_labels(trace_x: str, trace_y: str, offset_y: float) -> tuple[str, str]: def handle_options( - ax: plt.Axes, + ax: matplotlib.axes.Axes, vals: list[np.ndarray], trace_y: str, ref: list[ReferencePoint], y_limits: float | np.ndarray | None, offset_y: float, -) -> plt.Axes: +) -> matplotlib.axes.Axes: """ Apply post-plotting transformations to the axis object. @@ -540,8 +541,8 @@ def monotonic_history( def sacess_history( histories: list[HistoryBase], - ax: plt.Axes | None = None, -) -> plt.Axes: + ax: matplotlib.axes.Axes | None = None, +) -> matplotlib.axes.Axes: """Plot `SacessOptimizer` history. Plot the history of the best objective values for each @@ -560,7 +561,7 @@ def sacess_history( ------- The plot axes. `ax` or a new axes if `ax` was `None`. """ - ax = ax or plt.subplot() + ax = get_ax(ax) if len(histories) == 0: warnings.warn("No histories to plot.", stacklevel=2) diff --git a/pypesto/visualize/parameters.py b/pypesto/visualize/parameters.py index 3865ff177..597f615f0 100644 --- a/pypesto/visualize/parameters.py +++ b/pypesto/visualize/parameters.py @@ -1,9 +1,7 @@ import logging from collections.abc import Callable, Iterable, Sequence -from typing import Optional import matplotlib.axes -import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib.colors import Colormap @@ -21,6 +19,9 @@ from ..result import Result from .clust_color import assign_colors from .misc import ( + get_ax, + get_axes_array, + plot_diagonal_marginal, process_parameter_indices, process_result_list, process_start_indices, @@ -197,11 +198,11 @@ def parameter_hist( result: Result, parameter_name: str, bins: int | str = "auto", - ax: Optional["matplotlib.Axes"] = None, - size: tuple[float] | None = (18.5, 10.5), + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = (18.5, 10.5), color: COLOR | None = None, start_indices: int | list[int] | None = None, -): +) -> matplotlib.axes.Axes: """ Plot parameter values as a histogram. @@ -229,10 +230,7 @@ def parameter_hist( ax: The plot axes. """ - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) xs = result.optimize_result.x @@ -305,10 +303,7 @@ def parameters_lowlevel( # 0.5 inch height per parameter size = (18.5, max(xs.shape[1], 1) / 2) - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) # assign colors colors = assign_colors( @@ -663,38 +658,45 @@ def optimization_scatter( parameter_indices: str | Sequence[int] = "free_only", start_indices: int | Iterable[int] | None = None, diag_kind: str = "kde", - suptitle: str = None, - size: tuple[float, float] = None, + suptitle: str | None = None, + size: tuple[float, float] | None = None, show_bounds: bool = False, -): + axes: np.ndarray | None = None, +) -> np.ndarray: """ - Plot a scatter plot of all pairs of parameters for the given starts. + Plot a scatter matrix of all parameter pairs for the given starts. Parameters ---------- result: - Optimization result obtained by 'optimize.py'. + Optimization result obtained by ‘optimize.py’. parameter_indices: List of integers specifying the parameters to be considered. start_indices: List of integers specifying the multistarts to be plotted or int specifying up to which start index should be plotted. diag_kind: - Visualization mode for marginal densities {‘auto’, ‘hist’, ‘kde’, - None}. + Marginal distribution shown on the diagonal: ``’kde’`` (default) + or ``’hist’``. suptitle: - Title of the plot. + Title of the figure. size: - Size of the plot. + Figure size (width, height) in inches. Defaults to + ``(2.5 * n + 0.5, 2.5 * n + 0.5)``. show_bounds: - Whether to show the parameter bounds. + Whether to draw dashed lines at the parameter bounds. + axes: + Optional axes grid to draw into. Must have shape + ``(n_params, n_params)``. Returns ------- - ax: - The plot axis. + axes: + 2-D NumPy array of shape ``(n_params, n_params)`` containing one + matplotlib Axes per panel. """ - import seaborn as sns + import matplotlib.cm as mpl_cm + from matplotlib.colors import Normalize start_indices = process_start_indices( start_indices=start_indices, result=result @@ -702,33 +704,129 @@ def optimization_scatter( parameter_indices = process_parameter_indices( parameter_indices=parameter_indices, result=result ) - # put all parameters into a dataframe, where columns are parameters - parameters = [ - result.optimize_result[i_start]["x"][parameter_indices] - for i_start in start_indices - ] - x_labels = [ - result.problem.x_names[parameter_index] - for parameter_index in parameter_indices - ] - df = pd.DataFrame(parameters, columns=x_labels) - sns.set(style="ticks") + n = len(parameter_indices) + x_labels = [result.problem.x_names[i] for i in parameter_indices] - ax = sns.pairplot( - df, - diag_kind=diag_kind, + # data matrix: rows = starts, cols = selected parameters + data = np.array( + [ + result.optimize_result[i]["x"][parameter_indices] + for i in start_indices + ] + ) + fvals = np.array([result.optimize_result[i].fval for i in start_indices]) + + # continuous colormap: viridis, low fval (best) → yellow, high fval (worst) → dark + cmap = mpl_cm.viridis_r + min_fval_range = 1.0 + fval_min = fvals.min() + fval_max = fvals.max() + fval_mid = 0.5 * (fval_min + fval_max) + fval_half_range = max(fval_max - fval_min, min_fval_range) / 2 + fval_norm = Normalize( + vmin=fval_mid - fval_half_range, + vmax=fval_mid + fval_half_range, ) - if size is not None: - ax.fig.set_size_inches(size) + if size is None and axes is None: + size = (2.5 * n + 0.5, 2.5 * n + 0.5) + + axes = get_axes_array(axes=axes, nrows=n, ncols=n, size=size) + fig = axes.flat[0].figure + fig.set_layout_engine("constrained") + + previous_colorbar_axes = [] + for ax in axes.flat: + colorbar_ax = getattr( + ax, + "_pypesto_optimization_scatter_colorbar_ax", + None, + ) + if ( + colorbar_ax is not None + and colorbar_ax not in previous_colorbar_axes + ): + previous_colorbar_axes.append(colorbar_ax) + for colorbar_ax in previous_colorbar_axes: + if colorbar_ax in fig.axes: + colorbar_ax.remove() + + for ax in axes.flat: + ax.clear() + ax.set_visible(True) + if hasattr(ax, "_pypesto_optimization_scatter_colorbar_ax"): + delattr(ax, "_pypesto_optimization_scatter_colorbar_ax") + + for row in range(n): + for col in range(n): + ax = axes[row, col] + col_vals = data[:, col] + row_vals = data[:, row] + + if row == col: + plot_diagonal_marginal( + ax=ax, values=col_vals, diag_kind=diag_kind + ) + else: + ax.scatter( + col_vals, + row_vals, + c=fvals, + cmap=cmap, + norm=fval_norm, + s=35, + alpha=0.85, + linewidths=0.6, + edgecolors="white", + zorder=3, + ) + ax.set_ylabel(x_labels[row]) + + ax.set_xlabel(x_labels[col]) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + if show_bounds: + pi_col = parameter_indices[col] + pi_row = parameter_indices[row] + for val in ( + result.problem.lb_full[pi_col], + result.problem.ub_full[pi_col], + ): + ax.axvline(val, color="k", ls="--", lw=0.8) + if row != col: + for val in ( + result.problem.lb_full[pi_row], + result.problem.ub_full[pi_row], + ): + ax.axhline(val, color="k", ls="--", lw=0.8) + + # shared x-limits per column, shared y-limits per row (non-diagonal) + for col in range(n): + vals = data[:, col] + data_range = vals.max() - vals.min() + pad = data_range * 0.1 if data_range > 0 else 0.5 + xlim = (vals.min() - pad, vals.max() + pad) + for row in range(n): + axes[row, col].set_xlim(xlim) + for row in range(n): + vals = data[:, row] + data_range = vals.max() - vals.min() + pad = data_range * 0.1 if data_range > 0 else 0.5 + ylim = (vals.min() - pad, vals.max() + pad) + for col in range(n): + if col != row: + axes[row, col].set_ylim(ylim) + + sm = mpl_cm.ScalarMappable(cmap=cmap, norm=fval_norm) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.8, pad=0.03) + cbar.set_label("Objective value") + for ax in axes.flat: + ax._pypesto_optimization_scatter_colorbar_ax = cbar.ax + if suptitle: - ax.fig.suptitle(suptitle) - if show_bounds: - # set bounds of plot to parameter bounds. Only use diagonal as - # sns.PairGrid has sharex,sharey = True by default. - for i_axis, axis in enumerate(np.diag(ax.axes)): - axis.set_xlim(result.problem.lb[i_axis], result.problem.ub[i_axis]) - axis.set_ylim(result.problem.lb[i_axis], result.problem.ub[i_axis]) + fig.suptitle(suptitle) - return ax + return axes diff --git a/pypesto/visualize/profile_cis.py b/pypesto/visualize/profile_cis.py index 8b7b0eed1..e33202375 100644 --- a/pypesto/visualize/profile_cis.py +++ b/pypesto/visualize/profile_cis.py @@ -3,13 +3,13 @@ import matplotlib.axes import matplotlib.cm as cm -import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import PatchCollection from matplotlib.patches import Patch, Rectangle from ..profile import calculate_approximate_ci, chi2_quantile_to_ratio from ..result import Result +from .misc import get_ax # kwargs passed to `matplotlib.axes.Axes.errorbar` for plotting confidence levels cis_visualization_settings = { @@ -26,7 +26,7 @@ def profile_cis( profile_list: int = 0, color: str | tuple = "C0", show_bounds: bool = False, - ax: matplotlib.axes.Axes = None, + ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """ Plot approximate confidence intervals based on profiles. @@ -62,8 +62,7 @@ def profile_cis( if profile_indices is None: profile_indices = [ix for ix, res in enumerate(profile_list) if res] - if ax is None: - _, ax = plt.subplots() + ax = get_ax(ax) confidence_ratio = chi2_quantile_to_ratio(confidence_level, df=df) @@ -113,9 +112,9 @@ def profile_nested_cis( profile_indices: Sequence[int] = None, profile_list: int = 0, colors: Sequence = None, - ax: matplotlib.axes.Axes = None, + ax: matplotlib.axes.Axes | None = None, orientation: Literal["v", "h"] = "v", -): +) -> matplotlib.axes.Axes: """ Plot approximate nested confidence intervals based on profiles. @@ -162,8 +161,7 @@ def profile_nested_cis( if profile_indices is None: profile_indices = [ix for ix, res in enumerate(profile_list) if res] - if ax is None: - _, ax = plt.subplots() + ax = get_ax(ax) legends = [] for i, confidence_level in enumerate(confidence_levels): diff --git a/pypesto/visualize/profiles.py b/pypesto/visualize/profiles.py index a556af3e8..d61f5bd7e 100644 --- a/pypesto/visualize/profiles.py +++ b/pypesto/visualize/profiles.py @@ -1,21 +1,116 @@ from collections.abc import Sequence from warnings import warn +import matplotlib.axes import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import is_color_like +from matplotlib.lines import Line2D from matplotlib.ticker import MaxNLocator from ..C import COLOR +from ..problem import Problem +from ..profile import chi2_quantile_to_ratio from ..result import Result from .clust_color import assign_colors -from .misc import process_result_list +from .misc import get_ax, process_result_list from .reference_points import ReferencePoint, create_references +def _parameter_label(problem: Problem, idx: int) -> str: + """Return a scale-aware axis label for parameter ``idx``.""" + name = problem.x_names[idx] + scale = problem.x_scales[idx] if problem.x_scales is not None else "lin" + if scale == "log10": + return f"log10({name})" + if scale == "log": + return f"log({name})" + return name + + +# Fraction of the bound range added to axis limits so bound lines (drawn +# at the true lb/ub) are visible without overlapping the axis spine. +_BOUND_VIEW_MARGIN = 0.03 + + +def _add_bound_lines_1d( + ax: matplotlib.axes.Axes, lb: float, ub: float +) -> None: + """Draw dashed vertical lines at the lower and upper parameter bounds.""" + for bound in (lb, ub): + ax.axvline( + bound, + color="0.5", + linestyle="--", + linewidth=1.4, + alpha=0.95, + zorder=1, + ) + + +def _add_bound_lines_2d( + ax: matplotlib.axes.Axes, + x_lb: float, + x_ub: float, + y_lb: float, + y_ub: float, +) -> None: + """Draw dashed lines at the lower and upper bounds on both axes.""" + for bound in (x_lb, x_ub): + ax.axvline( + bound, + color="0.5", + linestyle="--", + linewidth=1.4, + alpha=0.95, + zorder=1, + ) + for bound in (y_lb, y_ub): + ax.axhline( + bound, + color="0.5", + linestyle="--", + linewidth=1.4, + alpha=0.95, + zorder=1, + ) + + +def _add_panel_legend( + ax: matplotlib.axes.Axes, + handles: list[Line2D], + fontsize: int, + loc: str = "upper left", +) -> None: + """Add a compact styled legend to a subplot.""" + existing_handles, existing_labels = ax.get_legend_handles_labels() + label_to_handle = { + label: handle + for handle, label in zip( + existing_handles, existing_labels, strict=True + ) + if label + } + for handle in handles: + if handle.get_label() not in label_to_handle: + label_to_handle[handle.get_label()] = handle + ax.legend( + label_to_handle.values(), + label_to_handle.keys(), + loc=loc, + frameon=True, + framealpha=0.95, + facecolor="white", + edgecolor="0.85", + fontsize=fontsize, + handlelength=1.8, + borderpad=0.4, + ) + + def profiles( results: Result | Sequence[Result], - ax=None, + ax: matplotlib.axes.Axes | None = None, profile_indices: Sequence[int] = None, size: tuple[float, float] = (18.5, 6.5), reference: ReferencePoint | Sequence[ReferencePoint] = None, @@ -24,10 +119,11 @@ def profiles( x_labels: Sequence[str] = None, profile_list_ids: int | Sequence[int] = 0, ratio_min: float = 0.0, + confidence_level: float | None = None, show_bounds: bool = False, plot_objective_values: bool = False, quality_colors: bool = False, -) -> plt.Axes: +) -> matplotlib.axes.Axes: """ Plot classical 1D profile plot. @@ -58,7 +154,12 @@ def profiles( profile_list_ids: Index or list of indices of the profile lists to visualize. ratio_min: - Minimum ratio below which to cut off. + Minimum likelihood-ratio value below which to cut off profile points. + Mutually exclusive with ``confidence_level``. + confidence_level: + Confidence level in (0, 1) (e.g. ``0.95``). Converted to + ``ratio_min`` via :func:`pypesto.profile.chi2_quantile_to_ratio`. + Convenience alternative to specifying ``ratio_min`` directly. show_bounds: Whether to show, and extend the plot to, the lower and upper bounds. plot_objective_values: @@ -85,6 +186,13 @@ def profiles( " and `colors` provided at the same time. Please provide only one of them." ) + if confidence_level is not None: + if ratio_min != 0.0: + raise ValueError( + "Pass either `confidence_level` or `ratio_min`, not both." + ) + ratio_min = chi2_quantile_to_ratio(confidence_level) + # parse input results, profile_list_ids, colors, legends = process_result_list_profiles( results, profile_list_ids, legends, colors @@ -156,14 +264,12 @@ def profiles( # plot reference points ax = handle_reference_points(ref, ax, profile_indices) - plt.tight_layout() - return ax def profiles_lowlevel( fvals: float | Sequence[float], - ax: Sequence[plt.Axes] | None = None, + ax: Sequence[matplotlib.axes.Axes] | None = None, size: tuple[float, float] = (18.5, 6.5), color: COLOR | list[np.ndarray] | None = None, legend_text: str = None, @@ -172,7 +278,7 @@ def profiles_lowlevel( lb_full: Sequence[float] = None, ub_full: Sequence[float] = None, plot_objective_values: bool = False, -) -> list[plt.Axes]: +) -> list[matplotlib.axes.Axes]: """ Lowlevel routine for profile plotting. @@ -314,14 +420,14 @@ def profiles_lowlevel( def profile_lowlevel( fvals: Sequence[float], - ax: plt.Axes | None = None, + ax: matplotlib.axes.Axes | None = None, size: tuple[float, float] = (18.5, 6.5), color: COLOR | np.ndarray | None = None, - legend_text: str = None, + legend_text: str | None = None, show_bounds: bool = False, - lb: float = None, - ub: float = None, -) -> plt.Axes: + lb: float | None = None, + ub: float | None = None, +) -> matplotlib.axes.Axes: """ Lowlevel routine for plotting one profile, working with a numpy array only. @@ -358,13 +464,9 @@ def profile_lowlevel( else: single_color = False - # axes - if ax is None: - ax = plt.subplots()[1] - ax.set_xlabel("Parameter value") - ax.set_ylabel("Log-posterior ratio") - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) + ax.set_xlabel("Parameter value") + ax.set_ylabel("Log-posterior ratio") # plot if fvals.size != 0: @@ -421,7 +523,7 @@ def handle_reference_points(ref, ax, profile_indices): ref: list, optional List of reference points for optimization results, containing et least a function value fval - ax: matplotlib.Axes, optional + ax: matplotlib.axes.Axes, optional Axes object to use. profile_indices: list of integer values List of integer values specifying which profiles should be plotted. @@ -609,3 +711,412 @@ def process_profile_indices( ) return profile_indices_ret + + +def profile_lowlevel_2d( + result: Result, + profile_index: int, + second_par_index: int, + ax: matplotlib.axes.Axes, + profile_list_id: int = 0, + ratio_min: float = 0.0, + cmap: str = "viridis", + plot_objective_values: bool = False, + x_labels: Sequence[str] = None, + vmin: float = None, + vmax: float = None, +) -> matplotlib.axes.Axes: + """ + Lowlevel routine for plotting a two-parameter profile visualization. + + Visualizes the profile of one parameter (x-axis) while showing the values + of a second parameter (y-axis), with colors indicating the objective ratio + or function value. Axis limits are always set to the parameter bounds, + with dashed lines marking the lower and upper bounds. + Axis labels include the parameter scale (e.g. ``log10(k1)``) unless + overridden via ``x_labels``. + + Parameters + ---------- + result: + A single `pypesto.Result` after profiling. + profile_index: + Integer index specifying which profile to plot (x-axis parameter). + second_par_index: + Integer index specifying which parameter to show on y-axis. + ax: + Axes object to use for plotting. + profile_list_id: + Index of the profile list to visualize. + ratio_min: + Minimum ratio below which to cut off. + cmap: + Colormap to use for the objective ratio/value colors. + plot_objective_values: + Whether to plot the objective function values instead of the likelihood + ratio values. + x_labels: + Labels for the parameters (indexed by full parameter index). + If None, labels are auto-generated from parameter names and scales. + vmin: + Minimum value for the color scale. If None, auto-scaled to the data. + vmax: + Maximum value for the color scale. If None, auto-scaled to the data. + + Returns + ------- + The plot axes. + """ + if result.profile_result is None: + raise ValueError("Result does not contain profile results.") + + profile_list = result.profile_result.list[profile_list_id] + + if profile_list[profile_index] is None: + raise ValueError( + f"Profile for parameter {profile_index} has not been computed." + ) + + profiler_result = profile_list[profile_index] + + x_path = profiler_result.x_path + ratio_path = profiler_result.ratio_path + fval_path = profiler_result.fval_path + + x_values = x_path[profile_index, :] + y_values = x_path[second_par_index, :] + color_values = fval_path if plot_objective_values else ratio_path + + # Filter based on ratio_min + indices = np.where(ratio_path >= ratio_min) + x_values = x_values[indices] + y_values = y_values[indices] + color_values = color_values[indices] + + # Draw the connector line in profile traversal order (pre-sort) so it + # represents the actual profile path rather than a color-sorted spaghetti. + ax.plot(x_values, y_values, "k-", alpha=0.2, linewidth=0.8, zorder=0) + + # Draw best points on top: ascending for ratio (high on top), + # descending for objective value (low on top). + sort_idx = ( + np.argsort(-color_values) + if plot_objective_values + else np.argsort(color_values) + ) + ax.scatter( + x_values[sort_idx], + y_values[sort_idx], + c=color_values[sort_idx], + cmap=cmap, + s=30, + vmin=vmin, + vmax=vmax, + ) + + def _label(idx): + if x_labels is not None: + return x_labels[idx] + return _parameter_label(result.problem, idx) + + ax.set_xlabel(_label(profile_index)) + ax.set_ylabel(_label(second_par_index)) + + x_lb = result.problem.lb_full[profile_index] + x_ub = result.problem.ub_full[profile_index] + y_lb = result.problem.lb_full[second_par_index] + y_ub = result.problem.ub_full[second_par_index] + x_margin = _BOUND_VIEW_MARGIN * (x_ub - x_lb) + y_margin = _BOUND_VIEW_MARGIN * (y_ub - y_lb) + ax.set_xlim([x_lb - x_margin, x_ub + x_margin]) + ax.set_ylim([y_lb - y_margin, y_ub + y_margin]) + _add_bound_lines_2d(ax, x_lb, x_ub, y_lb, y_ub) + + return ax + + +def visualize_2d_profile( + result: Result, + profile_indices: Sequence[int] = None, + size: tuple[float, float] = None, + profile_list_id: int = 0, + ratio_min: float = 0.0, + cmap: str = "viridis", + plot_objective_values: bool = False, + x_labels: Sequence[str] = None, + profile_color: COLOR | np.ndarray | None = None, + reference: ReferencePoint | Sequence[ReferencePoint] = None, + label_fontsize: int = 14, +) -> tuple[plt.Figure, np.ndarray]: + """ + Create an n×n grid of profile plots. + + Diagonal plots show 1D profiles (likelihood ratio vs. parameter value). + Off-diagonal plots show the path of one parameter while another is + profiled, with color indicating the likelihood ratio or objective value. + Legend panels summarizing profile points and bound lines are drawn on + the top-left diagonal and the first off-diagonal subplot. + + Parameters + ---------- + result: + A single `pypesto.Result` after profiling. + profile_indices: + List of integer indices specifying which parameters to include. + If None, all parameters with computed profiles are included. + size: + Figure size (width, height) in inches. If None, automatically sized + based on number of parameters (3.5 inches per parameter). + profile_list_id: + Index of the profile list to visualize. + ratio_min: + Minimum ratio below which to cut off. + cmap: + Colormap to use for the 2D off-diagonal scatter plots. + plot_objective_values: + Whether to plot the objective function values instead of the likelihood + ratio values. + x_labels: + Labels for the parameters (indexed by full parameter index). + If None, labels are auto-generated from parameter names and scales. + profile_color: + Color for the diagonal 1D profile lines. Passed directly to + :func:`profile_lowlevel`. If None, the default color is used. + reference: + List of reference points for optimization results, shown on diagonal + 1D plots. + label_fontsize: + Font size for axis labels and the colorbar label. Tick labels are + drawn two points smaller. + + Returns + ------- + fig: + The figure object. + axes: + Array of axes objects (n×n grid). + """ + if result.profile_result is None: + raise ValueError("Result does not contain profile results.") + + profile_list = result.profile_result.list[profile_list_id] + + if profile_indices is None: + profile_indices = [ + i for i, prof in enumerate(profile_list) if prof is not None + ] + + n_params = len(profile_indices) + + if n_params == 0: + raise ValueError("No profiles available to plot.") + + if size is None: + # +1 inch of extra width reserves space for the colorbar so that + # each subplot cell remains approximately square. + size = (n_params * 3.5 + 1, n_params * 3.5) + + fig, axes = plt.subplots( + n_params, + n_params, + figsize=size, + constrained_layout=True, + ) + fig.get_layout_engine().set(wspace=0.1, hspace=0.1) + + if n_params == 1: + axes = np.array([[axes]]) + + ref = create_references(references=reference) + + def _label(idx): + if x_labels is not None: + return x_labels[idx] + return _parameter_label(result.problem, idx) + + # Compute global color range across all 2D off-diagonal subplots so the + # shared colorbar is accurate for every panel. + all_color_values = [] + for row_idx in profile_indices: + for col_idx in profile_indices: + if row_idx == col_idx or profile_list[col_idx] is None: + continue + profiler = profile_list[col_idx] + mask = profiler.ratio_path >= ratio_min + vals = ( + profiler.fval_path[mask] + if plot_objective_values + else profiler.ratio_path[mask] + ) + if vals.size > 0: + all_color_values.append(vals) + if all_color_values: + all_vals = np.concatenate(all_color_values) + color_vmin, color_vmax = float(all_vals.min()), float(all_vals.max()) + else: + color_vmin, color_vmax = None, None + + first_2d_ax = None + last_2d_ax = None + + for i, row_param_idx in enumerate(profile_indices): + for j, col_param_idx in enumerate(profile_indices): + ax = axes[i, j] + + if i == j: + # Diagonal: 1D profile + fvals, _ = handle_inputs( + result, + profile_indices=[row_param_idx], + profile_list=profile_list_id, + ratio_min=ratio_min, + plot_objective_values=plot_objective_values, + ) + + if fvals[row_param_idx] is not None: + profile_lowlevel( + fvals[row_param_idx], + ax, + show_bounds=True, + color=profile_color, + lb=result.problem.lb_full[row_param_idx], + ub=result.problem.ub_full[row_param_idx], + ) + # Fix integer tick locator from profile_lowlevel for float params + ax.xaxis.set_major_locator(plt.AutoLocator()) + ax.set_xlabel(_label(row_param_idx)) + ax.set_ylabel( + "Objective value" + if plot_objective_values + else "Log-posterior ratio" + ) + diag_lb = result.problem.lb_full[row_param_idx] + diag_ub = result.problem.ub_full[row_param_idx] + diag_margin = _BOUND_VIEW_MARGIN * (diag_ub - diag_lb) + ax.set_xlim([diag_lb - diag_margin, diag_ub + diag_margin]) + _add_bound_lines_1d(ax, diag_lb, diag_ub) + + if len(ref) > 0: + for i_ref in ref: + current_x = i_ref["x"][row_param_idx] + ax.plot( + [current_x, current_x], + [0.0, 1.0], + color=i_ref.color, + label=i_ref.legend + if i == 0 and j == 0 + else None, + ) + if i == 0 and j == 0: + profile_legend_color = ( + profile_color + if profile_color is not None + else "red" + ) + _add_panel_legend( + ax, + handles=[ + Line2D( + [0], + [0], + color=profile_legend_color, + linewidth=2.0, + label="Profile", + ), + Line2D( + [0], + [0], + color="0.65", + linestyle="--", + linewidth=1.0, + label="Bounds", + ), + ], + fontsize=label_fontsize - 3, + ) + + else: + # Off-diagonal: 2D profile + # subplot (i, j): x-axis = col_param_idx, y-axis = row_param_idx + try: + profile_lowlevel_2d( + result=result, + profile_index=col_param_idx, + second_par_index=row_param_idx, + ax=ax, + profile_list_id=profile_list_id, + ratio_min=ratio_min, + cmap=cmap, + plot_objective_values=plot_objective_values, + x_labels=x_labels, + vmin=color_vmin, + vmax=color_vmax, + ) + if first_2d_ax is None: + first_2d_ax = ax + last_2d_ax = ax + except (ValueError, IndexError): + ax.text( + 0.5, + 0.5, + "No profile", + ha="center", + va="center", + transform=ax.transAxes, + ) + ax.set_xticks([]) + ax.set_yticks([]) + + # yaxis.labelpad is tightened so y-labels stay close to their axis + # rather than floating in the gap between columns. + for ax in axes.flat: + ax.xaxis.label.set_size(label_fontsize) + ax.yaxis.label.set_size(label_fontsize) + ax.xaxis.label.set_weight("bold") + ax.yaxis.label.set_weight("bold") + ax.yaxis.labelpad = 2 + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.tick_params(axis="both", labelsize=label_fontsize - 2) + + if first_2d_ax is not None: + _add_panel_legend( + first_2d_ax, + handles=[ + Line2D( + [0], + [0], + marker="o", + linestyle="None", + markerfacecolor="0.35", + markeredgecolor="none", + markersize=6, + label="Profile points", + ), + Line2D( + [0], + [0], + color="0.65", + linestyle="--", + linewidth=1.0, + label="Bounds", + ), + ], + fontsize=label_fontsize - 3, + ) + + if last_2d_ax is not None: + scatter = last_2d_ax.collections[-1] + cbar = fig.colorbar(scatter, ax=axes) + cbar.set_label( + "Objective value" + if plot_objective_values + else "Log-posterior ratio", + rotation=270, + labelpad=20, + fontsize=label_fontsize, + fontweight="bold", + ) + cbar.ax.tick_params(labelsize=label_fontsize - 2) + + return fig, axes diff --git a/pypesto/visualize/sampling.py b/pypesto/visualize/sampling.py index 6e27b946d..bb520d352 100644 --- a/pypesto/visualize/sampling.py +++ b/pypesto/visualize/sampling.py @@ -25,9 +25,17 @@ from ..ensemble import EnsemblePrediction, get_percentile_label from ..result import McmcPtResult, PredictionResult, Result from ..sample import calculate_ci_mcmc_sample -from .misc import rgba2rgb +from .misc import ( + _UNSET, + get_ax, + get_axes_array, + hide_unused_axes, + make_grid_shape, + plot_diagonal_marginal, + process_deprecated_kwarg, + rgba2rgb, +) -cmap = matplotlib.cm.viridis logger = logging.getLogger(__name__) @@ -43,10 +51,10 @@ def sampling_fval_traces( i_chain: int = 0, full_trace: bool = False, stepsize: int = 1, - title: str = None, - size: tuple[float, float] = None, - ax: matplotlib.axes.Axes = None, -): + title: str | None = None, + size: tuple[float, float] | None = None, + ax: matplotlib.axes.Axes | None = None, +) -> matplotlib.axes.Axes: """ Plot log-posterior (=function value) over iterations. @@ -82,11 +90,8 @@ def sampling_fval_traces( full_trace=full_trace, ) - # set axes and figure - if ax is None: - _, ax = plt.subplots(figsize=size) + ax = get_ax(ax, size) - sns.set(style="ticks") kwargs = {"edgecolor": "w", "linewidth": 0.3, "s": 10} # for edge color if full_trace: kwargs["hue"] = "converged" @@ -707,7 +712,7 @@ def _handle_colors( # define colormap variable_colors = [ - list(cmap(v))[:LEN_RGB] + list(matplotlib.cm.viridis(v))[:LEN_RGB] for v in np.linspace(cmap_min, cmap_max, n_variables) ] @@ -717,10 +722,10 @@ def _handle_colors( def sampling_prediction_trajectories( ensemble_prediction: EnsemblePrediction, levels: float | Sequence[float], - title: str = None, - size: tuple[float, float] = None, - axes: matplotlib.axes.Axes = None, - labels: dict[str, str] = None, + title: str | None = None, + size: tuple[float, float] | None = None, + axes: matplotlib.axes.Axes | np.ndarray | None = None, + labels: dict[str, str] | None = None, axis_label_padding: int = 50, groupby: str = CONDITION, condition_gap: float = 0.01, @@ -730,8 +735,8 @@ def sampling_prediction_trajectories( reverse_opacities: bool = False, average: str = MEDIAN, add_sd: bool = False, - measurement_df: pd.DataFrame = None, -) -> matplotlib.axes.Axes: + measurement_df: pd.DataFrame | None = None, +) -> np.ndarray: """ Visualize prediction trajectory of an EnsemblePrediction. @@ -785,7 +790,7 @@ def sampling_prediction_trajectories( Returns ------- axes: - The plot axes. + 2-D NumPy array containing one matplotlib Axes per panel. """ if labels is None: labels = {} @@ -865,21 +870,12 @@ def sampling_prediction_trajectories( reverse=reverse_opacities, ) - if axes is None: - n_row = int(np.round(np.sqrt(n_subplots))) - n_col = int(np.ceil(n_subplots / n_row)) - fig, axes = plt.subplots(n_row, n_col, figsize=size, squeeze=False) - for ax in axes.flat[n_subplots:]: - ax.remove() - else: - fig = axes.get_figure() - if not isinstance(axes, np.ndarray): - axes = np.array([[axes]]) - if len(axes.flat) < n_subplots: - raise ValueError( - "Provided `axes` contains insufficient subplots. At least " - f"{n_subplots} are required." - ) + n_row = int(np.round(np.sqrt(n_subplots))) + n_col = int(np.ceil(n_subplots / n_row)) + + axes = get_axes_array(axes=axes, nrows=n_row, ncols=n_col, size=size) + fig = axes.flat[0].figure + axes = hide_unused_axes(axes=axes, n_used=n_subplots, clear=True) artist_padding = axis_label_padding / (fig.get_size_inches() * fig.dpi)[0] if groupby == CONDITION: @@ -931,8 +927,9 @@ def sampling_prediction_trajectories( ) # X and Y labels - xmin = min(ax.get_position().xmin for ax in axes.flat) - ymin = min(ax.get_position().ymin for ax in axes.flat) + visible_axes = [ax for ax in axes.flat if ax.get_visible()] + xmin = min(ax.get_position().xmin for ax in visible_axes) + ymin = min(ax.get_position().ymin for ax in visible_axes) xlabel = ( "Cumulative time across all conditions" if groupby == OUTPUT @@ -962,12 +959,13 @@ def sampling_prediction_trajectories( def sampling_parameter_cis( result: Result, - alpha: Sequence[int] = None, + confidence_levels: Sequence[float] = None, step: float = 0.05, show_median: bool = True, - title: str = None, - size: tuple[float, float] = None, - ax: matplotlib.axes.Axes = None, + title: str | None = None, + size: tuple[float, float] | None = None, + ax: matplotlib.axes.Axes | None = None, + alpha: Sequence[int] = None, ) -> matplotlib.axes.Axes: """ Plot MCMC-based parameter credibility intervals. @@ -976,8 +974,13 @@ def sampling_parameter_cis( ---------- result: The pyPESTO result object with filled sample result. + confidence_levels: + Credibility levels as fractions in (0, 1), e.g. ``[0.95]`` for a + 95% credibility interval. Defaults to ``[0.95]``. alpha: - List of lower tail probabilities, defaults to 95% interval. + Deprecated. Use ``confidence_levels`` instead. + Previously accepted integer percentages (e.g. ``[95]``); values + are divided by 100 automatically during the transition. step: Height of boxes for projectile plot, defaults to 0.05. show_median: @@ -994,31 +997,46 @@ def sampling_parameter_cis( ax: The plot axes. """ - if alpha is None: - alpha = [95] + if alpha is not None: + if confidence_levels is not None: + raise ValueError( + "Pass either `confidence_levels` or the deprecated `alpha`, not both." + ) + import warnings + + warnings.warn( + "`alpha` is deprecated; use `confidence_levels` instead. " + "Note: units have changed — pass fractions in (0, 1) " + "(e.g. `confidence_levels=[0.95]`) instead of integer percentages " + "(e.g. `alpha=[95]`). Your values have been divided by 100 automatically.", + DeprecationWarning, + stacklevel=2, + ) + confidence_levels = [a / 100 for a in alpha] + + if confidence_levels is None: + confidence_levels = [0.95] # automatically sort values in decreasing order - alpha_sorted = sorted(alpha, reverse=True) + levels_sorted = sorted(confidence_levels, reverse=True) # define colormap - evenly_spaced_interval = np.linspace(0, 1, len(alpha_sorted)) + evenly_spaced_interval = np.linspace(0, 1, len(levels_sorted)) colors = [plt.cm.tab20c_r(x) for x in evenly_spaced_interval] # number of sampled parameters n_pars = result.sample_result.trace_x.shape[-1] - # set axes and figure - if ax is None: - _, ax = plt.subplots(figsize=size) + ax = get_ax(ax, size) # loop over parameters for npar in range(n_pars): # initialize height of boxes _step = step # loop over confidence levels - for n, level in enumerate(alpha_sorted): + for n, level in enumerate(levels_sorted): # extract percentile-based confidence intervals lb, ub = calculate_ci_mcmc_sample( result=result, - ci_level=level / 100, + ci_level=level, ) # assemble boxes for projectile plot @@ -1030,11 +1048,11 @@ def sampling_parameter_cis( np.append(x1, x1[::-1]), np.append(y1, y2[::-1]), color=colors[n], - label=str(level) + "% CI", + label=f"{level:.0%} CI", ) if show_median: - if n == len(alpha_sorted) - 1: + if n == len(levels_sorted) - 1: burn_in = result.sample_result.burn_in converged = result.sample_result.trace_x[0, burn_in:, npar] _median = np.median(converged) @@ -1070,14 +1088,16 @@ def sampling_parameter_cis( def sampling_parameter_traces( result: Result, i_chain: int = 0, - par_indices: Sequence[int] = None, + parameter_indices: Sequence[int] = None, full_trace: bool = False, stepsize: int = 1, use_problem_bounds: bool = True, - suptitle: str = None, - size: tuple[float, float] = None, - ax: matplotlib.axes.Axes = None, -): + suptitle: str | None = None, + size: tuple[float, float] | None = None, + axes: np.ndarray | None = None, + ax: np.ndarray | None = _UNSET, + par_indices: Sequence[int] = _UNSET, +) -> np.ndarray: """ Plot parameter values over iterations. @@ -1087,7 +1107,7 @@ def sampling_parameter_traces( The pyPESTO result object with filled sample result. i_chain: Which chain to plot. Default: First chain. - par_indices: list of integer values + parameter_indices: list of integer values List of integer values specifying which parameters to plot. Default: All parameters are shown. full_trace: @@ -1101,14 +1121,26 @@ def sampling_parameter_traces( Figure suptitle. size: Figure size in inches. + axes: + Axes grid to use. Must match the computed subplot layout. ax: - Axes object to use. + Deprecated. Use ``axes`` instead. + par_indices: + Deprecated. Use ``parameter_indices`` instead. Returns ------- - ax: - The plot axes. + axes: + 2-D NumPy array containing one matplotlib Axes per panel. """ + parameter_indices = process_deprecated_kwarg( + "parameter_indices", + parameter_indices, + "par_indices", + par_indices, + ) + axes = process_deprecated_kwarg("axes", axes, "ax", ax) + import seaborn as sns # get data which should be plotted @@ -1117,22 +1149,18 @@ def sampling_parameter_traces( i_chain=i_chain, stepsize=stepsize, full_trace=full_trace, - par_indices=par_indices, + parameter_indices=parameter_indices, ) - # compute, how many rows and columns we need for the subplots - num_row = int(np.round(np.sqrt(nr_params))) - num_col = int(np.ceil(nr_params / num_row)) + num_row, num_col = make_grid_shape(nr_params) + if size is None and axes is None: + size = (3.5 * num_col, 2.5 * num_row) + axes = get_axes_array(axes=axes, nrows=num_row, ncols=num_col, size=size) + fig = axes.flat[0].figure + axes = hide_unused_axes(axes=axes, n_used=nr_params, clear=True) - # set axes and figure - if ax is None: - fig, ax = plt.subplots(num_row, num_col, squeeze=False, figsize=size) - else: - fig = ax.get_figure() + par_ax = dict(zip(param_names, axes.flat, strict=True)) - par_ax = dict(zip(param_names, ax.flat, strict=True)) - - sns.set(style="ticks") kwargs = {"edgecolor": "w", "linewidth": 0.3, "s": 10} # for edge color if full_trace: @@ -1174,22 +1202,21 @@ def sampling_parameter_traces( if suptitle: fig.suptitle(suptitle) - - fig.tight_layout() sns.despine() - return ax + return axes def sampling_scatter( result: Result, i_chain: int = 0, stepsize: int = 1, - suptitle: str = None, + suptitle: str | None = None, diag_kind: str = "kde", - size: tuple[float, float] = None, + size: tuple[float, float] | None = None, show_bounds: bool = True, -): + axes: np.ndarray | None = None, +) -> np.ndarray: """ Parameter scatter plot. @@ -1212,50 +1239,84 @@ def sampling_scatter( Returns ------- - ax: - The plot axes. + axes: + 2-D NumPy array containing one matplotlib Axes per panel. """ - import seaborn as sns - # get data which should be plotted - nr_params, params_fval, theta_lb, theta_ub, _ = get_data_to_plot( + nr_params, params_fval, theta_lb, theta_ub, param_names = get_data_to_plot( result=result, i_chain=i_chain, stepsize=stepsize ) - sns.set(style="ticks") + if size is None and axes is None: + size = (2.5 * nr_params + 0.5, 2.5 * nr_params + 0.5) - # TODO: Think this throws the axis errors in seaborn. - ax = sns.pairplot( - params_fval.drop(["logPosterior", "iteration"], axis=1), - diag_kind=diag_kind, + axes = get_axes_array( + axes=axes, nrows=nr_params, ncols=nr_params, size=size ) + fig = axes.flat[0].figure + for ax in axes.flat: + ax.clear() + ax.set_visible(True) + + data = params_fval[param_names] + for row in range(nr_params): + for col in range(nr_params): + ax = axes[row, col] + col_name = param_names[col] + row_name = param_names[row] + col_vals = data[col_name] + row_vals = data[row_name] + + if row == col: + plot_diagonal_marginal( + ax=ax, values=col_vals, diag_kind=diag_kind + ) + else: + ax.scatter( + col_vals, + row_vals, + color="C0", + alpha=0.85, + s=35, + linewidths=0.6, + edgecolors="white", + zorder=3, + ) + ax.set_ylabel(row_name) - if size is not None: - ax.fig.set_size_inches(size) - - if suptitle: - ax.fig.suptitle(suptitle) + ax.set_xlabel(col_name) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) if show_bounds: - # set bounds of plot to parameter bounds. Only use diagonal as - # sns.PairGrid has sharex,sharey = True by default. - for i_axis, axis in enumerate(np.diag(ax.axes)): - axis.set_xlim(result.problem.lb[i_axis], result.problem.ub[i_axis]) - axis.set_ylim(result.problem.lb[i_axis], result.problem.ub[i_axis]) + for col in range(nr_params): + xlim = (theta_lb[col], theta_ub[col]) + for row in range(nr_params): + axes[row, col].set_xlim(xlim) + for row in range(nr_params): + ylim = (theta_lb[row], theta_ub[row]) + for col in range(nr_params): + if row != col: + axes[row, col].set_ylim(ylim) - return ax + if suptitle: + fig.suptitle(suptitle) + + return axes def sampling_1d_marginals( result: Result, i_chain: int = 0, - par_indices: Sequence[int] = None, + parameter_indices: Sequence[int] = None, stepsize: int = 1, plot_type: str = "both", bw_method: str = "scott", - suptitle: str = None, - size: tuple[float, float] = None, -): + suptitle: str | None = None, + size: tuple[float, float] | None = None, + axes: np.ndarray | None = None, + par_indices: Sequence[int] = _UNSET, +) -> np.ndarray: """ Plot marginals. @@ -1265,7 +1326,7 @@ def sampling_1d_marginals( The pyPESTO result object with filled sample result. i_chain: Which chain to plot. Default: First chain. - par_indices: list of integer values + parameter_indices: list of integer values List of integer values specifying which parameters to plot. Default: All parameters are shown. stepsize: @@ -1279,12 +1340,23 @@ def sampling_1d_marginals( Figure super title. size: Figure size in inches. + axes: + Axes grid to use. Must match the computed subplot layout. + par_indices: + Deprecated. Use ``parameter_indices`` instead. Return -------- - ax: - matplotlib-axes + axes: + 2-D NumPy array containing one matplotlib Axes per panel. """ + parameter_indices = process_deprecated_kwarg( + "parameter_indices", + parameter_indices, + "par_indices", + par_indices, + ) + import seaborn as sns # get data which should be plotted @@ -1292,17 +1364,17 @@ def sampling_1d_marginals( result=result, i_chain=i_chain, stepsize=stepsize, - par_indices=par_indices, + parameter_indices=parameter_indices, ) - # compute, how many rows and columns we need for the subplots - num_row = int(np.round(np.sqrt(nr_params))) - num_col = int(np.ceil(nr_params / num_row)) - - fig, ax = plt.subplots(num_row, num_col, squeeze=False, figsize=size) + num_row, num_col = make_grid_shape(nr_params) + if size is None and axes is None: + size = (3.5 * num_col, 2.5 * num_row) + axes = get_axes_array(axes=axes, nrows=num_row, ncols=num_col, size=size) + fig = axes.flat[0].figure + axes = hide_unused_axes(axes=axes, n_used=nr_params, clear=True) - par_ax = dict(zip(param_names, ax.flat, strict=True)) - sns.set(style="ticks") + par_ax = dict(zip(param_names, axes.flat, strict=True)) # fig, ax = plt.subplots(nr_params, figsize=size)[1] for idx, par_id in enumerate(param_names): @@ -1334,9 +1406,7 @@ def sampling_1d_marginals( if suptitle: fig.suptitle(suptitle) - fig.tight_layout() - - return ax + return axes def get_data_to_plot( @@ -1344,7 +1414,7 @@ def get_data_to_plot( i_chain: int, stepsize: int, full_trace: bool = False, - par_indices: Sequence[int] = None, + parameter_indices: Sequence[int] = None, ): """Get the data which should be plotted as a pandas.DataFrame. @@ -1358,7 +1428,7 @@ def get_data_to_plot( Only one in `stepsize` values is plotted. full_trace: Keep the full length of the chain. Default: False. - par_indices: list of integer values + parameter_indices: list of integer values List of integer values specifying which parameters to plot. Default: All parameters are shown. @@ -1432,9 +1502,9 @@ def get_data_to_plot( # some global parameters nr_params = arr_param.shape[1] # number of parameters - if par_indices is not None: - param_names = params_fval.columns.values[par_indices] - nr_params = len(par_indices) + if parameter_indices is not None: + param_names = params_fval.columns.values[parameter_indices] + nr_params = len(parameter_indices) else: param_names = params_fval.columns.values[0:nr_params] diff --git a/pypesto/visualize/waterfall.py b/pypesto/visualize/waterfall.py index 1d6a5b60d..c25158363 100644 --- a/pypesto/visualize/waterfall.py +++ b/pypesto/visualize/waterfall.py @@ -1,6 +1,6 @@ from collections.abc import Sequence -import matplotlib.pyplot as plt +import matplotlib.axes import numpy as np from matplotlib.ticker import MaxNLocator from mpl_toolkits.axes_grid1 import inset_locator @@ -11,6 +11,7 @@ from ..result import Result from .clust_color import assign_colors from .misc import ( + get_ax, process_offset_y, process_result_list, process_start_indices, @@ -21,7 +22,7 @@ def waterfall( results: Result | Sequence[Result], - ax: plt.Axes | None = None, + ax: matplotlib.axes.Axes | None = None, size: tuple[float, float] | None = (18.5, 10.5), y_limits: tuple[float] | None = None, scale_y: str | None = "log10", @@ -32,7 +33,7 @@ def waterfall( colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: Sequence[str] | str | None = None, order_by_id: bool = False, -): +) -> matplotlib.axes.Axes: """ Plot waterfall plot. @@ -76,11 +77,7 @@ def waterfall( ax: matplotlib.Axes The plot axes. """ - # axes - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) if n_starts_to_zoom: # create zoom in @@ -200,13 +197,13 @@ def waterfall( def waterfall_lowlevel( fvals, - ax: plt.Axes | None = None, - size: tuple[float] | None = (18.5, 10.5), + ax: matplotlib.axes.Axes | None = None, + size: tuple[float, float] | None = (18.5, 10.5), scale_y: str = "log10", offset_y: float = 0.0, colors: COLOR | list[COLOR] | np.ndarray | None = None, legend_text: str | None = None, -): +) -> matplotlib.axes.Axes: """ Plot waterfall plot using list of function values. @@ -235,11 +232,7 @@ def waterfall_lowlevel( ax: matplotlib.Axes The plot axes. """ - # axes - if ax is None: - ax = plt.subplots()[1] - fig = plt.gcf() - fig.set_size_inches(*size) + ax = get_ax(ax, size) start_indices = [i for i, fval in enumerate(fvals) if fval is not None] fvals = [fvals[i] for i in start_indices] diff --git a/pyproject.toml b/pyproject.toml index 0e336b3be..7b84f5576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ authors = [ maintainers = [ { name = "Paul Jonas Jost", email = "paul.jost@uni-bonn.de" }, { name = "Domagoj Dorešić", email = "domagoj.doresic@uni-bonn.de" }, - { name = "Vincent Wieland", email = "vwieland@uni-bonn.de" }, + { name = "Moritz Richter", email = "moritz.richter@uni-bonn.de" }, { name = "Fabian Fröhlich", email = "fabian.frohlich@crick.ac.uk" }, ] license-files = ["LICENSE"] @@ -106,7 +106,9 @@ mpi = [ ] pymc = [ - "arviz>=0.12.1", + # TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0. + "arviz>=0.12.1,<1.0; python_version < '3.12'", + "arviz>=1.1.0; python_version >= '3.12'", "pymc>=4.2.1", ] @@ -129,7 +131,7 @@ mltools = [ ] julia = [ - "julia>=0.5.7", + "juliacall>=0.9.31", "ipython>=8.4.0", "pygments>=2.12.0", ] diff --git a/test/base/test_workflow.py b/test/base/test_workflow.py index 89ccf2bf6..b1572e2e7 100644 --- a/test/base/test_workflow.py +++ b/test/base/test_workflow.py @@ -3,31 +3,16 @@ These tests are not for correctness, but for basic functionality. """ -from functools import wraps - -import matplotlib.pyplot as plt - import pypesto import pypesto.optimize as optimize import pypesto.profile as profile import pypesto.sample as sample import pypesto.visualize as visualize +from ..conftest import close_fig from ..util import CRProblem -def close_fig(fun): - """Close figure.""" - - @wraps(fun) - def wrapped_fun(*args, **kwargs): - ret = fun(*args, **kwargs) - plt.close("all") - return ret - - return wrapped_fun - - def test_objective(): """Test a simple objective function.""" crproblem = CRProblem() diff --git a/test/conftest.py b/test/conftest.py index 755a48a7c..f3188d023 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,8 @@ import os import tempfile +from functools import wraps +import matplotlib.pyplot as plt import numpy as np import pytest import scipy.optimize as so @@ -10,6 +12,19 @@ from pypesto.store import write_result +def close_fig(fun): + """Close all figures after a test, even on failure.""" + + @wraps(fun) + def wrapped_fun(*args, **kwargs): + try: + return fun(*args, **kwargs) + finally: + plt.close("all") + + return wrapped_fun + + @pytest.fixture def hdf5_file(): """Generate a temporary hdf5 file.""" diff --git a/test/julia/test_pyjulia.py b/test/julia/test_pyjulia.py index 8325a44e8..20ea9d251 100644 --- a/test/julia/test_pyjulia.py +++ b/test/julia/test_pyjulia.py @@ -7,8 +7,6 @@ from pypesto.objective.julia import JuliaObjective, display_source_ipython from pypesto.objective.julia.petab_jl_importer import PetabJlImporter -# The pyjulia wrapper appears to ignore global noqas, thus per line here - def test_pyjulia_pipeline(): """Test that a pipeline with julia objective works.""" diff --git a/test/profile/test_profile.py b/test/profile/test_profile.py index aea7556f0..5de231925 100644 --- a/test/profile/test_profile.py +++ b/test/profile/test_profile.py @@ -15,9 +15,14 @@ import pypesto.profile as profile import pypesto.visualize as visualize from pypesto import ObjectiveBase +from pypesto.profile.util import ( + precheck_profile_step_size, + resolve_profile_step_sizes, + resolve_profile_step_sizes_for_parameters, +) +from ..conftest import close_fig from ..util import rosen_for_sensi -from ..visualize import close_fig class ProfilerTest(unittest.TestCase): @@ -145,9 +150,9 @@ def test_engine_profiling(self): def test_selected_profiling(self): # create options in order to ensure a short computation time options = profile.ProfileOptions( - default_step_size=0.02, - min_step_size=0.005, - max_step_size=1.0, + default_step_size_absolute=0.02, + min_step_size_absolute=0.005, + max_step_size_absolute=1.0, step_size_factor=1.5, delta_ratio_max=0.2, ratio_min=0.3, @@ -280,9 +285,9 @@ def test_profile_with_history(): ) profile_options = profile.ProfileOptions( - min_step_size=0.0005, + min_step_size_absolute=0.0005, delta_ratio_max=0.05, - default_step_size=0.005, + default_step_size_absolute=0.005, ratio_min=0.03, ) @@ -352,6 +357,10 @@ def test_profile_with_fixed_parameters(): # test profiling with all parameters fixed but one problem.fix_parameters([2, 3, 4], result.optimize_result.list[0]["x"][2:5]) + resolved_steps_by_par = resolve_profile_step_sizes_for_parameters( + problem, problem.x_free_indices, profile.ProfileOptions() + ) + assert set(resolved_steps_by_par) == set(problem.x_free_indices) profile.parameter_profile( problem=problem, result=result, @@ -426,23 +435,198 @@ def test_options_valid(): """Test ProfileOptions validity checks.""" # default settings are valid profile.ProfileOptions() + profile.ProfileOptions( + min_step_size_relative=0.0025, + default_step_size_relative=0.005, + max_step_size_relative=0.02, + ) # try to set invalid values with pytest.raises(ValueError): - profile.ProfileOptions(default_step_size=-1) - with pytest.raises(ValueError): - profile.ProfileOptions(default_step_size=1, min_step_size=2) + profile.ProfileOptions(default_step_size_absolute=-1) with pytest.raises(ValueError): - profile.ProfileOptions( - default_step_size=2, - min_step_size=1, + profile.ProfileOptions(default_step_size_relative=-0.01) + with pytest.warns(DeprecationWarning, match="`default_step_size`"): + options = profile.ProfileOptions(default_step_size=0.05) + assert options.default_step_size_absolute == 0.05 + # the deprecated argument overrides the new one + with pytest.warns(DeprecationWarning, match="`default_step_size`"): + options = profile.ProfileOptions( + default_step_size=0.01, + default_step_size_absolute=0.03, ) - with pytest.raises(ValueError): - profile.ProfileOptions( - min_step_size=2, - max_step_size=1, + assert options.default_step_size_absolute == 0.01 + # the deprecated attribute is still readable + with pytest.warns(DeprecationWarning, match="`default_step_size`"): + assert options.default_step_size == 0.01 + for kwargs in ( + { + "default_step_size_absolute": 1, + "min_step_size_absolute": 2, + }, + { + "default_step_size_absolute": 2, + "min_step_size_absolute": 1, + "max_step_size_absolute": 1, + }, + { + "min_step_size_relative": 0.006, + "default_step_size_relative": 0.005, + }, + { + "default_step_size_relative": 0.03, + "max_step_size_relative": 0.02, + }, + { + "default_step_size_absolute": 0.0, + "default_step_size_relative": 0.0, + }, + {"step_size_precheck_mode": "invalid"}, + ): + with pytest.raises(ValueError): + profile.ProfileOptions(**kwargs) + + +@pytest.mark.parametrize( + ( + "scale", + "lb", + "ub", + "profile_options", + "expected_min", + "expected_default", + "expected_max", + "expected_mode", + ), + [ + ("lin", 0.0, 100.0, None, 0.125, 0.25, 2.5, "relative"), + ("lin", 0.0, 1.0, None, 0.01, 0.02, 0.2, "absolute"), + ("log10", -6.0, 6.0, None, 0.015, 0.03, 0.3, "relative"), + ( + "lin", + 0.0, + 100.0, + profile.ProfileOptions( + min_step_size_absolute=0.1, + default_step_size_absolute=0.5, + max_step_size_absolute=10.0, + min_step_size_relative=0.002, + default_step_size_relative=0.005, + max_step_size_relative=0.006, + ), + 0.2, + 0.5, + 0.6, + "relative", + ), + ], +) +def test_resolve_profile_step_sizes( + scale, + lb, + ub, + profile_options, + expected_min, + expected_default, + expected_max, + expected_mode, +): + """Resolved step sizes should pick one family on the optimization scale.""" + problem = pypesto.Problem( + objective=pypesto.Objective(fun=lambda x: np.sum(x**2)), + lb=np.array([lb]), + ub=np.array([ub]), + x_scales=[scale], + x_names=["x0"], + ) + resolved_steps = resolve_profile_step_sizes( + problem, + 0, + profile_options or profile.ProfileOptions(), + ) + + assert np.isclose(resolved_steps.min_step_size, expected_min) + assert np.isclose(resolved_steps.default_step_size, expected_default) + assert np.isclose(resolved_steps.max_step_size, expected_max) + assert resolved_steps.mode == expected_mode + assert np.isclose(resolved_steps.span, ub - lb) + assert ( + resolve_profile_step_sizes_for_parameters( + problem, + [0], + profile_options or profile.ProfileOptions(), + )[0] + == resolved_steps + ) + + +@pytest.mark.parametrize( + ("mode", "expect_warning", "expect_raise"), + [ + ("off", False, False), + ("warn", True, False), + ("raise", False, True), + ], +) +def test_profile_step_size_precheck_modes(mode, expect_warning, expect_raise): + """Precheck modes should suppress, warn, or raise on large spans.""" + problem = pypesto.Problem( + objective=pypesto.Objective(fun=lambda x: np.sum(x**2)), + lb=np.array([-5.0]), + ub=np.array([15.0]), + x_scales=["log10"], + x_names=["x0"], + ) + current_profile = pypesto.ProfilerResult( + x_path=np.array([[0.0]]), + fval_path=np.array([0.0]), + ratio_path=np.array([1.0]), + ) + profile_options = profile.ProfileOptions( + min_step_size_relative=0.0005, + default_step_size_relative=0.001, + max_step_size_relative=0.01, + step_size_precheck_mode=mode, + whole_path=True, + ) + resolved_steps = resolve_profile_step_sizes(problem, 0, profile_options) + + if expect_raise: + with pytest.raises(ValueError, match="may require many steps"): + precheck_profile_step_size( + current_profile=current_profile, + problem=problem, + i_par=0, + par_direction=1, + options=profile_options, + resolved_steps=resolved_steps, + ) + return + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + precheck_profile_step_size( + current_profile=current_profile, + problem=problem, + i_par=0, + par_direction=1, + options=profile_options, + resolved_steps=resolved_steps, ) + precheck_warnings = [ + warning + for warning in caught + if "may require many steps" in str(warning.message) + ] + if expect_warning: + assert precheck_warnings + message = str(precheck_warnings[0].message) + assert "default step size" in message + assert "minimum step size" in message + else: + assert not precheck_warnings + @pytest.mark.parametrize( "lb,ub", @@ -478,10 +662,10 @@ def test_gh1165(lb, ub): profile_index=[par_idx], progress_bar=False, profile_options=profile.ProfileOptions( - min_step_size=0.1, - max_step_size=1.0, + min_step_size_absolute=0.1, + max_step_size_absolute=1.0, delta_ratio_max=0.05, - default_step_size=0.5, + default_step_size_absolute=0.5, ratio_min=0.01, whole_path=True, ), diff --git a/test/visualize/__init__.py b/test/visualize/__init__.py index deb30067c..82ff3e935 100644 --- a/test/visualize/__init__.py +++ b/test/visualize/__init__.py @@ -1,7 +1,7 @@ """Visualization tests.""" +from ..conftest import close_fig from .test_visualize import ( - close_fig, create_optimization_result, create_petab_problem, create_problem, diff --git a/test/visualize/test_misc.py b/test/visualize/test_misc.py new file mode 100644 index 000000000..db904f5d3 --- /dev/null +++ b/test/visualize/test_misc.py @@ -0,0 +1,75 @@ +"""Tests for visualize utility helpers in :mod:`pypesto.visualize.misc`.""" + +import matplotlib.pyplot as plt +import pytest + +from pypesto.visualize.misc import ( + _UNSET, + get_ax, + get_axes_array, + hide_unused_axes, + process_deprecated_kwarg, +) + +from ..conftest import close_fig + + +@close_fig +def test_get_ax(): + """Returns the given Axes; otherwise creates one with ``size``.""" + _, given = plt.subplots() + assert get_ax(given) is given + + custom = get_ax(size=(4.0, 3.0)) + assert tuple(custom.get_figure().get_size_inches()) == (4.0, 3.0) + + +@close_fig +def test_get_axes_array(): + """Normalizes existing grids and creates new ones with ``size``.""" + _, given = plt.subplots(1, 2) + normalized = get_axes_array(given, nrows=1, ncols=2) + assert normalized.shape == (1, 2) + + created = get_axes_array(nrows=2, ncols=1, size=(4.0, 3.0)) + assert created.shape == (2, 1) + assert tuple(created.flat[0].figure.get_size_inches()) == (4.0, 3.0) + + with pytest.raises(ValueError, match="shape"): + get_axes_array(given, nrows=2, ncols=2) + + +@close_fig +def test_hide_unused_axes(): + """Hides unused panels and re-shows reused ones.""" + _, axes = plt.subplots(2, 2, squeeze=False) + + axes = hide_unused_axes(axes=axes, n_used=3, clear=True) + assert axes[0, 0].get_visible() + assert axes[1, 0].get_visible() + assert axes[1, 1].get_visible() is False + + axes = hide_unused_axes(axes=axes, used_indices=(0, 3)) + assert axes[0, 0].get_visible() + assert axes[0, 1].get_visible() is False + assert axes[1, 1].get_visible() + + with pytest.raises(ValueError, match="exactly one"): + hide_unused_axes(axes=axes) + + +def test_process_deprecated_kwarg(): + """Resolves rename: canonical wins, deprecated warns, both raises.""" + # deprecated not passed (_UNSET) → return canonical + assert process_deprecated_kwarg("new", 1, "old", _UNSET) == 1 + assert process_deprecated_kwarg("new", None, "old", _UNSET) is None + + # explicit old=None still warns (distinguishable from "not passed") + with pytest.warns(DeprecationWarning, match="old.*deprecated.*new"): + assert process_deprecated_kwarg("new", None, "old", None) is None + + with pytest.warns(DeprecationWarning, match="old.*deprecated.*new"): + assert process_deprecated_kwarg("new", None, "old", 2) == 2 + + with pytest.raises(ValueError, match="not both"): + process_deprecated_kwarg("new", 1, "old", 2) diff --git a/test/visualize/test_visualize.py b/test/visualize/test_visualize.py index 421015b3d..77dd2cc00 100644 --- a/test/visualize/test_visualize.py +++ b/test/visualize/test_visualize.py @@ -3,7 +3,6 @@ import os from collections.abc import Sequence from copy import deepcopy -from functools import wraps from pathlib import Path import matplotlib.pyplot as plt @@ -25,22 +24,12 @@ get_Boehm_JProteomeRes2014_hierarchical_petab_corrected_bounds, ) from pypesto.visualize.model_fit import ( + _get_simulation_rdatas, time_trajectory_model, visualize_optimized_model_fit, ) - -def close_fig(fun): - """Close figure.""" - - @wraps(fun) - def wrapped_fun(*args, **kwargs): - ret = fun(*args, **kwargs) - plt.close("all") - return ret - - return wrapped_fun - +from ..conftest import close_fig # Define some helper functions, to have the test code more readable @@ -549,7 +538,30 @@ def test_parameters_hierarchical(scale_to_interval): @close_fig def test_optimization_scatter(): result = create_optimization_result() - visualize.optimization_scatter(result) + axes = visualize.optimization_scatter(result) + assert axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.optimization_scatter(result, axes=custom_axes) + assert returned_axes is custom_axes + + nrows, ncols = axes.shape + fig = plt.figure() + grid_spec = fig.add_gridspec(nrows, ncols + 1) + custom_axes = np.empty((nrows, ncols), dtype=object) + for row in range(nrows): + for col in range(ncols): + custom_axes[row, col] = fig.add_subplot(grid_spec[row, col]) + extra_ax = fig.add_subplot(grid_spec[:, -1]) + visualize.optimization_scatter(result, axes=custom_axes) + assert extra_ax in fig.axes + + for i, optimizer_result in enumerate(result.optimize_result.list): + optimizer_result.fval = 1.0 + i * 1e-12 + + axes = visualize.optimization_scatter(result) + colorbar_ax = axes[0, 0].figure.axes[-1] + + assert colorbar_ax.get_ylim()[1] - colorbar_ax.get_ylim()[0] >= 1.0 @close_fig @@ -629,6 +641,55 @@ def _test_ensemble_dimension_reduction(): visualize.ensemble_scatter_lowlevel(pca_components[:, 0:2]) +@close_fig +def test_ensemble_scatter_lowlevel(): + dataset = np.array([[0.0, 1.0], [1.0, 0.0], [0.5, 0.5]]) + + ax = visualize.ensemble_scatter_lowlevel( + dataset, x_label="component x", y_label="component y" + ) + + assert ax.get_xlabel() == "component x" + assert ax.get_ylabel() == "component y" + + +@close_fig +def test_projection_scatter_umap_original(monkeypatch): + import sys + import types + + plot_module = types.ModuleType("umap.plot") + + def fake_points(umap_object, values=None, theme=None, **kwargs): + assert values == [0.0, 1.0] + assert theme == "viridis" + return kwargs["ax"] + + plot_module.points = fake_points + umap_module = types.ModuleType("umap") + umap_module.plot = plot_module + + monkeypatch.setitem(sys.modules, "umap", umap_module) + monkeypatch.setitem(sys.modules, "umap.plot", plot_module) + + class DummyUmap: + def __init__(self): + self.embedding_ = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + + dummy_umap = DummyUmap() + _, ax = plt.subplots() + + returned_ax = visualize.projection_scatter_umap_original( + dummy_umap, + color_by=[0.0, 1.0], + components=(0, 2), + ax=ax, + ) + + assert returned_ax is ax + assert dummy_umap.embedding_.shape == (2, 2) + + @close_fig def test_ensemble_identifiability(): # creates a test problem @@ -761,6 +822,13 @@ def test_profile_lowlevel(): ] ) visualize.profile_lowlevel(fvals=fvals, color="m") + _, ax = plt.subplots() + + returned_ax = visualize.profile_lowlevel(fvals=fvals, color="m", ax=ax) + + assert returned_ax is ax + assert ax.get_xlabel() == "Parameter value" + assert ax.get_ylabel() == "Log-posterior ratio" @close_fig @@ -779,6 +847,55 @@ def test_nested_profile_cis(): visualize.profile_nested_cis(result, colors=["#5F9ED1", "#007ACC"]) +@close_fig +def test_visualize_2d_profile(): + result = create_profile_result() + # basic call — all profiles, default settings + _, axes = visualize.visualize_2d_profile(result) + assert axes[0, 1].yaxis.labelpad == 2 + assert axes[0, 1].yaxis.label.get_size() == 14 + assert not axes[0, 1].spines["top"].get_visible() + assert not axes[0, 1].spines["right"].get_visible() + assert axes[0, 0].get_legend() is not None + assert axes[0, 1].get_legend() is not None + # explicit profile indices and ratio cutoff + visualize.visualize_2d_profile( + result, profile_indices=[0, 1], ratio_min=0.1 + ) + # objective values instead of ratio + visualize.visualize_2d_profile(result, plot_objective_values=True) + # custom figure size + visualize.visualize_2d_profile(result, size=(8, 8)) + + +@close_fig +def test_profile_lowlevel_2d(): + result = create_profile_result() + _, ax = plt.subplots() + visualize.profile_lowlevel_2d( + result, profile_index=0, second_par_index=1, ax=ax + ) + assert ( + len([line for line in ax.lines if line.get_linestyle() == "--"]) == 4 + ) + _, ax = plt.subplots() + visualize.profile_lowlevel_2d( + result, + profile_index=0, + second_par_index=1, + ax=ax, + plot_objective_values=True, + ) + _, ax = plt.subplots() + visualize.profile_lowlevel_2d( + result, + profile_index=0, + second_par_index=1, + ax=ax, + ratio_min=0.1, + ) + + @close_fig def test_optimizer_history(): # create the necessary results @@ -915,7 +1032,15 @@ def test_optimization_stats(): plot_type="hist", ) - visualize.optimization_run_properties_per_multistart([result_1, result_2]) + axes = visualize.optimization_run_properties_per_multistart( + [result_1, result_2] + ) + assert axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.optimization_run_properties_per_multistart( + [result_1, result_2], axes=custom_axes + ) + assert returned_axes is custom_axes visualize.optimization_run_properties_one_plot( result_1, ["time"], colors="C0" @@ -925,13 +1050,24 @@ def test_optimization_stats(): result_1, ["n_fval", "n_grad", "n_hess"] ) - visualize.optimization_run_property_per_multistart( + axes = visualize.optimization_run_property_per_multistart( + [result_1, result_2], + "time", + colors=["g", "C1"], + legends=["result1", "result2"], + plot_type="both", + ) + assert axes.shape == (1, 2) + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.optimization_run_property_per_multistart( [result_1, result_2], "time", colors=["g", "C1"], legends=["result1", "result2"], plot_type="both", + axes=custom_axes, ) + assert returned_axes is custom_axes @close_fig @@ -1095,31 +1231,56 @@ def test_sampling_fval_traces(): def test_sampling_parameter_traces(): """Test pypesto.visualize.sampling_parameter_traces""" result = create_sampling_result() - visualize.sampling_parameter_traces(result) + axes = visualize.sampling_parameter_traces(result) + assert axes.ndim == 2 # call with custom arguments - visualize.sampling_parameter_traces( + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.sampling_parameter_traces( result, i_chain=1, stepsize=5, size=(10, 10), use_problem_bounds=False ) + assert returned_axes.ndim == 2 + returned_axes = visualize.sampling_parameter_traces( + result, + i_chain=1, + stepsize=5, + use_problem_bounds=False, + axes=custom_axes, + ) + assert returned_axes is custom_axes @close_fig def test_sampling_scatter(): """Test pypesto.visualize.sampling_scatter""" result = create_sampling_result() - visualize.sampling_scatter(result) + axes = visualize.sampling_scatter(result) + assert axes.ndim == 2 # call with custom arguments - visualize.sampling_scatter(result, i_chain=1, stepsize=5, size=(10, 10)) + returned_axes = visualize.sampling_scatter( + result, i_chain=1, stepsize=5, size=(10, 10), diag_kind="hist" + ) + assert returned_axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.sampling_scatter(result, axes=custom_axes) + assert returned_axes is custom_axes @close_fig def test_sampling_1d_marginals(): """Test pypesto.visualize.sampling_1d_marginals""" result = create_sampling_result() - visualize.sampling_1d_marginals(result) + axes = visualize.sampling_1d_marginals(result) + assert axes.ndim == 2 # call with custom arguments - visualize.sampling_1d_marginals( + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.sampling_1d_marginals( result, i_chain=1, stepsize=5, size=(10, 10) ) + assert returned_axes.ndim == 2 + returned_axes = visualize.sampling_1d_marginals( + result, i_chain=1, stepsize=5, axes=custom_axes + ) + assert returned_axes is custom_axes # call with other modes visualize.sampling_1d_marginals(result, plot_type="hist") visualize.sampling_1d_marginals( @@ -1132,9 +1293,9 @@ def test_sampling_parameter_cis(): """Test pypesto.visualize.sampling_parameter_cis""" result = create_sampling_result() visualize.sampling_parameter_cis(result) - # call with custom arguments + # call with canonical kwarg visualize.sampling_parameter_cis( - result, alpha=[99, 68], step=0.1, size=(10, 10) + result, confidence_levels=[0.99, 0.68], step=0.1, size=(10, 10) ) @@ -1169,17 +1330,46 @@ def test_sampling_prediction_trajectories(): ) # Plot by - visualize.sampling_prediction_trajectories( + axes = visualize.sampling_prediction_trajectories( + ensemble_prediction, + levels=credibility_interval_levels, + groupby=pypesto.C.CONDITION, + ) + assert axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.sampling_prediction_trajectories( ensemble_prediction, levels=credibility_interval_levels, groupby=pypesto.C.CONDITION, + axes=custom_axes, ) - visualize.sampling_prediction_trajectories( + assert returned_axes is custom_axes + returned_axes = visualize.sampling_prediction_trajectories( ensemble_prediction, levels=credibility_interval_levels, size=(10, 10), groupby=pypesto.C.OUTPUT, ) + assert returned_axes.ndim == 2 + + +@close_fig +def test_get_simulation_rdatas_default_timepoints(): + """Regression test for default AMICI timepoint access in model-fit plots.""" + problem = create_petab_problem() + + result = optimize.minimize( + problem=problem, + n_starts=1, + optimizer=optimize.ScipyOptimizer( + method="L-BFGS-B", options={"maxiter": 1} + ), + progress_bar=False, + ) + + rdatas = _get_simulation_rdatas(result=result, problem=problem) + + assert len(rdatas) == len(problem.objective.edatas) @close_fig @@ -1291,7 +1481,48 @@ def test_time_trajectory_model(): ) # test call of time_trajectory_model - time_trajectory_model(result=result) + axes = time_trajectory_model(result=result) + assert axes is not None + + # test call of time_trajectory_model for hierarchical problems + hierarchical_petab_problem = ( + get_Boehm_JProteomeRes2014_hierarchical_petab_corrected_bounds() + ) + importer = pypesto.petab.PetabImporter( + hierarchical_petab_problem, hierarchical=True + ) + problem = importer.create_problem() + + # Set nominal values as start point, mapped by name to avoid ordering + # assumptions about where inner parameters sit in x_nominal_scaled + x_nominal_by_id = dict( + zip( + hierarchical_petab_problem.x_ids, + hierarchical_petab_problem.x_nominal_scaled, + strict=True, + ) + ) + x_guess = np.array([x_nominal_by_id[xid] for xid in problem.x_names]) + problem.set_x_guesses([x_guess]) + + result = optimize.minimize( + problem=problem, + n_starts=1, + optimizer=optimize.ScipyOptimizer( + method="L-BFGS-B", options={"maxiter": 1} + ), + progress_bar=False, + ) + + first_state_name = problem.objective.amici_model.get_state_names()[0] + + axes = time_trajectory_model( + result=result, + problem=problem, + state_names=[first_state_name], + ) + + assert axes is not None def test_monotonic_history(): @@ -1427,7 +1658,16 @@ def test_visualize_estimated_observable_mapping(): result = pypesto.optimize.minimize( problem=problem, n_starts=1, optimizer=optimizer ) - visualize.visualize_estimated_observable_mapping(result, problem) + axes = visualize.visualize_estimated_observable_mapping(result, problem) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.visualize_estimated_observable_mapping( + result, + problem, + axes=custom_axes, + ) + assert returned_axes is custom_axes @close_fig @@ -1493,6 +1733,14 @@ def test_projection_scatter_pca_parameters(): # Test visualization with specific components visualize.projection_scatter_pca(pca_repr, components=(0, 1)) + dummy_pca = np.random.randn(20, 4) + axes = visualize.projection_scatter_pca(dummy_pca, components=range(4)) + assert axes.ndim == 2 + custom_axes = plt.subplots(*axes.shape, squeeze=False)[1] + returned_axes = visualize.projection_scatter_pca( + dummy_pca, components=range(4), axes=custom_axes + ) + assert returned_axes is custom_axes @close_fig diff --git a/tox.ini b/tox.ini index 8942f8cf4..2a4ff983d 100644 --- a/tox.ini +++ b/tox.ini @@ -106,8 +106,7 @@ description = [testenv:julia] extras = test,julia commands = - python -c "import julia; julia.install()" - python-jl -m pytest --cov=pypesto --cov-report=xml --cov-append \ + pytest --cov=pypesto --cov-report=xml --cov-append \ test/julia description = Test Julia interface