diff --git a/examples/scripts/battery_parameterisation/bayesian_feature_fitting.py b/examples/scripts/battery_parameterisation/bayesian_feature_fitting.py index c84c325cb..8c68af73a 100644 --- a/examples/scripts/battery_parameterisation/bayesian_feature_fitting.py +++ b/examples/scripts/battery_parameterisation/bayesian_feature_fitting.py @@ -120,6 +120,9 @@ result.plot_convergence(yaxis={"type": "log"}) result.plot_parameters(yaxis={"type": "log"}, yaxis2={"type": "log"}) + # Plot the prior and posterior distributions + pybop.plot.distribution(result.problem.parameters, result.posterior) + # Plot predictions for a set of inputs sampled from the posterior fig = result.plot_predictive(show=False) fig[0].show() diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 06e5a0245..8c669df92 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -12,3 +12,4 @@ from .voronoi import surface from .samples import trace, chains, posterior, summary_table from .predictive import predictive +from .distribution import distribution diff --git a/pybop/plot/distribution.py b/pybop/plot/distribution.py new file mode 100644 index 000000000..e0e0430dc --- /dev/null +++ b/pybop/plot/distribution.py @@ -0,0 +1,77 @@ +import numpy as np + +from pybop.parameters.parameter import Parameters +from pybop.plot.standard_plots import StandardSubplot + + +def distribution( + parameters: Parameters, + posterior: Parameters | None = None, + n_samples: int = 100, + transformed: bool = False, + show: bool = True, + **layout_kwargs, +): + """ + Plot the posterior on top of the prior distribution for a Bayesian optimisation result. + """ + # Create lists of axis titles and trace names + axis_titles = [] + trace_names = ( + parameters.names + if posterior is None + else ["Prior"] * len(parameters) + ["Posterior"] * len(parameters) + ) + for name in parameters.names: + axis_titles.append( + (name + " (transformed)" if transformed else name, "Probability density") + ) + + # Evaluate marginal distributions for each parameter + values = [] + probability = [] + for p in parameters: + d = p.transformed_distribution if transformed else p.distribution + samples = d.rvs(size=n_samples) + parameter_range = np.linspace(min(samples), max(samples), n_samples) + values.append(parameter_range) + probability.append([d.pdf(s) for s in values[-1]]) + + # Set subplot layout options + layout_options = dict( + width=1024, + height=576, + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), + ) + + # Create a plot dictionary + plot_dict = StandardSubplot( + x=values, + y=probability, + axis_titles=axis_titles, + layout_options=layout_options, + trace_names=trace_names, + trace_name_width=50, + ) + fig = plot_dict(show=False) + + if posterior is not None: + for idx, p in enumerate(posterior): + d = p.transformed_distribution if transformed else p.distribution + samples = d.rvs(size=n_samples) + parameter_range = np.linspace(min(samples), max(samples), n_samples) + values.append(parameter_range) + probability.append([d.pdf(s) for s in values[-1]]) + + trace = plot_dict.create_trace( + values[-1], probability[-1], **plot_dict.trace_options + ) + row = (idx // plot_dict.num_cols) + 1 + col = (idx % plot_dict.num_cols) + 1 + fig.add_trace(trace, row=row, col=col) + + fig.update_layout(**layout_kwargs) + if show: + fig.show() + + return fig diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index 4422516b8..962f9a5f8 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -200,12 +200,7 @@ def create_trace(self, x, y, **trace_options): plotly.graph_objs.Scatter A trace for a Plotly figure. """ - - return self.go.Scatter( - x=x, - y=y, - **trace_options, - ) + return self.go.Scatter(x=x, y=y, **trace_options) @staticmethod def wrap_text(text, width): diff --git a/tests/unit/test_plots.py b/tests/unit/test_plots.py index 607585540..6f9c19d1d 100644 --- a/tests/unit/test_plots.py +++ b/tests/unit/test_plots.py @@ -203,6 +203,11 @@ def test_posterior_plots(self, sampling_result): # Plot posterior predictions sampling_result.plot_predictive() + # Plot the prior and posterior distributions + pybop.plot.distribution( + sampling_result.problem.parameters, sampling_result.posterior + ) + def test_with_ipykernel(self, dataset, fitting_problem, result): import ipykernel