From b24c64c62cf1634e49cc35d2762c873e61817a97 Mon Sep 17 00:00:00 2001 From: u2370093 Date: Thu, 2 Apr 2026 12:44:32 +0100 Subject: [PATCH 01/10] add plotting with matplotlib and allow user to choose backend --- .../echem_identification_pitfalls.ipynb | 5 +- .../ecm_monte_carlo_sampling.ipynb | 3 +- .../ecm_multipulse_identification.ipynb | 5 +- .../ecm_scipy_constraints.ipynb | 3 +- .../electrode_balancing.ipynb | 3 +- .../lgm50_pulse_validation.ipynb | 5 +- .../pouch_cell_identification.ipynb | 5 +- .../sensitivity_analysis_hessian.ipynb | 1471 +++++++++-------- .../sensitivity_analysis_salib.ipynb | 3 +- .../comparing_cost_functions.ipynb | 4 +- .../optimiser_calibration.ipynb | 3 +- .../energy_based_electrode_design.ipynb | 3 +- .../cost_compute_methods.ipynb | 2 +- .../maximum_a_posteriori.ipynb | 3 +- .../optimising_with_adamw.ipynb | 3 +- .../setting_optimiser_options.ipynb | 2 +- .../using_transformations.ipynb | 1 + .../comparison_examples/grouped_SPMe.py | 17 +- pybop/plot/__init__.py | 37 +- pybop/plot/matplotlib/__init__.py | 9 + pybop/plot/matplotlib/contour.py | 237 +++ pybop/plot/matplotlib/convergence.py | 50 + pybop/plot/matplotlib/dataset.py | 54 + pybop/plot/matplotlib/nyquist.py | 86 + pybop/plot/matplotlib/parameters.py | 71 + pybop/plot/matplotlib/problem.py | 114 ++ pybop/plot/matplotlib/samples.py | 138 ++ pybop/plot/matplotlib/standard_plots.py | 386 +++++ pybop/plot/matplotlib/voronoi.py | 142 ++ pybop/plot/plotly/__init__.py | 10 + pybop/plot/{ => plotly}/contour.py | 2 +- pybop/plot/{ => plotly}/convergence.py | 2 +- pybop/plot/{ => plotly}/dataset.py | 2 +- pybop/plot/{ => plotly}/nyquist.py | 2 +- pybop/plot/{ => plotly}/parameters.py | 2 +- pybop/plot/{ => plotly}/plotly_manager.py | 0 pybop/plot/{ => plotly}/problem.py | 2 +- pybop/plot/{ => plotly}/samples.py | 2 +- pybop/plot/{ => plotly}/standard_plots.py | 12 +- pybop/plot/plotly/voronoi.py | 216 +++ pybop/plot/plots.py | 287 ++++ pybop/plot/util.py | 50 + pybop/plot/voronoi.py | 208 +-- tests/plotting/test_plotly_manager.py | 2 +- tests/unit/test_plots.py | 29 +- 45 files changed, 2701 insertions(+), 992 deletions(-) create mode 100644 pybop/plot/matplotlib/__init__.py create mode 100644 pybop/plot/matplotlib/contour.py create mode 100644 pybop/plot/matplotlib/convergence.py create mode 100644 pybop/plot/matplotlib/dataset.py create mode 100644 pybop/plot/matplotlib/nyquist.py create mode 100644 pybop/plot/matplotlib/parameters.py create mode 100644 pybop/plot/matplotlib/problem.py create mode 100644 pybop/plot/matplotlib/samples.py create mode 100644 pybop/plot/matplotlib/standard_plots.py create mode 100644 pybop/plot/matplotlib/voronoi.py create mode 100644 pybop/plot/plotly/__init__.py rename pybop/plot/{ => plotly}/contour.py (99%) rename pybop/plot/{ => plotly}/convergence.py (96%) rename pybop/plot/{ => plotly}/dataset.py (95%) rename pybop/plot/{ => plotly}/nyquist.py (98%) rename pybop/plot/{ => plotly}/parameters.py (97%) rename pybop/plot/{ => plotly}/plotly_manager.py (100%) rename pybop/plot/{ => plotly}/problem.py (98%) rename pybop/plot/{ => plotly}/samples.py (98%) rename pybop/plot/{ => plotly}/standard_plots.py (96%) create mode 100644 pybop/plot/plotly/voronoi.py create mode 100644 pybop/plot/plots.py create mode 100644 pybop/plot/util.py diff --git a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb index 1eb2a59c5..069d68660 100644 --- a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb +++ b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb @@ -32,8 +32,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb index dce5640a7..addca5b86 100644 --- a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb @@ -46,7 +46,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb index 51c852358..e17669e50 100644 --- a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb @@ -48,7 +48,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] @@ -505,7 +506,7 @@ " [3600 3]], duration=3600)]" ] }, - "execution_count": null, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb index ce327fe6c..70915b6e9 100644 --- a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb @@ -31,7 +31,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb index adc764f25..94b981c61 100644 --- a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb +++ b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb @@ -29,7 +29,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb index ae8a0d3ba..3d63a9ef3 100644 --- a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb +++ b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb @@ -32,8 +32,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb index ce32e337b..1a73079e4 100644 --- a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb +++ b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb @@ -30,8 +30,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb index 3cab87a0d..d5a7ecd4f 100644 --- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb +++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb @@ -29,6 +29,7 @@ "\n", "import pybop\n", "\n", + "pybop.plot.set_backend(\"matplotlib\")\n", "np.random.seed(8) # users can remove this line" ] }, @@ -210,708 +211,708 @@ "showlegend": false, "type": "scatter", "x": [ - 0.0, - 10.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - 70.0, - 80.0, - 90.0, - 100.0, - 110.0, - 120.0, - 130.0, - 140.0, - 150.0, - 160.0, - 170.0, - 180.0, - 190.0, - 200.0, - 210.0, - 220.0, - 230.0, - 240.0, - 250.0, - 260.0, - 270.0, - 280.0, - 290.0, - 300.0, - 310.0, - 320.0, - 330.0, - 340.0, - 350.0, - 360.0, - 370.0, - 380.0, - 390.0, - 400.0, - 410.0, - 420.0, - 430.0, - 440.0, - 450.0, - 460.0, - 470.0, - 480.0, - 490.0, - 500.0, - 510.0, - 520.0, - 530.0, - 540.0, - 550.0, - 560.0, - 570.0, - 580.0, - 590.0, - 600.0, - 610.0, - 620.0, - 630.0, - 640.0, - 650.0, - 660.0, - 670.0, - 680.0, - 690.0, - 700.0, - 710.0, - 720.0, - 730.0, - 740.0, - 750.0, - 760.0, - 770.0, - 780.0, - 790.0, - 800.0, - 810.0, - 820.0, - 830.0, - 840.0, - 850.0, - 860.0, - 870.0, - 880.0, - 890.0, - 900.0, - 910.0, - 920.0, - 930.0, - 940.0, - 950.0, - 960.0, - 970.0, - 980.0, - 990.0, - 1000.0, - 1010.0, - 1020.0, - 1030.0, - 1040.0, - 1050.0, - 1060.0, - 1070.0, - 1080.0, - 1090.0, - 1100.0, - 1110.0, - 1120.0, - 1130.0, - 1140.0, - 1150.0, - 1160.0, - 1170.0, - 1180.0, - 1190.0, - 1200.0, - 1210.0, - 1220.0, - 1230.0, - 1240.0, - 1250.0, - 1260.0, - 1270.0, - 1280.0, - 1290.0, - 1300.0, - 1310.0, - 1320.0, - 1330.0, - 1340.0, - 1350.0, - 1360.0, - 1370.0, - 1380.0, - 1390.0, - 1400.0, - 1410.0, - 1420.0, - 1430.0, - 1440.0, - 1450.0, - 1460.0, - 1470.0, - 1480.0, - 1490.0, - 1500.0, - 1510.0, - 1520.0, - 1530.0, - 1540.0, - 1550.0, - 1560.0, - 1570.0, - 1580.0, - 1590.0, - 1600.0, - 1610.0, - 1620.0, - 1630.0, - 1640.0, - 1650.0, - 1660.0, - 1670.0, - 1680.0, - 1690.0, - 1700.0, - 1710.0, - 1720.0, - 1730.0, - 1740.0, - 1750.0, - 1760.0, - 1770.0, - 1780.0, - 1790.0, - 1800.0, - 1810.0, - 1820.0, - 1830.0, - 1840.0, - 1850.0, - 1860.0, - 1870.0, - 1880.0, - 1890.0, - 1900.0, - 1910.0, - 1920.0, - 1930.0, - 1940.0, - 1950.0, - 1960.0, - 1970.0, - 1980.0, - 1990.0, - 2000.0, - 2010.0, - 2020.0, - 2030.0, - 2040.0, - 2050.0, - 2060.0, - 2070.0, - 2080.0, - 2090.0, - 2100.0, - 2110.0, - 2120.0, - 2130.0, - 2140.0, - 2150.0, - 2160.0, - 2170.0, - 2180.0, - 2190.0, - 2200.0, - 2210.0, - 2220.0, - 2230.0, - 2240.0, - 2250.0, - 2260.0, - 2270.0, - 2280.0, - 2290.0, - 2300.0, - 2310.0, - 2320.0, - 2330.0, - 2340.0, - 2350.0, - 2360.0, - 2370.0, - 2380.0, - 2390.0, - 2400.0, - 2410.0, - 2420.0, - 2430.0, - 2440.0, - 2450.0, - 2460.0, - 2470.0, - 2480.0, - 2490.0, - 2500.0, - 2510.0, - 2520.0, - 2530.0, - 2540.0, - 2550.0, - 2560.0, - 2570.0, - 2580.0, - 2590.0, - 2600.0, - 2610.0, - 2620.0, - 2630.0, - 2640.0, - 2650.0, - 2660.0, - 2670.0, - 2680.0, - 2690.0, - 2700.0, - 2710.0, - 2720.0, - 2730.0, - 2740.0, - 2750.0, - 2760.0, - 2770.0, - 2780.0, - 2790.0, - 2800.0, - 2810.0, - 2820.0, - 2830.0, - 2840.0, - 2850.0, - 2860.0, - 2870.0, - 2880.0, - 2890.0, - 2900.0, - 2910.0, - 2920.0, - 2930.0, - 2940.0, - 2950.0, - 2960.0, - 2970.0, - 2980.0, - 2990.0, - 3000.0, - 3010.0, - 3020.0, - 3030.0, - 3040.0, - 3050.0, - 3060.0, - 3070.0, - 3080.0, - 3090.0, - 3100.0, - 3110.0, - 3120.0, - 3130.0, - 3140.0, - 3150.0, - 3160.0, - 3170.0, - 3180.0, - 3190.0, - 3200.0, - 3210.0, - 3220.0, - 3230.0, - 3240.0, - 3250.0, - 3260.0, - 3270.0, - 3280.0, - 3290.0, - 3300.0, - 3310.0, - 3320.0, - 3330.0, - 3340.0, - 3350.0, - 3360.0, - 3370.0, - 3380.0, - 3390.0, - 3400.0, - 3410.0, - 3420.0, - 3430.0, - 3440.0, - 3450.0, - 3460.0, - 3470.0, - 3480.0, - 3490.0, - 3500.0, - 3500.0, - 3490.0, - 3480.0, - 3470.0, - 3460.0, - 3450.0, - 3440.0, - 3430.0, - 3420.0, - 3410.0, - 3400.0, - 3390.0, - 3380.0, - 3370.0, - 3360.0, - 3350.0, - 3340.0, - 3330.0, - 3320.0, - 3310.0, - 3300.0, - 3290.0, - 3280.0, - 3270.0, - 3260.0, - 3250.0, - 3240.0, - 3230.0, - 3220.0, - 3210.0, - 3200.0, - 3190.0, - 3180.0, - 3170.0, - 3160.0, - 3150.0, - 3140.0, - 3130.0, - 3120.0, - 3110.0, - 3100.0, - 3090.0, - 3080.0, - 3070.0, - 3060.0, - 3050.0, - 3040.0, - 3030.0, - 3020.0, - 3010.0, - 3000.0, - 2990.0, - 2980.0, - 2970.0, - 2960.0, - 2950.0, - 2940.0, - 2930.0, - 2920.0, - 2910.0, - 2900.0, - 2890.0, - 2880.0, - 2870.0, - 2860.0, - 2850.0, - 2840.0, - 2830.0, - 2820.0, - 2810.0, - 2800.0, - 2790.0, - 2780.0, - 2770.0, - 2760.0, - 2750.0, - 2740.0, - 2730.0, - 2720.0, - 2710.0, - 2700.0, - 2690.0, - 2680.0, - 2670.0, - 2660.0, - 2650.0, - 2640.0, - 2630.0, - 2620.0, - 2610.0, - 2600.0, - 2590.0, - 2580.0, - 2570.0, - 2560.0, - 2550.0, - 2540.0, - 2530.0, - 2520.0, - 2510.0, - 2500.0, - 2490.0, - 2480.0, - 2470.0, - 2460.0, - 2450.0, - 2440.0, - 2430.0, - 2420.0, - 2410.0, - 2400.0, - 2390.0, - 2380.0, - 2370.0, - 2360.0, - 2350.0, - 2340.0, - 2330.0, - 2320.0, - 2310.0, - 2300.0, - 2290.0, - 2280.0, - 2270.0, - 2260.0, - 2250.0, - 2240.0, - 2230.0, - 2220.0, - 2210.0, - 2200.0, - 2190.0, - 2180.0, - 2170.0, - 2160.0, - 2150.0, - 2140.0, - 2130.0, - 2120.0, - 2110.0, - 2100.0, - 2090.0, - 2080.0, - 2070.0, - 2060.0, - 2050.0, - 2040.0, - 2030.0, - 2020.0, - 2010.0, - 2000.0, - 1990.0, - 1980.0, - 1970.0, - 1960.0, - 1950.0, - 1940.0, - 1930.0, - 1920.0, - 1910.0, - 1900.0, - 1890.0, - 1880.0, - 1870.0, - 1860.0, - 1850.0, - 1840.0, - 1830.0, - 1820.0, - 1810.0, - 1800.0, - 1790.0, - 1780.0, - 1770.0, - 1760.0, - 1750.0, - 1740.0, - 1730.0, - 1720.0, - 1710.0, - 1700.0, - 1690.0, - 1680.0, - 1670.0, - 1660.0, - 1650.0, - 1640.0, - 1630.0, - 1620.0, - 1610.0, - 1600.0, - 1590.0, - 1580.0, - 1570.0, - 1560.0, - 1550.0, - 1540.0, - 1530.0, - 1520.0, - 1510.0, - 1500.0, - 1490.0, - 1480.0, - 1470.0, - 1460.0, - 1450.0, - 1440.0, - 1430.0, - 1420.0, - 1410.0, - 1400.0, - 1390.0, - 1380.0, - 1370.0, - 1360.0, - 1350.0, - 1340.0, - 1330.0, - 1320.0, - 1310.0, - 1300.0, - 1290.0, - 1280.0, - 1270.0, - 1260.0, - 1250.0, - 1240.0, - 1230.0, - 1220.0, - 1210.0, - 1200.0, - 1190.0, - 1180.0, - 1170.0, - 1160.0, - 1150.0, - 1140.0, - 1130.0, - 1120.0, - 1110.0, - 1100.0, - 1090.0, - 1080.0, - 1070.0, - 1060.0, - 1050.0, - 1040.0, - 1030.0, - 1020.0, - 1010.0, - 1000.0, - 990.0, - 980.0, - 970.0, - 960.0, - 950.0, - 940.0, - 930.0, - 920.0, - 910.0, - 900.0, - 890.0, - 880.0, - 870.0, - 860.0, - 850.0, - 840.0, - 830.0, - 820.0, - 810.0, - 800.0, - 790.0, - 780.0, - 770.0, - 760.0, - 750.0, - 740.0, - 730.0, - 720.0, - 710.0, - 700.0, - 690.0, - 680.0, - 670.0, - 660.0, - 650.0, - 640.0, - 630.0, - 620.0, - 610.0, - 600.0, - 590.0, - 580.0, - 570.0, - 560.0, - 550.0, - 540.0, - 530.0, - 520.0, - 510.0, - 500.0, - 490.0, - 480.0, - 470.0, - 460.0, - 450.0, - 440.0, - 430.0, - 420.0, - 410.0, - 400.0, - 390.0, - 380.0, - 370.0, - 360.0, - 350.0, - 340.0, - 330.0, - 320.0, - 310.0, - 300.0, - 290.0, - 280.0, - 270.0, - 260.0, - 250.0, - 240.0, - 230.0, - 220.0, - 210.0, - 200.0, - 190.0, - 180.0, - 170.0, - 160.0, - 150.0, - 140.0, - 130.0, - 120.0, - 110.0, - 100.0, - 90.0, - 80.0, - 70.0, - 60.0, - 50.0, - 40.0, - 30.0, - 20.0, - 10.0, - 0.0 + 0, + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500, + 510, + 520, + 530, + 540, + 550, + 560, + 570, + 580, + 590, + 600, + 610, + 620, + 630, + 640, + 650, + 660, + 670, + 680, + 690, + 700, + 710, + 720, + 730, + 740, + 750, + 760, + 770, + 780, + 790, + 800, + 810, + 820, + 830, + 840, + 850, + 860, + 870, + 880, + 890, + 900, + 910, + 920, + 930, + 940, + 950, + 960, + 970, + 980, + 990, + 1000, + 1010, + 1020, + 1030, + 1040, + 1050, + 1060, + 1070, + 1080, + 1090, + 1100, + 1110, + 1120, + 1130, + 1140, + 1150, + 1160, + 1170, + 1180, + 1190, + 1200, + 1210, + 1220, + 1230, + 1240, + 1250, + 1260, + 1270, + 1280, + 1290, + 1300, + 1310, + 1320, + 1330, + 1340, + 1350, + 1360, + 1370, + 1380, + 1390, + 1400, + 1410, + 1420, + 1430, + 1440, + 1450, + 1460, + 1470, + 1480, + 1490, + 1500, + 1510, + 1520, + 1530, + 1540, + 1550, + 1560, + 1570, + 1580, + 1590, + 1600, + 1610, + 1620, + 1630, + 1640, + 1650, + 1660, + 1670, + 1680, + 1690, + 1700, + 1710, + 1720, + 1730, + 1740, + 1750, + 1760, + 1770, + 1780, + 1790, + 1800, + 1810, + 1820, + 1830, + 1840, + 1850, + 1860, + 1870, + 1880, + 1890, + 1900, + 1910, + 1920, + 1930, + 1940, + 1950, + 1960, + 1970, + 1980, + 1990, + 2000, + 2010, + 2020, + 2030, + 2040, + 2050, + 2060, + 2070, + 2080, + 2090, + 2100, + 2110, + 2120, + 2130, + 2140, + 2150, + 2160, + 2170, + 2180, + 2190, + 2200, + 2210, + 2220, + 2230, + 2240, + 2250, + 2260, + 2270, + 2280, + 2290, + 2300, + 2310, + 2320, + 2330, + 2340, + 2350, + 2360, + 2370, + 2380, + 2390, + 2400, + 2410, + 2420, + 2430, + 2440, + 2450, + 2460, + 2470, + 2480, + 2490, + 2500, + 2510, + 2520, + 2530, + 2540, + 2550, + 2560, + 2570, + 2580, + 2590, + 2600, + 2610, + 2620, + 2630, + 2640, + 2650, + 2660, + 2670, + 2680, + 2690, + 2700, + 2710, + 2720, + 2730, + 2740, + 2750, + 2760, + 2770, + 2780, + 2790, + 2800, + 2810, + 2820, + 2830, + 2840, + 2850, + 2860, + 2870, + 2880, + 2890, + 2900, + 2910, + 2920, + 2930, + 2940, + 2950, + 2960, + 2970, + 2980, + 2990, + 3000, + 3010, + 3020, + 3030, + 3040, + 3050, + 3060, + 3070, + 3080, + 3090, + 3100, + 3110, + 3120, + 3130, + 3140, + 3150, + 3160, + 3170, + 3180, + 3190, + 3200, + 3210, + 3220, + 3230, + 3240, + 3250, + 3260, + 3270, + 3280, + 3290, + 3300, + 3310, + 3320, + 3330, + 3340, + 3350, + 3360, + 3370, + 3380, + 3390, + 3400, + 3410, + 3420, + 3430, + 3440, + 3450, + 3460, + 3470, + 3480, + 3490, + 3500, + 3500, + 3490, + 3480, + 3470, + 3460, + 3450, + 3440, + 3430, + 3420, + 3410, + 3400, + 3390, + 3380, + 3370, + 3360, + 3350, + 3340, + 3330, + 3320, + 3310, + 3300, + 3290, + 3280, + 3270, + 3260, + 3250, + 3240, + 3230, + 3220, + 3210, + 3200, + 3190, + 3180, + 3170, + 3160, + 3150, + 3140, + 3130, + 3120, + 3110, + 3100, + 3090, + 3080, + 3070, + 3060, + 3050, + 3040, + 3030, + 3020, + 3010, + 3000, + 2990, + 2980, + 2970, + 2960, + 2950, + 2940, + 2930, + 2920, + 2910, + 2900, + 2890, + 2880, + 2870, + 2860, + 2850, + 2840, + 2830, + 2820, + 2810, + 2800, + 2790, + 2780, + 2770, + 2760, + 2750, + 2740, + 2730, + 2720, + 2710, + 2700, + 2690, + 2680, + 2670, + 2660, + 2650, + 2640, + 2630, + 2620, + 2610, + 2600, + 2590, + 2580, + 2570, + 2560, + 2550, + 2540, + 2530, + 2520, + 2510, + 2500, + 2490, + 2480, + 2470, + 2460, + 2450, + 2440, + 2430, + 2420, + 2410, + 2400, + 2390, + 2380, + 2370, + 2360, + 2350, + 2340, + 2330, + 2320, + 2310, + 2300, + 2290, + 2280, + 2270, + 2260, + 2250, + 2240, + 2230, + 2220, + 2210, + 2200, + 2190, + 2180, + 2170, + 2160, + 2150, + 2140, + 2130, + 2120, + 2110, + 2100, + 2090, + 2080, + 2070, + 2060, + 2050, + 2040, + 2030, + 2020, + 2010, + 2000, + 1990, + 1980, + 1970, + 1960, + 1950, + 1940, + 1930, + 1920, + 1910, + 1900, + 1890, + 1880, + 1870, + 1860, + 1850, + 1840, + 1830, + 1820, + 1810, + 1800, + 1790, + 1780, + 1770, + 1760, + 1750, + 1740, + 1730, + 1720, + 1710, + 1700, + 1690, + 1680, + 1670, + 1660, + 1650, + 1640, + 1630, + 1620, + 1610, + 1600, + 1590, + 1580, + 1570, + 1560, + 1550, + 1540, + 1530, + 1520, + 1510, + 1500, + 1490, + 1480, + 1470, + 1460, + 1450, + 1440, + 1430, + 1420, + 1410, + 1400, + 1390, + 1380, + 1370, + 1360, + 1350, + 1340, + 1330, + 1320, + 1310, + 1300, + 1290, + 1280, + 1270, + 1260, + 1250, + 1240, + 1230, + 1220, + 1210, + 1200, + 1190, + 1180, + 1170, + 1160, + 1150, + 1140, + 1130, + 1120, + 1110, + 1100, + 1090, + 1080, + 1070, + 1060, + 1050, + 1040, + 1030, + 1020, + 1010, + 1000, + 990, + 980, + 970, + 960, + 950, + 940, + 930, + 920, + 910, + 900, + 890, + 880, + 870, + 860, + 850, + 840, + 830, + 820, + 810, + 800, + 790, + 780, + 770, + 760, + 750, + 740, + 730, + 720, + 710, + 700, + 690, + 680, + 670, + 660, + 650, + 640, + 630, + 620, + 610, + 600, + 590, + 580, + 570, + 560, + 550, + 540, + 530, + 520, + 510, + 500, + 490, + 480, + 470, + 460, + 450, + 440, + 430, + 420, + 410, + 400, + 390, + 380, + 370, + 360, + 350, + 340, + 330, + 320, + 310, + 300, + 290, + 280, + 270, + 260, + 250, + 240, + 230, + 220, + 210, + 200, + 190, + 180, + 170, + 160, + 150, + 140, + 130, + 120, + 110, + 100, + 90, + 80, + 70, + 60, + 50, + 40, + 30, + 20, + 10, + 0 ], "y": [ 4.0531258862154065, @@ -1744,7 +1745,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1780,7 +1781,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1804,7 +1805,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1840,7 +1841,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1867,7 +1868,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1903,7 +1904,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1918,7 +1919,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1954,7 +1955,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -2110,7 +2111,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2146,7 +2147,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -2237,7 +2238,7 @@ ], "sequential": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2273,13 +2274,13 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], "sequentialminus": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2315,7 +2316,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ] @@ -2503,7 +2504,7 @@ { "colorscale": [ [ - 0.0, + 0, "#440154" ], [ @@ -2539,7 +2540,7 @@ "#b5de2b" ], [ - 1.0, + 1, "#fde725" ] ], @@ -15963,7 +15964,7 @@ "hoverinfo": "text", "marker": { "color": [ - 0.0, + 0, 0.001968503937007874, 0.003937007874015748, 0.005905511811023622, @@ -16474,7 +16475,7 @@ ], "colorscale": [ [ - 0.0, + 0, "rgb(255,255,255)" ], [ @@ -16506,7 +16507,7 @@ "rgb(37,37,37)" ], [ - 1.0, + 1, "rgb(0,0,0)" ] ], @@ -18165,7 +18166,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18201,7 +18202,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18225,7 +18226,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18261,7 +18262,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18288,7 +18289,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18324,7 +18325,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18339,7 +18340,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18375,7 +18376,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18531,7 +18532,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18567,7 +18568,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18658,7 +18659,7 @@ ], "sequential": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18694,13 +18695,13 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], "sequentialminus": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18736,7 +18737,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ] diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb index e7a1a308c..4324ce855 100644 --- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb +++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb @@ -58,7 +58,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"matplotlib\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb index 71dc80155..b6a4251bd 100644 --- a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb +++ b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb @@ -26,8 +26,8 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb index 9ee9e69d5..c23fdf369 100644 --- a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb +++ b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb @@ -32,7 +32,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb index 6250b85e9..e4541c8d8 100644 --- a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb +++ b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb @@ -35,7 +35,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/cost_compute_methods.ipynb b/examples/notebooks/getting_started/cost_compute_methods.ipynb index 3a424f38c..909bcddd5 100644 --- a/examples/notebooks/getting_started/cost_compute_methods.ipynb +++ b/examples/notebooks/getting_started/cost_compute_methods.ipynb @@ -28,7 +28,7 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb index ccd95b2ea..7da706f7f 100644 --- a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb +++ b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb @@ -51,7 +51,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/optimising_with_adamw.ipynb b/examples/notebooks/getting_started/optimising_with_adamw.ipynb index d6fc110af..ab708410e 100644 --- a/examples/notebooks/getting_started/optimising_with_adamw.ipynb +++ b/examples/notebooks/getting_started/optimising_with_adamw.ipynb @@ -34,7 +34,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/setting_optimiser_options.ipynb b/examples/notebooks/getting_started/setting_optimiser_options.ipynb index a790b26b2..65760ef97 100644 --- a/examples/notebooks/getting_started/setting_optimiser_options.ipynb +++ b/examples/notebooks/getting_started/setting_optimiser_options.ipynb @@ -32,7 +32,7 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/using_transformations.ipynb b/examples/notebooks/getting_started/using_transformations.ipynb index 236dc8bba..2de75b771 100644 --- a/examples/notebooks/getting_started/using_transformations.ipynb +++ b/examples/notebooks/getting_started/using_transformations.ipynb @@ -29,6 +29,7 @@ "\n", "import pybop\n", "\n", + "pybop.plot.set_backend(\"plotly\")\n", "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" diff --git a/examples/scripts/comparison_examples/grouped_SPMe.py b/examples/scripts/comparison_examples/grouped_SPMe.py index 012104f32..0a21a1a42 100644 --- a/examples/scripts/comparison_examples/grouped_SPMe.py +++ b/examples/scripts/comparison_examples/grouped_SPMe.py @@ -11,11 +11,10 @@ """ # Prepare figure -layout_options = dict( - xaxis_title="Time / s", - yaxis_title="Voltage / V", -) -plot_dict = pybop.plot.StandardPlot(layout_options=layout_options) +pybop.plot.set_backend('matplotlib') +plot_dict = pybop.plot.StandardPlot() +plt.xlabel("Time / s") +plt.ylabel("Voltage / V") # Use the Chen2020 parameters parameter_values = pybamm.ParameterValues("Chen2020") @@ -46,18 +45,18 @@ ) SPMe_model = pybamm.lithium_ion.SPMe(options=model_options) grouped_SPMe_model = pybop.lithium_ion.GroupedSPMe(options=model_options) -for model, param, line_style in zip( +for model, param, linestyle in zip( [SPMe_model, grouped_SPMe_model], [parameter_values, grouped_parameter_values], - ["solid", "dash"], + ["-", "--"], strict=False, ): solution = pybamm.Simulation( model, parameter_values=param, experiment=experiment ).solve(initial_soc=init_soc) dataset = pybop.import_pybamm_solution(solution) - plot_dict.add_traces( - dataset["Time [s]"], dataset["Voltage [V]"], line_dash=line_style + plot_dict.create_trace( + dataset["Time [s]"], dataset["Voltage [V]"], label=None, linestyle=linestyle ) plot_dict() diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index f58db2016..744e749ab 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -1,13 +1,30 @@ +# Plotting backend default +DEFAULT_BACKEND = 'matplotlib' +backend=DEFAULT_BACKEND + +from .util import set_backend, call_plotting_function, get_class + # # Import plots # -from .plotly_manager import PlotlyManager -from .standard_plots import StandardPlot, StandardSubplot, trajectories -from .contour import contour -from .dataset import dataset -from .convergence import convergence -from .parameters import parameters -from .problem import problem -from .nyquist import nyquist -from .voronoi import surface -from .samples import trace, chains, posterior, summary_table +from .plots import ( + chains, + contour, + convergence, + dataset, + nyquist, + parameters, + posterior, + problem, + summary_table, + surface, + trace, + trajectories + ) + +from .voronoi import voronoi_data, _voronoi_regions +from . import matplotlib +from . import plotly + +StandardPlot = matplotlib.StandardPlot +StandardSubplot = matplotlib.StandardSubplot diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py new file mode 100644 index 000000000..863ad89ec --- /dev/null +++ b/pybop/plot/matplotlib/__init__.py @@ -0,0 +1,9 @@ +from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .dataset import dataset +from .convergence import convergence +from .parameters import parameters +from .problem import problem +from .contour import contour +from .voronoi import surface +from .nyquist import nyquist +from .samples import chains, posterior, summary_table, trace \ No newline at end of file diff --git a/pybop/plot/matplotlib/contour.py b/pybop/plot/matplotlib/contour.py new file mode 100644 index 000000000..eb7ee77f3 --- /dev/null +++ b/pybop/plot/matplotlib/contour.py @@ -0,0 +1,237 @@ +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING + +import numpy as np +from matplotlib import pyplot as plt +from scipy.interpolate import griddata + +from pybop.problems.problem import Problem + +if TYPE_CHECKING: + from pybop._result import Result + + +def contour( + call_object: "Problem | Result", + gradient: bool = False, + bounds: np.ndarray | None = None, + transformed: bool = False, + steps: int = 10, + show: bool = True, + title: str = 'Cost Landscape', +): + """ + Plot a 2D visualisation of a cost landscape using Plotly. + + This function generates a contour plot representing the cost landscape for a provided + callable cost function over a grid of parameter values within the specified bounds. + + Parameters + ---------- + call_object : pybop.Problem | pybop.Result + Either: + - the cost function to be evaluated. Must accept a list of parameter values and return a cost value. + - an optimiser result which provides a specific optimisation trace overlaid on the cost landscape. + gradient : bool, optional + If True, the gradient is shown (default: False). + bounds : numpy.ndarray | list[list[float]], optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `parameters.get_bounds_for_plotly`. + transformed : bool, optional + Uses the transformed parameter values (as seen by the optimiser) for plotting. + steps : int, optional + The number of grid points to divide the parameter space into along each dimension (default: 10). + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object containing the cost landscape plot. + + Raises + ------ + ValueError + If the cost function does not return a valid cost when called with a parameter list. + """ + plot_optim = False + problem = call_object + + # Assign input as a cost or optimisation result + if not isinstance(call_object, Callable): + plot_optim = True + result = call_object + problem = result.problem + + parameters = problem.parameters + names = parameters.names + additional_values = [] + + if len(parameters) < 2: + raise ValueError("This cost function takes fewer than 2 parameters.") + + if len(parameters) > 2: + warnings.warn( + "This cost function requires more than 2 parameters. " + "Plotting in 2d with fixed values for the additional parameters.", + UserWarning, + stacklevel=2, + ) + for ( + i, + (name, param), + ) in enumerate(parameters.items()): + if i > 1: + # TODO: Update from the initial to the intended value + additional_values.append(param.get_initial_value()) + print(f"Fixed {name}:", param.get_initial_value()) + + # Set up parameter bounds + if bounds is None: + bounds = parameters.get_bounds_for_plotly() + else: + bounds = np.asarray(bounds) + + # Generate grid + x = np.linspace(bounds[0, 0], bounds[0, 1], steps) + y = np.linspace(bounds[1, 0], bounds[1, 1], steps) + + # Initialize cost matrix + costs = np.zeros((len(y), len(x))) + + if gradient: + grad_parameter_costs = [] + + # Create an array to hold the gradient with respect to each parameter + grads = [np.zeros((len(y), len(x))) for _ in range(len(parameters))] + + # Populate cost matrix + for i, xi in enumerate(x): + for j, yj in enumerate(y): + if gradient: + out = problem.evaluate( + np.asarray([xi, yj] + additional_values), + calculate_sensitivities=True, + ).get_values() + costs[j, i], sensitivities = out[0][0], out[1] + for k, key in enumerate(problem.parameters.names): + grads[k][j, i] = sensitivities[key].item() + else: + costs[j, i] = problem.evaluate( + np.asarray([xi, yj] + additional_values), + ).get_values()[0] + + # Append the arrays to the grad_parameter_costs list + if gradient: + grad_parameter_costs.extend(grads) + + # Apply any transformation if requested + def transform_array_of_values(list_of_values, parameter): + """Apply transformation if requested.""" + if transformed: + return np.asarray( + [parameter.transformation.to_search(value) for value in list_of_values] + ).flatten() + return list_of_values + + x = transform_array_of_values(x, parameters[names[0]]) + y = transform_array_of_values(y, parameters[names[1]]) + bounds[0] = transform_array_of_values(bounds[0], parameters[names[0]]) + bounds[1] = transform_array_of_values(bounds[1], parameters[names[1]]) + + # define levels + exponent = np.floor(np.log10(np.abs(np.max(costs)))) + levels = np.linspace(np.floor(np.min(costs)/(10**exponent))*(10**exponent), np.ceil(np.max(costs)/(10**exponent))*(10**exponent), 2 * steps - 1) + + # Create contour plot and update the layout + fig = plt.figure(figsize=(6, 6), dpi=100) + plt.contourf(x, y, costs, levels=levels, extend='both', cmap='viridis') + plt.colorbar() + plt.contour(x, y, costs, levels=levels, colors=('k',), linestyles='solid', linewidths=0.1) + + # Layout + plt.xlabel("Transformed " + names[0] if transformed else names[0], labelpad=15) + plt.ticklabel_format(axis='both', **dict(style='sci',scilimits=(-4,4))) + plt.ylabel("Transformed " + names[1] if transformed else names[1], labelpad=15) + plt.title(title, pad=40) + plt.xlim(bounds[0]) + plt.ylim(bounds[1]) + + if plot_optim: + # Plot the optimisation trace + optim_trace = np.asarray([item[:2] for item in result.x_model]) + optim_trace = optim_trace.reshape(-1, 2) + + plt.scatter( + transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), + transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), + c=[i / len(optim_trace) for i in range(len(optim_trace))], + cmap='Grays', + zorder=1, + ) + + # Plot the initial guess + if len(result.x_model) > 0: + x0 = result.x_model[0] + plt.plot( + transform_array_of_values([x0[0]], parameters[names[0]]), + transform_array_of_values([x0[1]], parameters[names[1]]), + 'X', + markersize=14, + markerfacecolor='w', + markeredgecolor='k', + label="Initial values", + linestyle='None', + ) + + # Plot optimised value + if result.x is not None: + x_best = result.x + plt.plot( + transform_array_of_values([x_best[0]], parameters[names[0]]), + transform_array_of_values([x_best[1]], parameters[names[1]]), + "P", + markersize=14, + markerfacecolor='k', + markeredgecolor='w', + label="Final values", + linestyle='None', + ) + + plt.legend(ncols=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) + + plt.tight_layout() + + if show: + plt.show() + + # if gradient: + # grad_figs = [] + # for i, grad_costs in enumerate(grad_parameter_costs): + # # Update title for gradient plots + # updated_layout_options = layout_options.copy() + # updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}" + + # # Create contour plot with updated layout options + # grad_layout = go.Layout(updated_layout_options) + + # # Create fig + # grad_fig = go.Figure( + # data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout + # ) + # grad_fig.update_layout(**layout_kwargs) + + # if show: + # grad_fig.show() + + # # append grad_fig to list + # grad_figs.append(grad_fig) + + # return fig, grad_figs + + return fig diff --git a/pybop/plot/matplotlib/convergence.py b/pybop/plot/matplotlib/convergence.py new file mode 100644 index 000000000..4ad2f31ad --- /dev/null +++ b/pybop/plot/matplotlib/convergence.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from pybop.plot.matplotlib.standard_plots import StandardPlot +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from pybop._result import Result + + +def convergence(result: "Result", show=True): + """ + Plot the convergence of the optimisation algorithm. + + Parameters + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + show : bool, optional + If True, the figure is shown upon creation (default: True). + + Returns + --------- + fig : plotly.graph_objs.Figure + The Plotly figure object for the convergence plot. + """ + + # Extract log from the optimisation object + cost_log = result.cost_convergence + + # Generate a list of iteration numbers + iteration_numbers = list(range(1, len(cost_log) + 1)) + + # Create a plot dictionary + plot_dict = StandardPlot( + x=iteration_numbers, + y=cost_log, + trace_names=result.method_name, + ) + + # Generate and display the figure + fig = plot_dict(show=False) + plt.xlabel("Evaluation") + plt.ylabel("Cost") + plt.title("Convergence") + plt.tight_layout() + + if show: + plt.show() + + return fig \ No newline at end of file diff --git a/pybop/plot/matplotlib/dataset.py b/pybop/plot/matplotlib/dataset.py new file mode 100644 index 000000000..44cd3ce25 --- /dev/null +++ b/pybop/plot/matplotlib/dataset.py @@ -0,0 +1,54 @@ +import matplotlib.pyplot as plt +from pybop.plot.matplotlib.standard_plots import StandardPlot, trajectories + + +def dataset(dataset, signal=None, trace_names=None, show=True): + """ + Quickly plot a PyBOP Dataset using Plotly. + + Parameters + ---------- + dataset : object + A PyBOP dataset. + signal : list or str, optional + The name of the time series to plot (default: "Voltage [V]"). + trace_names : list or str, optional + Name(s) for the trace(s) (default: "Data"). + show : bool, optional + If True, the figure is shown upon creation (default: True). + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + + # Get data dictionary + if signal is None: + signal = ["Voltage [V]"] + dataset.check(signal=signal) + + # Compile ydata and labels or legend + y = [dataset[s] for s in signal] + if len(signal) == 1: + yaxis_title = signal[0] + if trace_names is None: + trace_names = ["Data"] + else: + yaxis_title = "Output" + if trace_names is None: + trace_names = StandardPlot.remove_brackets(signal) + + # Create the figure + fig = trajectories( + x=dataset[dataset.domain], + y=y, + trace_names=trace_names, + show=False, + xaxis_title=StandardPlot.remove_brackets(dataset.domain), + yaxis_title=yaxis_title, + ) + if show: + plt.show() + + return fig diff --git a/pybop/plot/matplotlib/nyquist.py b/pybop/plot/matplotlib/nyquist.py new file mode 100644 index 000000000..c45b29beb --- /dev/null +++ b/pybop/plot/matplotlib/nyquist.py @@ -0,0 +1,86 @@ +from pybop.parameters.parameter import Inputs +from pybop.plot.matplotlib.standard_plots import StandardPlot +from matplotlib import pyplot as plt + + +def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): + """ + Generates Nyquist plots for the given problem by evaluating the model's output and target values. + + Parameters + ---------- + problem : pybop.Problem + An instance of a problem class that contains the parameters and methods + for evaluation and target retrieval. + inputs : Inputs, optional + Input parameters for the problem. If not provided, the default parameters from the problem + instance will be used. These parameters are verified before use (default is None). + show : bool, optional + If True, the plots will be displayed. + **layout_kwargs : dict, optional + Additional keyword arguments for customising the plot layout. These arguments are passed to + `fig.update_layout()`. + + Returns + ------- + list + A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. + + Notes + ----- + - The function extracts the real part of the impedance from the model's output and the real and imaginary parts + of the impedance from the target output. + - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. + - An additional trace for the reference (target output) is added to the plot. + - The plot layout can be customised using `layout_kwargs`. + + Example + ------- + >>> problem = pybop.EISProblem() + >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") + >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. + """ + if not isinstance(inputs, dict): + inputs = problem.parameters.to_dict(inputs) + + model_output = problem.simulate(inputs) + domain_data = model_output["Impedance"].data.real + target_output = problem.target_data + + figure_list = [] + for var in problem.target: + plot_dict = StandardPlot( + x=domain_data, + y=-model_output[var].data.imag, + trace_names="Model", + ) + + fig = plot_dict(show=False) + plot_dict.traces[0].set_color("#00CC96") + plot_dict.traces[0].set_linewidth(2) + plot_dict.traces[0].set_marker('.') + plot_dict.traces[0].set_markersize(8) + + target_trace = plot_dict.create_trace( + x=target_output[var].real, + y=-target_output[var].imag, + label="Reference", + ) + target_trace.set_linestyle('None') + target_trace.set_marker('o') + target_trace.set_fillstyle('none') + target_trace.set_markersize(8) + target_trace.set_markeredgecolor("#636EFA") + + # Layout + plt.title('Nyquist Plot', fontsize=14, x=0.2) + plt.xlabel(r"$Z_{re} / \Omega$", fontsize=16) + plt.ylabel(r"$-Z_{im} / \Omega$", fontsize=16) + plt.legend(loc='upper right', bbox_to_anchor=(1, 1.08), ncols=2) + + if show: + plt.show() + + figure_list.append(fig) + + return figure_list diff --git a/pybop/plot/matplotlib/parameters.py b/pybop/plot/matplotlib/parameters.py new file mode 100644 index 000000000..671f00cb2 --- /dev/null +++ b/pybop/plot/matplotlib/parameters.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +import warnings + +from pybop.costs.log_likelihoods import GaussianLogLikelihood +from pybop.plot.matplotlib.standard_plots import StandardSubplot +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from pybop._result import Result + + +def parameters(result: "Result", show=True, **layout_kwargs): + """ + Plot the evolution of parameters during the optimisation process using Plotly. + + Parameters + ---------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + show : bool, optional + If True, the figure is shown upon creation (default: True). + + Returns + ------- + plotly.graph_objs.Figure + A Plotly figure object showing the parameter evolution over iterations. + """ + + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + # Extract parameters and log from the optimisation object + parameters = result.problem.parameters + x = list(range(len(result.x_model))) + y = [list(item) for item in zip(*result.x_model, strict=False)] + + # Create lists of axis titles and trace names + axis_titles = [] + trace_names = parameters.names + for name in trace_names: + axis_titles.append(("Evaluation", name)) + + if isinstance(result.problem, GaussianLogLikelihood): + axis_titles.append(("Evaluation", "Sigma")) + trace_names.append("Sigma") + + + # Create a plot dictionary + plot_dict = StandardSubplot( + x=x, + y=y, + axis_titles=axis_titles, + trace_names=trace_names, + trace_name_width=50, + figsize= (18, 8), + ) + + plt.suptitle("Parameter Convergence") + + # Generate the figure and update the layout + fig = plot_dict(show=False) + if show: + plt.show() + + return fig diff --git a/pybop/plot/matplotlib/problem.py b/pybop/plot/matplotlib/problem.py new file mode 100644 index 000000000..451ed76a1 --- /dev/null +++ b/pybop/plot/matplotlib/problem.py @@ -0,0 +1,114 @@ +import numpy as np + +from pybop.costs.design_cost import DesignCost +from pybop.costs.error_measures import ErrorMeasure +from pybop.parameters.parameter import Inputs +from pybop.plot.matplotlib.standard_plots import StandardPlot +from pybop.problems.meta_problem import MetaProblem +from pybop.problems.problem import Problem +from pybop.simulators.solution import Solution + +import matplotlib.pyplot as plt + + +def problem( + problem: Problem, + inputs: Inputs = None, + show: bool = True, + title = 'Scatter Plot', +): + """ + Produce a quick plot of the target dataset against optimised model output. + + Generates an interactive plot comparing the simulated model output with + an optional target dataset and visualises uncertainty. + + Parameters + ---------- + problem : pybop.Problem + Problem object with dataset and targets attributes. + inputs : Inputs + Optimised (or example) parameter values. + show : bool, optional + If True, the figure is shown upon creation (default: True). + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + if inputs is None: + inputs = problem.parameters.to_dict() + elif not isinstance(inputs, dict): + raise TypeError(f"Expecting a dictionary, received {type(inputs)}") + + domain = problem.domain + if problem.domain_data is None: + # Simulate the model for the both the initial and the given inputs + target = problem.target + problem.set_target(target + [domain]) + initial_inputs = problem.simulator.parameters.to_dict("initial") + target_output = problem.simulate(initial_inputs) + target_domain = target_output[domain].data + model_output = problem.simulate(inputs) + model_domain = model_output[domain].data + problem.set_target(target) + else: + # Extract the time data and simulate the model for the given inputs + target_output = Solution() + for target in problem.target: + target_output.set_solution_variable( + target, data=problem.target_data[target] + ) + target_domain = problem.domain_data + model_output = problem.simulate(inputs) + model_domain = target_domain[: len(model_output[target].data)] + + # Create a plot for each output + figure_list = [] + for var in problem.target: + # Create a plot dictionary + plot_dict = StandardPlot() + + plot_dict.create_trace( + x=target_domain, + y=target_output[var].data, + label="Reference", + marker=".", + linestyle="None" + ) + + plot_dict.create_trace( + x=model_domain, + y=model_output[var].data, + label="Optimised" if isinstance(problem.cost, DesignCost) else "Model", + marker="." if isinstance(problem, MetaProblem) else None, + linestyle='None' if isinstance(problem, MetaProblem) else "-", + ) + + if isinstance(problem.cost, ErrorMeasure) and len( + model_output[var].data + ) == len(target_output[var].data): + # Compute the standard deviation as proxy for uncertainty + plot_dict.sigma = np.std(model_output[var].data - target_output[var].data) + + # Convert x and upper and lower limits into lists to create a filled trace + x = target_domain.tolist() + y_upper = (model_output[var].data + plot_dict.sigma).tolist() + y_lower = (model_output[var].data - plot_dict.sigma).tolist() + + plt.fill_between(x, y_upper, y_lower, color=[(1.0, 0.898, 0.800, 0.8)]) + + # Generate the figure and update the layout + fig = plot_dict(show=False) + plt.xlabel("Time / s") + plt.ylabel(StandardPlot.remove_brackets(var)) + plt.title(title) + plt.legend() + plt.tight_layout() + if show: + plt.show() + + figure_list.append(fig) + + return figure_list diff --git a/pybop/plot/matplotlib/samples.py b/pybop/plot/matplotlib/samples.py new file mode 100644 index 000000000..dd14cf908 --- /dev/null +++ b/pybop/plot/matplotlib/samples.py @@ -0,0 +1,138 @@ +from typing import TYPE_CHECKING +import warnings + +from matplotlib import pyplot as plt + +if TYPE_CHECKING: + from pybop.samplers.base_pints_sampler import SamplingResult + + +def trace(result: "SamplingResult", show=True, **kwargs): + """ + Plot trace plots for the posterior samples. + """ + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + figlist = [] + for i in range(result.n_parameters): + fig = plt.figure() + + for j, chain in enumerate(result.chains): + plt.plot(chain[:, i], label=f"Chain {j}") + + plt.title(f"Parameter {i} Trace Plot") + plt.xlabel("Sample Index") + plt.ylabel("Value") + plt.legend(fontsize=12) + figlist.append(fig) + + + if show: + plt.show() + else: + return figlist + + + +def chains(result: "SamplingResult", show=True, **kwargs): + """ + Plot posterior distributions for each chain. + """ + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + fig = plt.figure(figsize=(15, 8), dpi=100) + + for i, chain in enumerate(result.chains): + for j in range(chain.shape[1]): + plt.hist( + x=chain[:, j], + label=f"Chain {i} - Parameter {j}", + alpha=0.5, + rwidth=2.0 + ) + + for j in range(chain.shape[1]): + plt.plot([result.mean[j], result.mean[j]], [0, result.max[j]],"--", lw=3, label=f"Mean - Parameter {j}") + + plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) + plt.grid(axis='y', zorder=-1) + plt.title("Posterior Distribution") + plt.xlabel("Value") + plt.ylabel("Density") + plt.tight_layout() + if show: + plt.show() + else: + return fig + +def posterior(result: "SamplingResult", show=True, **kwargs): + """ + Plot the summed posterior distribution across chains. + """ + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + fig = plt.figure(figsize=(15, 8), dpi=100) + + for j in range(result.all_samples.shape[1]): + plt.hist( + x=result.all_samples[:, j], + label=f"Parameter {j}", + alpha=0.75, + ) + plt.axvline(result.mean[j], ls='--', c='k', lw=3) + + plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) + plt.grid(axis='y', zorder=-1) + plt.title("Posterior Distribution") + plt.xlabel("Value") + plt.ylabel("Density") + plt.tight_layout() + if show: + plt.show() + else: + return fig + + +def summary_table(result: "SamplingResult"): + """ + Display summary statistics in a table. + """ + + summary_stats = result.get_summary_statistics() + + header = ["Statistic", "Value"] + values = [ + ["Mean", ', '.join(summary_stats["mean"].astype(str))], + ["Median", ', '.join(summary_stats["median"].astype(str))], + ["Standard Deviation", ', '.join(summary_stats["std"].astype(str))], + ["95% CI Lower", ', '.join(summary_stats["ci_lower"].astype(str))], + ["95% CI Upper", ', '.join(summary_stats["ci_upper"].astype(str))], + ] + fig, ax = plt.subplots(figsize=(6, 2), dpi=100) + + # hide axes + ax.axis('off') + ax.axis('tight') + ax.table(cellText=values, colLabels=header, loc='center', cellLoc='center', colColours=['lightsteelblue', 'lightsteelblue']) + ax.set_title("Summary Statistics") + fig.tight_layout() + plt.show() diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py new file mode 100644 index 000000000..578ef365b --- /dev/null +++ b/pybop/plot/matplotlib/standard_plots.py @@ -0,0 +1,386 @@ +import math +import textwrap +import warnings + +import numpy as np + +from matplotlib import pyplot as plt + +DEFAULT_TRACE_OPTIONS = dict(linewidth=2.0) + +class StandardPlot: + """ + A class for creating and displaying interactive Plotly figures. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + x=None, + y=None, + trace_options=None, + trace_names=None, + trace_name_width=20, + figsize=(8, 6), + **kwargs, + ): + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + self.traces = [] + self.trace_name_width = trace_name_width + + # Set default trace options and update if provided + self.trace_options = DEFAULT_TRACE_OPTIONS.copy() + if trace_options: + self.trace_options.update(trace_options) + + + # Parse the data + x, y = self.parse_data(x, y) + self.x = x + self.y = y + # Check and wrap trace names + if trace_names is not None: + if isinstance(trace_names, str): + trace_names = [trace_names] + for i, name in enumerate(trace_names): + trace_names[i] = self.wrap_text(name, width=self.trace_name_width) + self.trace_names = trace_names + + self.fig = plt.figure(figsize=figsize, dpi=100) + + + + def __call__(self, show=True): + """ + Generate and show the figure. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + # Add traces + if self.x is not None and self.y is not None: + self.add_traces(self.x, self.y, self.trace_names) + self.default_layout() + if show: + plt.show() + + return self.fig + + def default_layout(self): + + plt.tick_params(axis='both', labelsize=12) + plt.ticklabel_format(axis='both', style='sci', scilimits=(-4, 4)) + + def add_traces(self, x, y, trace_names=None, **trace_options): + """ + Add a set of traces to the plot dictionary. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + trace_names : str or list[str], optional + Name(s) for the primary trace(s) (default: None). + """ + + options = self.trace_options.copy() + options.update(trace_options) + + # Create a trace for each trajectory + xi = x[0] + for i in range(0, len(y)): + trace_options = options.copy() + if len(x) > 1: + xi = x[i] + + label = None + if trace_names is not None: + label = trace_names[i] + + self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) + + if self.trace_names is not None: + plt.legend(**dict(loc="best", fontsize=12),) + + + def parse_data(self, x, y): + """ + Check the type and dimensions of the data and convert if necessary to a list + of 'things plotly can take', e.g. numpy arrays or lists of numbers. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + """ + if x is None or y is None: + return None, None + if isinstance(x, list): + # If it's a list of numpy arrays, it's fine + # If it's a list of lists, it's fine + # If it's neither, it's a list of numbers that we need to wrap + if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list): + x = [x] + elif isinstance(x, np.ndarray): + x = np.squeeze(x) + if x.ndim == 1: + x = [x] + else: + x = x.tolist() + if isinstance(y, list): + if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list): + y = [y] + if isinstance(y, np.ndarray): + y = np.squeeze(y) + if y.ndim == 1: + y = [y] + else: + y = y.tolist() + if len(x) > 1 and len(x) != len(y): + raise ValueError( + "Input x should have either one data series or the same number as y." + ) + return x, y + + def create_trace(self, x, y, label, ax=None, **trace_options): + """ + Create a trace for the Plotly figure. + + Returns + ------- + plotly.graph_objs.Scatter + A trace for a Plotly figure. + """ + + if ax is None: + ax = plt.gca() + + line = ax.plot( + x, + y, + label=label, + **trace_options, + ) + if len(line) > 1: + return line + else: + return line[0] + + + @staticmethod + def wrap_text(text, width): + """ + Wrap text to a specified width with HTML line breaks. + + Parameters + ---------- + text : str + The text to wrap. + width : int + The width to wrap the text to. + + Returns + ------- + str + The wrapped text. + """ + wrapped_text = textwrap.fill(text, width=width, break_long_words=False) + return wrapped_text + + @staticmethod + def remove_brackets(s): + """ + Remove square brackets from a string and replace with forward slashes + as per section 7.1 of the SI Handbook + """ + # If s is an iterable (but not a string), apply the function recursively to each element + if hasattr(s, "__iter__") and not isinstance(s, str): + return type(s)(StandardPlot.remove_brackets(i) for i in s) + elif isinstance(s, str): + start = s.find("[") + end = s.find("]") + if start != -1 and end != -1: + char_in_brackets = s[start + 1 : end] + return s[:start] + " / " + char_in_brackets + s[end + 1 :] + return s + + + +class StandardSubplot(StandardPlot): + """ + A class for creating and displaying a set of interactive Plotly figures in a grid layout. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + num_rows : int, optional + Number of rows of subplots, can be set automatically (default: None). + num_cols : int, optional + Number of columns of subplots, can be set automatically (default: None). + layout : Plotly layout, optional + A layout for the figure, overrides the layout options (default: None). + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + x, + y, + num_rows=None, + num_cols=None, + axis_titles=None, + trace_options=DEFAULT_TRACE_OPTIONS, + trace_names=None, + trace_name_width=40, + figsize=(8, 6) + ): + super().__init__( + x, y, trace_options, trace_names, trace_name_width, figsize + ) + self.num_traces = len(self.y) + self.num_rows = num_rows + self.num_cols = num_cols + if self.num_rows is None and self.num_cols is None: + # Work out the number of subplots + self.num_cols = int(math.ceil(math.sqrt(self.num_traces))) + self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) + elif self.num_rows is None: + self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) + elif self.num_cols is None: + self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) + self.axis_titles = axis_titles + + + + def __call__(self, show): + """ + Generate and show the set of figures. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + + color_cycle = plt.rcParams['axes.prop_cycle']() + + xi = self.x[0] + lines = [] + for idx, yi in enumerate(self.y): + ax = self.fig.add_subplot(self.num_rows, self.num_cols, idx+1) + if self.axis_titles and idx < len(self.axis_titles): + x_title, y_title = self.axis_titles[idx] + ax.set_xlabel(x_title) + ax.set_ylabel(y_title) + if len(self.x)>1: + xi = self.x[idx] + + label = None + if self.trace_names is not None: + label = self.trace_names[idx] + + lines.append(self.create_trace(xi, yi, label, ax = ax, **next(color_cycle))) + + + lines_labels = [ax.get_legend_handles_labels() for ax in self.fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + if self.trace_names is not None: + self.fig.legend(lines, labels, loc='upper right', ncol=len(lines), bbox_to_anchor=(0.99, 0.95)) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + if show: + plt.show() + + return self.fig + + +def trajectories(x, y, trace_names=None, show=True, xaxis_title='', yaxis_title='', title='', **layout_kwargs): + """ + Quickly plot one or more trajectories using Plotly. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Y-axis data points for each trajectory. + trace_names : list or str, optional + Name(s) for the trace(s) (default: None). + **layout_kwargs : optional + This argument is ignored for the matplotlib backend. + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + # Create a plot dictionary + plot_dict = StandardPlot( + x=x, + y=y, + trace_names=trace_names, + ) + + # Generate the figure and update the layout + fig = plot_dict(show=False) + plt.title(title) + plt.xlabel(xaxis_title, fontsize=12) + plt.ylabel(yaxis_title, fontsize=12) + plt.tight_layout() + if show: + plt.show() + + return plot_dict \ No newline at end of file diff --git a/pybop/plot/matplotlib/voronoi.py b/pybop/plot/matplotlib/voronoi.py new file mode 100644 index 000000000..1a44f0f77 --- /dev/null +++ b/pybop/plot/matplotlib/voronoi.py @@ -0,0 +1,142 @@ +from typing import TYPE_CHECKING + +import numpy as np +from scipy.spatial import cKDTree +from matplotlib import pyplot as plt +import matplotlib as mpl +import warnings + +if TYPE_CHECKING: + from pybop._result import Result + +from pybop.plot.voronoi import voronoi_data + +def surface( + result: "Result", + bounds=None, + normalise=True, + title='Voronoi Cost Landscape', + show=True, + **layout_kwargs +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + points = result.x_model + parameters = result.problem.parameters + + if points[0].shape[0] != 2: + raise ValueError("This plot method requires two parameters.") + + x_optim, y_optim = map(list, zip(*points, strict=False)) + f = result.cost + + # Translate bounds, taking only the first two elements + xlim, ylim = ( + bounds if bounds is not None else [param.bounds for param in parameters] + )[:2] + + _, _, f, regions, relative_sizes = voronoi_data(xlim, ylim, x_optim, y_optim, f, normalise) + + + # Construct figure + plt.figure(figsize=(7, 6), dpi=100) + + # normalise cost + f_min = np.nanmin(f[np.isfinite(f)]) + f_max = np.nanmax(f[np.isfinite(f)]) + norm = mpl.colors.Normalize(vmin=f_min, vmax=f_max, clip=True) + norm_f = norm(f, clip=True) + + # get colours + cmap = mpl.colormaps['viridis'] + colors = cmap(norm_f) + + # Add Voronoi edges and fill Voronoi regions + for j, (region, size) in enumerate(zip(regions, relative_sizes, strict=False)): + x_region = region[:, 0].tolist() + [region[0, 0]] + y_region = region[:, 1].tolist() + [region[0, 1]] + + plt.fill(x_region, y_region, color=colors[j]) + plt.plot(x_region, y_region, color='w', linewidth=0.5 + size*0.1) + + plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax = plt.gca()) + + + # Add original points + plt.scatter( + x_optim, + y_optim, + c=[i / len(x_optim) for i in range(len(x_optim))], + cmap='Grays', + zorder=2.5, + ) + + # Plot the initial guess + if len(result.x_model) > 0: + x0 = result.x_model[0] + plt.plot( + [x0[0]], + [x0[1]], + 'X', + markersize=14, + markerfacecolor='w', + markeredgecolor='k', + label="Initial values", + linestyle='None', + zorder=2.6, + ) + + # Plot optimised value + if result.x is not None: + x_best = result.x + plt.plot( + [x_best[0]], + [x_best[1]], + "P", + markersize=14, + markerfacecolor='k', + markeredgecolor='w', + label="Final values", + linestyle='None', + zorder=2.6, + ) + + + # Layout + names = result.problem.parameters.names + plt.xlabel(names[0], labelpad=15) + plt.ticklabel_format(axis='both', **dict(style='sci',scilimits=(-4,4))) + plt.ylabel(names[1], labelpad=15) + plt.title(title, pad=40) + plt.legend(ncols=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) + plt.xlim(xlim[0], xlim[1]) + plt.ylim(ylim[0], ylim[1]) + plt.tight_layout() + + if show: + plt.show() diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py new file mode 100644 index 000000000..0a3a1f91a --- /dev/null +++ b/pybop/plot/plotly/__init__.py @@ -0,0 +1,10 @@ +from .plotly_manager import PlotlyManager +from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .contour import contour +from .dataset import dataset +from .convergence import convergence +from .parameters import parameters +from .problem import problem +from .nyquist import nyquist +from .voronoi import surface +from .samples import chains, posterior, summary_table, trace \ No newline at end of file diff --git a/pybop/plot/contour.py b/pybop/plot/plotly/contour.py similarity index 99% rename from pybop/plot/contour.py rename to pybop/plot/plotly/contour.py index da2d6a1a6..4fb1a1156 100644 --- a/pybop/plot/contour.py +++ b/pybop/plot/plotly/contour.py @@ -4,7 +4,7 @@ import numpy as np -from pybop.plot.plotly_manager import PlotlyManager +from pybop.plot.plotly.plotly_manager import PlotlyManager from pybop.problems.problem import Problem if TYPE_CHECKING: diff --git a/pybop/plot/convergence.py b/pybop/plot/plotly/convergence.py similarity index 96% rename from pybop/plot/convergence.py rename to pybop/plot/plotly/convergence.py index deabf66d5..6b190a93c 100644 --- a/pybop/plot/convergence.py +++ b/pybop/plot/plotly/convergence.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from pybop.plot.standard_plots import StandardPlot +from pybop.plot.plotly.standard_plots import StandardPlot if TYPE_CHECKING: from pybop._result import Result diff --git a/pybop/plot/dataset.py b/pybop/plot/plotly/dataset.py similarity index 95% rename from pybop/plot/dataset.py rename to pybop/plot/plotly/dataset.py index b6b20a7d7..dc02d01dc 100644 --- a/pybop/plot/dataset.py +++ b/pybop/plot/plotly/dataset.py @@ -1,4 +1,4 @@ -from pybop.plot.standard_plots import StandardPlot, trajectories +from pybop.plot.plotly.standard_plots import StandardPlot, trajectories def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs): diff --git a/pybop/plot/nyquist.py b/pybop/plot/plotly/nyquist.py similarity index 98% rename from pybop/plot/nyquist.py rename to pybop/plot/plotly/nyquist.py index 80f7eb77a..bbc5fa49a 100644 --- a/pybop/plot/nyquist.py +++ b/pybop/plot/plotly/nyquist.py @@ -1,5 +1,5 @@ from pybop.parameters.parameter import Inputs -from pybop.plot.standard_plots import StandardPlot +from pybop.plot.plotly.standard_plots import StandardPlot def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): diff --git a/pybop/plot/parameters.py b/pybop/plot/plotly/parameters.py similarity index 97% rename from pybop/plot/parameters.py rename to pybop/plot/plotly/parameters.py index a13d9e3fb..02cc281f0 100644 --- a/pybop/plot/parameters.py +++ b/pybop/plot/plotly/parameters.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from pybop.costs.log_likelihoods import GaussianLogLikelihood -from pybop.plot.standard_plots import StandardSubplot +from pybop.plot.plotly.standard_plots import StandardSubplot if TYPE_CHECKING: from pybop._result import Result diff --git a/pybop/plot/plotly_manager.py b/pybop/plot/plotly/plotly_manager.py similarity index 100% rename from pybop/plot/plotly_manager.py rename to pybop/plot/plotly/plotly_manager.py diff --git a/pybop/plot/problem.py b/pybop/plot/plotly/problem.py similarity index 98% rename from pybop/plot/problem.py rename to pybop/plot/plotly/problem.py index f87e6f88a..820284da2 100644 --- a/pybop/plot/problem.py +++ b/pybop/plot/plotly/problem.py @@ -3,7 +3,7 @@ from pybop.costs.design_cost import DesignCost from pybop.costs.error_measures import ErrorMeasure from pybop.parameters.parameter import Inputs -from pybop.plot.standard_plots import StandardPlot +from pybop.plot.plotly.standard_plots import StandardPlot from pybop.problems.meta_problem import MetaProblem from pybop.problems.problem import Problem from pybop.simulators.solution import Solution diff --git a/pybop/plot/samples.py b/pybop/plot/plotly/samples.py similarity index 98% rename from pybop/plot/samples.py rename to pybop/plot/plotly/samples.py index 55ee77cd0..f5016c325 100644 --- a/pybop/plot/samples.py +++ b/pybop/plot/plotly/samples.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from pybop.plot import PlotlyManager +from pybop.plot.plotly import PlotlyManager if TYPE_CHECKING: from pybop.samplers.base_pints_sampler import SamplingResult diff --git a/pybop/plot/standard_plots.py b/pybop/plot/plotly/standard_plots.py similarity index 96% rename from pybop/plot/standard_plots.py rename to pybop/plot/plotly/standard_plots.py index 4422516b8..8e3961d71 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -2,8 +2,9 @@ import textwrap import numpy as np +import warnings -from pybop.plot.plotly_manager import PlotlyManager +from pybop.plot.plotly.plotly_manager import PlotlyManager DEFAULT_LAYOUT_OPTIONS = dict( title=None, @@ -71,7 +72,16 @@ def __init__( trace_options=None, trace_names=None, trace_name_width=40, + **kwargs, ): + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (plotly): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) self.traces = [] self.layout = layout self.trace_name_width = trace_name_width diff --git a/pybop/plot/plotly/voronoi.py b/pybop/plot/plotly/voronoi.py new file mode 100644 index 000000000..dc1e5ea1e --- /dev/null +++ b/pybop/plot/plotly/voronoi.py @@ -0,0 +1,216 @@ +from typing import TYPE_CHECKING + +import numpy as np +from scipy.spatial import Voronoi, cKDTree + +if TYPE_CHECKING: + from pybop._result import Result +from pybop.plot.plotly.plotly_manager import PlotlyManager +from pybop.plot import voronoi_data + + + +def assign_nearest_value(x, y, f, xi, yi): + """ + Computes an array of values given by the score of the nearest point. + + Parameters + ---------- + x : array-like + The x coordinates of points with known scores. + y : array-like + The y coordinates of points with known scores. + f : array-like + The score function at the given x and y coordinates. + xi : array-like + The x coordinates of grid points. + yi : array-like + The y coordinates of grid points. + + Returns + ------- + A numpy array containing the scores corresponding to the grid points. + """ + # Create a KD-tree for efficient nearest neighbor search + tree = cKDTree(np.column_stack((x, y))) + + # Find the nearest point for each grid point + _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel()))) + zi = f[indices].reshape(xi.shape) + + return zi + + +def surface( + result: "Result", + bounds=None, + normalise=True, + resolution=250, + show=True, + **layout_kwargs, +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + """ + points = result.x_model + parameters = result.problem.parameters + + if points[0].shape[0] != 2: + raise ValueError("This plot method requires two parameters.") + + x_optim, y_optim = map(list, zip(*points, strict=False)) + f = result.cost + + # Translate bounds, taking only the first two elements + xlim, ylim = ( + bounds if bounds is not None else [param.bounds for param in parameters] + )[:2] + + x, y, f, regions, relative_sizes = voronoi_data(xlim, ylim, x_optim, y_optim, f, normalise) + + # Create a grid for plot + xi = np.linspace(xlim[0], xlim[1], resolution) + yi = np.linspace(ylim[0], ylim[1], resolution) + xi, yi = np.meshgrid(xi, yi) + + + if normalise: + # Create a normalised grid + norm_xi = np.linspace(0, 1, resolution) + norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi) + + # Assign a value to each point in the grid + zi = assign_nearest_value(x, y, f, norm_xi, norm_yi) + else: + # Assign a value to each point in the grid + zi = assign_nearest_value(x, y, f, xi, yi) + + # Calculate the size of each Voronoi region + region_sizes = np.array([len(region) for region in regions]) + relative_sizes = (region_sizes - region_sizes.min()) / ( + region_sizes.max() - region_sizes.min() + ) + + # Construct figure + go = PlotlyManager().go + fig = go.Figure() + + # Heatmap + fig.add_trace( + go.Heatmap( + x=xi[0], + y=yi[:, 0], + z=zi, + colorscale="Viridis", + zsmooth="best", + ) + ) + + # Add Voronoi edges + for region, size in zip(regions, relative_sizes, strict=False): + x_region = region[:, 0].tolist() + [region[0, 0]] + y_region = region[:, 1].tolist() + [region[0, 1]] + + fig.add_trace( + go.Scatter( + x=x_region, + y=y_region, + mode="lines", + line=dict(color="white", width=0.5 + size * 0.1), + showlegend=False, + ) + ) + + # Add original points + fig.add_trace( + go.Scatter( + x=x_optim, + y=y_optim, + mode="markers", + marker=dict( + color=[i / len(x_optim) for i in range(len(x_optim))], + colorscale="Greys", + size=8, + showscale=False, + ), + text=[f"f={val:.2f}" for val in f], + hoverinfo="text", + showlegend=False, + ) + ) + + # Plot the initial guess + if len(result.x_model) > 0: + x0 = result.x_model[0] + fig.add_trace( + go.Scatter( + x=[x0[0]], + y=[x0[1]], + mode="markers", + marker_symbol="x", + marker=dict( + color="white", + line_color="black", + line_width=1, + size=14, + showscale=False, + ), + name="Initial values", + ) + ) + + # Plot optimised value + if result.x is not None: + x_best = result.x + fig.add_trace( + go.Scatter( + x=[x_best[0]], + y=[x_best[1]], + mode="markers", + marker_symbol="cross", + marker=dict( + color="black", + line_color="white", + line_width=1, + size=14, + showscale=False, + ), + name="Final values", + ) + ) + + names = parameters.names + fig.update_layout( + title="Voronoi Cost Landscape", + title_x=0.5, + title_y=0.905, + xaxis_title=names[0], + yaxis_title=names[1], + width=600, + height=600, + xaxis=dict(range=xlim, showexponent="last", exponentformat="e"), + yaxis=dict(range=ylim, showexponent="last", exponentformat="e"), + legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), + ) + fig.update_layout(**layout_kwargs) + if show: + fig.show() \ No newline at end of file diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py new file mode 100644 index 000000000..daadc5d11 --- /dev/null +++ b/pybop/plot/plots.py @@ -0,0 +1,287 @@ +from typing import TYPE_CHECKING +import numpy as np + +if TYPE_CHECKING: + from pybop._result import Result + from pybop.samplers.base_pints_sampler import SamplingResult + +from pybop.parameters.parameter import Inputs +from pybop.problems.problem import Problem +from pybop.plot.util import call_plotting_function, get_class + + + +def chains(result: "SamplingResult", show=True, backend=None, **kwargs): + """ + Plot posterior distributions for each chain. + """ + return call_plotting_function('chains', backend, result=result, **kwargs) + +def contour( + call_object: "Problem | Result", + gradient: bool = False, + bounds: np.ndarray | None = None, + transformed: bool = False, + steps: int = 10, + show: bool = True, + backend = None, + **layout_kwargs, +): + """ + Plot a 2D visualisation of a cost landscape using Plotly. + + This function generates a contour plot representing the cost landscape for a provided + callable cost function over a grid of parameter values within the specified bounds. + + Parameters + ---------- + call_object : pybop.Problem | pybop.Result + Either: + - the cost function to be evaluated. Must accept a list of parameter values and return a cost value. + - an optimiser result which provides a specific optimisation trace overlaid on the cost landscape. + gradient : bool, optional + If True, the gradient is shown (default: False). + bounds : numpy.ndarray | list[list[float]], optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `parameters.get_bounds_for_plotly`. + transformed : bool, optional + Uses the transformed parameter values (as seen by the optimiser) for plotting. + steps : int, optional + The number of grid points to divide the parameter space into along each dimension (default: 10). + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object containing the cost landscape plot. + + Raises + ------ + ValueError + If the cost function does not return a valid cost when called with a parameter list. + """ + return call_plotting_function('contour', backend, call_object=call_object, gradient=gradient, bounds=bounds, transformed=transformed, steps=steps, show=show, **layout_kwargs) + +def convergence(result: "Result", show=True, backend=None, **layout_kwargs): + """ + Plot the convergence of the optimisation algorithm. + + Parameters + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + --------- + fig : plotly.graph_objs.Figure + The Plotly figure object for the convergence plot. + """ + return call_plotting_function('convergence', backend, result = result, show=show, **layout_kwargs) + +def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs): + """ + Quickly plot a PyBOP Dataset using Plotly. + + Parameters + ---------- + dataset : object + A PyBOP dataset. + signal : list or str, optional + The name of the time series to plot (default: "Voltage [V]"). + trace_names : list or str, optional + Name(s) for the trace(s) (default: "Data"). + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + call_plotting_function('dataset', backend, dataset=dataset, signal=signal, trace_names=trace_names, show=show, **layout_kwargs) + +def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): + """ + Generates Nyquist plots for the given problem by evaluating the model's output and target values. + + Parameters + ---------- + problem : pybop.Problem + An instance of a problem class that contains the parameters and methods + for evaluation and target retrieval. + inputs : Inputs, optional + Input parameters for the problem. If not provided, the default parameters from the problem + instance will be used. These parameters are verified before use (default is None). + show : bool, optional + If True, the plots will be displayed. + **layout_kwargs : dict, optional + Additional keyword arguments for customising the plot layout. These arguments are passed to + `fig.update_layout()`. + + Returns + ------- + list + A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. + + Notes + ----- + - The function extracts the real part of the impedance from the model's output and the real and imaginary parts + of the impedance from the target output. + - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. + - An additional trace for the reference (target output) is added to the plot. + - The plot layout can be customised using `layout_kwargs`. + + Example + ------- + >>> problem = pybop.EISProblem() + >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") + >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. + """ + return call_plotting_function('nyquist', backend, problem=problem, inputs=inputs, show=show, **layout_kwargs) + +def parameters(result: "Result", show=True, backend=None, **layout_kwargs): + """ + Plot the evolution of parameters during the optimisation process using Plotly. + + Parameters + ---------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + A Plotly figure object showing the parameter evolution over iterations. + """ + return call_plotting_function('parameters', backend, result = result, show=show, **layout_kwargs) + +def posterior(result: "SamplingResult", show=True, backend=None, **kwargs): + """ + Plot the summed posterior distribution across chains. + """ + return call_plotting_function('posterior', backend, result=result, **kwargs) + +def problem( + problem: Problem, + inputs: Inputs = None, + show: bool = True, + backend=None, + **layout_kwargs, +): + """ + Produce a quick plot of the target dataset against optimised model output. + + Generates an interactive plot comparing the simulated model output with + an optional target dataset and visualises uncertainty. + + Parameters + ---------- + problem : pybop.Problem + Problem object with dataset and targets attributes. + inputs : Inputs + Optimised (or example) parameter values. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + return call_plotting_function('problem', backend, problem=problem, inputs=inputs, show=show, **layout_kwargs) + +def summary_table(result: "SamplingResult", backend=None): + """ + Display summary statistics in a table. + """ + + return call_plotting_function('summary_table', backend, result=result) + +def surface( + result: "Result", + bounds=None, + normalise=True, + resolution=250, + show=True, + backend=None, + **layout_kwargs, +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + """ + return call_plotting_function('surface', backend, result = result, bounds = bounds, normalise=normalise, resolution=resolution, show=show, **layout_kwargs) + +def trace(result: "SamplingResult", backend=None, **kwargs): + """ + Plot trace plots for the posterior samples. + """ + return call_plotting_function('trace', backend, result=result, **kwargs) + +def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): + """ + Quickly plot one or more trajectories using Plotly. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Y-axis data points for each trajectory. + trace_names : list or str, optional + Name(s) for the trace(s) (default: None). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + + return call_plotting_function('trajectories', backend, x=x, y=y, trace_names=trace_names, show=show, **layout_kwargs) \ No newline at end of file diff --git a/pybop/plot/util.py b/pybop/plot/util.py new file mode 100644 index 000000000..3ec299ce3 --- /dev/null +++ b/pybop/plot/util.py @@ -0,0 +1,50 @@ +import importlib.util +import pybop.plot + +def get_class(class_name): + err_msg = f"Plotting backend {pybop.plot.backend} is not available." + try: + module = importlib.import_module('pybop.plot.' + pybop.plot.backend) + if hasattr(module, class_name): + return getattr(module, class_name) + + else: + err_msg = f"Plotting backend {pybop.plot.backend} has no attribute {class_name}." + raise ModuleNotFoundError(err_msg) + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + +def set_backend(backend): + err_msg = f"Plotting backend {backend} is not available. The default backend has not been updated. \n"\ + f"The default backend is set to {pybop.plot.backend}" + try: + importlib.import_module('pybop.plot.' + backend) + pybop.plot.backend = backend + for attr in ['StandardPlot', 'StandardSubplot']: + setattr(pybop.plot, attr, get_class(attr)) + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + + +def call_plotting_function(function_name, backend, **kwargs): + if backend is None: + backend = pybop.plot.backend + err_msg = f"Plotting backend {backend} is not available." + try: + module = importlib.import_module('pybop.plot.' + backend) + if hasattr(module, function_name): + plotting_function = getattr(module, function_name) + # Return the imported attribute + return plotting_function(**kwargs) + else: + err_msg = f"Plotting backend {backend} has no attribute {function_name}." + raise ModuleNotFoundError(err_msg) + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + diff --git a/pybop/plot/voronoi.py b/pybop/plot/voronoi.py index 29be3c096..2ac57fada 100644 --- a/pybop/plot/voronoi.py +++ b/pybop/plot/voronoi.py @@ -1,11 +1,10 @@ from typing import TYPE_CHECKING import numpy as np -from scipy.spatial import Voronoi, cKDTree +from scipy.spatial import Voronoi if TYPE_CHECKING: - from pybop._result import Result -from pybop.plot.plotly_manager import PlotlyManager + pass def _voronoi_regions(x, y, f, xlim, ylim): @@ -195,85 +194,7 @@ def interpolate_point(p, q, axis, boundary_val): return np.array([boundary_val, s]) if axis == 0 else np.array([s, boundary_val]) -def assign_nearest_value(x, y, f, xi, yi): - """ - Computes an array of values given by the score of the nearest point. - - Parameters - ---------- - x : array-like - The x coordinates of points with known scores. - y : array-like - The y coordinates of points with known scores. - f : array-like - The score function at the given x and y coordinates. - xi : array-like - The x coordinates of grid points. - yi : array-like - The y coordinates of grid points. - - Returns - ------- - A numpy array containing the scores corresponding to the grid points. - """ - # Create a KD-tree for efficient nearest neighbor search - tree = cKDTree(np.column_stack((x, y))) - - # Find the nearest point for each grid point - _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel()))) - zi = f[indices].reshape(xi.shape) - - return zi - - -def surface( - result: "Result", - bounds=None, - normalise=True, - resolution=250, - show=True, - **layout_kwargs, -): - """ - Plot a 2D representation of the Voronoi diagram with color-coded regions. - - Parameters: - ----------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - bounds : numpy.ndarray, optional - A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses - `cost.parameters.get_bounds_for_plotly`. - normalise : bool, optional - If True, the voronoi regions are computed using the Euclidean distance between - points normalised with respect to the bounds (default: True). - resolution : int, optional - Resolution of the plot. Default is 500. - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time [s]"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - """ - points = result.x_model - parameters = result.problem.parameters - - if points[0].shape[0] != 2: - raise ValueError("This plot method requires two parameters.") - - x_optim, y_optim = map(list, zip(*points, strict=False)) - f = result.cost - - # Translate bounds, taking only the first two elements - xlim, ylim = ( - bounds if bounds is not None else [param.bounds for param in parameters] - )[:2] - - # Create a grid for plot - xi = np.linspace(xlim[0], xlim[1], resolution) - yi = np.linspace(ylim[0], ylim[1], resolution) - xi, yi = np.meshgrid(xi, yi) +def voronoi_data(xlim, ylim, pts_x, pts_y, f, normalise=True): if normalise: if xlim[1] <= xlim[0] or ylim[1] <= ylim[0]: @@ -282,21 +203,14 @@ def surface( # Normalise the region x_range = xlim[1] - xlim[0] y_range = ylim[1] - ylim[0] - norm_x_optim = (np.asarray(x_optim) - xlim[0]) / x_range - norm_y_optim = (np.asarray(y_optim) - ylim[0]) / y_range + norm_x_optim = (np.asarray(pts_x) - xlim[0]) / x_range + norm_y_optim = (np.asarray(pts_y) - ylim[0]) / y_range # Compute regions - norm_x, norm_y, f, norm_regions = _voronoi_regions( + x, y, f, norm_regions = _voronoi_regions( norm_x_optim, norm_y_optim, f, (0, 1), (0, 1) ) - # Create a normalised grid - norm_xi = np.linspace(0, 1, resolution) - norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi) - - # Assign a value to each point in the grid - zi = assign_nearest_value(norm_x, norm_y, f, norm_xi, norm_yi) - # Rescale for plotting regions = [] for norm_region in norm_regions: @@ -307,10 +221,7 @@ def surface( else: # Compute regions - x, y, f, regions = _voronoi_regions(x_optim, y_optim, f, xlim, ylim) - - # Assign a value to each point in the grid - zi = assign_nearest_value(x, y, f, xi, yi) + x, y, f, regions = _voronoi_regions(pts_x, pts_y, f, xlim, ylim) # Calculate the size of each Voronoi region region_sizes = np.array([len(region) for region in regions]) @@ -318,107 +229,4 @@ def surface( region_sizes.max() - region_sizes.min() ) - # Construct figure - go = PlotlyManager().go - fig = go.Figure() - - # Heatmap - fig.add_trace( - go.Heatmap( - x=xi[0], - y=yi[:, 0], - z=zi, - colorscale="Viridis", - zsmooth="best", - ) - ) - - # Add Voronoi edges - for region, size in zip(regions, relative_sizes, strict=False): - x_region = region[:, 0].tolist() + [region[0, 0]] - y_region = region[:, 1].tolist() + [region[0, 1]] - - fig.add_trace( - go.Scatter( - x=x_region, - y=y_region, - mode="lines", - line=dict(color="white", width=0.5 + size * 0.1), - showlegend=False, - ) - ) - - # Add original points - fig.add_trace( - go.Scatter( - x=x_optim, - y=y_optim, - mode="markers", - marker=dict( - color=[i / len(x_optim) for i in range(len(x_optim))], - colorscale="Greys", - size=8, - showscale=False, - ), - text=[f"f={val:.2f}" for val in f], - hoverinfo="text", - showlegend=False, - ) - ) - - # Plot the initial guess - if len(result.x_model) > 0: - x0 = result.x_model[0] - fig.add_trace( - go.Scatter( - x=[x0[0]], - y=[x0[1]], - mode="markers", - marker_symbol="x", - marker=dict( - color="white", - line_color="black", - line_width=1, - size=14, - showscale=False, - ), - name="Initial values", - ) - ) - - # Plot optimised value - if result.x is not None: - x_best = result.x - fig.add_trace( - go.Scatter( - x=[x_best[0]], - y=[x_best[1]], - mode="markers", - marker_symbol="cross", - marker=dict( - color="black", - line_color="white", - line_width=1, - size=14, - showscale=False, - ), - name="Final values", - ) - ) - - names = parameters.names - fig.update_layout( - title="Voronoi Cost Landscape", - title_x=0.5, - title_y=0.905, - xaxis_title=names[0], - yaxis_title=names[1], - width=600, - height=600, - xaxis=dict(range=xlim, showexponent="last", exponentformat="e"), - yaxis=dict(range=ylim, showexponent="last", exponentformat="e"), - legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), - ) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + return x, y, f, regions, relative_sizes diff --git a/tests/plotting/test_plotly_manager.py b/tests/plotting/test_plotly_manager.py index f64e608f1..66e9ba098 100644 --- a/tests/plotting/test_plotly_manager.py +++ b/tests/plotting/test_plotly_manager.py @@ -8,7 +8,7 @@ import pytest import pybop -from pybop.plot import PlotlyManager +from pybop.plot.plotly import PlotlyManager # Find the Python executable python_executable = which("python") diff --git a/tests/unit/test_plots.py b/tests/unit/test_plots.py index 58b707900..db9e0c6c1 100644 --- a/tests/unit/test_plots.py +++ b/tests/unit/test_plots.py @@ -6,6 +6,7 @@ import pybop +@pytest.mark.parametrize("backend", ["plotly", "matplotlib"]) class TestPlots: """ A class to test the plot classes. @@ -13,7 +14,8 @@ class TestPlots: pytestmark = pytest.mark.unit - def test_standard_plot(self): + def test_standard_plot(self, backend): + pybop.plot.set_backend(backend) # Test standard plot trace_names = pybop.plot.StandardPlot.remove_brackets( ["Trace [1]", "Trace [2]"] @@ -69,7 +71,8 @@ def dataset(self, model): solution = pybamm.Simulation(model).solve(t_eval=t_eval, t_interp=t_eval) return pybop.import_pybamm_solution(solution) - def test_dataset_plots(self, dataset): + def test_dataset_plots(self, dataset, backend): + pybop.plot.set_backend(backend) # Test plot of Dataset objects pybop.plot.trajectories( dataset["Time [s]"], @@ -112,7 +115,8 @@ def design_problem(self, model, parameters, experiment): ) return pybop.Problem(simulator) - def test_problem_plots(self, fitting_problem, design_problem): + def test_problem_plots(self, fitting_problem, design_problem, backend): + pybop.plot.set_backend(backend) # Test plot of Problem objects pybop.plot.problem(fitting_problem, title="Optimised Comparison") pybop.plot.problem(design_problem) @@ -122,7 +126,8 @@ def test_problem_plots(self, fitting_problem, design_problem): fitting_problem, inputs=fitting_problem.parameters.to_dict([0.6, 0.6]) ) - def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds): + def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds, backend): + pybop.plot.set_backend(backend) # Test plot of Cost objects pybop.plot.contour(fitting_problem, gradient=True, steps=5) @@ -143,7 +148,8 @@ def result(self, fitting_problem): optim = pybop.XNES(fitting_problem) return optim.run() - def test_optim_plots(self, result): + def test_optim_plots(self, result, backend): + pybop.plot.set_backend(backend) bounds = np.asarray([[0.5, 0.8], [0.4, 0.7]]) # Plot convergence @@ -185,7 +191,8 @@ def sampling_result(self, model, parameters, dataset): sampler = pybop.SliceStepoutMCMC(log_pdf, options=options) return sampler.run() - def test_posterior_plots(self, sampling_result): + def test_posterior_plots(self, sampling_result, backend): + pybop.plot.set_backend(backend) sampling_result.get_summary_statistics() # Plot trace @@ -200,9 +207,11 @@ def test_posterior_plots(self, sampling_result): # Plot summary table sampling_result.summary_table() - def test_with_ipykernel(self, dataset, fitting_problem, result): + def test_with_ipykernel(self, dataset, fitting_problem, result, backend): import ipykernel + pybop.plot.set_backend(backend) + assert version.parse(ipykernel.__version__) >= version.parse("0.6") pybop.plot.dataset(dataset, signal=["Voltage [V]"]) pybop.plot.contour(fitting_problem, gradient=True, steps=5) @@ -210,7 +219,8 @@ def test_with_ipykernel(self, dataset, fitting_problem, result): result.plot_parameters() result.plot_contour(steps=5) - def test_contour_incorrect_number_of_parameters(self, model, dataset): + def test_contour_incorrect_number_of_parameters(self, model, dataset, backend): + pybop.plot.set_backend(backend) parameter_values = model.default_parameter_values # Test with less than two paramters @@ -251,7 +261,8 @@ def test_contour_incorrect_number_of_parameters(self, model, dataset): fitting_problem = pybop.Problem(simulator, cost) pybop.plot.contour(fitting_problem) - def test_nyquist(self): + def test_nyquist(self, backend): + pybop.plot.set_backend(backend) # Define model model = pybamm.lithium_ion.SPM(options={"surface form": "differential"}) parameter_values = model.default_parameter_values From 8ca515855765e4b772939179b0f819a19b4f6ef5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:33:15 +0000 Subject: [PATCH 02/10] style: pre-commit fixes --- .../ecm_multipulse_identification.ipynb | 2 +- .../comparison_examples/grouped_SPMe.py | 2 +- pybop/plot/matplotlib/__init__.py | 2 +- pybop/plot/matplotlib/contour.py | 41 +++++---- pybop/plot/matplotlib/convergence.py | 5 +- pybop/plot/matplotlib/dataset.py | 1 + pybop/plot/matplotlib/nyquist.py | 17 ++-- pybop/plot/matplotlib/parameters.py | 7 +- pybop/plot/matplotlib/problem.py | 11 ++- pybop/plot/matplotlib/samples.py | 46 ++++++---- pybop/plot/matplotlib/standard_plots.py | 64 +++++++------ pybop/plot/matplotlib/voronoi.py | 49 +++++----- pybop/plot/plotly/__init__.py | 2 +- pybop/plot/plotly/standard_plots.py | 2 +- pybop/plot/plotly/voronoi.py | 12 +-- pybop/plot/plots.py | 90 +++++++++++++++---- pybop/plot/util.py | 32 ++++--- 17 files changed, 233 insertions(+), 152 deletions(-) diff --git a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb index f03a01249..8140c73c1 100644 --- a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb @@ -507,7 +507,7 @@ " [3600 3]], duration=3600)]" ] }, - "execution_count": 11, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/scripts/comparison_examples/grouped_SPMe.py b/examples/scripts/comparison_examples/grouped_SPMe.py index 85119a9da..cb1d060e2 100644 --- a/examples/scripts/comparison_examples/grouped_SPMe.py +++ b/examples/scripts/comparison_examples/grouped_SPMe.py @@ -11,7 +11,7 @@ """ # Prepare figure -pybop.plot.set_backend('matplotlib') +pybop.plot.set_backend("matplotlib") plot_dict = pybop.plot.StandardPlot() plt.xlabel("Time / s") plt.ylabel("Voltage / V") diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py index 863ad89ec..be1f45a58 100644 --- a/pybop/plot/matplotlib/__init__.py +++ b/pybop/plot/matplotlib/__init__.py @@ -6,4 +6,4 @@ from .contour import contour from .voronoi import surface from .nyquist import nyquist -from .samples import chains, posterior, summary_table, trace \ No newline at end of file +from .samples import chains, posterior, summary_table, trace diff --git a/pybop/plot/matplotlib/contour.py b/pybop/plot/matplotlib/contour.py index eb7ee77f3..d0127375a 100644 --- a/pybop/plot/matplotlib/contour.py +++ b/pybop/plot/matplotlib/contour.py @@ -4,7 +4,6 @@ import numpy as np from matplotlib import pyplot as plt -from scipy.interpolate import griddata from pybop.problems.problem import Problem @@ -19,7 +18,7 @@ def contour( transformed: bool = False, steps: int = 10, show: bool = True, - title: str = 'Cost Landscape', + title: str = "Cost Landscape", ): """ Plot a 2D visualisation of a cost landscape using Plotly. @@ -146,17 +145,23 @@ def transform_array_of_values(list_of_values, parameter): # define levels exponent = np.floor(np.log10(np.abs(np.max(costs)))) - levels = np.linspace(np.floor(np.min(costs)/(10**exponent))*(10**exponent), np.ceil(np.max(costs)/(10**exponent))*(10**exponent), 2 * steps - 1) + levels = np.linspace( + np.floor(np.min(costs) / (10**exponent)) * (10**exponent), + np.ceil(np.max(costs) / (10**exponent)) * (10**exponent), + 2 * steps - 1, + ) # Create contour plot and update the layout fig = plt.figure(figsize=(6, 6), dpi=100) - plt.contourf(x, y, costs, levels=levels, extend='both', cmap='viridis') + plt.contourf(x, y, costs, levels=levels, extend="both", cmap="viridis") plt.colorbar() - plt.contour(x, y, costs, levels=levels, colors=('k',), linestyles='solid', linewidths=0.1) + plt.contour( + x, y, costs, levels=levels, colors=("k",), linestyles="solid", linewidths=0.1 + ) - # Layout + # Layout plt.xlabel("Transformed " + names[0] if transformed else names[0], labelpad=15) - plt.ticklabel_format(axis='both', **dict(style='sci',scilimits=(-4,4))) + plt.ticklabel_format(axis="both", **dict(style="sci", scilimits=(-4, 4))) plt.ylabel("Transformed " + names[1] if transformed else names[1], labelpad=15) plt.title(title, pad=40) plt.xlim(bounds[0]) @@ -171,9 +176,9 @@ def transform_array_of_values(list_of_values, parameter): transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), c=[i / len(optim_trace) for i in range(len(optim_trace))], - cmap='Grays', + cmap="Grays", zorder=1, - ) + ) # Plot the initial guess if len(result.x_model) > 0: @@ -181,12 +186,12 @@ def transform_array_of_values(list_of_values, parameter): plt.plot( transform_array_of_values([x0[0]], parameters[names[0]]), transform_array_of_values([x0[1]], parameters[names[1]]), - 'X', + "X", markersize=14, - markerfacecolor='w', - markeredgecolor='k', + markerfacecolor="w", + markeredgecolor="k", label="Initial values", - linestyle='None', + linestyle="None", ) # Plot optimised value @@ -197,14 +202,14 @@ def transform_array_of_values(list_of_values, parameter): transform_array_of_values([x_best[1]], parameters[names[1]]), "P", markersize=14, - markerfacecolor='k', - markeredgecolor='w', + markerfacecolor="k", + markeredgecolor="w", label="Final values", - linestyle='None', + linestyle="None", ) - plt.legend(ncols=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) - + plt.legend(ncols=2, loc="lower center", bbox_to_anchor=(0.5, 1.0)) + plt.tight_layout() if show: diff --git a/pybop/plot/matplotlib/convergence.py b/pybop/plot/matplotlib/convergence.py index 4ad2f31ad..2e53e4738 100644 --- a/pybop/plot/matplotlib/convergence.py +++ b/pybop/plot/matplotlib/convergence.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING -from pybop.plot.matplotlib.standard_plots import StandardPlot import matplotlib.pyplot as plt +from pybop.plot.matplotlib.standard_plots import StandardPlot + if TYPE_CHECKING: from pybop._result import Result @@ -47,4 +48,4 @@ def convergence(result: "Result", show=True): if show: plt.show() - return fig \ No newline at end of file + return fig diff --git a/pybop/plot/matplotlib/dataset.py b/pybop/plot/matplotlib/dataset.py index 44cd3ce25..7ade00091 100644 --- a/pybop/plot/matplotlib/dataset.py +++ b/pybop/plot/matplotlib/dataset.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt + from pybop.plot.matplotlib.standard_plots import StandardPlot, trajectories diff --git a/pybop/plot/matplotlib/nyquist.py b/pybop/plot/matplotlib/nyquist.py index c45b29beb..eba00efec 100644 --- a/pybop/plot/matplotlib/nyquist.py +++ b/pybop/plot/matplotlib/nyquist.py @@ -1,6 +1,7 @@ +from matplotlib import pyplot as plt + from pybop.parameters.parameter import Inputs from pybop.plot.matplotlib.standard_plots import StandardPlot -from matplotlib import pyplot as plt def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): @@ -58,7 +59,7 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): fig = plot_dict(show=False) plot_dict.traces[0].set_color("#00CC96") plot_dict.traces[0].set_linewidth(2) - plot_dict.traces[0].set_marker('.') + plot_dict.traces[0].set_marker(".") plot_dict.traces[0].set_markersize(8) target_trace = plot_dict.create_trace( @@ -66,17 +67,17 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): y=-target_output[var].imag, label="Reference", ) - target_trace.set_linestyle('None') - target_trace.set_marker('o') - target_trace.set_fillstyle('none') + target_trace.set_linestyle("None") + target_trace.set_marker("o") + target_trace.set_fillstyle("none") target_trace.set_markersize(8) target_trace.set_markeredgecolor("#636EFA") - # Layout - plt.title('Nyquist Plot', fontsize=14, x=0.2) + # Layout + plt.title("Nyquist Plot", fontsize=14, x=0.2) plt.xlabel(r"$Z_{re} / \Omega$", fontsize=16) plt.ylabel(r"$-Z_{im} / \Omega$", fontsize=16) - plt.legend(loc='upper right', bbox_to_anchor=(1, 1.08), ncols=2) + plt.legend(loc="upper right", bbox_to_anchor=(1, 1.08), ncols=2) if show: plt.show() diff --git a/pybop/plot/matplotlib/parameters.py b/pybop/plot/matplotlib/parameters.py index 671f00cb2..1c6e24e26 100644 --- a/pybop/plot/matplotlib/parameters.py +++ b/pybop/plot/matplotlib/parameters.py @@ -1,10 +1,10 @@ +import warnings from typing import TYPE_CHECKING -import warnings +import matplotlib.pyplot as plt from pybop.costs.log_likelihoods import GaussianLogLikelihood from pybop.plot.matplotlib.standard_plots import StandardSubplot -import matplotlib.pyplot as plt if TYPE_CHECKING: from pybop._result import Result @@ -50,7 +50,6 @@ def parameters(result: "Result", show=True, **layout_kwargs): axis_titles.append(("Evaluation", "Sigma")) trace_names.append("Sigma") - # Create a plot dictionary plot_dict = StandardSubplot( x=x, @@ -58,7 +57,7 @@ def parameters(result: "Result", show=True, **layout_kwargs): axis_titles=axis_titles, trace_names=trace_names, trace_name_width=50, - figsize= (18, 8), + figsize=(18, 8), ) plt.suptitle("Parameter Convergence") diff --git a/pybop/plot/matplotlib/problem.py b/pybop/plot/matplotlib/problem.py index 451ed76a1..da080207c 100644 --- a/pybop/plot/matplotlib/problem.py +++ b/pybop/plot/matplotlib/problem.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import numpy as np from pybop.costs.design_cost import DesignCost @@ -8,14 +9,12 @@ from pybop.problems.problem import Problem from pybop.simulators.solution import Solution -import matplotlib.pyplot as plt - def problem( problem: Problem, inputs: Inputs = None, show: bool = True, - title = 'Scatter Plot', + title="Scatter Plot", ): """ Produce a quick plot of the target dataset against optimised model output. @@ -75,15 +74,15 @@ def problem( y=target_output[var].data, label="Reference", marker=".", - linestyle="None" + linestyle="None", ) plot_dict.create_trace( x=model_domain, y=model_output[var].data, label="Optimised" if isinstance(problem.cost, DesignCost) else "Model", - marker="." if isinstance(problem, MetaProblem) else None, - linestyle='None' if isinstance(problem, MetaProblem) else "-", + marker="." if isinstance(problem, MetaProblem) else None, + linestyle="None" if isinstance(problem, MetaProblem) else "-", ) if isinstance(problem.cost, ErrorMeasure) and len( diff --git a/pybop/plot/matplotlib/samples.py b/pybop/plot/matplotlib/samples.py index dd14cf908..0d2ca526a 100644 --- a/pybop/plot/matplotlib/samples.py +++ b/pybop/plot/matplotlib/samples.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING import warnings +from typing import TYPE_CHECKING from matplotlib import pyplot as plt @@ -32,14 +32,12 @@ def trace(result: "SamplingResult", show=True, **kwargs): plt.legend(fontsize=12) figlist.append(fig) - if show: plt.show() else: return figlist - def chains(result: "SamplingResult", show=True, **kwargs): """ Plot posterior distributions for each chain. @@ -57,17 +55,20 @@ def chains(result: "SamplingResult", show=True, **kwargs): for i, chain in enumerate(result.chains): for j in range(chain.shape[1]): plt.hist( - x=chain[:, j], - label=f"Chain {i} - Parameter {j}", - alpha=0.5, - rwidth=2.0 + x=chain[:, j], label=f"Chain {i} - Parameter {j}", alpha=0.5, rwidth=2.0 ) for j in range(chain.shape[1]): - plt.plot([result.mean[j], result.mean[j]], [0, result.max[j]],"--", lw=3, label=f"Mean - Parameter {j}") + plt.plot( + [result.mean[j], result.mean[j]], + [0, result.max[j]], + "--", + lw=3, + label=f"Mean - Parameter {j}", + ) plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) - plt.grid(axis='y', zorder=-1) + plt.grid(axis="y", zorder=-1) plt.title("Posterior Distribution") plt.xlabel("Value") plt.ylabel("Density") @@ -77,6 +78,7 @@ def chains(result: "SamplingResult", show=True, **kwargs): else: return fig + def posterior(result: "SamplingResult", show=True, **kwargs): """ Plot the summed posterior distribution across chains. @@ -98,10 +100,10 @@ def posterior(result: "SamplingResult", show=True, **kwargs): label=f"Parameter {j}", alpha=0.75, ) - plt.axvline(result.mean[j], ls='--', c='k', lw=3) + plt.axvline(result.mean[j], ls="--", c="k", lw=3) plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) - plt.grid(axis='y', zorder=-1) + plt.grid(axis="y", zorder=-1) plt.title("Posterior Distribution") plt.xlabel("Value") plt.ylabel("Density") @@ -121,18 +123,24 @@ def summary_table(result: "SamplingResult"): header = ["Statistic", "Value"] values = [ - ["Mean", ', '.join(summary_stats["mean"].astype(str))], - ["Median", ', '.join(summary_stats["median"].astype(str))], - ["Standard Deviation", ', '.join(summary_stats["std"].astype(str))], - ["95% CI Lower", ', '.join(summary_stats["ci_lower"].astype(str))], - ["95% CI Upper", ', '.join(summary_stats["ci_upper"].astype(str))], + ["Mean", ", ".join(summary_stats["mean"].astype(str))], + ["Median", ", ".join(summary_stats["median"].astype(str))], + ["Standard Deviation", ", ".join(summary_stats["std"].astype(str))], + ["95% CI Lower", ", ".join(summary_stats["ci_lower"].astype(str))], + ["95% CI Upper", ", ".join(summary_stats["ci_upper"].astype(str))], ] fig, ax = plt.subplots(figsize=(6, 2), dpi=100) # hide axes - ax.axis('off') - ax.axis('tight') - ax.table(cellText=values, colLabels=header, loc='center', cellLoc='center', colColours=['lightsteelblue', 'lightsteelblue']) + ax.axis("off") + ax.axis("tight") + ax.table( + cellText=values, + colLabels=header, + loc="center", + cellLoc="center", + colColours=["lightsteelblue", "lightsteelblue"], + ) ax.set_title("Summary Statistics") fig.tight_layout() plt.show() diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py index 578ef365b..31e8898cc 100644 --- a/pybop/plot/matplotlib/standard_plots.py +++ b/pybop/plot/matplotlib/standard_plots.py @@ -3,11 +3,11 @@ import warnings import numpy as np - from matplotlib import pyplot as plt DEFAULT_TRACE_OPTIONS = dict(linewidth=2.0) + class StandardPlot: """ A class for creating and displaying interactive Plotly figures. @@ -57,9 +57,8 @@ def __init__( if trace_options: self.trace_options.update(trace_options) - # Parse the data - x, y = self.parse_data(x, y) + x, y = self.parse_data(x, y) self.x = x self.y = y # Check and wrap trace names @@ -68,12 +67,10 @@ def __init__( trace_names = [trace_names] for i, name in enumerate(trace_names): trace_names[i] = self.wrap_text(name, width=self.trace_name_width) - self.trace_names = trace_names + self.trace_names = trace_names self.fig = plt.figure(figsize=figsize, dpi=100) - - def __call__(self, show=True): """ Generate and show the figure. @@ -94,8 +91,8 @@ def __call__(self, show=True): def default_layout(self): - plt.tick_params(axis='both', labelsize=12) - plt.ticklabel_format(axis='both', style='sci', scilimits=(-4, 4)) + plt.tick_params(axis="both", labelsize=12) + plt.ticklabel_format(axis="both", style="sci", scilimits=(-4, 4)) def add_traces(self, x, y, trace_names=None, **trace_options): """ @@ -123,13 +120,14 @@ def add_traces(self, x, y, trace_names=None, **trace_options): label = None if trace_names is not None: - label = trace_names[i] + label = trace_names[i] - self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) + self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) if self.trace_names is not None: - plt.legend(**dict(loc="best", fontsize=12),) - + plt.legend( + **dict(loc="best", fontsize=12), + ) def parse_data(self, x, y): """ @@ -186,7 +184,7 @@ def create_trace(self, x, y, label, ax=None, **trace_options): ax = plt.gca() line = ax.plot( - x, + x, y, label=label, **trace_options, @@ -196,7 +194,6 @@ def create_trace(self, x, y, label, ax=None, **trace_options): else: return line[0] - @staticmethod def wrap_text(text, width): """ @@ -235,7 +232,6 @@ def remove_brackets(s): return s - class StandardSubplot(StandardPlot): """ A class for creating and displaying a set of interactive Plotly figures in a grid layout. @@ -275,11 +271,9 @@ def __init__( trace_options=DEFAULT_TRACE_OPTIONS, trace_names=None, trace_name_width=40, - figsize=(8, 6) + figsize=(8, 6), ): - super().__init__( - x, y, trace_options, trace_names, trace_name_width, figsize - ) + super().__init__(x, y, trace_options, trace_names, trace_name_width, figsize) self.num_traces = len(self.y) self.num_rows = num_rows self.num_cols = num_cols @@ -293,8 +287,6 @@ def __init__( self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) self.axis_titles = axis_titles - - def __call__(self, show): """ Generate and show the set of figures. @@ -305,30 +297,35 @@ def __call__(self, show): If True, the figure is shown upon creation (default: True). """ - color_cycle = plt.rcParams['axes.prop_cycle']() + color_cycle = plt.rcParams["axes.prop_cycle"]() xi = self.x[0] lines = [] for idx, yi in enumerate(self.y): - ax = self.fig.add_subplot(self.num_rows, self.num_cols, idx+1) + ax = self.fig.add_subplot(self.num_rows, self.num_cols, idx + 1) if self.axis_titles and idx < len(self.axis_titles): x_title, y_title = self.axis_titles[idx] ax.set_xlabel(x_title) ax.set_ylabel(y_title) - if len(self.x)>1: + if len(self.x) > 1: xi = self.x[idx] label = None if self.trace_names is not None: label = self.trace_names[idx] - lines.append(self.create_trace(xi, yi, label, ax = ax, **next(color_cycle))) - + lines.append(self.create_trace(xi, yi, label, ax=ax, **next(color_cycle))) lines_labels = [ax.get_legend_handles_labels() for ax in self.fig.axes] lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] if self.trace_names is not None: - self.fig.legend(lines, labels, loc='upper right', ncol=len(lines), bbox_to_anchor=(0.99, 0.95)) + self.fig.legend( + lines, + labels, + loc="upper right", + ncol=len(lines), + bbox_to_anchor=(0.99, 0.95), + ) plt.tight_layout(rect=[0, 0, 1, 0.95]) if show: plt.show() @@ -336,7 +333,16 @@ def __call__(self, show): return self.fig -def trajectories(x, y, trace_names=None, show=True, xaxis_title='', yaxis_title='', title='', **layout_kwargs): +def trajectories( + x, + y, + trace_names=None, + show=True, + xaxis_title="", + yaxis_title="", + title="", + **layout_kwargs, +): """ Quickly plot one or more trajectories using Plotly. @@ -383,4 +389,4 @@ def trajectories(x, y, trace_names=None, show=True, xaxis_title='', yaxis_title= if show: plt.show() - return plot_dict \ No newline at end of file + return plot_dict diff --git a/pybop/plot/matplotlib/voronoi.py b/pybop/plot/matplotlib/voronoi.py index 1a44f0f77..e7502b941 100644 --- a/pybop/plot/matplotlib/voronoi.py +++ b/pybop/plot/matplotlib/voronoi.py @@ -1,23 +1,23 @@ +import warnings from typing import TYPE_CHECKING +import matplotlib as mpl import numpy as np -from scipy.spatial import cKDTree from matplotlib import pyplot as plt -import matplotlib as mpl -import warnings if TYPE_CHECKING: from pybop._result import Result from pybop.plot.voronoi import voronoi_data + def surface( result: "Result", bounds=None, normalise=True, - title='Voronoi Cost Landscape', + title="Voronoi Cost Landscape", show=True, - **layout_kwargs + **layout_kwargs, ): """ Plot a 2D representation of the Voronoi diagram with color-coded regions. @@ -60,8 +60,9 @@ def surface( bounds if bounds is not None else [param.bounds for param in parameters] )[:2] - _, _, f, regions, relative_sizes = voronoi_data(xlim, ylim, x_optim, y_optim, f, normalise) - + _, _, f, regions, relative_sizes = voronoi_data( + xlim, ylim, x_optim, y_optim, f, normalise + ) # Construct figure plt.figure(figsize=(7, 6), dpi=100) @@ -73,26 +74,25 @@ def surface( norm_f = norm(f, clip=True) # get colours - cmap = mpl.colormaps['viridis'] + cmap = mpl.colormaps["viridis"] colors = cmap(norm_f) - # Add Voronoi edges and fill Voronoi regions + # Add Voronoi edges and fill Voronoi regions for j, (region, size) in enumerate(zip(regions, relative_sizes, strict=False)): x_region = region[:, 0].tolist() + [region[0, 0]] y_region = region[:, 1].tolist() + [region[0, 1]] plt.fill(x_region, y_region, color=colors[j]) - plt.plot(x_region, y_region, color='w', linewidth=0.5 + size*0.1) + plt.plot(x_region, y_region, color="w", linewidth=0.5 + size * 0.1) + + plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca()) - plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax = plt.gca()) - - # Add original points plt.scatter( x_optim, y_optim, c=[i / len(x_optim) for i in range(len(x_optim))], - cmap='Grays', + cmap="Grays", zorder=2.5, ) @@ -102,12 +102,12 @@ def surface( plt.plot( [x0[0]], [x0[1]], - 'X', + "X", markersize=14, - markerfacecolor='w', - markeredgecolor='k', + markerfacecolor="w", + markeredgecolor="k", label="Initial values", - linestyle='None', + linestyle="None", zorder=2.6, ) @@ -119,21 +119,20 @@ def surface( [x_best[1]], "P", markersize=14, - markerfacecolor='k', - markeredgecolor='w', + markerfacecolor="k", + markeredgecolor="w", label="Final values", - linestyle='None', + linestyle="None", zorder=2.6, ) - - # Layout + # Layout names = result.problem.parameters.names plt.xlabel(names[0], labelpad=15) - plt.ticklabel_format(axis='both', **dict(style='sci',scilimits=(-4,4))) + plt.ticklabel_format(axis="both", **dict(style="sci", scilimits=(-4, 4))) plt.ylabel(names[1], labelpad=15) plt.title(title, pad=40) - plt.legend(ncols=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) + plt.legend(ncols=2, loc="lower center", bbox_to_anchor=(0.5, 1.0)) plt.xlim(xlim[0], xlim[1]) plt.ylim(ylim[0], ylim[1]) plt.tight_layout() diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py index 0a3a1f91a..d3560892a 100644 --- a/pybop/plot/plotly/__init__.py +++ b/pybop/plot/plotly/__init__.py @@ -7,4 +7,4 @@ from .problem import problem from .nyquist import nyquist from .voronoi import surface -from .samples import chains, posterior, summary_table, trace \ No newline at end of file +from .samples import chains, posterior, summary_table, trace diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py index 8e3961d71..cdd0d8341 100644 --- a/pybop/plot/plotly/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -1,8 +1,8 @@ import math import textwrap +import warnings import numpy as np -import warnings from pybop.plot.plotly.plotly_manager import PlotlyManager diff --git a/pybop/plot/plotly/voronoi.py b/pybop/plot/plotly/voronoi.py index dc1e5ea1e..d60b9597e 100644 --- a/pybop/plot/plotly/voronoi.py +++ b/pybop/plot/plotly/voronoi.py @@ -1,13 +1,12 @@ from typing import TYPE_CHECKING import numpy as np -from scipy.spatial import Voronoi, cKDTree +from scipy.spatial import cKDTree if TYPE_CHECKING: from pybop._result import Result -from pybop.plot.plotly.plotly_manager import PlotlyManager from pybop.plot import voronoi_data - +from pybop.plot.plotly.plotly_manager import PlotlyManager def assign_nearest_value(x, y, f, xi, yi): @@ -85,14 +84,15 @@ def surface( bounds if bounds is not None else [param.bounds for param in parameters] )[:2] - x, y, f, regions, relative_sizes = voronoi_data(xlim, ylim, x_optim, y_optim, f, normalise) + x, y, f, regions, relative_sizes = voronoi_data( + xlim, ylim, x_optim, y_optim, f, normalise + ) # Create a grid for plot xi = np.linspace(xlim[0], xlim[1], resolution) yi = np.linspace(ylim[0], ylim[1], resolution) xi, yi = np.meshgrid(xi, yi) - if normalise: # Create a normalised grid norm_xi = np.linspace(0, 1, resolution) @@ -213,4 +213,4 @@ def surface( ) fig.update_layout(**layout_kwargs) if show: - fig.show() \ No newline at end of file + fig.show() diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index daadc5d11..83598f4de 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -1,21 +1,22 @@ from typing import TYPE_CHECKING + import numpy as np if TYPE_CHECKING: from pybop._result import Result from pybop.samplers.base_pints_sampler import SamplingResult - + from pybop.parameters.parameter import Inputs +from pybop.plot.util import call_plotting_function from pybop.problems.problem import Problem -from pybop.plot.util import call_plotting_function, get_class - def chains(result: "SamplingResult", show=True, backend=None, **kwargs): """ Plot posterior distributions for each chain. """ - return call_plotting_function('chains', backend, result=result, **kwargs) + return call_plotting_function("chains", backend, result=result, **kwargs) + def contour( call_object: "Problem | Result", @@ -24,7 +25,7 @@ def contour( transformed: bool = False, steps: int = 10, show: bool = True, - backend = None, + backend=None, **layout_kwargs, ): """ @@ -65,7 +66,18 @@ def contour( ValueError If the cost function does not return a valid cost when called with a parameter list. """ - return call_plotting_function('contour', backend, call_object=call_object, gradient=gradient, bounds=bounds, transformed=transformed, steps=steps, show=show, **layout_kwargs) + return call_plotting_function( + "contour", + backend, + call_object=call_object, + gradient=gradient, + bounds=bounds, + transformed=transformed, + steps=steps, + show=show, + **layout_kwargs, + ) + def convergence(result: "Result", show=True, backend=None, **layout_kwargs): """ @@ -87,9 +99,14 @@ def convergence(result: "Result", show=True, backend=None, **layout_kwargs): fig : plotly.graph_objs.Figure The Plotly figure object for the convergence plot. """ - return call_plotting_function('convergence', backend, result = result, show=show, **layout_kwargs) + return call_plotting_function( + "convergence", backend, result=result, show=show, **layout_kwargs + ) + -def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs): +def dataset( + dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs +): """ Quickly plot a PyBOP Dataset using Plotly. @@ -113,7 +130,16 @@ def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **l plotly.graph_objs.Figure The Plotly figure object for the scatter plot. """ - call_plotting_function('dataset', backend, dataset=dataset, signal=signal, trace_names=trace_names, show=show, **layout_kwargs) + call_plotting_function( + "dataset", + backend, + dataset=dataset, + signal=signal, + trace_names=trace_names, + show=show, + **layout_kwargs, + ) + def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): """ @@ -152,7 +178,10 @@ def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kw >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. """ - return call_plotting_function('nyquist', backend, problem=problem, inputs=inputs, show=show, **layout_kwargs) + return call_plotting_function( + "nyquist", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs + ) + def parameters(result: "Result", show=True, backend=None, **layout_kwargs): """ @@ -174,13 +203,17 @@ def parameters(result: "Result", show=True, backend=None, **layout_kwargs): plotly.graph_objs.Figure A Plotly figure object showing the parameter evolution over iterations. """ - return call_plotting_function('parameters', backend, result = result, show=show, **layout_kwargs) + return call_plotting_function( + "parameters", backend, result=result, show=show, **layout_kwargs + ) + def posterior(result: "SamplingResult", show=True, backend=None, **kwargs): """ Plot the summed posterior distribution across chains. """ - return call_plotting_function('posterior', backend, result=result, **kwargs) + return call_plotting_function("posterior", backend, result=result, **kwargs) + def problem( problem: Problem, @@ -213,14 +246,18 @@ def problem( plotly.graph_objs.Figure The Plotly figure object for the scatter plot. """ - return call_plotting_function('problem', backend, problem=problem, inputs=inputs, show=show, **layout_kwargs) + return call_plotting_function( + "problem", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs + ) + def summary_table(result: "SamplingResult", backend=None): """ Display summary statistics in a table. """ - return call_plotting_function('summary_table', backend, result=result) + return call_plotting_function("summary_table", backend, result=result) + def surface( result: "Result", @@ -253,13 +290,24 @@ def surface( e.g. `xaxis_title="Time [s]"` or `xaxis={"title": "Time [s]", font={"size":14}}` """ - return call_plotting_function('surface', backend, result = result, bounds = bounds, normalise=normalise, resolution=resolution, show=show, **layout_kwargs) + return call_plotting_function( + "surface", + backend, + result=result, + bounds=bounds, + normalise=normalise, + resolution=resolution, + show=show, + **layout_kwargs, + ) + def trace(result: "SamplingResult", backend=None, **kwargs): """ Plot trace plots for the posterior samples. """ - return call_plotting_function('trace', backend, result=result, **kwargs) + return call_plotting_function("trace", backend, result=result, **kwargs) + def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): """ @@ -284,4 +332,12 @@ def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwarg The Plotly figure object for the scatter plot. """ - return call_plotting_function('trajectories', backend, x=x, y=y, trace_names=trace_names, show=show, **layout_kwargs) \ No newline at end of file + return call_plotting_function( + "trajectories", + backend, + x=x, + y=y, + trace_names=trace_names, + show=show, + **layout_kwargs, + ) diff --git a/pybop/plot/util.py b/pybop/plot/util.py index 3ec299ce3..12b66fc26 100644 --- a/pybop/plot/util.py +++ b/pybop/plot/util.py @@ -1,41 +1,48 @@ import importlib.util + import pybop.plot + def get_class(class_name): err_msg = f"Plotting backend {pybop.plot.backend} is not available." try: - module = importlib.import_module('pybop.plot.' + pybop.plot.backend) + module = importlib.import_module("pybop.plot." + pybop.plot.backend) if hasattr(module, class_name): return getattr(module, class_name) else: - err_msg = f"Plotting backend {pybop.plot.backend} has no attribute {class_name}." + err_msg = ( + f"Plotting backend {pybop.plot.backend} has no attribute {class_name}." + ) raise ModuleNotFoundError(err_msg) - + except ModuleNotFoundError as error: # Raise an ModuleNotFoundError if the module or attribute is not available raise ModuleNotFoundError(err_msg) from error + def set_backend(backend): - err_msg = f"Plotting backend {backend} is not available. The default backend has not been updated. \n"\ + err_msg = ( + f"Plotting backend {backend} is not available. The default backend has not been updated. \n" f"The default backend is set to {pybop.plot.backend}" + ) try: - importlib.import_module('pybop.plot.' + backend) + importlib.import_module("pybop.plot." + backend) pybop.plot.backend = backend - for attr in ['StandardPlot', 'StandardSubplot']: + for attr in ["StandardPlot", "StandardSubplot"]: setattr(pybop.plot, attr, get_class(attr)) except ModuleNotFoundError as error: - # Raise an ModuleNotFoundError if the module or attribute is not available - raise ModuleNotFoundError(err_msg) from error - + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + def call_plotting_function(function_name, backend, **kwargs): if backend is None: - backend = pybop.plot.backend + backend = pybop.plot.backend err_msg = f"Plotting backend {backend} is not available." try: - module = importlib.import_module('pybop.plot.' + backend) + module = importlib.import_module("pybop.plot." + backend) if hasattr(module, function_name): plotting_function = getattr(module, function_name) # Return the imported attribute @@ -43,8 +50,7 @@ def call_plotting_function(function_name, backend, **kwargs): else: err_msg = f"Plotting backend {backend} has no attribute {function_name}." raise ModuleNotFoundError(err_msg) - + except ModuleNotFoundError as error: # Raise an ModuleNotFoundError if the module or attribute is not available raise ModuleNotFoundError(err_msg) from error - From f0e5566fccc4adfd18843e2d7d952b7275a8188d Mon Sep 17 00:00:00 2001 From: u2370093 Date: Wed, 8 Apr 2026 14:45:31 +0100 Subject: [PATCH 03/10] start cleaning up StandardPlot --- .../using_transformations.ipynb | 2 +- pybop/plot/__init__.py | 9 +- pybop/plot/matplotlib/__init__.py | 2 +- pybop/plot/matplotlib/nyquist.py | 26 +- pybop/plot/matplotlib/parameters.py | 3 +- pybop/plot/matplotlib/standard_plots.py | 200 ++++----------- pybop/plot/plotly/__init__.py | 2 +- pybop/plot/plotly/parameters.py | 3 +- pybop/plot/plotly/standard_plots.py | 148 ++---------- pybop/plot/plots.py | 34 --- pybop/plot/standard_plots.py | 228 ++++++++++++++++++ pybop/plot/util.py | 20 -- tests/plotting/test_plotly_manager.py | 6 + 13 files changed, 321 insertions(+), 362 deletions(-) create mode 100644 pybop/plot/standard_plots.py diff --git a/examples/notebooks/getting_started/using_transformations.ipynb b/examples/notebooks/getting_started/using_transformations.ipynb index 2de75b771..5d5d9a47f 100644 --- a/examples/notebooks/getting_started/using_transformations.ipynb +++ b/examples/notebooks/getting_started/using_transformations.ipynb @@ -30,7 +30,7 @@ "import pybop\n", "\n", "pybop.plot.set_backend(\"plotly\")\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 744e749ab..29116410e 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -2,7 +2,7 @@ DEFAULT_BACKEND = 'matplotlib' backend=DEFAULT_BACKEND -from .util import set_backend, call_plotting_function, get_class +from .util import set_backend, call_plotting_function # # Import plots @@ -18,13 +18,10 @@ problem, summary_table, surface, - trace, - trajectories + trace ) +from .standard_plots import StandardPlot, StandardSubplot, trajectories from .voronoi import voronoi_data, _voronoi_regions from . import matplotlib from . import plotly - -StandardPlot = matplotlib.StandardPlot -StandardSubplot = matplotlib.StandardSubplot diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py index be1f45a58..21da2259a 100644 --- a/pybop/plot/matplotlib/__init__.py +++ b/pybop/plot/matplotlib/__init__.py @@ -1,4 +1,4 @@ -from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .standard_plots import Plotter, SubplotPlotter, trajectories from .dataset import dataset from .convergence import convergence from .parameters import parameters diff --git a/pybop/plot/matplotlib/nyquist.py b/pybop/plot/matplotlib/nyquist.py index eba00efec..c702e9049 100644 --- a/pybop/plot/matplotlib/nyquist.py +++ b/pybop/plot/matplotlib/nyquist.py @@ -56,23 +56,27 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): trace_names="Model", ) - fig = plot_dict(show=False) - plot_dict.traces[0].set_color("#00CC96") - plot_dict.traces[0].set_linewidth(2) - plot_dict.traces[0].set_marker(".") - plot_dict.traces[0].set_markersize(8) + plotting_options = dict(color="#00CC96", linewidth=2, marker=".", markersize=8) + plot_dict.traces[0].update(plotting_options) target_trace = plot_dict.create_trace( x=target_output[var].real, y=-target_output[var].imag, label="Reference", ) - target_trace.set_linestyle("None") - target_trace.set_marker("o") - target_trace.set_fillstyle("none") - target_trace.set_markersize(8) - target_trace.set_markeredgecolor("#636EFA") + plotting_options = dict( + linestyle="none", + marker="o", + fillstyle="none", + markersize=8, + markeredgecolor="#636EFA", + ) + target_trace.update(plotting_options) + + plot_dict.traces.append(target_trace) + + fig = plot_dict(show=False) # Layout plt.title("Nyquist Plot", fontsize=14, x=0.2) plt.xlabel(r"$Z_{re} / \Omega$", fontsize=16) @@ -80,7 +84,7 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): plt.legend(loc="upper right", bbox_to_anchor=(1, 1.08), ncols=2) if show: - plt.show() + fig.show() figure_list.append(fig) diff --git a/pybop/plot/matplotlib/parameters.py b/pybop/plot/matplotlib/parameters.py index 1c6e24e26..eef582f0e 100644 --- a/pybop/plot/matplotlib/parameters.py +++ b/pybop/plot/matplotlib/parameters.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from pybop.costs.log_likelihoods import GaussianLogLikelihood -from pybop.plot.matplotlib.standard_plots import StandardSubplot +from pybop.plot.standard_plots import StandardSubplot if TYPE_CHECKING: from pybop._result import Result @@ -58,6 +58,7 @@ def parameters(result: "Result", show=True, **layout_kwargs): trace_names=trace_names, trace_name_width=50, figsize=(18, 8), + backend="matplotlib", ) plt.suptitle("Parameter Convergence") diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py index 31e8898cc..e96cf19af 100644 --- a/pybop/plot/matplotlib/standard_plots.py +++ b/pybop/plot/matplotlib/standard_plots.py @@ -1,14 +1,13 @@ -import math -import textwrap import warnings -import numpy as np from matplotlib import pyplot as plt +from pybop.plot import StandardPlot + DEFAULT_TRACE_OPTIONS = dict(linewidth=2.0) -class StandardPlot: +class Plotter: """ A class for creating and displaying interactive Plotly figures. @@ -33,14 +32,11 @@ class StandardPlot: def __init__( self, - x=None, - y=None, trace_options=None, - trace_names=None, - trace_name_width=20, figsize=(8, 6), **kwargs, ): + self.backend = "matplotlib" # Warning if layout arguments ignored if len(kwargs) > 0: warnings.warn( @@ -49,27 +45,14 @@ def __init__( UserWarning, stacklevel=2, ) - self.traces = [] - self.trace_name_width = trace_name_width # Set default trace options and update if provided self.trace_options = DEFAULT_TRACE_OPTIONS.copy() if trace_options: self.trace_options.update(trace_options) - # Parse the data - x, y = self.parse_data(x, y) - self.x = x - self.y = y - # Check and wrap trace names - if trace_names is not None: - if isinstance(trace_names, str): - trace_names = [trace_names] - for i, name in enumerate(trace_names): - trace_names[i] = self.wrap_text(name, width=self.trace_name_width) - self.trace_names = trace_names - self.fig = plt.figure(figsize=figsize, dpi=100) + self.traces = [] def __call__(self, show=True): """ @@ -81,19 +64,28 @@ def __call__(self, show=True): If True, the figure is shown upon creation (default: True). """ # Add traces - if self.x is not None and self.y is not None: - self.add_traces(self.x, self.y, self.trace_names) - self.default_layout() - if show: - plt.show() - - return self.fig - - def default_layout(self): + for trace in self.traces: + self._plot_trace(**trace) plt.tick_params(axis="both", labelsize=12) plt.ticklabel_format(axis="both", style="sci", scilimits=(-4, 4)) + labels_in_fig = True + for ax in self.fig.axes: + if not ax.get_legend_handles_labels() == ([], []): + break + else: + labels_in_fig = False + if labels_in_fig: + plt.legend( + **dict(loc="best", fontsize=12), + ) + + if show: + plt.show() + else: + return self.fig + def add_traces(self, x, y, trace_names=None, **trace_options): """ Add a set of traces to the plot dictionary. @@ -124,62 +116,21 @@ def add_traces(self, x, y, trace_names=None, **trace_options): self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) - if self.trace_names is not None: - plt.legend( - **dict(loc="best", fontsize=12), - ) - - def parse_data(self, x, y): - """ - Check the type and dimensions of the data and convert if necessary to a list - of 'things plotly can take', e.g. numpy arrays or lists of numbers. - - Parameters - ---------- - x : list or np.ndarray, optional - X-axis data points. - y : list or np.ndarray, optional - Primary Y-axis data points for simulated model output. - """ - if x is None or y is None: - return None, None - if isinstance(x, list): - # If it's a list of numpy arrays, it's fine - # If it's a list of lists, it's fine - # If it's neither, it's a list of numbers that we need to wrap - if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list): - x = [x] - elif isinstance(x, np.ndarray): - x = np.squeeze(x) - if x.ndim == 1: - x = [x] - else: - x = x.tolist() - if isinstance(y, list): - if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list): - y = [y] - if isinstance(y, np.ndarray): - y = np.squeeze(y) - if y.ndim == 1: - y = [y] - else: - y = y.tolist() - if len(x) > 1 and len(x) != len(y): - raise ValueError( - "Input x should have either one data series or the same number as y." - ) - return x, y - def create_trace(self, x, y, label, ax=None, **trace_options): """ - Create a trace for the Plotly figure. + Add line to plot. Returns ------- plotly.graph_objs.Scatter A trace for a Plotly figure. """ + size = min(len(x), len(y)) + trace = dict(x=x[:size], y=y[:size], label=label, ax=ax) + trace.update(trace_options) + return trace + def _plot_trace(self, x, y, label, ax=None, **trace_options): if ax is None: ax = plt.gca() @@ -189,50 +140,14 @@ def create_trace(self, x, y, label, ax=None, **trace_options): label=label, **trace_options, ) + if len(line) > 1: return line else: return line[0] - @staticmethod - def wrap_text(text, width): - """ - Wrap text to a specified width with HTML line breaks. - - Parameters - ---------- - text : str - The text to wrap. - width : int - The width to wrap the text to. - - Returns - ------- - str - The wrapped text. - """ - wrapped_text = textwrap.fill(text, width=width, break_long_words=False) - return wrapped_text - @staticmethod - def remove_brackets(s): - """ - Remove square brackets from a string and replace with forward slashes - as per section 7.1 of the SI Handbook - """ - # If s is an iterable (but not a string), apply the function recursively to each element - if hasattr(s, "__iter__") and not isinstance(s, str): - return type(s)(StandardPlot.remove_brackets(i) for i in s) - elif isinstance(s, str): - start = s.find("[") - end = s.find("]") - if start != -1 and end != -1: - char_in_brackets = s[start + 1 : end] - return s[:start] + " / " + char_in_brackets + s[end + 1 :] - return s - - -class StandardSubplot(StandardPlot): +class SubplotPlotter(Plotter): """ A class for creating and displaying a set of interactive Plotly figures in a grid layout. @@ -263,31 +178,15 @@ class StandardSubplot(StandardPlot): def __init__( self, - x, - y, - num_rows=None, - num_cols=None, axis_titles=None, trace_options=DEFAULT_TRACE_OPTIONS, - trace_names=None, - trace_name_width=40, figsize=(8, 6), + **kwargs, ): - super().__init__(x, y, trace_options, trace_names, trace_name_width, figsize) - self.num_traces = len(self.y) - self.num_rows = num_rows - self.num_cols = num_cols - if self.num_rows is None and self.num_cols is None: - # Work out the number of subplots - self.num_cols = int(math.ceil(math.sqrt(self.num_traces))) - self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) - elif self.num_rows is None: - self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) - elif self.num_cols is None: - self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) + super().__init__(trace_options, figsize, **kwargs) self.axis_titles = axis_titles - def __call__(self, show): + def __call__(self, show=True, num_rows=1, num_cols=1): """ Generate and show the set of figures. @@ -299,26 +198,23 @@ def __call__(self, show): color_cycle = plt.rcParams["axes.prop_cycle"]() - xi = self.x[0] lines = [] - for idx, yi in enumerate(self.y): - ax = self.fig.add_subplot(self.num_rows, self.num_cols, idx + 1) + show_legend = False + for idx, trace in enumerate(self.traces): + ax = self.fig.add_subplot(num_rows, num_cols, idx + 1) + trace["ax"] = ax if self.axis_titles and idx < len(self.axis_titles): x_title, y_title = self.axis_titles[idx] ax.set_xlabel(x_title) ax.set_ylabel(y_title) - if len(self.x) > 1: - xi = self.x[idx] - - label = None - if self.trace_names is not None: - label = self.trace_names[idx] + if "label" in trace.keys() and trace["label"] is not None: + show_legend = True - lines.append(self.create_trace(xi, yi, label, ax=ax, **next(color_cycle))) + lines.append(self._plot_trace(**trace, **next(color_cycle))) lines_labels = [ax.get_legend_handles_labels() for ax in self.fig.axes] - lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] - if self.trace_names is not None: + lines, labels = [sum(lol, []) for lol in zip(*lines_labels, strict=False)] + if show_legend: self.fig.legend( lines, labels, @@ -374,11 +270,7 @@ def trajectories( stacklevel=2, ) # Create a plot dictionary - plot_dict = StandardPlot( - x=x, - y=y, - trace_names=trace_names, - ) + plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="matplotlib") # Generate the figure and update the layout fig = plot_dict(show=False) @@ -387,6 +279,6 @@ def trajectories( plt.ylabel(yaxis_title, fontsize=12) plt.tight_layout() if show: - plt.show() + fig.show() - return plot_dict + return fig diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py index d3560892a..b3b907abd 100644 --- a/pybop/plot/plotly/__init__.py +++ b/pybop/plot/plotly/__init__.py @@ -1,5 +1,5 @@ from .plotly_manager import PlotlyManager -from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .standard_plots import Plotter, SubplotPlotter, trajectories from .contour import contour from .dataset import dataset from .convergence import convergence diff --git a/pybop/plot/plotly/parameters.py b/pybop/plot/plotly/parameters.py index 02cc281f0..e93907248 100644 --- a/pybop/plot/plotly/parameters.py +++ b/pybop/plot/plotly/parameters.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from pybop.costs.log_likelihoods import GaussianLogLikelihood -from pybop.plot.plotly.standard_plots import StandardSubplot +from pybop.plot.standard_plots import StandardSubplot if TYPE_CHECKING: from pybop._result import Result @@ -59,6 +59,7 @@ def parameters(result: "Result", show=True, **layout_kwargs): layout_options=layout_options, trace_names=trace_names, trace_name_width=50, + backend="plotly", ) # Generate the figure and update the layout diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py index cdd0d8341..62cd708a2 100644 --- a/pybop/plot/plotly/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -1,9 +1,6 @@ -import math -import textwrap import warnings -import numpy as np - +from pybop.plot import StandardPlot from pybop.plot.plotly.plotly_manager import PlotlyManager DEFAULT_LAYOUT_OPTIONS = dict( @@ -36,7 +33,7 @@ DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines") -class StandardPlot: +class Plotter: """ A class for creating and displaying interactive Plotly figures. @@ -65,13 +62,9 @@ class StandardPlot: def __init__( self, - x=None, - y=None, layout=None, layout_options=None, trace_options=None, - trace_names=None, - trace_name_width=40, **kwargs, ): # Warning if layout arguments ignored @@ -82,9 +75,11 @@ def __init__( UserWarning, stacklevel=2, ) + + self.backend = "plotly" + self.traces = [] self.layout = layout - self.trace_name_width = trace_name_width # Set default layout options and update if provided if self.layout is None: @@ -104,10 +99,6 @@ def __init__( if self.layout is None: self.layout = self.go.Layout(**self.layout_options) - # Add traces - if x is not None and y is not None: - self.add_traces(x, y, trace_names) - def __call__(self, show=True): """ Generate and show the figure. @@ -139,16 +130,6 @@ def add_traces(self, x, y, trace_names=None, **trace_options): options = self.trace_options.copy() options.update(trace_options) - # Check and wrap trace names - if trace_names is not None: - if isinstance(trace_names, str): - trace_names = [trace_names] - for i, name in enumerate(trace_names): - trace_names[i] = self.wrap_text(name, width=self.trace_name_width) - - # Parse the data - x, y = self.parse_data(x, y) - # Create a trace for each trajectory xi = x[0] for i in range(0, len(y)): @@ -162,45 +143,6 @@ def add_traces(self, x, y, trace_names=None, **trace_options): trace = self.create_trace(xi, y[i], **trace_options) self.traces.append(trace) - def parse_data(self, x, y): - """ - Check the type and dimensions of the data and convert if necessary to a list - of 'things plotly can take', e.g. numpy arrays or lists of numbers. - - Parameters - ---------- - x : list or np.ndarray, optional - X-axis data points. - y : list or np.ndarray, optional - Primary Y-axis data points for simulated model output. - """ - if isinstance(x, list): - # If it's a list of numpy arrays, it's fine - # If it's a list of lists, it's fine - # If it's neither, it's a list of numbers that we need to wrap - if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list): - x = [x] - elif isinstance(x, np.ndarray): - x = np.squeeze(x) - if x.ndim == 1: - x = [x] - else: - x = x.tolist() - if isinstance(y, list): - if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list): - y = [y] - if isinstance(y, np.ndarray): - y = np.squeeze(y) - if y.ndim == 1: - y = [y] - else: - y = y.tolist() - if len(x) > 1 and len(x) != len(y): - raise ValueError( - "Input x should have either one data series or the same number as y." - ) - return x, y - def create_trace(self, x, y, **trace_options): """ Create a trace for the Plotly figure. @@ -217,45 +159,8 @@ def create_trace(self, x, y, **trace_options): **trace_options, ) - @staticmethod - def wrap_text(text, width): - """ - Wrap text to a specified width with HTML line breaks. - - Parameters - ---------- - text : str - The text to wrap. - width : int - The width to wrap the text to. - - Returns - ------- - str - The wrapped text. - """ - wrapped_text = textwrap.fill(text, width=width, break_long_words=False) - return wrapped_text.replace("\n", "
") - @staticmethod - def remove_brackets(s): - """ - Remove square brackets from a string and replace with forward slashes - as per section 7.1 of the SI Handbook - """ - # If s is an iterable (but not a string), apply the function recursively to each element - if hasattr(s, "__iter__") and not isinstance(s, str): - return type(s)(StandardPlot.remove_brackets(i) for i in s) - elif isinstance(s, str): - start = s.find("[") - end = s.find("]") - if start != -1 and end != -1: - char_in_brackets = s[start + 1 : end] - return s[:start] + " / " + char_in_brackets + s[end + 1 :] - return s - - -class StandardSubplot(StandardPlot): +class SubplotPlotter(Plotter): """ A class for creating and displaying a set of interactive Plotly figures in a grid layout. @@ -288,33 +193,14 @@ class StandardSubplot(StandardPlot): def __init__( self, - x, - y, - num_rows=None, - num_cols=None, axis_titles=None, layout=None, layout_options=DEFAULT_LAYOUT_OPTIONS, subplot_options=DEFAULT_SUBPLOT_OPTIONS, trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS, - trace_names=None, - trace_name_width=40, + **kwargs, ): - super().__init__( - x, y, layout, layout_options, trace_options, trace_names, trace_name_width - ) - self.num_traces = len(self.traces) - self.num_rows = num_rows - self.num_cols = num_cols - if self.num_rows is None and self.num_cols is None: - # Work out the number of subplots - self.num_cols = int(math.ceil(math.sqrt(self.num_traces))) - self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) - elif self.num_rows is None: - self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) - elif self.num_cols is None: - self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) - self.axis_titles = axis_titles + super().__init__(layout, layout_options, trace_options, **kwargs) self.subplot_options = subplot_options.copy() if subplot_options is not None: for arg, value in subplot_options.items(): @@ -323,7 +209,9 @@ def __init__( # Attempt to import plotly when an instance is created self.make_subplots = PlotlyManager().make_subplots - def __call__(self, show): + self.axis_titles = axis_titles + + def __call__(self, show=True, num_rows=1, num_cols=1): """ Generate and show the set of figures. @@ -333,8 +221,8 @@ def __call__(self, show): If True, the figure is shown upon creation (default: True). """ fig = self.make_subplots( - rows=self.num_rows, - cols=self.num_cols, + rows=num_rows, + cols=num_cols, horizontal_spacing=0.1, vertical_spacing=0.15, **self.subplot_options, @@ -342,8 +230,8 @@ def __call__(self, show): fig.update_layout(self.layout_options) for idx, trace in enumerate(self.traces): - row = (idx // self.num_cols) + 1 - col = (idx % self.num_cols) + 1 + row = (idx // num_cols) + 1 + col = (idx % num_cols) + 1 fig.add_trace(trace, row=row, col=col) if self.axis_titles and idx < len(self.axis_titles): @@ -386,11 +274,7 @@ def trajectories(x, y, trace_names=None, show=True, **layout_kwargs): The Plotly figure object for the scatter plot. """ # Create a plot dictionary - plot_dict = StandardPlot( - x=x, - y=y, - trace_names=trace_names, - ) + plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="plotly") # Generate the figure and update the layout fig = plot_dict(show=False) diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index 83598f4de..95fb1b701 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -307,37 +307,3 @@ def trace(result: "SamplingResult", backend=None, **kwargs): Plot trace plots for the posterior samples. """ return call_plotting_function("trace", backend, result=result, **kwargs) - - -def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): - """ - Quickly plot one or more trajectories using Plotly. - - Parameters - ---------- - x : list or np.ndarray - X-axis data points. - y : list or np.ndarray - Y-axis data points for each trajectory. - trace_names : list or str, optional - Name(s) for the trace(s) (default: None). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time / s"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object for the scatter plot. - """ - - return call_plotting_function( - "trajectories", - backend, - x=x, - y=y, - trace_names=trace_names, - show=show, - **layout_kwargs, - ) diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py new file mode 100644 index 000000000..86111bbae --- /dev/null +++ b/pybop/plot/standard_plots.py @@ -0,0 +1,228 @@ +import math +import textwrap + +import numpy as np +import warnings + +from pybop.plot.util import call_plotting_function + +class StandardPlot: + def __init__( + self, + x=None, + y=None, + trace_options=None, + trace_names=None, + trace_name_width=20, + backend = None, + **kwargs): + + self.plotter = call_plotting_function('Plotter', backend, trace_options=trace_options, **kwargs) + + self.trace_name_width = trace_name_width + + # Add traces + if x is not None and y is not None: + self.add_traces(x, y, trace_names) + + def __call__(self, show=True): + return self.plotter(show=show) + + @property + def traces(self): + return self.plotter.traces + + @traces.setter + def traces(self, value): + self.plotter.traces = value + + def add_traces(self, x, y, trace_names): + # Check and wrap trace names + if trace_names is not None: + if isinstance(trace_names, str): + trace_names = [trace_names] + for i, name in enumerate(trace_names): + trace_names[i] = self.wrap_text(name, width=self.trace_name_width, backend=self.plotter.backend) + + # Parse the data + x, y = self.parse_data(x, y) + + # Add traces + self.plotter.add_traces(x, y, trace_names) + + + def parse_data(self, x, y): + """ + Check the type and dimensions of the data and convert if necessary to a list + of 'things plotly can take', e.g. numpy arrays or lists of numbers. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + """ + if isinstance(x, list): + # If it's a list of numpy arrays, it's fine + # If it's a list of lists, it's fine + # If it's neither, it's a list of numbers that we need to wrap + if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list): + x = [x] + elif isinstance(x, np.ndarray): + x = np.squeeze(x) + if x.ndim == 1: + x = [x] + else: + x = x.tolist() + if isinstance(y, list): + if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list): + y = [y] + if isinstance(y, np.ndarray): + y = np.squeeze(y) + if y.ndim == 1: + y = [y] + else: + y = y.tolist() + if len(x) > 1 and len(x) != len(y): + raise ValueError( + "Input x should have either one data series or the same number as y." + ) + return x, y + + def create_trace(self, x, y, **trace_options): + return self.plotter.create_trace(x, y, **trace_options) + + @staticmethod + def wrap_text(text, width, backend='matplotlib'): + """ + Wrap text to a specified width with HTML line breaks. + + Parameters + ---------- + text : str + The text to wrap. + width : int + The width to wrap the text to. + + Returns + ------- + str + The wrapped text. + """ + wrapped_text = textwrap.fill(text, width=width, break_long_words=False) + if backend == 'plotly': + return wrapped_text.replace("\n", "
") + else: + return wrapped_text + + @staticmethod + def remove_brackets(s): + """ + Remove square brackets from a string and replace with forward slashes + as per section 7.1 of the SI Handbook + """ + # If s is an iterable (but not a string), apply the function recursively to each element + if hasattr(s, "__iter__") and not isinstance(s, str): + return type(s)(StandardPlot.remove_brackets(i) for i in s) + elif isinstance(s, str): + start = s.find("[") + end = s.find("]") + if start != -1 and end != -1: + char_in_brackets = s[start + 1 : end] + return s[:start] + " / " + char_in_brackets + s[end + 1 :] + return s + + +class StandardSubplot(StandardPlot): + """ + A class for creating and displaying a set of interactive Plotly figures in a grid layout. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + num_rows : int, optional + Number of rows of subplots, can be set automatically (default: None). + num_cols : int, optional + Number of columns of subplots, can be set automatically (default: None). + layout : Plotly layout, optional + A layout for the figure, overrides the layout options (default: None). + layout_options : dict, optional + Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + x, + y, + backend=None, + num_rows=None, + num_cols=None, + axis_titles=None, + trace_options=None, + trace_names=None, + trace_name_width=40, + **kwargs, + ): + self.plotter = call_plotting_function('SubplotPlotter', backend, axis_titles=axis_titles, trace_options=trace_options, **kwargs) + + self.trace_name_width = trace_name_width + + # Add traces + if x is not None and y is not None: + self.add_traces(x, y, trace_names) + + self.num_traces = len(self.plotter.traces) + self.num_rows = num_rows + self.num_cols = num_cols + if self.num_rows is None and self.num_cols is None: + # Work out the number of subplots + self.num_cols = int(math.ceil(math.sqrt(self.num_traces))) + self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) + elif self.num_rows is None: + self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) + elif self.num_cols is None: + self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) + + def __call__(self, show=True): + return self.plotter(show=show, num_rows=self.num_rows, num_cols=self.num_cols) + + +def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): + """ + Quickly plot one or more trajectories using Plotly. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Y-axis data points for each trajectory. + trace_names : list or str, optional + Name(s) for the trace(s) (default: None). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + + return call_plotting_function('trajectories', backend, x=x, y=y, trace_names=trace_names, show=show, **layout_kwargs) \ No newline at end of file diff --git a/pybop/plot/util.py b/pybop/plot/util.py index 12b66fc26..39a1ecfd2 100644 --- a/pybop/plot/util.py +++ b/pybop/plot/util.py @@ -3,24 +3,6 @@ import pybop.plot -def get_class(class_name): - err_msg = f"Plotting backend {pybop.plot.backend} is not available." - try: - module = importlib.import_module("pybop.plot." + pybop.plot.backend) - if hasattr(module, class_name): - return getattr(module, class_name) - - else: - err_msg = ( - f"Plotting backend {pybop.plot.backend} has no attribute {class_name}." - ) - raise ModuleNotFoundError(err_msg) - - except ModuleNotFoundError as error: - # Raise an ModuleNotFoundError if the module or attribute is not available - raise ModuleNotFoundError(err_msg) from error - - def set_backend(backend): err_msg = ( f"Plotting backend {backend} is not available. The default backend has not been updated. \n" @@ -29,8 +11,6 @@ def set_backend(backend): try: importlib.import_module("pybop.plot." + backend) pybop.plot.backend = backend - for attr in ["StandardPlot", "StandardSubplot"]: - setattr(pybop.plot, attr, get_class(attr)) except ModuleNotFoundError as error: # Raise an ModuleNotFoundError if the module or attribute is not available diff --git a/tests/plotting/test_plotly_manager.py b/tests/plotting/test_plotly_manager.py index 66e9ba098..b39df37cd 100644 --- a/tests/plotting/test_plotly_manager.py +++ b/tests/plotting/test_plotly_manager.py @@ -128,6 +128,9 @@ def dataset(plotly_installed): @pytest.mark.unit def test_standard_plot(dataset, plotly_installed): + # Set plotting backend + pybop.plot.set_backend("plotly") + # Check the StandardPlot class pybop.plot.StandardPlot(dataset["Time [s]"], dataset["Voltage [V]"]) @@ -167,6 +170,9 @@ def test_standard_plot(dataset, plotly_installed): @pytest.mark.unit def test_plot_dataset(dataset, plotly_installed): + # Set plotting backend + pybop.plot.set_backend("plotly") + # Test plot of a dataset pybop.plot.dataset(dataset, signal=["Voltage [V]"]) pybop.plot.dataset(dataset, signal=["Voltage [V]", "Current [A]"]) From 58f2bcfa64c46158d0b7aaccfac271e7d5a50088 Mon Sep 17 00:00:00 2001 From: u2370093 Date: Wed, 8 Apr 2026 17:08:27 +0100 Subject: [PATCH 04/10] update nyquist --- pybop/plot/__init__.py | 2 +- pybop/plot/matplotlib/nyquist.py | 65 +++++++--------- pybop/plot/nyquist.py | 74 +++++++++++++++++++ pybop/plot/plotly/nyquist.py | 122 +++++++++++++------------------ pybop/plot/plots.py | 42 ----------- pybop/plot/standard_plots.py | 48 ++++++++---- 6 files changed, 187 insertions(+), 166 deletions(-) create mode 100644 pybop/plot/nyquist.py diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 29116410e..656dbee6c 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -12,7 +12,6 @@ contour, convergence, dataset, - nyquist, parameters, posterior, problem, @@ -22,6 +21,7 @@ ) from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .nyquist import nyquist from .voronoi import voronoi_data, _voronoi_regions from . import matplotlib from . import plotly diff --git a/pybop/plot/matplotlib/nyquist.py b/pybop/plot/matplotlib/nyquist.py index c702e9049..5c6a64144 100644 --- a/pybop/plot/matplotlib/nyquist.py +++ b/pybop/plot/matplotlib/nyquist.py @@ -1,7 +1,9 @@ +import warnings + from matplotlib import pyplot as plt from pybop.parameters.parameter import Inputs -from pybop.plot.matplotlib.standard_plots import StandardPlot +from pybop.plot.nyquist import _nyquist def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): @@ -41,42 +43,31 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. """ - if not isinstance(inputs, dict): - inputs = problem.parameters.to_dict(inputs) - - model_output = problem.simulate(inputs) - domain_data = model_output["Impedance"].data.real - target_output = problem.target_data - - figure_list = [] - for var in problem.target: - plot_dict = StandardPlot( - x=domain_data, - y=-model_output[var].data.imag, - trace_names="Model", - ) - - plotting_options = dict(color="#00CC96", linewidth=2, marker=".", markersize=8) - plot_dict.traces[0].update(plotting_options) - target_trace = plot_dict.create_trace( - x=target_output[var].real, - y=-target_output[var].imag, - label="Reference", + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, ) - - plotting_options = dict( - linestyle="none", - marker="o", - fillstyle="none", - markersize=8, - markeredgecolor="#636EFA", - ) - target_trace.update(plotting_options) - - plot_dict.traces.append(target_trace) - - fig = plot_dict(show=False) + trace_options_model = dict( + label="Model", color="#00CC96", linewidth=2, marker=".", markersize=8 + ) + trace_options_reference = dict( + label="Reference", + linestyle="none", + marker="o", + fillstyle="none", + markersize=8, + markeredgecolor="#636EFA", + ) + figure_list = _nyquist( + problem, trace_options_model, trace_options_reference, inputs=inputs + ) + + for fig in figure_list: + plt.sca(fig.gca()) # Layout plt.title("Nyquist Plot", fontsize=14, x=0.2) plt.xlabel(r"$Z_{re} / \Omega$", fontsize=16) @@ -84,8 +75,6 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): plt.legend(loc="upper right", bbox_to_anchor=(1, 1.08), ncols=2) if show: - fig.show() - - figure_list.append(fig) + plt.show() return figure_list diff --git a/pybop/plot/nyquist.py b/pybop/plot/nyquist.py new file mode 100644 index 000000000..792991110 --- /dev/null +++ b/pybop/plot/nyquist.py @@ -0,0 +1,74 @@ +from pybop.parameters.parameter import Inputs +from pybop.plot.util import call_plotting_function +from pybop.plot.standard_plots import StandardPlot + +def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): + """ + Generates Nyquist plots for the given problem by evaluating the model's output and target values. + + Parameters + ---------- + problem : pybop.Problem + An instance of a problem class that contains the parameters and methods + for evaluation and target retrieval. + inputs : Inputs, optional + Input parameters for the problem. If not provided, the default parameters from the problem + instance will be used. These parameters are verified before use (default is None). + show : bool, optional + If True, the plots will be displayed. + **layout_kwargs : dict, optional + Additional keyword arguments for customising the plot layout. These arguments are passed to + `fig.update_layout()`. + + Returns + ------- + list + A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. + + Notes + ----- + - The function extracts the real part of the impedance from the model's output and the real and imaginary parts + of the impedance from the target output. + - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. + - An additional trace for the reference (target output) is added to the plot. + - The plot layout can be customised using `layout_kwargs`. + + Example + ------- + >>> problem = pybop.EISProblem() + >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") + >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. + """ + return call_plotting_function( + "nyquist", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs + ) + +def _nyquist(problem, trace_options_model: dict, trace_options_reference, inputs: Inputs = None): + + if not isinstance(inputs, dict): + inputs = problem.parameters.to_dict(inputs) + + model_output = problem.simulate(inputs) + domain_data = model_output["Impedance"].data.real + target_output = problem.target_data + figure_list = [] + for var in problem.target: + plot_dict = StandardPlot( + x=domain_data, + y=-model_output[var].data.imag, + trace_names="Model", + ) + + plot_dict.traces[0].update(trace_options_model) + + target_trace = plot_dict.create_trace( + x=target_output[var].real, + y=-target_output[var].imag, + **trace_options_reference, + ) + plot_dict.traces.append(target_trace) + + fig = plot_dict(show=False) + figure_list.append(fig) + + return figure_list diff --git a/pybop/plot/plotly/nyquist.py b/pybop/plot/plotly/nyquist.py index bbc5fa49a..1578d2462 100644 --- a/pybop/plot/plotly/nyquist.py +++ b/pybop/plot/plotly/nyquist.py @@ -1,5 +1,5 @@ from pybop.parameters.parameter import Inputs -from pybop.plot.plotly.standard_plots import StandardPlot +from pybop.plot.nyquist import _nyquist def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): @@ -39,81 +39,63 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. """ - if not isinstance(inputs, dict): - inputs = problem.parameters.to_dict(inputs) + default_layout_options = dict( + title="Nyquist Plot", + font=dict(family="Arial", size=14), + plot_bgcolor="white", + paper_bgcolor="white", + xaxis=dict( + title=dict(text="Zre / Ω", font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + ), + yaxis=dict( + title=dict(text="-Zim / Ω", font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + scaleanchor="x", + scaleratio=1, + ), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), + width=600, + height=600, + ) - model_output = problem.simulate(inputs) - domain_data = model_output["Impedance"].data.real - target_output = problem.target_data + # Overwrite with user-kwargs + default_layout_options.update(layout_kwargs) - figure_list = [] - for var in problem.target: - default_layout_options = dict( - title="Nyquist Plot", - font=dict(family="Arial", size=14), - plot_bgcolor="white", - paper_bgcolor="white", - xaxis=dict( - title=dict(text="Zre / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - ), - yaxis=dict( - title=dict(text="-Zim / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - scaleanchor="x", - scaleratio=1, - ), - legend=dict( - orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 - ), - width=600, - height=600, - ) + trace_options_model = dict( + mode="lines+markers", + line=dict(color="#00CC96", width=2), + marker=dict(size=8, color="#00CC96", symbol="circle"), + ) - plot_dict = StandardPlot( - x=domain_data, - y=-model_output[var].data.imag, - layout_options=default_layout_options, - trace_names="Model", - ) + trace_options_reference = dict( + name="Reference", + mode="markers", + marker=dict(size=8, color="#636EFA", symbol="circle-open"), + showlegend=True, + ) - plot_dict.traces[0].update( - mode="lines+markers", - line=dict(color="#00CC96", width=2), - marker=dict(size=8, color="#00CC96", symbol="circle"), - ) + figure_list = _nyquist( + problem, trace_options_model, trace_options_reference, inputs=inputs + ) - target_trace = plot_dict.create_trace( - x=target_output[var].real, - y=-target_output[var].imag, - name="Reference", - mode="markers", - marker=dict(size=8, color="#636EFA", symbol="circle-open"), - showlegend=True, - ) - plot_dict.traces.append(target_trace) - - fig = plot_dict(show=False) - - # Overwrite with user-kwargs - fig.update_layout(**layout_kwargs) + for fig in figure_list: + fig.update_layout(**default_layout_options) if show: fig.show() - figure_list.append(fig) - return figure_list diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index 95fb1b701..7e6425aee 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -141,48 +141,6 @@ def dataset( ) -def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): - """ - Generates Nyquist plots for the given problem by evaluating the model's output and target values. - - Parameters - ---------- - problem : pybop.Problem - An instance of a problem class that contains the parameters and methods - for evaluation and target retrieval. - inputs : Inputs, optional - Input parameters for the problem. If not provided, the default parameters from the problem - instance will be used. These parameters are verified before use (default is None). - show : bool, optional - If True, the plots will be displayed. - **layout_kwargs : dict, optional - Additional keyword arguments for customising the plot layout. These arguments are passed to - `fig.update_layout()`. - - Returns - ------- - list - A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. - - Notes - ----- - - The function extracts the real part of the impedance from the model's output and the real and imaginary parts - of the impedance from the target output. - - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. - - An additional trace for the reference (target output) is added to the plot. - - The plot layout can be customised using `layout_kwargs`. - - Example - ------- - >>> problem = pybop.EISProblem() - >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") - >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. - """ - return call_plotting_function( - "nyquist", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs - ) - - def parameters(result: "Result", show=True, backend=None, **layout_kwargs): """ Plot the evolution of parameters during the optimisation process using Plotly. diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index 86111bbae..0dc0bfddf 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -2,10 +2,10 @@ import textwrap import numpy as np -import warnings from pybop.plot.util import call_plotting_function + class StandardPlot: def __init__( self, @@ -14,10 +14,13 @@ def __init__( trace_options=None, trace_names=None, trace_name_width=20, - backend = None, - **kwargs): - - self.plotter = call_plotting_function('Plotter', backend, trace_options=trace_options, **kwargs) + backend=None, + **kwargs, + ): + + self.plotter = call_plotting_function( + "Plotter", backend, trace_options=trace_options, **kwargs + ) self.trace_name_width = trace_name_width @@ -31,7 +34,7 @@ def __call__(self, show=True): @property def traces(self): return self.plotter.traces - + @traces.setter def traces(self, value): self.plotter.traces = value @@ -42,7 +45,9 @@ def add_traces(self, x, y, trace_names): if isinstance(trace_names, str): trace_names = [trace_names] for i, name in enumerate(trace_names): - trace_names[i] = self.wrap_text(name, width=self.trace_name_width, backend=self.plotter.backend) + trace_names[i] = self.wrap_text( + name, width=self.trace_name_width, backend=self.plotter.backend + ) # Parse the data x, y = self.parse_data(x, y) @@ -50,7 +55,6 @@ def add_traces(self, x, y, trace_names): # Add traces self.plotter.add_traces(x, y, trace_names) - def parse_data(self, x, y): """ Check the type and dimensions of the data and convert if necessary to a list @@ -89,12 +93,12 @@ def parse_data(self, x, y): "Input x should have either one data series or the same number as y." ) return x, y - + def create_trace(self, x, y, **trace_options): return self.plotter.create_trace(x, y, **trace_options) @staticmethod - def wrap_text(text, width, backend='matplotlib'): + def wrap_text(text, width, backend="matplotlib"): """ Wrap text to a specified width with HTML line breaks. @@ -111,7 +115,7 @@ def wrap_text(text, width, backend='matplotlib'): The wrapped text. """ wrapped_text = textwrap.fill(text, width=width, break_long_words=False) - if backend == 'plotly': + if backend == "plotly": return wrapped_text.replace("\n", "
") else: return wrapped_text @@ -132,7 +136,7 @@ def remove_brackets(s): char_in_brackets = s[start + 1 : end] return s[:start] + " / " + char_in_brackets + s[end + 1 :] return s - + class StandardSubplot(StandardPlot): """ @@ -178,7 +182,13 @@ def __init__( trace_name_width=40, **kwargs, ): - self.plotter = call_plotting_function('SubplotPlotter', backend, axis_titles=axis_titles, trace_options=trace_options, **kwargs) + self.plotter = call_plotting_function( + "SubplotPlotter", + backend, + axis_titles=axis_titles, + trace_options=trace_options, + **kwargs, + ) self.trace_name_width = trace_name_width @@ -201,7 +211,7 @@ def __init__( def __call__(self, show=True): return self.plotter(show=show, num_rows=self.num_rows, num_cols=self.num_cols) - + def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): """ Quickly plot one or more trajectories using Plotly. @@ -225,4 +235,12 @@ def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwarg The Plotly figure object for the scatter plot. """ - return call_plotting_function('trajectories', backend, x=x, y=y, trace_names=trace_names, show=show, **layout_kwargs) \ No newline at end of file + return call_plotting_function( + "trajectories", + backend, + x=x, + y=y, + trace_names=trace_names, + show=show, + **layout_kwargs, + ) From 99b0ed0026d6f2eedf86b99cda1c87164b930532 Mon Sep 17 00:00:00 2001 From: u2370093 Date: Wed, 8 Apr 2026 17:12:39 +0100 Subject: [PATCH 05/10] minor edits --- pybop/plot/nyquist.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pybop/plot/nyquist.py b/pybop/plot/nyquist.py index 792991110..bcd609fd1 100644 --- a/pybop/plot/nyquist.py +++ b/pybop/plot/nyquist.py @@ -1,6 +1,7 @@ from pybop.parameters.parameter import Inputs -from pybop.plot.util import call_plotting_function from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import call_plotting_function + def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): """ @@ -43,7 +44,10 @@ def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kw "nyquist", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs ) -def _nyquist(problem, trace_options_model: dict, trace_options_reference, inputs: Inputs = None): + +def _nyquist( + problem, trace_options_model: dict, trace_options_reference, inputs: Inputs = None +): if not isinstance(inputs, dict): inputs = problem.parameters.to_dict(inputs) @@ -54,9 +58,9 @@ def _nyquist(problem, trace_options_model: dict, trace_options_reference, inputs figure_list = [] for var in problem.target: plot_dict = StandardPlot( - x=domain_data, - y=-model_output[var].data.imag, - trace_names="Model", + x=domain_data, + y=-model_output[var].data.imag, + trace_names="Model", ) plot_dict.traces[0].update(trace_options_model) From ea5af267d52c765086119df3244a668affb18569 Mon Sep 17 00:00:00 2001 From: u2370093 Date: Thu, 9 Apr 2026 14:42:04 +0100 Subject: [PATCH 06/10] some simplifications --- .../ecm_tau_redefined.py | 1 + pybop/plot/__init__.py | 8 +- pybop/plot/{plotly => }/dataset.py | 11 +- pybop/plot/matplotlib/__init__.py | 4 +- pybop/plot/matplotlib/dataset.py | 55 --------- pybop/plot/matplotlib/parameters.py | 71 ----------- pybop/plot/matplotlib/problem.py | 113 ------------------ pybop/plot/matplotlib/standard_plots.py | 35 ++++-- pybop/plot/matplotlib/util.py | 27 +++++ pybop/plot/{plotly => }/parameters.py | 18 +-- pybop/plot/plotly/__init__.py | 5 +- pybop/plot/plotly/standard_plots.py | 11 ++ pybop/plot/plotly/util.py | 24 ++++ pybop/plot/plots.py | 99 --------------- pybop/plot/{plotly => }/problem.py | 67 +++++------ pybop/plot/standard_plots.py | 3 + pybop/plot/util.py | 21 ++++ 17 files changed, 164 insertions(+), 409 deletions(-) rename pybop/plot/{plotly => }/dataset.py (82%) delete mode 100644 pybop/plot/matplotlib/dataset.py delete mode 100644 pybop/plot/matplotlib/parameters.py delete mode 100644 pybop/plot/matplotlib/problem.py create mode 100644 pybop/plot/matplotlib/util.py rename pybop/plot/{plotly => }/parameters.py (81%) create mode 100644 pybop/plot/plotly/util.py rename pybop/plot/{plotly => }/problem.py (70%) diff --git a/examples/scripts/battery_parameterisation/ecm_tau_redefined.py b/examples/scripts/battery_parameterisation/ecm_tau_redefined.py index 4bd754dd9..53fe089c1 100644 --- a/examples/scripts/battery_parameterisation/ecm_tau_redefined.py +++ b/examples/scripts/battery_parameterisation/ecm_tau_redefined.py @@ -138,6 +138,7 @@ print("Identified parameters:", result.x.tolist() + [result.x[2] / result.x[1]]) # Plot the timeseries output +pybop.plot.problem(problem, inputs=result.best_inputs, title="Optimised Comparison", backend='plotly') pybop.plot.problem(problem, inputs=result.best_inputs, title="Optimised Comparison") # Plot the optimisation result diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 656dbee6c..ce95ef70d 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -2,7 +2,7 @@ DEFAULT_BACKEND = 'matplotlib' backend=DEFAULT_BACKEND -from .util import set_backend, call_plotting_function +from .util import set_backend, call_plotting_function, get_default_options # # Import plots @@ -11,17 +11,17 @@ chains, contour, convergence, - dataset, - parameters, posterior, - problem, summary_table, surface, trace ) from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .dataset import dataset from .nyquist import nyquist +from .parameters import parameters +from .problem import problem from .voronoi import voronoi_data, _voronoi_regions from . import matplotlib from . import plotly diff --git a/pybop/plot/plotly/dataset.py b/pybop/plot/dataset.py similarity index 82% rename from pybop/plot/plotly/dataset.py rename to pybop/plot/dataset.py index 899b2ed1a..639bcc88a 100644 --- a/pybop/plot/plotly/dataset.py +++ b/pybop/plot/dataset.py @@ -1,7 +1,8 @@ -from pybop.plot.plotly.standard_plots import StandardPlot, trajectories +from pybop.plot.standard_plots import StandardPlot, trajectories +from pybop.plot.util import update_and_show -def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs): +def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs): """ Quickly plot a PyBOP Dataset using Plotly. @@ -50,9 +51,9 @@ def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs): show=False, xaxis_title=StandardPlot.remove_brackets(dataset.domain), yaxis_title=yaxis_title, + backend = backend, ) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + + fig = update_and_show(fig, backend = backend, show=show, **layout_kwargs) return fig diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py index 21da2259a..7311edd07 100644 --- a/pybop/plot/matplotlib/__init__.py +++ b/pybop/plot/matplotlib/__init__.py @@ -1,9 +1,7 @@ from .standard_plots import Plotter, SubplotPlotter, trajectories -from .dataset import dataset from .convergence import convergence -from .parameters import parameters -from .problem import problem from .contour import contour from .voronoi import surface from .nyquist import nyquist from .samples import chains, posterior, summary_table, trace +from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/matplotlib/dataset.py b/pybop/plot/matplotlib/dataset.py deleted file mode 100644 index 7ade00091..000000000 --- a/pybop/plot/matplotlib/dataset.py +++ /dev/null @@ -1,55 +0,0 @@ -import matplotlib.pyplot as plt - -from pybop.plot.matplotlib.standard_plots import StandardPlot, trajectories - - -def dataset(dataset, signal=None, trace_names=None, show=True): - """ - Quickly plot a PyBOP Dataset using Plotly. - - Parameters - ---------- - dataset : object - A PyBOP dataset. - signal : list or str, optional - The name of the time series to plot (default: "Voltage [V]"). - trace_names : list or str, optional - Name(s) for the trace(s) (default: "Data"). - show : bool, optional - If True, the figure is shown upon creation (default: True). - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object for the scatter plot. - """ - - # Get data dictionary - if signal is None: - signal = ["Voltage [V]"] - dataset.check(signal=signal) - - # Compile ydata and labels or legend - y = [dataset[s] for s in signal] - if len(signal) == 1: - yaxis_title = signal[0] - if trace_names is None: - trace_names = ["Data"] - else: - yaxis_title = "Output" - if trace_names is None: - trace_names = StandardPlot.remove_brackets(signal) - - # Create the figure - fig = trajectories( - x=dataset[dataset.domain], - y=y, - trace_names=trace_names, - show=False, - xaxis_title=StandardPlot.remove_brackets(dataset.domain), - yaxis_title=yaxis_title, - ) - if show: - plt.show() - - return fig diff --git a/pybop/plot/matplotlib/parameters.py b/pybop/plot/matplotlib/parameters.py deleted file mode 100644 index eef582f0e..000000000 --- a/pybop/plot/matplotlib/parameters.py +++ /dev/null @@ -1,71 +0,0 @@ -import warnings -from typing import TYPE_CHECKING - -import matplotlib.pyplot as plt - -from pybop.costs.log_likelihoods import GaussianLogLikelihood -from pybop.plot.standard_plots import StandardSubplot - -if TYPE_CHECKING: - from pybop._result import Result - - -def parameters(result: "Result", show=True, **layout_kwargs): - """ - Plot the evolution of parameters during the optimisation process using Plotly. - - Parameters - ---------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - show : bool, optional - If True, the figure is shown upon creation (default: True). - - Returns - ------- - plotly.graph_objs.Figure - A Plotly figure object showing the parameter evolution over iterations. - """ - - if len(layout_kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" - f"{list(layout_kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - - # Extract parameters and log from the optimisation object - parameters = result.problem.parameters - x = list(range(len(result.x_model))) - y = [list(item) for item in zip(*result.x_model, strict=False)] - - # Create lists of axis titles and trace names - axis_titles = [] - trace_names = parameters.names - for name in trace_names: - axis_titles.append(("Evaluation", name)) - - if isinstance(result.problem, GaussianLogLikelihood): - axis_titles.append(("Evaluation", "Sigma")) - trace_names.append("Sigma") - - # Create a plot dictionary - plot_dict = StandardSubplot( - x=x, - y=y, - axis_titles=axis_titles, - trace_names=trace_names, - trace_name_width=50, - figsize=(18, 8), - backend="matplotlib", - ) - - plt.suptitle("Parameter Convergence") - - # Generate the figure and update the layout - fig = plot_dict(show=False) - if show: - plt.show() - - return fig diff --git a/pybop/plot/matplotlib/problem.py b/pybop/plot/matplotlib/problem.py deleted file mode 100644 index da080207c..000000000 --- a/pybop/plot/matplotlib/problem.py +++ /dev/null @@ -1,113 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np - -from pybop.costs.design_cost import DesignCost -from pybop.costs.error_measures import ErrorMeasure -from pybop.parameters.parameter import Inputs -from pybop.plot.matplotlib.standard_plots import StandardPlot -from pybop.problems.meta_problem import MetaProblem -from pybop.problems.problem import Problem -from pybop.simulators.solution import Solution - - -def problem( - problem: Problem, - inputs: Inputs = None, - show: bool = True, - title="Scatter Plot", -): - """ - Produce a quick plot of the target dataset against optimised model output. - - Generates an interactive plot comparing the simulated model output with - an optional target dataset and visualises uncertainty. - - Parameters - ---------- - problem : pybop.Problem - Problem object with dataset and targets attributes. - inputs : Inputs - Optimised (or example) parameter values. - show : bool, optional - If True, the figure is shown upon creation (default: True). - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object for the scatter plot. - """ - if inputs is None: - inputs = problem.parameters.to_dict() - elif not isinstance(inputs, dict): - raise TypeError(f"Expecting a dictionary, received {type(inputs)}") - - domain = problem.domain - if problem.domain_data is None: - # Simulate the model for the both the initial and the given inputs - target = problem.target - problem.set_target(target + [domain]) - initial_inputs = problem.simulator.parameters.to_dict("initial") - target_output = problem.simulate(initial_inputs) - target_domain = target_output[domain].data - model_output = problem.simulate(inputs) - model_domain = model_output[domain].data - problem.set_target(target) - else: - # Extract the time data and simulate the model for the given inputs - target_output = Solution() - for target in problem.target: - target_output.set_solution_variable( - target, data=problem.target_data[target] - ) - target_domain = problem.domain_data - model_output = problem.simulate(inputs) - model_domain = target_domain[: len(model_output[target].data)] - - # Create a plot for each output - figure_list = [] - for var in problem.target: - # Create a plot dictionary - plot_dict = StandardPlot() - - plot_dict.create_trace( - x=target_domain, - y=target_output[var].data, - label="Reference", - marker=".", - linestyle="None", - ) - - plot_dict.create_trace( - x=model_domain, - y=model_output[var].data, - label="Optimised" if isinstance(problem.cost, DesignCost) else "Model", - marker="." if isinstance(problem, MetaProblem) else None, - linestyle="None" if isinstance(problem, MetaProblem) else "-", - ) - - if isinstance(problem.cost, ErrorMeasure) and len( - model_output[var].data - ) == len(target_output[var].data): - # Compute the standard deviation as proxy for uncertainty - plot_dict.sigma = np.std(model_output[var].data - target_output[var].data) - - # Convert x and upper and lower limits into lists to create a filled trace - x = target_domain.tolist() - y_upper = (model_output[var].data + plot_dict.sigma).tolist() - y_lower = (model_output[var].data - plot_dict.sigma).tolist() - - plt.fill_between(x, y_upper, y_lower, color=[(1.0, 0.898, 0.800, 0.8)]) - - # Generate the figure and update the layout - fig = plot_dict(show=False) - plt.xlabel("Time / s") - plt.ylabel(StandardPlot.remove_brackets(var)) - plt.title(title) - plt.legend() - plt.tight_layout() - if show: - plt.show() - - figure_list.append(fig) - - return figure_list diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py index e96cf19af..906cda36f 100644 --- a/pybop/plot/matplotlib/standard_plots.py +++ b/pybop/plot/matplotlib/standard_plots.py @@ -34,9 +34,11 @@ def __init__( self, trace_options=None, figsize=(8, 6), + title=None, **kwargs, ): self.backend = "matplotlib" + self.title = title # Warning if layout arguments ignored if len(kwargs) > 0: warnings.warn( @@ -81,6 +83,9 @@ def __call__(self, show=True): **dict(loc="best", fontsize=12), ) + if self.title is not None: + plt.suptitle(self.title) + if show: plt.show() else: @@ -116,7 +121,7 @@ def add_traces(self, x, y, trace_names=None, **trace_options): self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) - def create_trace(self, x, y, label, ax=None, **trace_options): + def create_trace(self, x, y, label=None, ax=None, **trace_options): """ Add line to plot. @@ -129,17 +134,27 @@ def create_trace(self, x, y, label, ax=None, **trace_options): trace = dict(x=x[:size], y=y[:size], label=label, ax=ax) trace.update(trace_options) return trace + + def create_fill_trace(self, x, y_upper, y_lower, **options): + trace = dict(x=x, y=y_upper, plot_type='fill', y_lower=y_lower) + trace.update(options) + return trace - def _plot_trace(self, x, y, label, ax=None, **trace_options): + def _plot_trace(self, x, y, label=None, ax=None, plot_type='plot', **trace_options): if ax is None: ax = plt.gca() - line = ax.plot( - x, - y, - label=label, - **trace_options, - ) + if plot_type=='plot': + line = ax.plot( + x, + y, + label=label, + **trace_options, + ) + elif plot_type=='fill': + line = ax.fill_between(x, y, **trace_options) + else: + raise ValueError('Plot type not recognised') if len(line) > 1: return line @@ -223,6 +238,10 @@ def __call__(self, show=True, num_rows=1, num_cols=1): bbox_to_anchor=(0.99, 0.95), ) plt.tight_layout(rect=[0, 0, 1, 0.95]) + + if self.title is not None: + plt.suptitle(self.title) + if show: plt.show() diff --git a/pybop/plot/matplotlib/util.py b/pybop/plot/matplotlib/util.py new file mode 100644 index 000000000..6f31ce29a --- /dev/null +++ b/pybop/plot/matplotlib/util.py @@ -0,0 +1,27 @@ +import warnings +import matplotlib.pyplot as plt + +def update_and_show(fig, show=True, **layout_kwargs): + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + if show: + plt.show() + + return fig + +DEFAULT_PLOT_OPTIONS = { + 'parameters' : dict(figsize=(18, 8), title="Parameter Convergence"), + 'problem' : { + 'default_trace_options' : dict (label="Model", marker=None, linestyle='-'), + 'design_cost_options' : dict(label = "Optimised"), + 'meta_problem_options' : dict(marker=".", linestyle='none'), + 'reference_options' : dict(label="Reference", marker=".", linestyle='none'), + 'fill_options' : dict(color=[(1.0, 0.898, 0.800, 0.8)]) + } +} \ No newline at end of file diff --git a/pybop/plot/plotly/parameters.py b/pybop/plot/parameters.py similarity index 81% rename from pybop/plot/plotly/parameters.py rename to pybop/plot/parameters.py index e93907248..c1082c6ee 100644 --- a/pybop/plot/plotly/parameters.py +++ b/pybop/plot/parameters.py @@ -2,12 +2,13 @@ from pybop.costs.log_likelihoods import GaussianLogLikelihood from pybop.plot.standard_plots import StandardSubplot +from pybop.plot.util import get_default_options, update_and_show if TYPE_CHECKING: from pybop._result import Result -def parameters(result: "Result", show=True, **layout_kwargs): +def parameters(result: "Result", show=True, backend=None, **layout_kwargs): """ Plot the evolution of parameters during the optimisation process using Plotly. @@ -44,28 +45,21 @@ def parameters(result: "Result", show=True, **layout_kwargs): trace_names.append("Sigma") # Set subplot layout options - layout_options = dict( - title="Parameter Convergence", - width=1024, - height=576, - legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), - ) + plot_options = get_default_options('paramters', backend) # Create a plot dictionary plot_dict = StandardSubplot( x=x, y=y, axis_titles=axis_titles, - layout_options=layout_options, trace_names=trace_names, trace_name_width=50, - backend="plotly", + backend=backend, + **plot_options, ) # Generate the figure and update the layout fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + fig = update_and_show(fig, backend, **layout_kwargs) return fig diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py index b3b907abd..de3350225 100644 --- a/pybop/plot/plotly/__init__.py +++ b/pybop/plot/plotly/__init__.py @@ -1,10 +1,9 @@ from .plotly_manager import PlotlyManager from .standard_plots import Plotter, SubplotPlotter, trajectories from .contour import contour -from .dataset import dataset from .convergence import convergence -from .parameters import parameters -from .problem import problem from .nyquist import nyquist from .voronoi import surface from .samples import chains, posterior, summary_table, trace + +from .util import update_and_show, DEFAULT_PLOT_OPTIONS \ No newline at end of file diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py index 62cd708a2..adce31af2 100644 --- a/pybop/plot/plotly/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -158,6 +158,17 @@ def create_trace(self, x, y, **trace_options): y=y, **trace_options, ) + + def create_fill_trace(self, x, y_upper, y_lower, **options): + return self.create_trace( + x=x + x[::-1], + y=y_upper + y_lower[::-1], + fill="toself", + line=dict(color="rgba(255,255,255,0)"), + hoverinfo="skip", + showlegend=False, + **options + ) class SubplotPlotter(Plotter): diff --git a/pybop/plot/plotly/util.py b/pybop/plot/plotly/util.py new file mode 100644 index 000000000..563dde595 --- /dev/null +++ b/pybop/plot/plotly/util.py @@ -0,0 +1,24 @@ +import warnings + +def update_and_show(fig, show=True, **layout_kwargs): + fig.update_layout(**layout_kwargs) + if show: + fig.show() + return fig + + +DEFAULT_PLOT_OPTIONS = { + 'parameters' : dict(layout_options = dict( + title="Parameter Convergence", + width=1024, + height=576, + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), + )), + 'problem' : { + 'default_trace_options' : dict (name="Model", mode="lines", showlegend=True), + 'design_cost_options' : dict(name = "Optimised"), + 'meta_problem_options' : dict(mode = "lines"), + 'reference_options' : dict(name="Reference", mode="markers", showlegend=True), + 'fill_options' : dict(fillcolor ="rgba(255,229,204,0.8)" ) + } +} \ No newline at end of file diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index 7e6425aee..b51745999 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -103,69 +103,6 @@ def convergence(result: "Result", show=True, backend=None, **layout_kwargs): "convergence", backend, result=result, show=show, **layout_kwargs ) - -def dataset( - dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs -): - """ - Quickly plot a PyBOP Dataset using Plotly. - - Parameters - ---------- - dataset : object - A PyBOP dataset. - signal : list or str, optional - The name of the time series to plot (default: "Voltage [V]"). - trace_names : list or str, optional - Name(s) for the trace(s) (default: "Data"). - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time / s"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object for the scatter plot. - """ - call_plotting_function( - "dataset", - backend, - dataset=dataset, - signal=signal, - trace_names=trace_names, - show=show, - **layout_kwargs, - ) - - -def parameters(result: "Result", show=True, backend=None, **layout_kwargs): - """ - Plot the evolution of parameters during the optimisation process using Plotly. - - Parameters - ---------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time [s]"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - ------- - plotly.graph_objs.Figure - A Plotly figure object showing the parameter evolution over iterations. - """ - return call_plotting_function( - "parameters", backend, result=result, show=show, **layout_kwargs - ) - - def posterior(result: "SamplingResult", show=True, backend=None, **kwargs): """ Plot the summed posterior distribution across chains. @@ -173,42 +110,6 @@ def posterior(result: "SamplingResult", show=True, backend=None, **kwargs): return call_plotting_function("posterior", backend, result=result, **kwargs) -def problem( - problem: Problem, - inputs: Inputs = None, - show: bool = True, - backend=None, - **layout_kwargs, -): - """ - Produce a quick plot of the target dataset against optimised model output. - - Generates an interactive plot comparing the simulated model output with - an optional target dataset and visualises uncertainty. - - Parameters - ---------- - problem : pybop.Problem - Problem object with dataset and targets attributes. - inputs : Inputs - Optimised (or example) parameter values. - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time / s"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object for the scatter plot. - """ - return call_plotting_function( - "problem", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs - ) - - def summary_table(result: "SamplingResult", backend=None): """ Display summary statistics in a table. diff --git a/pybop/plot/plotly/problem.py b/pybop/plot/problem.py similarity index 70% rename from pybop/plot/plotly/problem.py rename to pybop/plot/problem.py index 9aac2571f..b2fabaaa0 100644 --- a/pybop/plot/plotly/problem.py +++ b/pybop/plot/problem.py @@ -3,7 +3,8 @@ from pybop.costs.design_cost import DesignCost from pybop.costs.error_measures import ErrorMeasure from pybop.parameters.parameter import Inputs -from pybop.plot.plotly.standard_plots import StandardPlot +from pybop.plot.matplotlib.standard_plots import StandardPlot +from pybop.plot.util import get_default_options, update_and_show from pybop.problems.meta_problem import MetaProblem from pybop.problems.problem import Problem from pybop.simulators.solution import Solution @@ -13,7 +14,8 @@ def problem( problem: Problem, inputs: Inputs = None, show: bool = True, - **layout_kwargs, + title="Scatter Plot", + backend = None, ): """ Produce a quick plot of the target dataset against optimised model output. @@ -29,10 +31,6 @@ def problem( Optimised (or example) parameter values. show : bool, optional If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time / s"` or - `xaxis={"title": "Time [s]", font={"size":14}}` Returns ------- @@ -66,33 +64,39 @@ def problem( model_output = problem.simulate(inputs) model_domain = target_domain[: len(model_output[target].data)] + # Retrieve default layout options + plot_options = get_default_options('problem', backend) + trace_options = plot_options.get('default_trace_options') or {} + design_cost_options = plot_options.get('design_cost_options') or {} + meta_problem_options = plot_options.get('meta_problem_options') or {} + reference_options = plot_options.get('reference_options') or {} + fill_options = plot_options.get('fill_options') or {} + # Create a plot for each output figure_list = [] for var in problem.target: + options = trace_options.copy() + if isinstance(problem, MetaProblem): + options.update(meta_problem_options) + if isinstance(problem.cost, DesignCost): + options.update(design_cost_options) + + print(isinstance(problem, MetaProblem), isinstance(problem.cost, DesignCost), options) + # Create a plot dictionary - plot_dict = StandardPlot( - layout_options=dict( - title="Scatter Plot", - xaxis_title=StandardPlot.remove_brackets(domain), - yaxis_title=StandardPlot.remove_brackets(var), - ) - ) + plot_dict = StandardPlot(backend=backend) model_trace = plot_dict.create_trace( x=model_domain, y=model_output[var].data, - name="Optimised" if isinstance(problem.cost, DesignCost) else "Model", - mode="markers" if isinstance(problem, MetaProblem) else "lines", - showlegend=True, + **options ) plot_dict.traces.append(model_trace) - target_trace = plot_dict.create_trace( + target_trace =plot_dict.create_trace( x=target_domain, y=target_output[var].data, - name="Reference", - mode="markers", - showlegend=True, + **reference_options ) plot_dict.traces.append(target_trace) @@ -107,25 +111,16 @@ def problem( y_upper = (model_output[var].data + plot_dict.sigma).tolist() y_lower = (model_output[var].data - plot_dict.sigma).tolist() - fill_trace = plot_dict.create_trace( - x=x + x[::-1], - y=y_upper + y_lower[::-1], - fill="toself", - fillcolor="rgba(255,229,204,0.8)", - line=dict(color="rgba(255,255,255,0)"), - hoverinfo="skip", - showlegend=False, - ) - plot_dict.traces.append(fill_trace) - - # Reverse the order of the traces to put the model on top + fill_trace = plot_dict.create_fill_trace(x, y_upper, y_lower, **fill_options) plot_dict.traces = plot_dict.traces[::-1] - # Generate the figure and update the layout fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + # plt.xlabel("Time / s") + # plt.ylabel(StandardPlot.remove_brackets(var)) + # plt.title(title) + # plt.legend() + # plt.tight_layout() + fig = update_and_show(fig, backend=backend) figure_list.append(fig) diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index 0dc0bfddf..db566a533 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -96,6 +96,9 @@ def parse_data(self, x, y): def create_trace(self, x, y, **trace_options): return self.plotter.create_trace(x, y, **trace_options) + + def create_fill_trace(self, x, y_upper, y_lower, **options): + return self.plotter.create_fill_trace(x, y_upper, y_lower, **options) @staticmethod def wrap_text(text, width, backend="matplotlib"): diff --git a/pybop/plot/util.py b/pybop/plot/util.py index 39a1ecfd2..27cf9303c 100644 --- a/pybop/plot/util.py +++ b/pybop/plot/util.py @@ -34,3 +34,24 @@ def call_plotting_function(function_name, backend, **kwargs): except ModuleNotFoundError as error: # Raise an ModuleNotFoundError if the module or attribute is not available raise ModuleNotFoundError(err_msg) from error + + +def update_and_show(fig, backend=None, **kwargs): + return call_plotting_function("update_and_show", backend, fig=fig, **kwargs) + + +def get_default_options(plot_type, backend): + if backend is None: + backend = pybop.plot.backend + + if backend == 'plotly': + opts = pybop.plot.plotly.DEFAULT_PLOT_OPTIONS + elif backend == 'matplotlib': + opts = pybop.plot.matplotlib.DEFAULT_PLOT_OPTIONS + else: + opts = {} + + if plot_type in opts.keys(): + return opts[plot_type] + else: + return {} \ No newline at end of file From 112c231722449b19b5a0020af33a35bc5dbd6eb2 Mon Sep 17 00:00:00 2001 From: u2370093 Date: Mon, 13 Apr 2026 11:44:32 +0100 Subject: [PATCH 07/10] simplify sample plotting --- .../ecm_monte_carlo_sampling.ipynb | 6 +- .../ecm_tau_redefined.py | 1 - pybop/plot/__init__.py | 7 +- pybop/plot/dataset.py | 8 +- pybop/plot/matplotlib/__init__.py | 3 +- pybop/plot/matplotlib/samples.py | 146 ------------------ pybop/plot/matplotlib/standard_plots.py | 114 +++++++++++--- pybop/plot/matplotlib/util.py | 47 ++++-- pybop/plot/parameters.py | 2 +- pybop/plot/plotly/__init__.py | 5 +- pybop/plot/plotly/samples.py | 131 ---------------- pybop/plot/plotly/standard_plots.py | 89 +++++++---- pybop/plot/plotly/util.py | 68 +++++--- pybop/plot/plots.py | 30 ---- pybop/plot/problem.py | 48 +++--- pybop/plot/samples.py | 117 ++++++++++++++ pybop/plot/standard_plots.py | 23 ++- pybop/plot/util.py | 8 +- 18 files changed, 416 insertions(+), 437 deletions(-) delete mode 100644 pybop/plot/matplotlib/samples.py delete mode 100644 pybop/plot/plotly/samples.py create mode 100644 pybop/plot/samples.py diff --git a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb index addca5b86..23433be03 100644 --- a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb @@ -46,7 +46,7 @@ "\n", "import pybop\n", "\n", - "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.set_backend(\"matplotlib\")\n", "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" @@ -1104,7 +1104,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "env-py-3-13", "language": "python", "name": "python3" }, @@ -1118,7 +1118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.11" } }, "nbformat": 4, diff --git a/examples/scripts/battery_parameterisation/ecm_tau_redefined.py b/examples/scripts/battery_parameterisation/ecm_tau_redefined.py index 53fe089c1..4bd754dd9 100644 --- a/examples/scripts/battery_parameterisation/ecm_tau_redefined.py +++ b/examples/scripts/battery_parameterisation/ecm_tau_redefined.py @@ -138,7 +138,6 @@ print("Identified parameters:", result.x.tolist() + [result.x[2] / result.x[1]]) # Plot the timeseries output -pybop.plot.problem(problem, inputs=result.best_inputs, title="Optimised Comparison", backend='plotly') pybop.plot.problem(problem, inputs=result.best_inputs, title="Optimised Comparison") # Plot the optimisation result diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index ce95ef70d..5c6eced49 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -8,13 +8,9 @@ # Import plots # from .plots import ( - chains, contour, convergence, - posterior, - summary_table, - surface, - trace + surface ) from .standard_plots import StandardPlot, StandardSubplot, trajectories @@ -22,6 +18,7 @@ from .nyquist import nyquist from .parameters import parameters from .problem import problem +from .samples import chains, posterior, summary_table, trace from .voronoi import voronoi_data, _voronoi_regions from . import matplotlib from . import plotly diff --git a/pybop/plot/dataset.py b/pybop/plot/dataset.py index 639bcc88a..b92d2b78f 100644 --- a/pybop/plot/dataset.py +++ b/pybop/plot/dataset.py @@ -2,7 +2,9 @@ from pybop.plot.util import update_and_show -def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs): +def dataset( + dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs +): """ Quickly plot a PyBOP Dataset using Plotly. @@ -51,9 +53,9 @@ def dataset(dataset, signal=None, trace_names=None, show=True, backend=None, **l show=False, xaxis_title=StandardPlot.remove_brackets(dataset.domain), yaxis_title=yaxis_title, - backend = backend, + backend=backend, ) - fig = update_and_show(fig, backend = backend, show=show, **layout_kwargs) + fig = update_and_show(fig, backend=backend, show=show, **layout_kwargs) return fig diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py index 7311edd07..ae087d5fc 100644 --- a/pybop/plot/matplotlib/__init__.py +++ b/pybop/plot/matplotlib/__init__.py @@ -1,7 +1,6 @@ -from .standard_plots import Plotter, SubplotPlotter, trajectories +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories from .convergence import convergence from .contour import contour from .voronoi import surface from .nyquist import nyquist -from .samples import chains, posterior, summary_table, trace from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/matplotlib/samples.py b/pybop/plot/matplotlib/samples.py deleted file mode 100644 index 0d2ca526a..000000000 --- a/pybop/plot/matplotlib/samples.py +++ /dev/null @@ -1,146 +0,0 @@ -import warnings -from typing import TYPE_CHECKING - -from matplotlib import pyplot as plt - -if TYPE_CHECKING: - from pybop.samplers.base_pints_sampler import SamplingResult - - -def trace(result: "SamplingResult", show=True, **kwargs): - """ - Plot trace plots for the posterior samples. - """ - # Warning if layout arguments ignored - if len(kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" - f"{list(kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - figlist = [] - for i in range(result.n_parameters): - fig = plt.figure() - - for j, chain in enumerate(result.chains): - plt.plot(chain[:, i], label=f"Chain {j}") - - plt.title(f"Parameter {i} Trace Plot") - plt.xlabel("Sample Index") - plt.ylabel("Value") - plt.legend(fontsize=12) - figlist.append(fig) - - if show: - plt.show() - else: - return figlist - - -def chains(result: "SamplingResult", show=True, **kwargs): - """ - Plot posterior distributions for each chain. - """ - # Warning if layout arguments ignored - if len(kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" - f"{list(kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - fig = plt.figure(figsize=(15, 8), dpi=100) - - for i, chain in enumerate(result.chains): - for j in range(chain.shape[1]): - plt.hist( - x=chain[:, j], label=f"Chain {i} - Parameter {j}", alpha=0.5, rwidth=2.0 - ) - - for j in range(chain.shape[1]): - plt.plot( - [result.mean[j], result.mean[j]], - [0, result.max[j]], - "--", - lw=3, - label=f"Mean - Parameter {j}", - ) - - plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) - plt.grid(axis="y", zorder=-1) - plt.title("Posterior Distribution") - plt.xlabel("Value") - plt.ylabel("Density") - plt.tight_layout() - if show: - plt.show() - else: - return fig - - -def posterior(result: "SamplingResult", show=True, **kwargs): - """ - Plot the summed posterior distribution across chains. - """ - # Warning if layout arguments ignored - if len(kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" - f"{list(kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - - fig = plt.figure(figsize=(15, 8), dpi=100) - - for j in range(result.all_samples.shape[1]): - plt.hist( - x=result.all_samples[:, j], - label=f"Parameter {j}", - alpha=0.75, - ) - plt.axvline(result.mean[j], ls="--", c="k", lw=3) - - plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0)) - plt.grid(axis="y", zorder=-1) - plt.title("Posterior Distribution") - plt.xlabel("Value") - plt.ylabel("Density") - plt.tight_layout() - if show: - plt.show() - else: - return fig - - -def summary_table(result: "SamplingResult"): - """ - Display summary statistics in a table. - """ - - summary_stats = result.get_summary_statistics() - - header = ["Statistic", "Value"] - values = [ - ["Mean", ", ".join(summary_stats["mean"].astype(str))], - ["Median", ", ".join(summary_stats["median"].astype(str))], - ["Standard Deviation", ", ".join(summary_stats["std"].astype(str))], - ["95% CI Lower", ", ".join(summary_stats["ci_lower"].astype(str))], - ["95% CI Upper", ", ".join(summary_stats["ci_upper"].astype(str))], - ] - fig, ax = plt.subplots(figsize=(6, 2), dpi=100) - - # hide axes - ax.axis("off") - ax.axis("tight") - ax.table( - cellText=values, - colLabels=header, - loc="center", - cellLoc="center", - colColours=["lightsteelblue", "lightsteelblue"], - ) - ax.set_title("Summary Statistics") - fig.tight_layout() - plt.show() diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py index 906cda36f..64c31bd66 100644 --- a/pybop/plot/matplotlib/standard_plots.py +++ b/pybop/plot/matplotlib/standard_plots.py @@ -35,6 +35,10 @@ def __init__( trace_options=None, figsize=(8, 6), title=None, + xaxis_title=None, + yaxis_title=None, + grid=None, + axis_bg_color=None, **kwargs, ): self.backend = "matplotlib" @@ -54,6 +58,19 @@ def __init__( self.trace_options.update(trace_options) self.fig = plt.figure(figsize=figsize, dpi=100) + if title is not None: + plt.suptitle(self.title) + if xaxis_title is not None: + plt.xlabel(xaxis_title) + if yaxis_title is not None: + plt.ylabel(yaxis_title) + if grid is not None: + plt.grid(**grid) + if axis_bg_color is not None: + ax = plt.gca() + ax.set_facecolor(axis_bg_color) + ax.set_axisbelow(True) + self.traces = [] def __call__(self, show=True): @@ -83,9 +100,6 @@ def __call__(self, show=True): **dict(loc="best", fontsize=12), ) - if self.title is not None: - plt.suptitle(self.title) - if show: plt.show() else: @@ -121,7 +135,7 @@ def add_traces(self, x, y, trace_names=None, **trace_options): self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) - def create_trace(self, x, y, label=None, ax=None, **trace_options): + def create_trace(self, x=None, y=None, label=None, ax=None, **trace_options): """ Add line to plot. @@ -130,36 +144,62 @@ def create_trace(self, x, y, label=None, ax=None, **trace_options): plotly.graph_objs.Scatter A trace for a Plotly figure. """ - size = min(len(x), len(y)) - trace = dict(x=x[:size], y=y[:size], label=label, ax=ax) + if x is not None and y is not None: + size = min(len(x), len(y)) + trace = dict(x=x[:size], y=y[:size], label=label, ax=ax) + elif y is not None: + trace = dict(y=y, label=label, ax=ax) + trace.update(trace_options) return trace - + def create_fill_trace(self, x, y_upper, y_lower, **options): - trace = dict(x=x, y=y_upper, plot_type='fill', y_lower=y_lower) + trace = dict(x=x, y=y_upper, plot_type="fill", y_lower=y_lower) trace.update(options) return trace - def _plot_trace(self, x, y, label=None, ax=None, plot_type='plot', **trace_options): + def create_histogram(self, x, name, **trace_options): + trace = dict(x=x, label=name, plot_type="hist") + trace.update(trace_options) + return trace + + def create_vline(self, fig, x, **trace_options): + fig.gca() + plt.axvline(x, **trace_options) + + def _plot_trace( + self, x=None, y=None, label=None, ax=None, plot_type="plot", **trace_options + ): if ax is None: ax = plt.gca() - if plot_type=='plot': - line = ax.plot( - x, - y, - label=label, - **trace_options, - ) - elif plot_type=='fill': - line = ax.fill_between(x, y, **trace_options) - else: - raise ValueError('Plot type not recognised') - - if len(line) > 1: - return line + if plot_type == "plot": + if x is not None: + line = ax.plot( + x, + y, + label=label, + **trace_options, + ) + else: + line = ax.plot( + y, + label=label, + **trace_options, + ) + if len(line) > 1: + return line + else: + return line[0] + elif plot_type == "fill": + y_upper = y + y_lower = trace_options["y_lower"] + del trace_options["y_lower"] + return ax.fill_between(x, y_upper, y_lower, **trace_options) + elif plot_type == "hist": + return ax.hist(x=x, label=label, **trace_options) else: - return line[0] + raise ValueError("Plot type not recognised") class SubplotPlotter(Plotter): @@ -241,7 +281,7 @@ def __call__(self, show=True, num_rows=1, num_cols=1): if self.title is not None: plt.suptitle(self.title) - + if show: plt.show() @@ -301,3 +341,27 @@ def trajectories( fig.show() return fig + + +def show_table(header, values, title): + """ + Display data in a table. + """ + for i, val in enumerate(values): + values[i] = [val[0], ", ".join(val[1].astype(str))] + + fig, ax = plt.subplots(figsize=(6, 2), dpi=100) + + # hide axes + ax.axis("off") + ax.axis("tight") + ax.table( + cellText=values, + colLabels=header, + loc="center", + cellLoc="center", + colColours=["lightsteelblue", "lightsteelblue"], + ) + ax.set_title(title) + fig.tight_layout() + plt.show() diff --git a/pybop/plot/matplotlib/util.py b/pybop/plot/matplotlib/util.py index 6f31ce29a..2226c2966 100644 --- a/pybop/plot/matplotlib/util.py +++ b/pybop/plot/matplotlib/util.py @@ -1,6 +1,8 @@ import warnings + import matplotlib.pyplot as plt + def update_and_show(fig, show=True, **layout_kwargs): if len(layout_kwargs) > 0: warnings.warn( @@ -15,13 +17,40 @@ def update_and_show(fig, show=True, **layout_kwargs): return fig + DEFAULT_PLOT_OPTIONS = { - 'parameters' : dict(figsize=(18, 8), title="Parameter Convergence"), - 'problem' : { - 'default_trace_options' : dict (label="Model", marker=None, linestyle='-'), - 'design_cost_options' : dict(label = "Optimised"), - 'meta_problem_options' : dict(marker=".", linestyle='none'), - 'reference_options' : dict(label="Reference", marker=".", linestyle='none'), - 'fill_options' : dict(color=[(1.0, 0.898, 0.800, 0.8)]) - } -} \ No newline at end of file + "parameters": dict(figsize=(18, 8), title="Parameter Convergence"), + "problem": { + "default_trace_options": dict(label="Model", marker=None, linestyle="-"), + "design_cost_options": dict(label="Optimised"), + "meta_problem_options": dict(marker=".", linestyle="none"), + "reference_options": dict(label="Reference", marker=".", linestyle="none"), + "fill_options": dict(color=[(1.0, 0.898, 0.800, 0.8)]), + }, + "posterior": { + "plot_options": { + "figsize": (15, 8), + "grid": dict(axis="y", zorder=-1, color="w"), + "axis_bg_color": ( + 0.6784313725490196, + 0.8470588235294118, + 0.9019607843137255, + 0.3, + ), + }, + "trace_options": dict(alpha=0.75), + "trace_options_vline": dict(linewidth=3, linestyle="--", color="k"), + }, + "trace": { + "plot_options": { + "figsize": (15, 8), + "grid": dict(axis="y", zorder=-10, color="w"), + "axis_bg_color": ( + 0.6784313725490196, + 0.8470588235294118, + 0.9019607843137255, + 0.3, + ), + }, + }, +} diff --git a/pybop/plot/parameters.py b/pybop/plot/parameters.py index c1082c6ee..934ec67a3 100644 --- a/pybop/plot/parameters.py +++ b/pybop/plot/parameters.py @@ -45,7 +45,7 @@ def parameters(result: "Result", show=True, backend=None, **layout_kwargs): trace_names.append("Sigma") # Set subplot layout options - plot_options = get_default_options('paramters', backend) + plot_options = get_default_options("paramters", backend) # Create a plot dictionary plot_dict = StandardSubplot( diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py index de3350225..ad894f9bb 100644 --- a/pybop/plot/plotly/__init__.py +++ b/pybop/plot/plotly/__init__.py @@ -1,9 +1,8 @@ from .plotly_manager import PlotlyManager -from .standard_plots import Plotter, SubplotPlotter, trajectories +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories from .contour import contour from .convergence import convergence from .nyquist import nyquist from .voronoi import surface -from .samples import chains, posterior, summary_table, trace -from .util import update_and_show, DEFAULT_PLOT_OPTIONS \ No newline at end of file +from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/plotly/samples.py b/pybop/plot/plotly/samples.py deleted file mode 100644 index f5016c325..000000000 --- a/pybop/plot/plotly/samples.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import TYPE_CHECKING - -from pybop.plot.plotly import PlotlyManager - -if TYPE_CHECKING: - from pybop.samplers.base_pints_sampler import SamplingResult - - -def trace(result: "SamplingResult", **kwargs): - """ - Plot trace plots for the posterior samples. - """ - # Import plotly only when needed - go = PlotlyManager().go - - for i in range(result.n_parameters): - fig = go.Figure() - - for j, chain in enumerate(result.chains): - fig.add_trace(go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}")) - - fig.update_layout( - title=f"Parameter {i} Trace Plot", - xaxis_title="Sample Index", - yaxis_title="Value", - ) - fig.update_layout(**kwargs) - fig.show() - - -def chains(result: "SamplingResult", **kwargs): - """ - Plot posterior distributions for each chain. - """ - # Import plotly only when needed - go = PlotlyManager().go - - fig = go.Figure() - - for i, chain in enumerate(result.chains): - for j in range(chain.shape[1]): - fig.add_trace( - go.Histogram( - x=chain[:, j], - name=f"Chain {i} - Parameter {j}", - opacity=0.75, - ) - ) - - fig.add_shape( - type="line", - x0=result.mean[j], - y0=0, - x1=result.mean[j], - y1=result.max[j], - name=f"Mean - Parameter {j}", - line=dict(color="Black", width=1.5, dash="dash"), - ) - - fig.update_layout( - barmode="overlay", - title="Posterior Distribution", - xaxis_title="Value", - yaxis_title="Density", - ) - fig.update_layout(**kwargs) - fig.show() - - -def posterior(result: "SamplingResult", **kwargs): - """ - Plot the summed posterior distribution across chains. - """ - # Import plotly only when needed - go = PlotlyManager().go - - fig = go.Figure() - - for j in range(result.all_samples.shape[1]): - histogram = go.Histogram( - x=result.all_samples[:, j], - name=f"Parameter {j}", - opacity=0.75, - ) - fig.add_trace(histogram) - fig.add_vline( - x=result.mean[j], line_width=3, line_dash="dash", line_color="black" - ) - - fig.update_layout( - barmode="overlay", - title="Posterior Distribution", - xaxis_title="Value", - yaxis_title="Density", - ) - fig.update_layout(**kwargs) - fig.show() - return fig - - -def summary_table(result: "SamplingResult"): - """ - Display summary statistics in a table. - """ - # Import plotly only when needed - go = PlotlyManager().go - - summary_stats = result.get_summary_statistics() - - header = ["Statistic", "Value"] - values = [ - ["Mean", summary_stats["mean"]], - ["Median", summary_stats["median"]], - ["Standard Deviation", summary_stats["std"]], - ["95% CI Lower", summary_stats["ci_lower"]], - ["95% CI Upper", summary_stats["ci_upper"]], - ] - - fig = go.Figure( - data=[ - go.Table( - header=dict(values=header), - cells=dict( - values=[[row[0] for row in values], [row[1] for row in values]] - ), - ) - ] - ) - - fig.update_layout(title="Summary Statistics") - fig.show() diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py index adce31af2..537a24286 100644 --- a/pybop/plot/plotly/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -1,5 +1,3 @@ -import warnings - from pybop.plot import StandardPlot from pybop.plot.plotly.plotly_manager import PlotlyManager @@ -62,20 +60,13 @@ class Plotter: def __init__( self, + title=None, + xaxis_title=None, + yaxis_title=None, layout=None, layout_options=None, trace_options=None, - **kwargs, ): - # Warning if layout arguments ignored - if len(kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (plotly): \n" - f"{list(kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - self.backend = "plotly" self.traces = [] @@ -99,6 +90,16 @@ def __init__( if self.layout is None: self.layout = self.go.Layout(**self.layout_options) + title_options = {} + if title is not None: + title_options.update({"title": title}) + if title is not None: + title_options.update({"xaxis_title": xaxis_title}) + if title is not None: + title_options.update({"yaxis_title": yaxis_title}) + + self.layout.update(**title_options) + def __call__(self, show=True): """ Generate and show the figure. @@ -143,7 +144,7 @@ def add_traces(self, x, y, trace_names=None, **trace_options): trace = self.create_trace(xi, y[i], **trace_options) self.traces.append(trace) - def create_trace(self, x, y, **trace_options): + def create_trace(self, x=None, y=None, label=None, **trace_options): """ Create a trace for the Plotly figure. @@ -152,23 +153,36 @@ def create_trace(self, x, y, **trace_options): plotly.graph_objs.Scatter A trace for a Plotly figure. """ + if label is not None: + trace_options.update({"name": label}) + if x is not None and y is not None: + return self.go.Scatter( + x=x, + y=y, + **trace_options, + ) + if x is None and y is not None: + return self.go.Scatter( + y=y, + **trace_options, + ) - return self.go.Scatter( - x=x, - y=y, - **trace_options, - ) - def create_fill_trace(self, x, y_upper, y_lower, **options): return self.create_trace( - x=x + x[::-1], - y=y_upper + y_lower[::-1], - fill="toself", - line=dict(color="rgba(255,255,255,0)"), - hoverinfo="skip", - showlegend=False, - **options - ) + x=x + x[::-1], + y=y_upper + y_lower[::-1], + fill="toself", + line=dict(color="rgba(255,255,255,0)"), + hoverinfo="skip", + showlegend=False, + **options, + ) + + def create_histogram(self, x, name, **trace_options): + return self.go.Histogram(x=x, name=name, **trace_options) + + def create_vline(self, fig, x, **trace_options): + fig.add_vline(x=x, **trace_options) class SubplotPlotter(Plotter): @@ -294,3 +308,24 @@ def trajectories(x, y, trace_names=None, show=True, **layout_kwargs): fig.show() return fig + + +def show_table(header, values, title): + """ + Display data in a table. + """ + # Import plotly only when needed + go = PlotlyManager().go + fig = go.Figure( + data=[ + go.Table( + header=dict(values=header), + cells=dict( + values=[[row[0] for row in values], [row[1] for row in values]] + ), + ) + ] + ) + + fig.update_layout(title=title) + fig.show() diff --git a/pybop/plot/plotly/util.py b/pybop/plot/plotly/util.py index 563dde595..f622890fb 100644 --- a/pybop/plot/plotly/util.py +++ b/pybop/plot/plotly/util.py @@ -1,24 +1,54 @@ -import warnings - def update_and_show(fig, show=True, **layout_kwargs): - fig.update_layout(**layout_kwargs) - if show: - fig.show() + if hasattr(fig, "__len__") and len(fig) > 0: + for f in fig: + f.update_layout(**layout_kwargs) + if show: + f.show() + else: + fig.update_layout(**layout_kwargs) + if show: + fig.show() return fig DEFAULT_PLOT_OPTIONS = { - 'parameters' : dict(layout_options = dict( - title="Parameter Convergence", - width=1024, - height=576, - legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), - )), - 'problem' : { - 'default_trace_options' : dict (name="Model", mode="lines", showlegend=True), - 'design_cost_options' : dict(name = "Optimised"), - 'meta_problem_options' : dict(mode = "lines"), - 'reference_options' : dict(name="Reference", mode="markers", showlegend=True), - 'fill_options' : dict(fillcolor ="rgba(255,229,204,0.8)" ) - } -} \ No newline at end of file + "parameters": dict( + layout_options=dict( + title="Parameter Convergence", + width=1024, + height=576, + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + ) + ), + "problem": { + "default_trace_options": dict(name="Model", mode="lines", showlegend=True), + "design_cost_options": dict(name="Optimised"), + "meta_problem_options": dict(mode="lines"), + "reference_options": dict(name="Reference", mode="markers", showlegend=True), + "fill_options": dict(fillcolor="rgba(255,229,204,0.8)"), + }, + "trace": { + "plot_options": { + "layout_options": dict( + width=None, height=None, plot_bgcolor=None, autosize=None, legend=None + ) + }, + "trace_options": dict(mode="lines"), + }, + "posterior": { + "plot_options": { + "layout_options": dict( + barmode="overlay", + width=None, + height=None, + plot_bgcolor=None, + autosize=None, + legend=None, + ) + }, + "trace_options": dict(opacity=0.75), + "trace_options_vline": dict(line_width=3, line_dash="dash", line_color="black"), + }, +} diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index b51745999..99dc699b7 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -4,20 +4,11 @@ if TYPE_CHECKING: from pybop._result import Result - from pybop.samplers.base_pints_sampler import SamplingResult -from pybop.parameters.parameter import Inputs from pybop.plot.util import call_plotting_function from pybop.problems.problem import Problem -def chains(result: "SamplingResult", show=True, backend=None, **kwargs): - """ - Plot posterior distributions for each chain. - """ - return call_plotting_function("chains", backend, result=result, **kwargs) - - def contour( call_object: "Problem | Result", gradient: bool = False, @@ -103,20 +94,6 @@ def convergence(result: "Result", show=True, backend=None, **layout_kwargs): "convergence", backend, result=result, show=show, **layout_kwargs ) -def posterior(result: "SamplingResult", show=True, backend=None, **kwargs): - """ - Plot the summed posterior distribution across chains. - """ - return call_plotting_function("posterior", backend, result=result, **kwargs) - - -def summary_table(result: "SamplingResult", backend=None): - """ - Display summary statistics in a table. - """ - - return call_plotting_function("summary_table", backend, result=result) - def surface( result: "Result", @@ -159,10 +136,3 @@ def surface( show=show, **layout_kwargs, ) - - -def trace(result: "SamplingResult", backend=None, **kwargs): - """ - Plot trace plots for the posterior samples. - """ - return call_plotting_function("trace", backend, result=result, **kwargs) diff --git a/pybop/plot/problem.py b/pybop/plot/problem.py index b2fabaaa0..27b573641 100644 --- a/pybop/plot/problem.py +++ b/pybop/plot/problem.py @@ -3,7 +3,7 @@ from pybop.costs.design_cost import DesignCost from pybop.costs.error_measures import ErrorMeasure from pybop.parameters.parameter import Inputs -from pybop.plot.matplotlib.standard_plots import StandardPlot +from pybop.plot.standard_plots import StandardPlot from pybop.plot.util import get_default_options, update_and_show from pybop.problems.meta_problem import MetaProblem from pybop.problems.problem import Problem @@ -15,7 +15,7 @@ def problem( inputs: Inputs = None, show: bool = True, title="Scatter Plot", - backend = None, + backend=None, ): """ Produce a quick plot of the target dataset against optimised model output. @@ -65,12 +65,12 @@ def problem( model_domain = target_domain[: len(model_output[target].data)] # Retrieve default layout options - plot_options = get_default_options('problem', backend) - trace_options = plot_options.get('default_trace_options') or {} - design_cost_options = plot_options.get('design_cost_options') or {} - meta_problem_options = plot_options.get('meta_problem_options') or {} - reference_options = plot_options.get('reference_options') or {} - fill_options = plot_options.get('fill_options') or {} + plot_options = get_default_options("problem", backend) + trace_options = plot_options.get("default_trace_options") or {} + design_cost_options = plot_options.get("design_cost_options") or {} + meta_problem_options = plot_options.get("meta_problem_options") or {} + reference_options = plot_options.get("reference_options") or {} + fill_options = plot_options.get("fill_options") or {} # Create a plot for each output figure_list = [] @@ -81,22 +81,21 @@ def problem( if isinstance(problem.cost, DesignCost): options.update(design_cost_options) - print(isinstance(problem, MetaProblem), isinstance(problem.cost, DesignCost), options) - # Create a plot dictionary - plot_dict = StandardPlot(backend=backend) + plot_dict = StandardPlot( + title=title, + xaxis_title=StandardPlot.remove_brackets(domain), + yaxis_title=StandardPlot.remove_brackets(var), + backend=backend, + ) model_trace = plot_dict.create_trace( - x=model_domain, - y=model_output[var].data, - **options + x=model_domain, y=model_output[var].data, **options ) plot_dict.traces.append(model_trace) - target_trace =plot_dict.create_trace( - x=target_domain, - y=target_output[var].data, - **reference_options + target_trace = plot_dict.create_trace( + x=target_domain, y=target_output[var].data, **reference_options ) plot_dict.traces.append(target_trace) @@ -111,15 +110,16 @@ def problem( y_upper = (model_output[var].data + plot_dict.sigma).tolist() y_lower = (model_output[var].data - plot_dict.sigma).tolist() - fill_trace = plot_dict.create_fill_trace(x, y_upper, y_lower, **fill_options) + fill_trace = plot_dict.create_fill_trace( + x, y_upper, y_lower, **fill_options + ) + plot_dict.traces.append(fill_trace) + + # Reverse the order of the traces to put the model on top plot_dict.traces = plot_dict.traces[::-1] + # Generate the figure and update the layout fig = plot_dict(show=False) - # plt.xlabel("Time / s") - # plt.ylabel(StandardPlot.remove_brackets(var)) - # plt.title(title) - # plt.legend() - # plt.tight_layout() fig = update_and_show(fig, backend=backend) figure_list.append(fig) diff --git a/pybop/plot/samples.py b/pybop/plot/samples.py new file mode 100644 index 000000000..fe80746cc --- /dev/null +++ b/pybop/plot/samples.py @@ -0,0 +1,117 @@ +from typing import TYPE_CHECKING + +from pybop.plot import StandardPlot +from pybop.plot.util import update_and_show, get_default_options, call_plotting_function + +if TYPE_CHECKING: + from pybop.samplers.base_pints_sampler import SamplingResult + + +def chains(result: "SamplingResult", show=True, backend=None): + """ + Plot posterior distributions for each chain. + """ + options = get_default_options('posterior', backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + trace_options_vline = options.get("trace_options_vline") or {} + + plot_dict = StandardPlot( + backend=backend, + title="Posterior Distribution", + xaxis_title="Value", + yaxis_title="Density", + **plot_options + ) + + for i, chain in enumerate(result.chains): + for j in range(chain.shape[1]): + hist = plot_dict.create_histogram( + x=chain[:, j], + name=f"Chain {i} - Parameter {j}", + **trace_options + ) + plot_dict.traces.append(hist) + + fig = plot_dict(show=False) + for j in range(chain.shape[1]): + plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) + + update_and_show(fig, backend=backend) + +def trace(result: "SamplingResult", show = True, backend=None): + """ + Plot trace plots for the posterior samples. + """ + figlist = [] + options = get_default_options('trace', backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + for i in range(result.n_parameters): + plots = StandardPlot( + title=f"Parameter {i} Trace Plot", + xaxis_title="Sample Index", + yaxis_title="Value", + backend=backend, + **plot_options + ) + + for j, chain in enumerate(result.chains): + plots.traces.append(plots.create_trace(y=chain[:, i], label=f"Chain {j}", **trace_options)) + fig = plots(show=False) + figlist.append(fig) + + update_and_show(figlist, show=show, backend=backend) + + return figlist + +def posterior(result: "SamplingResult", backend=None): + """ + Plot the summed posterior distribution across chains. + """ + options = get_default_options('posterior', backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + trace_options_vline = options.get("trace_options_vline") or {} + # Import plotly only when needed + plot_dict = StandardPlot( + backend=backend, + title="Posterior Distribution", + xaxis_title="Value", + yaxis_title="Density", + **plot_options + ) + + for j in range(result.all_samples.shape[1]): + hist = plot_dict.create_histogram( + x=result.all_samples[:, j], + name=f"Parameter {j}", + **trace_options + ) + + plot_dict.traces.append(hist) + + + fig = plot_dict(show=False) + for j in range(result.all_samples.shape[1]): + plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) + + update_and_show(fig, backend=backend) + +def summary_table(result: "SamplingResult", backend=None): + """ + Display summary statistics in a table. + """ + + summary_stats = result.get_summary_statistics() + + header = ["Statistic", "Value"] + values = [ + ["Mean", summary_stats["mean"]], + ["Median", summary_stats["median"]], + ["Standard Deviation", summary_stats["std"]], + ["95% CI Lower", summary_stats["ci_lower"]], + ["95% CI Upper", summary_stats["ci_upper"]], + ] + + call_plotting_function('show_table', backend=backend, header=header, values=values, title="Summary Statistics") \ No newline at end of file diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index db566a533..b1a6f9e33 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -11,6 +11,9 @@ def __init__( self, x=None, y=None, + title=None, + xaxis_title=None, + yaxis_title=None, trace_options=None, trace_names=None, trace_name_width=20, @@ -19,7 +22,13 @@ def __init__( ): self.plotter = call_plotting_function( - "Plotter", backend, trace_options=trace_options, **kwargs + "Plotter", + backend, + title=title, + xaxis_title=xaxis_title, + yaxis_title=yaxis_title, + trace_options=trace_options, + **kwargs, ) self.trace_name_width = trace_name_width @@ -94,12 +103,18 @@ def parse_data(self, x, y): ) return x, y - def create_trace(self, x, y, **trace_options): - return self.plotter.create_trace(x, y, **trace_options) - + def create_trace(self, x=None, y=None, label=None, **trace_options): + return self.plotter.create_trace(x, y, label, **trace_options) + def create_fill_trace(self, x, y_upper, y_lower, **options): return self.plotter.create_fill_trace(x, y_upper, y_lower, **options) + def create_histogram(self, x, name, **trace_options): + return self.plotter.create_histogram(x, name, **trace_options) + + def create_vline(self, fig, x, **trace_options): + return self.plotter.create_vline(fig, x, **trace_options) + @staticmethod def wrap_text(text, width, backend="matplotlib"): """ diff --git a/pybop/plot/util.py b/pybop/plot/util.py index 27cf9303c..a30b0cee1 100644 --- a/pybop/plot/util.py +++ b/pybop/plot/util.py @@ -37,16 +37,16 @@ def call_plotting_function(function_name, backend, **kwargs): def update_and_show(fig, backend=None, **kwargs): - return call_plotting_function("update_and_show", backend, fig=fig, **kwargs) + return call_plotting_function("update_and_show", backend, fig=fig, **kwargs) def get_default_options(plot_type, backend): if backend is None: backend = pybop.plot.backend - if backend == 'plotly': + if backend == "plotly": opts = pybop.plot.plotly.DEFAULT_PLOT_OPTIONS - elif backend == 'matplotlib': + elif backend == "matplotlib": opts = pybop.plot.matplotlib.DEFAULT_PLOT_OPTIONS else: opts = {} @@ -54,4 +54,4 @@ def get_default_options(plot_type, backend): if plot_type in opts.keys(): return opts[plot_type] else: - return {} \ No newline at end of file + return {} From 072e8fab12a4843056628c3e6b189f4286c881a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:45:22 +0000 Subject: [PATCH 08/10] style: pre-commit fixes --- pybop/plot/samples.py | 44 ++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pybop/plot/samples.py b/pybop/plot/samples.py index fe80746cc..3ea022cce 100644 --- a/pybop/plot/samples.py +++ b/pybop/plot/samples.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from pybop.plot import StandardPlot -from pybop.plot.util import update_and_show, get_default_options, call_plotting_function +from pybop.plot.util import call_plotting_function, get_default_options, update_and_show if TYPE_CHECKING: from pybop.samplers.base_pints_sampler import SamplingResult @@ -11,7 +11,7 @@ def chains(result: "SamplingResult", show=True, backend=None): """ Plot posterior distributions for each chain. """ - options = get_default_options('posterior', backend) + options = get_default_options("posterior", backend) plot_options = options.get("plot_options") or {} trace_options = options.get("trace_options") or {} trace_options_vline = options.get("trace_options_vline") or {} @@ -21,30 +21,29 @@ def chains(result: "SamplingResult", show=True, backend=None): title="Posterior Distribution", xaxis_title="Value", yaxis_title="Density", - **plot_options + **plot_options, ) for i, chain in enumerate(result.chains): for j in range(chain.shape[1]): hist = plot_dict.create_histogram( - x=chain[:, j], - name=f"Chain {i} - Parameter {j}", - **trace_options + x=chain[:, j], name=f"Chain {i} - Parameter {j}", **trace_options ) plot_dict.traces.append(hist) fig = plot_dict(show=False) for j in range(chain.shape[1]): plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) - + update_and_show(fig, backend=backend) -def trace(result: "SamplingResult", show = True, backend=None): + +def trace(result: "SamplingResult", show=True, backend=None): """ Plot trace plots for the posterior samples. """ figlist = [] - options = get_default_options('trace', backend) + options = get_default_options("trace", backend) plot_options = options.get("plot_options") or {} trace_options = options.get("trace_options") or {} for i in range(result.n_parameters): @@ -53,11 +52,13 @@ def trace(result: "SamplingResult", show = True, backend=None): xaxis_title="Sample Index", yaxis_title="Value", backend=backend, - **plot_options + **plot_options, ) for j, chain in enumerate(result.chains): - plots.traces.append(plots.create_trace(y=chain[:, i], label=f"Chain {j}", **trace_options)) + plots.traces.append( + plots.create_trace(y=chain[:, i], label=f"Chain {j}", **trace_options) + ) fig = plots(show=False) figlist.append(fig) @@ -65,11 +66,12 @@ def trace(result: "SamplingResult", show = True, backend=None): return figlist + def posterior(result: "SamplingResult", backend=None): """ Plot the summed posterior distribution across chains. """ - options = get_default_options('posterior', backend) + options = get_default_options("posterior", backend) plot_options = options.get("plot_options") or {} trace_options = options.get("trace_options") or {} trace_options_vline = options.get("trace_options_vline") or {} @@ -79,25 +81,23 @@ def posterior(result: "SamplingResult", backend=None): title="Posterior Distribution", xaxis_title="Value", yaxis_title="Density", - **plot_options + **plot_options, ) for j in range(result.all_samples.shape[1]): hist = plot_dict.create_histogram( - x=result.all_samples[:, j], - name=f"Parameter {j}", - **trace_options + x=result.all_samples[:, j], name=f"Parameter {j}", **trace_options ) plot_dict.traces.append(hist) - fig = plot_dict(show=False) for j in range(result.all_samples.shape[1]): plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) update_and_show(fig, backend=backend) + def summary_table(result: "SamplingResult", backend=None): """ Display summary statistics in a table. @@ -113,5 +113,11 @@ def summary_table(result: "SamplingResult", backend=None): ["95% CI Lower", summary_stats["ci_lower"]], ["95% CI Upper", summary_stats["ci_upper"]], ] - - call_plotting_function('show_table', backend=backend, header=header, values=values, title="Summary Statistics") \ No newline at end of file + + call_plotting_function( + "show_table", + backend=backend, + header=header, + values=values, + title="Summary Statistics", + ) From d439edfe35fd0aef08640f568caa9333be665e1b Mon Sep 17 00:00:00 2001 From: u2370093 Date: Mon, 13 Apr 2026 17:45:16 +0100 Subject: [PATCH 09/10] simplify contour, nyquist, convergence --- pybop/plot/__init__.py | 4 +- pybop/plot/{matplotlib => }/contour.py | 93 ++++----- pybop/plot/{plotly => }/convergence.py | 18 +- pybop/plot/matplotlib/__init__.py | 5 +- pybop/plot/matplotlib/convergence.py | 51 ----- pybop/plot/matplotlib/nyquist.py | 80 ------- pybop/plot/matplotlib/standard_plots.py | 84 +++++--- pybop/plot/matplotlib/util.py | 36 ++++ pybop/plot/nyquist.py | 22 +- pybop/plot/plotly/__init__.py | 5 +- pybop/plot/plotly/contour.py | 263 ------------------------ pybop/plot/plotly/nyquist.py | 101 --------- pybop/plot/plotly/standard_plots.py | 47 +++-- pybop/plot/plotly/util.py | 124 +++++++++-- pybop/plot/plots.py | 25 --- pybop/plot/standard_plots.py | 3 + 16 files changed, 297 insertions(+), 664 deletions(-) rename pybop/plot/{matplotlib => }/contour.py (75%) rename pybop/plot/{plotly => }/convergence.py (78%) delete mode 100644 pybop/plot/matplotlib/convergence.py delete mode 100644 pybop/plot/matplotlib/nyquist.py delete mode 100644 pybop/plot/plotly/contour.py delete mode 100644 pybop/plot/plotly/nyquist.py diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 5c6eced49..da5e84d44 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -8,12 +8,12 @@ # Import plots # from .plots import ( - contour, - convergence, surface ) from .standard_plots import StandardPlot, StandardSubplot, trajectories +from .contour import contour +from .convergence import convergence from .dataset import dataset from .nyquist import nyquist from .parameters import parameters diff --git a/pybop/plot/matplotlib/contour.py b/pybop/plot/contour.py similarity index 75% rename from pybop/plot/matplotlib/contour.py rename to pybop/plot/contour.py index d0127375a..c6ff92786 100644 --- a/pybop/plot/matplotlib/contour.py +++ b/pybop/plot/contour.py @@ -3,8 +3,9 @@ from typing import TYPE_CHECKING import numpy as np -from matplotlib import pyplot as plt +from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import call_plotting_function, get_default_options, update_and_show from pybop.problems.problem import Problem if TYPE_CHECKING: @@ -18,7 +19,7 @@ def contour( transformed: bool = False, steps: int = 10, show: bool = True, - title: str = "Cost Landscape", + backend=None, ): """ Plot a 2D visualisation of a cost landscape using Plotly. @@ -143,77 +144,57 @@ def transform_array_of_values(list_of_values, parameter): bounds[0] = transform_array_of_values(bounds[0], parameters[names[0]]) bounds[1] = transform_array_of_values(bounds[1], parameters[names[1]]) - # define levels - exponent = np.floor(np.log10(np.abs(np.max(costs)))) - levels = np.linspace( - np.floor(np.min(costs) / (10**exponent)) * (10**exponent), - np.ceil(np.max(costs) / (10**exponent)) * (10**exponent), - 2 * steps - 1, - ) + # Get options + options = get_default_options("contour", backend) + plot_options = options.get("plot_options") or {} + trace_options_initial = options.get("trace_options_initial") or {} + trace_options_optim = options.get("trace_options_optim") or {} + trace_options_contour = options.get("trace_options_contour") or {} + + plot_dict = StandardPlot( + xaxis_title="Transformed " + names[0] if transformed else names[0], + yaxis_title="Transformed " + names[1] if transformed else names[1], + xaxis_range=bounds[0], + yaxis_range=bounds[1], + backend=backend, + **plot_options + ) # Create contour plot and update the layout - fig = plt.figure(figsize=(6, 6), dpi=100) - plt.contourf(x, y, costs, levels=levels, extend="both", cmap="viridis") - plt.colorbar() - plt.contour( - x, y, costs, levels=levels, colors=("k",), linestyles="solid", linewidths=0.1 - ) - - # Layout - plt.xlabel("Transformed " + names[0] if transformed else names[0], labelpad=15) - plt.ticklabel_format(axis="both", **dict(style="sci", scilimits=(-4, 4))) - plt.ylabel("Transformed " + names[1] if transformed else names[1], labelpad=15) - plt.title(title, pad=40) - plt.xlim(bounds[0]) - plt.ylim(bounds[1]) + plot_dict.create_contour(x=x, y=y, z=costs, **trace_options_contour) if plot_optim: # Plot the optimisation trace optim_trace = np.asarray([item[:2] for item in result.x_model]) optim_trace = optim_trace.reshape(-1, 2) - - plt.scatter( - transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), - transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), - c=[i / len(optim_trace) for i in range(len(optim_trace))], - cmap="Grays", - zorder=1, + call_plotting_function('plot_optimisation_path', backend=backend, + plot_dict=plot_dict, + x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), + y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), ) # Plot the initial guess if len(result.x_model) > 0: x0 = result.x_model[0] - plt.plot( - transform_array_of_values([x0[0]], parameters[names[0]]), - transform_array_of_values([x0[1]], parameters[names[1]]), - "X", - markersize=14, - markerfacecolor="w", - markeredgecolor="k", - label="Initial values", - linestyle="None", - ) + plot_dict.traces.append(plot_dict.create_trace( + x=transform_array_of_values([x0[0]], parameters[names[0]]), + y=transform_array_of_values([x0[1]], parameters[names[1]]), + **trace_options_initial + )) # Plot optimised value if result.x is not None: x_best = result.x - plt.plot( - transform_array_of_values([x_best[0]], parameters[names[0]]), - transform_array_of_values([x_best[1]], parameters[names[1]]), - "P", - markersize=14, - markerfacecolor="k", - markeredgecolor="w", - label="Final values", - linestyle="None", - ) - - plt.legend(ncols=2, loc="lower center", bbox_to_anchor=(0.5, 1.0)) - - plt.tight_layout() - + plot_dict.traces.append(plot_dict.create_trace( + x=transform_array_of_values([x_best[0]], parameters[names[0]]), + y=transform_array_of_values([x_best[1]], parameters[names[1]]), + **trace_options_optim + )) + + # Update the layout and display the figure + fig = plot_dict(show=False) if show: - plt.show() + update_and_show(fig) # if gradient: # grad_figs = [] diff --git a/pybop/plot/plotly/convergence.py b/pybop/plot/convergence.py similarity index 78% rename from pybop/plot/plotly/convergence.py rename to pybop/plot/convergence.py index 6b190a93c..9e3b1bbb7 100644 --- a/pybop/plot/plotly/convergence.py +++ b/pybop/plot/convergence.py @@ -1,12 +1,13 @@ from typing import TYPE_CHECKING -from pybop.plot.plotly.standard_plots import StandardPlot +from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import update_and_show if TYPE_CHECKING: from pybop._result import Result -def convergence(result: "Result", show=True, **layout_kwargs): +def convergence(result: "Result", show=True, backend=None): """ Plot the convergence of the optimisation algorithm. @@ -37,18 +38,15 @@ def convergence(result: "Result", show=True, **layout_kwargs): plot_dict = StandardPlot( x=iteration_numbers, y=cost_log, - layout_options=dict( - xaxis_title="Evaluation", - yaxis_title="Cost", - title="Convergence", - ), + xaxis_title="Evaluation", + yaxis_title="Cost", + title="Convergence", trace_names=result.method_name, + backend=backend ) # Generate and display the figure fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) if show: - fig.show() - + update_and_show(fig, backend=backend) return fig diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py index ae087d5fc..2730122df 100644 --- a/pybop/plot/matplotlib/__init__.py +++ b/pybop/plot/matplotlib/__init__.py @@ -1,6 +1,3 @@ -from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories -from .convergence import convergence -from .contour import contour +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path from .voronoi import surface -from .nyquist import nyquist from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/matplotlib/convergence.py b/pybop/plot/matplotlib/convergence.py deleted file mode 100644 index 2e53e4738..000000000 --- a/pybop/plot/matplotlib/convergence.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import TYPE_CHECKING - -import matplotlib.pyplot as plt - -from pybop.plot.matplotlib.standard_plots import StandardPlot - -if TYPE_CHECKING: - from pybop._result import Result - - -def convergence(result: "Result", show=True): - """ - Plot the convergence of the optimisation algorithm. - - Parameters - ----------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - show : bool, optional - If True, the figure is shown upon creation (default: True). - - Returns - --------- - fig : plotly.graph_objs.Figure - The Plotly figure object for the convergence plot. - """ - - # Extract log from the optimisation object - cost_log = result.cost_convergence - - # Generate a list of iteration numbers - iteration_numbers = list(range(1, len(cost_log) + 1)) - - # Create a plot dictionary - plot_dict = StandardPlot( - x=iteration_numbers, - y=cost_log, - trace_names=result.method_name, - ) - - # Generate and display the figure - fig = plot_dict(show=False) - plt.xlabel("Evaluation") - plt.ylabel("Cost") - plt.title("Convergence") - plt.tight_layout() - - if show: - plt.show() - - return fig diff --git a/pybop/plot/matplotlib/nyquist.py b/pybop/plot/matplotlib/nyquist.py deleted file mode 100644 index 5c6a64144..000000000 --- a/pybop/plot/matplotlib/nyquist.py +++ /dev/null @@ -1,80 +0,0 @@ -import warnings - -from matplotlib import pyplot as plt - -from pybop.parameters.parameter import Inputs -from pybop.plot.nyquist import _nyquist - - -def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): - """ - Generates Nyquist plots for the given problem by evaluating the model's output and target values. - - Parameters - ---------- - problem : pybop.Problem - An instance of a problem class that contains the parameters and methods - for evaluation and target retrieval. - inputs : Inputs, optional - Input parameters for the problem. If not provided, the default parameters from the problem - instance will be used. These parameters are verified before use (default is None). - show : bool, optional - If True, the plots will be displayed. - **layout_kwargs : dict, optional - Additional keyword arguments for customising the plot layout. These arguments are passed to - `fig.update_layout()`. - - Returns - ------- - list - A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. - - Notes - ----- - - The function extracts the real part of the impedance from the model's output and the real and imaginary parts - of the impedance from the target output. - - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. - - An additional trace for the reference (target output) is added to the plot. - - The plot layout can be customised using `layout_kwargs`. - - Example - ------- - >>> problem = pybop.EISProblem() - >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") - >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. - """ - - if len(layout_kwargs) > 0: - warnings.warn( - "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" - f"{list(layout_kwargs.keys())}", - UserWarning, - stacklevel=2, - ) - trace_options_model = dict( - label="Model", color="#00CC96", linewidth=2, marker=".", markersize=8 - ) - trace_options_reference = dict( - label="Reference", - linestyle="none", - marker="o", - fillstyle="none", - markersize=8, - markeredgecolor="#636EFA", - ) - figure_list = _nyquist( - problem, trace_options_model, trace_options_reference, inputs=inputs - ) - - for fig in figure_list: - plt.sca(fig.gca()) - # Layout - plt.title("Nyquist Plot", fontsize=14, x=0.2) - plt.xlabel(r"$Z_{re} / \Omega$", fontsize=16) - plt.ylabel(r"$-Z_{im} / \Omega$", fontsize=16) - plt.legend(loc="upper right", bbox_to_anchor=(1, 1.08), ncols=2) - - if show: - plt.show() - - return figure_list diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py index 64c31bd66..bd49cd6d5 100644 --- a/pybop/plot/matplotlib/standard_plots.py +++ b/pybop/plot/matplotlib/standard_plots.py @@ -37,6 +37,8 @@ def __init__( title=None, xaxis_title=None, yaxis_title=None, + xaxis_range=None, + yaxis_range=None, grid=None, axis_bg_color=None, **kwargs, @@ -70,7 +72,10 @@ def __init__( ax = plt.gca() ax.set_facecolor(axis_bg_color) ax.set_axisbelow(True) - + if xaxis_range is not None: + plt.xlim(xaxis_range) + if yaxis_range is not None: + plt.ylim(yaxis_range) self.traces = [] def __call__(self, show=True): @@ -146,20 +151,20 @@ def create_trace(self, x=None, y=None, label=None, ax=None, **trace_options): """ if x is not None and y is not None: size = min(len(x), len(y)) - trace = dict(x=x[:size], y=y[:size], label=label, ax=ax) + trace = dict(positional_args=[x[:size], y[:size]], label=label, ax=ax) elif y is not None: - trace = dict(y=y, label=label, ax=ax) + trace = dict(positional_args=[y], label=label, ax=ax) trace.update(trace_options) return trace def create_fill_trace(self, x, y_upper, y_lower, **options): - trace = dict(x=x, y=y_upper, plot_type="fill", y_lower=y_lower) + trace = dict(positional_args=(x, y_upper, y_lower), plot_type="fill_between") trace.update(options) return trace def create_histogram(self, x, name, **trace_options): - trace = dict(x=x, label=name, plot_type="hist") + trace = dict(positional_args=[x], label=name, plot_type="hist") trace.update(trace_options) return trace @@ -167,39 +172,42 @@ def create_vline(self, fig, x, **trace_options): fig.gca() plt.axvline(x, **trace_options) + def create_contour(self, x, y, z, **trace_options): + contour = dict(positional_args=[x, y, z], plot_type="contourf") + contour.update(**trace_options) + self.traces.append(contour) + self.traces.append( + dict( + positional_args=[x, y, z], + colors=("k"), + linestyles="solid", + linewidths=0.2, + plot_type="contour", + ) + ) + + def create_scatter(self, x, y, **trace_options): + scatter = dict(positional_args=[x, y], plot_type="scatter") + scatter.update(**trace_options) + self.traces.append(scatter) + def _plot_trace( - self, x=None, y=None, label=None, ax=None, plot_type="plot", **trace_options + self, ax=None, plot_type="plot", positional_args=None, **trace_options ): + if positional_args is None: + positional_args = [] if ax is None: ax = plt.gca() + try: + plot_function = getattr(ax, plot_type) + except ValueError: + print("Plot type not recognised") - if plot_type == "plot": - if x is not None: - line = ax.plot( - x, - y, - label=label, - **trace_options, - ) - else: - line = ax.plot( - y, - label=label, - **trace_options, - ) - if len(line) > 1: - return line - else: - return line[0] - elif plot_type == "fill": - y_upper = y - y_lower = trace_options["y_lower"] - del trace_options["y_lower"] - return ax.fill_between(x, y_upper, y_lower, **trace_options) - elif plot_type == "hist": - return ax.hist(x=x, label=label, **trace_options) - else: - raise ValueError("Plot type not recognised") + obj = plot_function(*positional_args, **trace_options) + if plot_type == "contourf": + plt.colorbar(obj) + + return obj class SubplotPlotter(Plotter): @@ -365,3 +373,13 @@ def show_table(header, values, title): ax.set_title(title) fig.tight_layout() plt.show() + + +def plot_optimisation_path(plot_dict: StandardPlot, x, y): + plot_dict.plotter.create_scatter( + x, + y, + c=[i / len(x) for i in range(len(x))], + cmap="Grays", + zorder=1, + ) diff --git a/pybop/plot/matplotlib/util.py b/pybop/plot/matplotlib/util.py index 2226c2966..b6a7311d8 100644 --- a/pybop/plot/matplotlib/util.py +++ b/pybop/plot/matplotlib/util.py @@ -19,6 +19,42 @@ def update_and_show(fig, show=True, **layout_kwargs): DEFAULT_PLOT_OPTIONS = { + "contour": { + "plot_options": dict(title="Cost Landscape"), + "trace_options_contour": dict(extend="both", cmap="viridis"), + "trace_options_initial": dict( + marker="X", + markersize=14, + markerfacecolor="w", + markeredgecolor="k", + label="Initial values", + linestyle="None", + ), + "trace_options_optim": dict( + marker="P", + markersize=14, + markerfacecolor="k", + markeredgecolor="w", + label="Final values", + linestyle="None", + ), + }, + "nyquist": { + "plot_options": dict( + xaxis_title=r"$Z_{re} / \Omega$", yaxis_title=r"$-Z_{im} / \Omega$" + ), + "trace_options_model": dict( + label="Model", color="#00CC96", linewidth=2, marker=".", markersize=8 + ), + "trace_options_reference": dict( + label="Reference", + linestyle="none", + marker="o", + fillstyle="none", + markersize=8, + markeredgecolor="#636EFA", + ), + }, "parameters": dict(figsize=(18, 8), title="Parameter Convergence"), "problem": { "default_trace_options": dict(label="Model", marker=None, linestyle="-"), diff --git a/pybop/plot/nyquist.py b/pybop/plot/nyquist.py index bcd609fd1..5d1452f2e 100644 --- a/pybop/plot/nyquist.py +++ b/pybop/plot/nyquist.py @@ -1,9 +1,11 @@ from pybop.parameters.parameter import Inputs from pybop.plot.standard_plots import StandardPlot -from pybop.plot.util import call_plotting_function +from pybop.plot.util import get_default_options, update_and_show -def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kwargs): +def nyquist( + problem, inputs: Inputs = None, show=True, title="Nyquist Plot", backend=None +): """ Generates Nyquist plots for the given problem by evaluating the model's output and target values. @@ -40,14 +42,11 @@ def nyquist(problem, inputs: Inputs = None, show=True, backend=None, **layout_kw >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. """ - return call_plotting_function( - "nyquist", backend, problem=problem, inputs=inputs, show=show, **layout_kwargs - ) - -def _nyquist( - problem, trace_options_model: dict, trace_options_reference, inputs: Inputs = None -): + options = get_default_options("nyquist", backend) + plot_options = options.get("plot_options") or {} + trace_options_model = options.get("trace_options_model") or {} + trace_options_reference = options.get("trace_options_reference") or {} if not isinstance(inputs, dict): inputs = problem.parameters.to_dict(inputs) @@ -61,6 +60,8 @@ def _nyquist( x=domain_data, y=-model_output[var].data.imag, trace_names="Model", + title=title, + **plot_options, ) plot_dict.traces[0].update(trace_options_model) @@ -75,4 +76,7 @@ def _nyquist( fig = plot_dict(show=False) figure_list.append(fig) + if show: + update_and_show(figure_list) + return figure_list diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py index ad894f9bb..4de178480 100644 --- a/pybop/plot/plotly/__init__.py +++ b/pybop/plot/plotly/__init__.py @@ -1,8 +1,5 @@ from .plotly_manager import PlotlyManager -from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories -from .contour import contour -from .convergence import convergence -from .nyquist import nyquist +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path from .voronoi import surface from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/plotly/contour.py b/pybop/plot/plotly/contour.py deleted file mode 100644 index 4fb1a1156..000000000 --- a/pybop/plot/plotly/contour.py +++ /dev/null @@ -1,263 +0,0 @@ -import warnings -from collections.abc import Callable -from typing import TYPE_CHECKING - -import numpy as np - -from pybop.plot.plotly.plotly_manager import PlotlyManager -from pybop.problems.problem import Problem - -if TYPE_CHECKING: - from pybop._result import Result - - -def contour( - call_object: "Problem | Result", - gradient: bool = False, - bounds: np.ndarray | None = None, - transformed: bool = False, - steps: int = 10, - show: bool = True, - **layout_kwargs, -): - """ - Plot a 2D visualisation of a cost landscape using Plotly. - - This function generates a contour plot representing the cost landscape for a provided - callable cost function over a grid of parameter values within the specified bounds. - - Parameters - ---------- - call_object : pybop.Problem | pybop.Result - Either: - - the cost function to be evaluated. Must accept a list of parameter values and return a cost value. - - an optimiser result which provides a specific optimisation trace overlaid on the cost landscape. - gradient : bool, optional - If True, the gradient is shown (default: False). - bounds : numpy.ndarray | list[list[float]], optional - A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses - `parameters.get_bounds_for_plotly`. - transformed : bool, optional - Uses the transformed parameter values (as seen by the optimiser) for plotting. - steps : int, optional - The number of grid points to divide the parameter space into along each dimension (default: 10). - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time [s]"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - ------- - plotly.graph_objs.Figure - The Plotly figure object containing the cost landscape plot. - - Raises - ------ - ValueError - If the cost function does not return a valid cost when called with a parameter list. - """ - plot_optim = False - problem = call_object - - # Assign input as a cost or optimisation result - if not isinstance(call_object, Callable): - plot_optim = True - result = call_object - problem = result.problem - - parameters = problem.parameters - names = parameters.names - additional_values = [] - - if len(parameters) < 2: - raise ValueError("This cost function takes fewer than 2 parameters.") - - if len(parameters) > 2: - warnings.warn( - "This cost function requires more than 2 parameters. " - "Plotting in 2d with fixed values for the additional parameters.", - UserWarning, - stacklevel=2, - ) - for ( - i, - (name, param), - ) in enumerate(parameters.items()): - if i > 1: - # TODO: Update from the initial to the intended value - additional_values.append(param.get_initial_value()) - print(f"Fixed {name}:", param.get_initial_value()) - - # Set up parameter bounds - if bounds is None: - bounds = parameters.get_bounds_for_plotly() - else: - bounds = np.asarray(bounds) - - # Generate grid - x = np.linspace(bounds[0, 0], bounds[0, 1], steps) - y = np.linspace(bounds[1, 0], bounds[1, 1], steps) - - # Initialize cost matrix - costs = np.zeros((len(y), len(x))) - - if gradient: - grad_parameter_costs = [] - - # Create an array to hold the gradient with respect to each parameter - grads = [np.zeros((len(y), len(x))) for _ in range(len(parameters))] - - # Populate cost matrix - for i, xi in enumerate(x): - for j, yj in enumerate(y): - if gradient: - out = problem.evaluate( - np.asarray([xi, yj] + additional_values), - calculate_sensitivities=True, - ).get_values() - costs[j, i], sensitivities = out[0][0], out[1] - for k, key in enumerate(problem.parameters.names): - grads[k][j, i] = sensitivities[key].item() - else: - costs[j, i] = problem.evaluate( - np.asarray([xi, yj] + additional_values), - ).get_values()[0] - - # Append the arrays to the grad_parameter_costs list - if gradient: - grad_parameter_costs.extend(grads) - - # Apply any transformation if requested - def transform_array_of_values(list_of_values, parameter): - """Apply transformation if requested.""" - if transformed: - return np.asarray( - [parameter.transformation.to_search(value) for value in list_of_values] - ).flatten() - return list_of_values - - x = transform_array_of_values(x, parameters[names[0]]) - y = transform_array_of_values(y, parameters[names[1]]) - bounds[0] = transform_array_of_values(bounds[0], parameters[names[0]]) - bounds[1] = transform_array_of_values(bounds[1], parameters[names[1]]) - - # Import plotly only when needed - go = PlotlyManager().go - - # Set default layout properties - layout_options = dict( - title="Cost Landscape", - title_x=0.5, - title_y=0.905, - width=600, - height=600, - xaxis=dict(range=bounds[0], showexponent="last", exponentformat="e"), - yaxis=dict(range=bounds[1], showexponent="last", exponentformat="e"), - legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), - ) - layout_options["xaxis_title"] = ( - "Transformed " + names[0] if transformed else names[0] - ) - layout_options["yaxis_title"] = ( - "Transformed " + names[1] if transformed else names[1] - ) - layout = go.Layout(layout_options) - - # Create contour plot and update the layout - fig = go.Figure( - data=[go.Contour(x=x, y=y, z=costs, colorscale="Viridis", connectgaps=True)], - layout=layout, - ) - - if plot_optim: - # Plot the optimisation trace - optim_trace = np.asarray([item[:2] for item in result.x_model]) - optim_trace = optim_trace.reshape(-1, 2) - - fig.add_trace( - go.Scatter( - x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), - y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), - mode="markers", - marker=dict( - color=[i / len(optim_trace) for i in range(len(optim_trace))], - colorscale="Greys", - size=8, - showscale=False, - ), - showlegend=False, - ) - ) - - # Plot the initial guess - if len(result.x_model) > 0: - x0 = result.x_model[0] - fig.add_trace( - go.Scatter( - x=transform_array_of_values([x0[0]], parameters[names[0]]), - y=transform_array_of_values([x0[1]], parameters[names[1]]), - mode="markers", - marker_symbol="x", - marker=dict( - color="white", - line_color="black", - line_width=1, - size=14, - showscale=False, - ), - name="Initial values", - ) - ) - - # Plot optimised value - if result.x is not None: - x_best = result.x - fig.add_trace( - go.Scatter( - x=transform_array_of_values([x_best[0]], parameters[names[0]]), - y=transform_array_of_values([x_best[1]], parameters[names[1]]), - mode="markers", - marker_symbol="cross", - marker=dict( - color="black", - line_color="white", - line_width=1, - size=14, - showscale=False, - ), - name="Final values", - ) - ) - - # Update the layout and display the figure - fig.update_layout(**layout_kwargs) - if show: - fig.show() - - if gradient: - grad_figs = [] - for i, grad_costs in enumerate(grad_parameter_costs): - # Update title for gradient plots - updated_layout_options = layout_options.copy() - updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}" - - # Create contour plot with updated layout options - grad_layout = go.Layout(updated_layout_options) - - # Create fig - grad_fig = go.Figure( - data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout - ) - grad_fig.update_layout(**layout_kwargs) - - if show: - grad_fig.show() - - # append grad_fig to list - grad_figs.append(grad_fig) - - return fig, grad_figs - - return fig diff --git a/pybop/plot/plotly/nyquist.py b/pybop/plot/plotly/nyquist.py deleted file mode 100644 index 1578d2462..000000000 --- a/pybop/plot/plotly/nyquist.py +++ /dev/null @@ -1,101 +0,0 @@ -from pybop.parameters.parameter import Inputs -from pybop.plot.nyquist import _nyquist - - -def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): - """ - Generates Nyquist plots for the given problem by evaluating the model's output and target values. - - Parameters - ---------- - problem : pybop.Problem - An instance of a problem class that contains the parameters and methods - for evaluation and target retrieval. - inputs : Inputs, optional - Input parameters for the problem. If not provided, the default parameters from the problem - instance will be used. These parameters are verified before use (default is None). - show : bool, optional - If True, the plots will be displayed. - **layout_kwargs : dict, optional - Additional keyword arguments for customising the plot layout. These arguments are passed to - `fig.update_layout()`. - - Returns - ------- - list - A list of plotly `Figure` objects, each representing a Nyquist plot for the model's output and target values. - - Notes - ----- - - The function extracts the real part of the impedance from the model's output and the real and imaginary parts - of the impedance from the target output. - - For each signal in the problem, a Nyquist plot is created with the model's impedance plotted as a scatter plot. - - An additional trace for the reference (target output) is added to the plot. - - The plot layout can be customised using `layout_kwargs`. - - Example - ------- - >>> problem = pybop.EISProblem() - >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") - >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. - """ - default_layout_options = dict( - title="Nyquist Plot", - font=dict(family="Arial", size=14), - plot_bgcolor="white", - paper_bgcolor="white", - xaxis=dict( - title=dict(text="Zre / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - ), - yaxis=dict( - title=dict(text="-Zim / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - scaleanchor="x", - scaleratio=1, - ), - legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), - width=600, - height=600, - ) - - # Overwrite with user-kwargs - default_layout_options.update(layout_kwargs) - - trace_options_model = dict( - mode="lines+markers", - line=dict(color="#00CC96", width=2), - marker=dict(size=8, color="#00CC96", symbol="circle"), - ) - - trace_options_reference = dict( - name="Reference", - mode="markers", - marker=dict(size=8, color="#636EFA", symbol="circle-open"), - showlegend=True, - ) - - figure_list = _nyquist( - problem, trace_options_model, trace_options_reference, inputs=inputs - ) - - for fig in figure_list: - fig.update_layout(**default_layout_options) - if show: - fig.show() - - return figure_list diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py index 537a24286..80429f0c4 100644 --- a/pybop/plot/plotly/standard_plots.py +++ b/pybop/plot/plotly/standard_plots.py @@ -63,6 +63,8 @@ def __init__( title=None, xaxis_title=None, yaxis_title=None, + xaxis_range=None, + yaxis_range=None, layout=None, layout_options=None, trace_options=None, @@ -77,6 +79,16 @@ def __init__( self.layout_options = DEFAULT_LAYOUT_OPTIONS.copy() if layout_options: self.layout_options.update(layout_options) + if title is not None: + self.layout_options.update({"title": title}) + if xaxis_title is not None: + self.layout_options.update({"xaxis_title": xaxis_title}) + if yaxis_title is not None: + self.layout_options.update({"yaxis_title": yaxis_title}) + if xaxis_range is not None: + self.layout_options["xaxis"].update({"range": xaxis_range}) + if yaxis_range is not None: + self.layout_options["yaxis"].update({"range": yaxis_range}) # Set default trace options and update if provided self.trace_options = DEFAULT_TRACE_OPTIONS.copy() @@ -90,16 +102,6 @@ def __init__( if self.layout is None: self.layout = self.go.Layout(**self.layout_options) - title_options = {} - if title is not None: - title_options.update({"title": title}) - if title is not None: - title_options.update({"xaxis_title": xaxis_title}) - if title is not None: - title_options.update({"yaxis_title": yaxis_title}) - - self.layout.update(**title_options) - def __call__(self, show=True): """ Generate and show the figure. @@ -184,6 +186,9 @@ def create_histogram(self, x, name, **trace_options): def create_vline(self, fig, x, **trace_options): fig.add_vline(x=x, **trace_options) + def create_contour(self, x, y, z, **trace_options): + self.traces.append(self.go.Contour(x=x, y=y, z=z, **trace_options)) + class SubplotPlotter(Plotter): """ @@ -223,9 +228,10 @@ def __init__( layout_options=DEFAULT_LAYOUT_OPTIONS, subplot_options=DEFAULT_SUBPLOT_OPTIONS, trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS, - **kwargs, ): - super().__init__(layout, layout_options, trace_options, **kwargs) + super().__init__( + layout, layout_options=layout_options, trace_options=trace_options + ) self.subplot_options = subplot_options.copy() if subplot_options is not None: for arg, value in subplot_options.items(): @@ -329,3 +335,20 @@ def show_table(header, values, title): fig.update_layout(title=title) fig.show() + + +def plot_optimisation_path(plot_dict: StandardPlot, x, y): + plot_dict.traces.append( + plot_dict.create_trace( + x, + y, + mode="markers", + marker=dict( + color=[i / len(x) for i in range(len(x))], + colorscale="Greys", + size=8, + showscale=False, + ), + showlegend=False, + ) + ) diff --git a/pybop/plot/plotly/util.py b/pybop/plot/plotly/util.py index f622890fb..9a913b3aa 100644 --- a/pybop/plot/plotly/util.py +++ b/pybop/plot/plotly/util.py @@ -12,6 +12,102 @@ def update_and_show(fig, show=True, **layout_kwargs): DEFAULT_PLOT_OPTIONS = { + "contour": { + "plot_options": dict( + layout_options=dict( + title="Cost Landscape", + title_x=0.5, + title_y=0.905, + width=600, + height=600, + xaxis=dict(showexponent="last", exponentformat="e"), + yaxis=dict(showexponent="last", exponentformat="e"), + legend=dict( + orientation="h", yanchor="bottom", y=1, xanchor="right", x=1 + ), + autosize=None, + showlegend=None, + margin=None, + ) + ), + "trace_options_contour": dict(colorscale="Viridis", connectgaps=True), + "trace_options_initial": dict( + mode="markers", + marker_symbol="x", + marker=dict( + color="white", + line_color="black", + line_width=1, + size=14, + showscale=False, + ), + name="Initial values", + ), + "trace_options_optim": dict( + mode="markers", + marker_symbol="cross", + marker=dict( + color="black", + line_color="white", + line_width=1, + size=14, + showscale=False, + ), + name="Final values", + ), + }, + "nyquist": { + "plot_options": dict( + layout_options=dict( + title="Nyquist Plot", + font=dict(family="Arial", size=14), + plot_bgcolor="white", + paper_bgcolor="white", + xaxis=dict( + title=dict(font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + ), + yaxis=dict( + title=dict(font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + scaleanchor="x", + scaleratio=1, + ), + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + width=600, + height=600, + ), + xaxis_title="Zre / Ω", + yaxis_title="-Zim / Ω", + ), + "trace_options_model": dict( + mode="lines+markers", + line=dict(color="#00CC96", width=2), + marker=dict(size=8, color="#00CC96", symbol="circle"), + ), + "trace_options_reference": dict( + name="Reference", + mode="markers", + marker=dict(size=8, color="#636EFA", symbol="circle-open"), + showlegend=True, + ), + }, "parameters": dict( layout_options=dict( title="Parameter Convergence", @@ -22,6 +118,20 @@ def update_and_show(fig, show=True, **layout_kwargs): ), ) ), + "posterior": { + "plot_options": { + "layout_options": dict( + barmode="overlay", + width=None, + height=None, + plot_bgcolor=None, + autosize=None, + legend=None, + ) + }, + "trace_options": dict(opacity=0.75), + "trace_options_vline": dict(line_width=3, line_dash="dash", line_color="black"), + }, "problem": { "default_trace_options": dict(name="Model", mode="lines", showlegend=True), "design_cost_options": dict(name="Optimised"), @@ -37,18 +147,4 @@ def update_and_show(fig, show=True, **layout_kwargs): }, "trace_options": dict(mode="lines"), }, - "posterior": { - "plot_options": { - "layout_options": dict( - barmode="overlay", - width=None, - height=None, - plot_bgcolor=None, - autosize=None, - legend=None, - ) - }, - "trace_options": dict(opacity=0.75), - "trace_options_vline": dict(line_width=3, line_dash="dash", line_color="black"), - }, } diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py index 99dc699b7..873af49fe 100644 --- a/pybop/plot/plots.py +++ b/pybop/plot/plots.py @@ -70,31 +70,6 @@ def contour( ) -def convergence(result: "Result", show=True, backend=None, **layout_kwargs): - """ - Plot the convergence of the optimisation algorithm. - - Parameters - ----------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time [s]"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - - Returns - --------- - fig : plotly.graph_objs.Figure - The Plotly figure object for the convergence plot. - """ - return call_plotting_function( - "convergence", backend, result=result, show=show, **layout_kwargs - ) - - def surface( result: "Result", bounds=None, diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index b1a6f9e33..3287789c5 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -115,6 +115,9 @@ def create_histogram(self, x, name, **trace_options): def create_vline(self, fig, x, **trace_options): return self.plotter.create_vline(fig, x, **trace_options) + def create_contour(self, x, y, z, **trace_options): + return self.plotter.create_contour(x, y, z, **trace_options) + @staticmethod def wrap_text(text, width, backend="matplotlib"): """ From 5c12c725c4edc61289d1d835cb803fe5c0cb63a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 16:46:54 +0000 Subject: [PATCH 10/10] style: pre-commit fixes --- pybop/plot/contour.py | 44 ++++++++++++++++++++++----------------- pybop/plot/convergence.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/pybop/plot/contour.py b/pybop/plot/contour.py index c6ff92786..cdd5b8b59 100644 --- a/pybop/plot/contour.py +++ b/pybop/plot/contour.py @@ -150,15 +150,15 @@ def transform_array_of_values(list_of_values, parameter): trace_options_initial = options.get("trace_options_initial") or {} trace_options_optim = options.get("trace_options_optim") or {} trace_options_contour = options.get("trace_options_contour") or {} - + plot_dict = StandardPlot( - xaxis_title="Transformed " + names[0] if transformed else names[0], - yaxis_title="Transformed " + names[1] if transformed else names[1], - xaxis_range=bounds[0], - yaxis_range=bounds[1], - backend=backend, - **plot_options - ) + xaxis_title="Transformed " + names[0] if transformed else names[0], + yaxis_title="Transformed " + names[1] if transformed else names[1], + xaxis_range=bounds[0], + yaxis_range=bounds[1], + backend=backend, + **plot_options, + ) # Create contour plot and update the layout plot_dict.create_contour(x=x, y=y, z=costs, **trace_options_contour) @@ -167,7 +167,9 @@ def transform_array_of_values(list_of_values, parameter): # Plot the optimisation trace optim_trace = np.asarray([item[:2] for item in result.x_model]) optim_trace = optim_trace.reshape(-1, 2) - call_plotting_function('plot_optimisation_path', backend=backend, + call_plotting_function( + "plot_optimisation_path", + backend=backend, plot_dict=plot_dict, x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), @@ -176,20 +178,24 @@ def transform_array_of_values(list_of_values, parameter): # Plot the initial guess if len(result.x_model) > 0: x0 = result.x_model[0] - plot_dict.traces.append(plot_dict.create_trace( - x=transform_array_of_values([x0[0]], parameters[names[0]]), - y=transform_array_of_values([x0[1]], parameters[names[1]]), - **trace_options_initial - )) + plot_dict.traces.append( + plot_dict.create_trace( + x=transform_array_of_values([x0[0]], parameters[names[0]]), + y=transform_array_of_values([x0[1]], parameters[names[1]]), + **trace_options_initial, + ) + ) # Plot optimised value if result.x is not None: x_best = result.x - plot_dict.traces.append(plot_dict.create_trace( - x=transform_array_of_values([x_best[0]], parameters[names[0]]), - y=transform_array_of_values([x_best[1]], parameters[names[1]]), - **trace_options_optim - )) + plot_dict.traces.append( + plot_dict.create_trace( + x=transform_array_of_values([x_best[0]], parameters[names[0]]), + y=transform_array_of_values([x_best[1]], parameters[names[1]]), + **trace_options_optim, + ) + ) # Update the layout and display the figure fig = plot_dict(show=False) diff --git a/pybop/plot/convergence.py b/pybop/plot/convergence.py index 9e3b1bbb7..5d3390959 100644 --- a/pybop/plot/convergence.py +++ b/pybop/plot/convergence.py @@ -42,7 +42,7 @@ def convergence(result: "Result", show=True, backend=None): yaxis_title="Cost", title="Convergence", trace_names=result.method_name, - backend=backend + backend=backend, ) # Generate and display the figure