diff --git a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb
index 1eb2a59c5..069d68660 100644
--- a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb
+++ b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb
@@ -32,8 +32,9 @@
"\n",
"import pybop\n",
"\n",
- "go = pybop.plot.PlotlyManager().go\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "go = pybop.plot.plotly.PlotlyManager().go\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb
index dce5640a7..23433be03 100644
--- a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb
+++ b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb
@@ -46,7 +46,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"matplotlib\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
@@ -1103,7 +1104,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv",
+ "display_name": "env-py-3-13",
"language": "python",
"name": "python3"
},
@@ -1117,7 +1118,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.3"
+ "version": "3.13.11"
}
},
"nbformat": 4,
diff --git a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb
index 65f4623b6..8140c73c1 100644
--- a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb
+++ b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb
@@ -48,7 +48,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb
index ce327fe6c..70915b6e9 100644
--- a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb
+++ b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb
@@ -31,7 +31,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb
index adc764f25..94b981c61 100644
--- a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb
+++ b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb
@@ -29,7 +29,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb
index 7cd0c65b0..d722d52af 100644
--- a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb
+++ b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb
@@ -32,8 +32,9 @@
"\n",
"import pybop\n",
"\n",
- "go = pybop.plot.PlotlyManager().go\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "go = pybop.plot.plotly.PlotlyManager().go\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb
index ce32e337b..1a73079e4 100644
--- a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb
+++ b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb
@@ -30,8 +30,9 @@
"\n",
"import pybop\n",
"\n",
- "go = pybop.plot.PlotlyManager().go\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "go = pybop.plot.plotly.PlotlyManager().go\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb
index 3cab87a0d..d5a7ecd4f 100644
--- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb
+++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb
@@ -29,6 +29,7 @@
"\n",
"import pybop\n",
"\n",
+ "pybop.plot.set_backend(\"matplotlib\")\n",
"np.random.seed(8) # users can remove this line"
]
},
@@ -210,708 +211,708 @@
"showlegend": false,
"type": "scatter",
"x": [
- 0.0,
- 10.0,
- 20.0,
- 30.0,
- 40.0,
- 50.0,
- 60.0,
- 70.0,
- 80.0,
- 90.0,
- 100.0,
- 110.0,
- 120.0,
- 130.0,
- 140.0,
- 150.0,
- 160.0,
- 170.0,
- 180.0,
- 190.0,
- 200.0,
- 210.0,
- 220.0,
- 230.0,
- 240.0,
- 250.0,
- 260.0,
- 270.0,
- 280.0,
- 290.0,
- 300.0,
- 310.0,
- 320.0,
- 330.0,
- 340.0,
- 350.0,
- 360.0,
- 370.0,
- 380.0,
- 390.0,
- 400.0,
- 410.0,
- 420.0,
- 430.0,
- 440.0,
- 450.0,
- 460.0,
- 470.0,
- 480.0,
- 490.0,
- 500.0,
- 510.0,
- 520.0,
- 530.0,
- 540.0,
- 550.0,
- 560.0,
- 570.0,
- 580.0,
- 590.0,
- 600.0,
- 610.0,
- 620.0,
- 630.0,
- 640.0,
- 650.0,
- 660.0,
- 670.0,
- 680.0,
- 690.0,
- 700.0,
- 710.0,
- 720.0,
- 730.0,
- 740.0,
- 750.0,
- 760.0,
- 770.0,
- 780.0,
- 790.0,
- 800.0,
- 810.0,
- 820.0,
- 830.0,
- 840.0,
- 850.0,
- 860.0,
- 870.0,
- 880.0,
- 890.0,
- 900.0,
- 910.0,
- 920.0,
- 930.0,
- 940.0,
- 950.0,
- 960.0,
- 970.0,
- 980.0,
- 990.0,
- 1000.0,
- 1010.0,
- 1020.0,
- 1030.0,
- 1040.0,
- 1050.0,
- 1060.0,
- 1070.0,
- 1080.0,
- 1090.0,
- 1100.0,
- 1110.0,
- 1120.0,
- 1130.0,
- 1140.0,
- 1150.0,
- 1160.0,
- 1170.0,
- 1180.0,
- 1190.0,
- 1200.0,
- 1210.0,
- 1220.0,
- 1230.0,
- 1240.0,
- 1250.0,
- 1260.0,
- 1270.0,
- 1280.0,
- 1290.0,
- 1300.0,
- 1310.0,
- 1320.0,
- 1330.0,
- 1340.0,
- 1350.0,
- 1360.0,
- 1370.0,
- 1380.0,
- 1390.0,
- 1400.0,
- 1410.0,
- 1420.0,
- 1430.0,
- 1440.0,
- 1450.0,
- 1460.0,
- 1470.0,
- 1480.0,
- 1490.0,
- 1500.0,
- 1510.0,
- 1520.0,
- 1530.0,
- 1540.0,
- 1550.0,
- 1560.0,
- 1570.0,
- 1580.0,
- 1590.0,
- 1600.0,
- 1610.0,
- 1620.0,
- 1630.0,
- 1640.0,
- 1650.0,
- 1660.0,
- 1670.0,
- 1680.0,
- 1690.0,
- 1700.0,
- 1710.0,
- 1720.0,
- 1730.0,
- 1740.0,
- 1750.0,
- 1760.0,
- 1770.0,
- 1780.0,
- 1790.0,
- 1800.0,
- 1810.0,
- 1820.0,
- 1830.0,
- 1840.0,
- 1850.0,
- 1860.0,
- 1870.0,
- 1880.0,
- 1890.0,
- 1900.0,
- 1910.0,
- 1920.0,
- 1930.0,
- 1940.0,
- 1950.0,
- 1960.0,
- 1970.0,
- 1980.0,
- 1990.0,
- 2000.0,
- 2010.0,
- 2020.0,
- 2030.0,
- 2040.0,
- 2050.0,
- 2060.0,
- 2070.0,
- 2080.0,
- 2090.0,
- 2100.0,
- 2110.0,
- 2120.0,
- 2130.0,
- 2140.0,
- 2150.0,
- 2160.0,
- 2170.0,
- 2180.0,
- 2190.0,
- 2200.0,
- 2210.0,
- 2220.0,
- 2230.0,
- 2240.0,
- 2250.0,
- 2260.0,
- 2270.0,
- 2280.0,
- 2290.0,
- 2300.0,
- 2310.0,
- 2320.0,
- 2330.0,
- 2340.0,
- 2350.0,
- 2360.0,
- 2370.0,
- 2380.0,
- 2390.0,
- 2400.0,
- 2410.0,
- 2420.0,
- 2430.0,
- 2440.0,
- 2450.0,
- 2460.0,
- 2470.0,
- 2480.0,
- 2490.0,
- 2500.0,
- 2510.0,
- 2520.0,
- 2530.0,
- 2540.0,
- 2550.0,
- 2560.0,
- 2570.0,
- 2580.0,
- 2590.0,
- 2600.0,
- 2610.0,
- 2620.0,
- 2630.0,
- 2640.0,
- 2650.0,
- 2660.0,
- 2670.0,
- 2680.0,
- 2690.0,
- 2700.0,
- 2710.0,
- 2720.0,
- 2730.0,
- 2740.0,
- 2750.0,
- 2760.0,
- 2770.0,
- 2780.0,
- 2790.0,
- 2800.0,
- 2810.0,
- 2820.0,
- 2830.0,
- 2840.0,
- 2850.0,
- 2860.0,
- 2870.0,
- 2880.0,
- 2890.0,
- 2900.0,
- 2910.0,
- 2920.0,
- 2930.0,
- 2940.0,
- 2950.0,
- 2960.0,
- 2970.0,
- 2980.0,
- 2990.0,
- 3000.0,
- 3010.0,
- 3020.0,
- 3030.0,
- 3040.0,
- 3050.0,
- 3060.0,
- 3070.0,
- 3080.0,
- 3090.0,
- 3100.0,
- 3110.0,
- 3120.0,
- 3130.0,
- 3140.0,
- 3150.0,
- 3160.0,
- 3170.0,
- 3180.0,
- 3190.0,
- 3200.0,
- 3210.0,
- 3220.0,
- 3230.0,
- 3240.0,
- 3250.0,
- 3260.0,
- 3270.0,
- 3280.0,
- 3290.0,
- 3300.0,
- 3310.0,
- 3320.0,
- 3330.0,
- 3340.0,
- 3350.0,
- 3360.0,
- 3370.0,
- 3380.0,
- 3390.0,
- 3400.0,
- 3410.0,
- 3420.0,
- 3430.0,
- 3440.0,
- 3450.0,
- 3460.0,
- 3470.0,
- 3480.0,
- 3490.0,
- 3500.0,
- 3500.0,
- 3490.0,
- 3480.0,
- 3470.0,
- 3460.0,
- 3450.0,
- 3440.0,
- 3430.0,
- 3420.0,
- 3410.0,
- 3400.0,
- 3390.0,
- 3380.0,
- 3370.0,
- 3360.0,
- 3350.0,
- 3340.0,
- 3330.0,
- 3320.0,
- 3310.0,
- 3300.0,
- 3290.0,
- 3280.0,
- 3270.0,
- 3260.0,
- 3250.0,
- 3240.0,
- 3230.0,
- 3220.0,
- 3210.0,
- 3200.0,
- 3190.0,
- 3180.0,
- 3170.0,
- 3160.0,
- 3150.0,
- 3140.0,
- 3130.0,
- 3120.0,
- 3110.0,
- 3100.0,
- 3090.0,
- 3080.0,
- 3070.0,
- 3060.0,
- 3050.0,
- 3040.0,
- 3030.0,
- 3020.0,
- 3010.0,
- 3000.0,
- 2990.0,
- 2980.0,
- 2970.0,
- 2960.0,
- 2950.0,
- 2940.0,
- 2930.0,
- 2920.0,
- 2910.0,
- 2900.0,
- 2890.0,
- 2880.0,
- 2870.0,
- 2860.0,
- 2850.0,
- 2840.0,
- 2830.0,
- 2820.0,
- 2810.0,
- 2800.0,
- 2790.0,
- 2780.0,
- 2770.0,
- 2760.0,
- 2750.0,
- 2740.0,
- 2730.0,
- 2720.0,
- 2710.0,
- 2700.0,
- 2690.0,
- 2680.0,
- 2670.0,
- 2660.0,
- 2650.0,
- 2640.0,
- 2630.0,
- 2620.0,
- 2610.0,
- 2600.0,
- 2590.0,
- 2580.0,
- 2570.0,
- 2560.0,
- 2550.0,
- 2540.0,
- 2530.0,
- 2520.0,
- 2510.0,
- 2500.0,
- 2490.0,
- 2480.0,
- 2470.0,
- 2460.0,
- 2450.0,
- 2440.0,
- 2430.0,
- 2420.0,
- 2410.0,
- 2400.0,
- 2390.0,
- 2380.0,
- 2370.0,
- 2360.0,
- 2350.0,
- 2340.0,
- 2330.0,
- 2320.0,
- 2310.0,
- 2300.0,
- 2290.0,
- 2280.0,
- 2270.0,
- 2260.0,
- 2250.0,
- 2240.0,
- 2230.0,
- 2220.0,
- 2210.0,
- 2200.0,
- 2190.0,
- 2180.0,
- 2170.0,
- 2160.0,
- 2150.0,
- 2140.0,
- 2130.0,
- 2120.0,
- 2110.0,
- 2100.0,
- 2090.0,
- 2080.0,
- 2070.0,
- 2060.0,
- 2050.0,
- 2040.0,
- 2030.0,
- 2020.0,
- 2010.0,
- 2000.0,
- 1990.0,
- 1980.0,
- 1970.0,
- 1960.0,
- 1950.0,
- 1940.0,
- 1930.0,
- 1920.0,
- 1910.0,
- 1900.0,
- 1890.0,
- 1880.0,
- 1870.0,
- 1860.0,
- 1850.0,
- 1840.0,
- 1830.0,
- 1820.0,
- 1810.0,
- 1800.0,
- 1790.0,
- 1780.0,
- 1770.0,
- 1760.0,
- 1750.0,
- 1740.0,
- 1730.0,
- 1720.0,
- 1710.0,
- 1700.0,
- 1690.0,
- 1680.0,
- 1670.0,
- 1660.0,
- 1650.0,
- 1640.0,
- 1630.0,
- 1620.0,
- 1610.0,
- 1600.0,
- 1590.0,
- 1580.0,
- 1570.0,
- 1560.0,
- 1550.0,
- 1540.0,
- 1530.0,
- 1520.0,
- 1510.0,
- 1500.0,
- 1490.0,
- 1480.0,
- 1470.0,
- 1460.0,
- 1450.0,
- 1440.0,
- 1430.0,
- 1420.0,
- 1410.0,
- 1400.0,
- 1390.0,
- 1380.0,
- 1370.0,
- 1360.0,
- 1350.0,
- 1340.0,
- 1330.0,
- 1320.0,
- 1310.0,
- 1300.0,
- 1290.0,
- 1280.0,
- 1270.0,
- 1260.0,
- 1250.0,
- 1240.0,
- 1230.0,
- 1220.0,
- 1210.0,
- 1200.0,
- 1190.0,
- 1180.0,
- 1170.0,
- 1160.0,
- 1150.0,
- 1140.0,
- 1130.0,
- 1120.0,
- 1110.0,
- 1100.0,
- 1090.0,
- 1080.0,
- 1070.0,
- 1060.0,
- 1050.0,
- 1040.0,
- 1030.0,
- 1020.0,
- 1010.0,
- 1000.0,
- 990.0,
- 980.0,
- 970.0,
- 960.0,
- 950.0,
- 940.0,
- 930.0,
- 920.0,
- 910.0,
- 900.0,
- 890.0,
- 880.0,
- 870.0,
- 860.0,
- 850.0,
- 840.0,
- 830.0,
- 820.0,
- 810.0,
- 800.0,
- 790.0,
- 780.0,
- 770.0,
- 760.0,
- 750.0,
- 740.0,
- 730.0,
- 720.0,
- 710.0,
- 700.0,
- 690.0,
- 680.0,
- 670.0,
- 660.0,
- 650.0,
- 640.0,
- 630.0,
- 620.0,
- 610.0,
- 600.0,
- 590.0,
- 580.0,
- 570.0,
- 560.0,
- 550.0,
- 540.0,
- 530.0,
- 520.0,
- 510.0,
- 500.0,
- 490.0,
- 480.0,
- 470.0,
- 460.0,
- 450.0,
- 440.0,
- 430.0,
- 420.0,
- 410.0,
- 400.0,
- 390.0,
- 380.0,
- 370.0,
- 360.0,
- 350.0,
- 340.0,
- 330.0,
- 320.0,
- 310.0,
- 300.0,
- 290.0,
- 280.0,
- 270.0,
- 260.0,
- 250.0,
- 240.0,
- 230.0,
- 220.0,
- 210.0,
- 200.0,
- 190.0,
- 180.0,
- 170.0,
- 160.0,
- 150.0,
- 140.0,
- 130.0,
- 120.0,
- 110.0,
- 100.0,
- 90.0,
- 80.0,
- 70.0,
- 60.0,
- 50.0,
- 40.0,
- 30.0,
- 20.0,
- 10.0,
- 0.0
+ 0,
+ 10,
+ 20,
+ 30,
+ 40,
+ 50,
+ 60,
+ 70,
+ 80,
+ 90,
+ 100,
+ 110,
+ 120,
+ 130,
+ 140,
+ 150,
+ 160,
+ 170,
+ 180,
+ 190,
+ 200,
+ 210,
+ 220,
+ 230,
+ 240,
+ 250,
+ 260,
+ 270,
+ 280,
+ 290,
+ 300,
+ 310,
+ 320,
+ 330,
+ 340,
+ 350,
+ 360,
+ 370,
+ 380,
+ 390,
+ 400,
+ 410,
+ 420,
+ 430,
+ 440,
+ 450,
+ 460,
+ 470,
+ 480,
+ 490,
+ 500,
+ 510,
+ 520,
+ 530,
+ 540,
+ 550,
+ 560,
+ 570,
+ 580,
+ 590,
+ 600,
+ 610,
+ 620,
+ 630,
+ 640,
+ 650,
+ 660,
+ 670,
+ 680,
+ 690,
+ 700,
+ 710,
+ 720,
+ 730,
+ 740,
+ 750,
+ 760,
+ 770,
+ 780,
+ 790,
+ 800,
+ 810,
+ 820,
+ 830,
+ 840,
+ 850,
+ 860,
+ 870,
+ 880,
+ 890,
+ 900,
+ 910,
+ 920,
+ 930,
+ 940,
+ 950,
+ 960,
+ 970,
+ 980,
+ 990,
+ 1000,
+ 1010,
+ 1020,
+ 1030,
+ 1040,
+ 1050,
+ 1060,
+ 1070,
+ 1080,
+ 1090,
+ 1100,
+ 1110,
+ 1120,
+ 1130,
+ 1140,
+ 1150,
+ 1160,
+ 1170,
+ 1180,
+ 1190,
+ 1200,
+ 1210,
+ 1220,
+ 1230,
+ 1240,
+ 1250,
+ 1260,
+ 1270,
+ 1280,
+ 1290,
+ 1300,
+ 1310,
+ 1320,
+ 1330,
+ 1340,
+ 1350,
+ 1360,
+ 1370,
+ 1380,
+ 1390,
+ 1400,
+ 1410,
+ 1420,
+ 1430,
+ 1440,
+ 1450,
+ 1460,
+ 1470,
+ 1480,
+ 1490,
+ 1500,
+ 1510,
+ 1520,
+ 1530,
+ 1540,
+ 1550,
+ 1560,
+ 1570,
+ 1580,
+ 1590,
+ 1600,
+ 1610,
+ 1620,
+ 1630,
+ 1640,
+ 1650,
+ 1660,
+ 1670,
+ 1680,
+ 1690,
+ 1700,
+ 1710,
+ 1720,
+ 1730,
+ 1740,
+ 1750,
+ 1760,
+ 1770,
+ 1780,
+ 1790,
+ 1800,
+ 1810,
+ 1820,
+ 1830,
+ 1840,
+ 1850,
+ 1860,
+ 1870,
+ 1880,
+ 1890,
+ 1900,
+ 1910,
+ 1920,
+ 1930,
+ 1940,
+ 1950,
+ 1960,
+ 1970,
+ 1980,
+ 1990,
+ 2000,
+ 2010,
+ 2020,
+ 2030,
+ 2040,
+ 2050,
+ 2060,
+ 2070,
+ 2080,
+ 2090,
+ 2100,
+ 2110,
+ 2120,
+ 2130,
+ 2140,
+ 2150,
+ 2160,
+ 2170,
+ 2180,
+ 2190,
+ 2200,
+ 2210,
+ 2220,
+ 2230,
+ 2240,
+ 2250,
+ 2260,
+ 2270,
+ 2280,
+ 2290,
+ 2300,
+ 2310,
+ 2320,
+ 2330,
+ 2340,
+ 2350,
+ 2360,
+ 2370,
+ 2380,
+ 2390,
+ 2400,
+ 2410,
+ 2420,
+ 2430,
+ 2440,
+ 2450,
+ 2460,
+ 2470,
+ 2480,
+ 2490,
+ 2500,
+ 2510,
+ 2520,
+ 2530,
+ 2540,
+ 2550,
+ 2560,
+ 2570,
+ 2580,
+ 2590,
+ 2600,
+ 2610,
+ 2620,
+ 2630,
+ 2640,
+ 2650,
+ 2660,
+ 2670,
+ 2680,
+ 2690,
+ 2700,
+ 2710,
+ 2720,
+ 2730,
+ 2740,
+ 2750,
+ 2760,
+ 2770,
+ 2780,
+ 2790,
+ 2800,
+ 2810,
+ 2820,
+ 2830,
+ 2840,
+ 2850,
+ 2860,
+ 2870,
+ 2880,
+ 2890,
+ 2900,
+ 2910,
+ 2920,
+ 2930,
+ 2940,
+ 2950,
+ 2960,
+ 2970,
+ 2980,
+ 2990,
+ 3000,
+ 3010,
+ 3020,
+ 3030,
+ 3040,
+ 3050,
+ 3060,
+ 3070,
+ 3080,
+ 3090,
+ 3100,
+ 3110,
+ 3120,
+ 3130,
+ 3140,
+ 3150,
+ 3160,
+ 3170,
+ 3180,
+ 3190,
+ 3200,
+ 3210,
+ 3220,
+ 3230,
+ 3240,
+ 3250,
+ 3260,
+ 3270,
+ 3280,
+ 3290,
+ 3300,
+ 3310,
+ 3320,
+ 3330,
+ 3340,
+ 3350,
+ 3360,
+ 3370,
+ 3380,
+ 3390,
+ 3400,
+ 3410,
+ 3420,
+ 3430,
+ 3440,
+ 3450,
+ 3460,
+ 3470,
+ 3480,
+ 3490,
+ 3500,
+ 3500,
+ 3490,
+ 3480,
+ 3470,
+ 3460,
+ 3450,
+ 3440,
+ 3430,
+ 3420,
+ 3410,
+ 3400,
+ 3390,
+ 3380,
+ 3370,
+ 3360,
+ 3350,
+ 3340,
+ 3330,
+ 3320,
+ 3310,
+ 3300,
+ 3290,
+ 3280,
+ 3270,
+ 3260,
+ 3250,
+ 3240,
+ 3230,
+ 3220,
+ 3210,
+ 3200,
+ 3190,
+ 3180,
+ 3170,
+ 3160,
+ 3150,
+ 3140,
+ 3130,
+ 3120,
+ 3110,
+ 3100,
+ 3090,
+ 3080,
+ 3070,
+ 3060,
+ 3050,
+ 3040,
+ 3030,
+ 3020,
+ 3010,
+ 3000,
+ 2990,
+ 2980,
+ 2970,
+ 2960,
+ 2950,
+ 2940,
+ 2930,
+ 2920,
+ 2910,
+ 2900,
+ 2890,
+ 2880,
+ 2870,
+ 2860,
+ 2850,
+ 2840,
+ 2830,
+ 2820,
+ 2810,
+ 2800,
+ 2790,
+ 2780,
+ 2770,
+ 2760,
+ 2750,
+ 2740,
+ 2730,
+ 2720,
+ 2710,
+ 2700,
+ 2690,
+ 2680,
+ 2670,
+ 2660,
+ 2650,
+ 2640,
+ 2630,
+ 2620,
+ 2610,
+ 2600,
+ 2590,
+ 2580,
+ 2570,
+ 2560,
+ 2550,
+ 2540,
+ 2530,
+ 2520,
+ 2510,
+ 2500,
+ 2490,
+ 2480,
+ 2470,
+ 2460,
+ 2450,
+ 2440,
+ 2430,
+ 2420,
+ 2410,
+ 2400,
+ 2390,
+ 2380,
+ 2370,
+ 2360,
+ 2350,
+ 2340,
+ 2330,
+ 2320,
+ 2310,
+ 2300,
+ 2290,
+ 2280,
+ 2270,
+ 2260,
+ 2250,
+ 2240,
+ 2230,
+ 2220,
+ 2210,
+ 2200,
+ 2190,
+ 2180,
+ 2170,
+ 2160,
+ 2150,
+ 2140,
+ 2130,
+ 2120,
+ 2110,
+ 2100,
+ 2090,
+ 2080,
+ 2070,
+ 2060,
+ 2050,
+ 2040,
+ 2030,
+ 2020,
+ 2010,
+ 2000,
+ 1990,
+ 1980,
+ 1970,
+ 1960,
+ 1950,
+ 1940,
+ 1930,
+ 1920,
+ 1910,
+ 1900,
+ 1890,
+ 1880,
+ 1870,
+ 1860,
+ 1850,
+ 1840,
+ 1830,
+ 1820,
+ 1810,
+ 1800,
+ 1790,
+ 1780,
+ 1770,
+ 1760,
+ 1750,
+ 1740,
+ 1730,
+ 1720,
+ 1710,
+ 1700,
+ 1690,
+ 1680,
+ 1670,
+ 1660,
+ 1650,
+ 1640,
+ 1630,
+ 1620,
+ 1610,
+ 1600,
+ 1590,
+ 1580,
+ 1570,
+ 1560,
+ 1550,
+ 1540,
+ 1530,
+ 1520,
+ 1510,
+ 1500,
+ 1490,
+ 1480,
+ 1470,
+ 1460,
+ 1450,
+ 1440,
+ 1430,
+ 1420,
+ 1410,
+ 1400,
+ 1390,
+ 1380,
+ 1370,
+ 1360,
+ 1350,
+ 1340,
+ 1330,
+ 1320,
+ 1310,
+ 1300,
+ 1290,
+ 1280,
+ 1270,
+ 1260,
+ 1250,
+ 1240,
+ 1230,
+ 1220,
+ 1210,
+ 1200,
+ 1190,
+ 1180,
+ 1170,
+ 1160,
+ 1150,
+ 1140,
+ 1130,
+ 1120,
+ 1110,
+ 1100,
+ 1090,
+ 1080,
+ 1070,
+ 1060,
+ 1050,
+ 1040,
+ 1030,
+ 1020,
+ 1010,
+ 1000,
+ 990,
+ 980,
+ 970,
+ 960,
+ 950,
+ 940,
+ 930,
+ 920,
+ 910,
+ 900,
+ 890,
+ 880,
+ 870,
+ 860,
+ 850,
+ 840,
+ 830,
+ 820,
+ 810,
+ 800,
+ 790,
+ 780,
+ 770,
+ 760,
+ 750,
+ 740,
+ 730,
+ 720,
+ 710,
+ 700,
+ 690,
+ 680,
+ 670,
+ 660,
+ 650,
+ 640,
+ 630,
+ 620,
+ 610,
+ 600,
+ 590,
+ 580,
+ 570,
+ 560,
+ 550,
+ 540,
+ 530,
+ 520,
+ 510,
+ 500,
+ 490,
+ 480,
+ 470,
+ 460,
+ 450,
+ 440,
+ 430,
+ 420,
+ 410,
+ 400,
+ 390,
+ 380,
+ 370,
+ 360,
+ 350,
+ 340,
+ 330,
+ 320,
+ 310,
+ 300,
+ 290,
+ 280,
+ 270,
+ 260,
+ 250,
+ 240,
+ 230,
+ 220,
+ 210,
+ 200,
+ 190,
+ 180,
+ 170,
+ 160,
+ 150,
+ 140,
+ 130,
+ 120,
+ 110,
+ 100,
+ 90,
+ 80,
+ 70,
+ 60,
+ 50,
+ 40,
+ 30,
+ 20,
+ 10,
+ 0
],
"y": [
4.0531258862154065,
@@ -1744,7 +1745,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -1780,7 +1781,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -1804,7 +1805,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -1840,7 +1841,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -1867,7 +1868,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -1903,7 +1904,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -1918,7 +1919,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -1954,7 +1955,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -2110,7 +2111,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -2146,7 +2147,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -2237,7 +2238,7 @@
],
"sequential": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -2273,13 +2274,13 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
"sequentialminus": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -2315,7 +2316,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
]
@@ -2503,7 +2504,7 @@
{
"colorscale": [
[
- 0.0,
+ 0,
"#440154"
],
[
@@ -2539,7 +2540,7 @@
"#b5de2b"
],
[
- 1.0,
+ 1,
"#fde725"
]
],
@@ -15963,7 +15964,7 @@
"hoverinfo": "text",
"marker": {
"color": [
- 0.0,
+ 0,
0.001968503937007874,
0.003937007874015748,
0.005905511811023622,
@@ -16474,7 +16475,7 @@
],
"colorscale": [
[
- 0.0,
+ 0,
"rgb(255,255,255)"
],
[
@@ -16506,7 +16507,7 @@
"rgb(37,37,37)"
],
[
- 1.0,
+ 1,
"rgb(0,0,0)"
]
],
@@ -18165,7 +18166,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18201,7 +18202,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -18225,7 +18226,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18261,7 +18262,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -18288,7 +18289,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18324,7 +18325,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -18339,7 +18340,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18375,7 +18376,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -18531,7 +18532,7 @@
},
"colorscale": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18567,7 +18568,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
@@ -18658,7 +18659,7 @@
],
"sequential": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18694,13 +18695,13 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
],
"sequentialminus": [
[
- 0.0,
+ 0,
"#0d0887"
],
[
@@ -18736,7 +18737,7 @@
"#fdca26"
],
[
- 1.0,
+ 1,
"#f0f921"
]
]
diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb
index e7a1a308c..4324ce855 100644
--- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb
+++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb
@@ -58,7 +58,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"matplotlib\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb
index 71dc80155..b6a4251bd 100644
--- a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb
+++ b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb
@@ -26,8 +26,8 @@
"\n",
"import pybop\n",
"\n",
- "go = pybop.plot.PlotlyManager().go\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "go = pybop.plot.plotly.PlotlyManager().go\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb
index 9ee9e69d5..c23fdf369 100644
--- a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb
+++ b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb
@@ -32,7 +32,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb
index 6250b85e9..e4541c8d8 100644
--- a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb
+++ b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb
@@ -35,7 +35,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/getting_started/cost_compute_methods.ipynb b/examples/notebooks/getting_started/cost_compute_methods.ipynb
index 3a424f38c..909bcddd5 100644
--- a/examples/notebooks/getting_started/cost_compute_methods.ipynb
+++ b/examples/notebooks/getting_started/cost_compute_methods.ipynb
@@ -28,7 +28,7 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb
index ccd95b2ea..7da706f7f 100644
--- a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb
+++ b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb
@@ -51,7 +51,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/getting_started/optimising_with_adamw.ipynb b/examples/notebooks/getting_started/optimising_with_adamw.ipynb
index d6fc110af..ab708410e 100644
--- a/examples/notebooks/getting_started/optimising_with_adamw.ipynb
+++ b/examples/notebooks/getting_started/optimising_with_adamw.ipynb
@@ -34,7 +34,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/getting_started/setting_optimiser_options.ipynb b/examples/notebooks/getting_started/setting_optimiser_options.ipynb
index a790b26b2..65760ef97 100644
--- a/examples/notebooks/getting_started/setting_optimiser_options.ipynb
+++ b/examples/notebooks/getting_started/setting_optimiser_options.ipynb
@@ -32,7 +32,7 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/notebooks/getting_started/using_transformations.ipynb b/examples/notebooks/getting_started/using_transformations.ipynb
index 236dc8bba..5d5d9a47f 100644
--- a/examples/notebooks/getting_started/using_transformations.ipynb
+++ b/examples/notebooks/getting_started/using_transformations.ipynb
@@ -29,7 +29,8 @@
"\n",
"import pybop\n",
"\n",
- "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
+ "pybop.plot.set_backend(\"plotly\")\n",
+ "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n",
"\n",
"np.random.seed(8) # users can remove this line"
]
diff --git a/examples/scripts/comparison_examples/grouped_SPMe.py b/examples/scripts/comparison_examples/grouped_SPMe.py
index de4c897ab..cb1d060e2 100644
--- a/examples/scripts/comparison_examples/grouped_SPMe.py
+++ b/examples/scripts/comparison_examples/grouped_SPMe.py
@@ -11,11 +11,10 @@
"""
# Prepare figure
-layout_options = dict(
- xaxis_title="Time / s",
- yaxis_title="Voltage / V",
-)
-plot_dict = pybop.plot.StandardPlot(layout_options=layout_options)
+pybop.plot.set_backend("matplotlib")
+plot_dict = pybop.plot.StandardPlot()
+plt.xlabel("Time / s")
+plt.ylabel("Voltage / V")
# Use the Chen2020 parameters
parameter_values = pybamm.ParameterValues("Chen2020")
@@ -46,18 +45,18 @@
)
SPMe_model = pybamm.lithium_ion.SPMe(options=model_options)
grouped_SPMe_model = pybop.lithium_ion.GroupedSPMe(options=model_options)
-for model, param, line_style in zip(
+for model, param, linestyle in zip(
[SPMe_model, grouped_SPMe_model],
[parameter_values, grouped_parameter_values],
- ["solid", "dash"],
+ ["-", "--"],
strict=False,
):
solution = pybamm.Simulation(
model, parameter_values=param, experiment=experiment, cache_esoh=False
).solve(initial_soc=init_soc)
dataset = pybop.import_pybamm_solution(solution)
- plot_dict.add_traces(
- dataset["Time [s]"], dataset["Voltage [V]"], line_dash=line_style
+ plot_dict.create_trace(
+ dataset["Time [s]"], dataset["Voltage [V]"], label=None, linestyle=linestyle
)
plot_dict()
diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py
index f58db2016..da5e84d44 100644
--- a/pybop/plot/__init__.py
+++ b/pybop/plot/__init__.py
@@ -1,13 +1,24 @@
+# Plotting backend default
+DEFAULT_BACKEND = 'matplotlib'
+backend=DEFAULT_BACKEND
+
+from .util import set_backend, call_plotting_function, get_default_options
+
#
# Import plots
#
-from .plotly_manager import PlotlyManager
+from .plots import (
+ surface
+ )
+
from .standard_plots import StandardPlot, StandardSubplot, trajectories
from .contour import contour
-from .dataset import dataset
from .convergence import convergence
+from .dataset import dataset
+from .nyquist import nyquist
from .parameters import parameters
from .problem import problem
-from .nyquist import nyquist
-from .voronoi import surface
-from .samples import trace, chains, posterior, summary_table
+from .samples import chains, posterior, summary_table, trace
+from .voronoi import voronoi_data, _voronoi_regions
+from . import matplotlib
+from . import plotly
diff --git a/pybop/plot/contour.py b/pybop/plot/contour.py
index da2d6a1a6..cdd5b8b59 100644
--- a/pybop/plot/contour.py
+++ b/pybop/plot/contour.py
@@ -4,7 +4,8 @@
import numpy as np
-from pybop.plot.plotly_manager import PlotlyManager
+from pybop.plot.standard_plots import StandardPlot
+from pybop.plot.util import call_plotting_function, get_default_options, update_and_show
from pybop.problems.problem import Problem
if TYPE_CHECKING:
@@ -18,7 +19,7 @@ def contour(
transformed: bool = False,
steps: int = 10,
show: bool = True,
- **layout_kwargs,
+ backend=None,
):
"""
Plot a 2D visualisation of a cost landscape using Plotly.
@@ -143,121 +144,86 @@ def transform_array_of_values(list_of_values, parameter):
bounds[0] = transform_array_of_values(bounds[0], parameters[names[0]])
bounds[1] = transform_array_of_values(bounds[1], parameters[names[1]])
- # Import plotly only when needed
- go = PlotlyManager().go
-
- # Set default layout properties
- layout_options = dict(
- title="Cost Landscape",
- title_x=0.5,
- title_y=0.905,
- width=600,
- height=600,
- xaxis=dict(range=bounds[0], showexponent="last", exponentformat="e"),
- yaxis=dict(range=bounds[1], showexponent="last", exponentformat="e"),
- legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1),
+ # Get options
+ options = get_default_options("contour", backend)
+ plot_options = options.get("plot_options") or {}
+ trace_options_initial = options.get("trace_options_initial") or {}
+ trace_options_optim = options.get("trace_options_optim") or {}
+ trace_options_contour = options.get("trace_options_contour") or {}
+
+ plot_dict = StandardPlot(
+ xaxis_title="Transformed " + names[0] if transformed else names[0],
+ yaxis_title="Transformed " + names[1] if transformed else names[1],
+ xaxis_range=bounds[0],
+ yaxis_range=bounds[1],
+ backend=backend,
+ **plot_options,
)
- layout_options["xaxis_title"] = (
- "Transformed " + names[0] if transformed else names[0]
- )
- layout_options["yaxis_title"] = (
- "Transformed " + names[1] if transformed else names[1]
- )
- layout = go.Layout(layout_options)
# Create contour plot and update the layout
- fig = go.Figure(
- data=[go.Contour(x=x, y=y, z=costs, colorscale="Viridis", connectgaps=True)],
- layout=layout,
- )
+ plot_dict.create_contour(x=x, y=y, z=costs, **trace_options_contour)
if plot_optim:
# Plot the optimisation trace
optim_trace = np.asarray([item[:2] for item in result.x_model])
optim_trace = optim_trace.reshape(-1, 2)
-
- fig.add_trace(
- go.Scatter(
- x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]),
- y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]),
- mode="markers",
- marker=dict(
- color=[i / len(optim_trace) for i in range(len(optim_trace))],
- colorscale="Greys",
- size=8,
- showscale=False,
- ),
- showlegend=False,
- )
+ call_plotting_function(
+ "plot_optimisation_path",
+ backend=backend,
+ plot_dict=plot_dict,
+ x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]),
+ y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]),
)
# Plot the initial guess
if len(result.x_model) > 0:
x0 = result.x_model[0]
- fig.add_trace(
- go.Scatter(
+ plot_dict.traces.append(
+ plot_dict.create_trace(
x=transform_array_of_values([x0[0]], parameters[names[0]]),
y=transform_array_of_values([x0[1]], parameters[names[1]]),
- mode="markers",
- marker_symbol="x",
- marker=dict(
- color="white",
- line_color="black",
- line_width=1,
- size=14,
- showscale=False,
- ),
- name="Initial values",
+ **trace_options_initial,
)
)
# Plot optimised value
if result.x is not None:
x_best = result.x
- fig.add_trace(
- go.Scatter(
+ plot_dict.traces.append(
+ plot_dict.create_trace(
x=transform_array_of_values([x_best[0]], parameters[names[0]]),
y=transform_array_of_values([x_best[1]], parameters[names[1]]),
- mode="markers",
- marker_symbol="cross",
- marker=dict(
- color="black",
- line_color="white",
- line_width=1,
- size=14,
- showscale=False,
- ),
- name="Final values",
+ **trace_options_optim,
)
)
# Update the layout and display the figure
- fig.update_layout(**layout_kwargs)
+ fig = plot_dict(show=False)
if show:
- fig.show()
+ update_and_show(fig)
- if gradient:
- grad_figs = []
- for i, grad_costs in enumerate(grad_parameter_costs):
- # Update title for gradient plots
- updated_layout_options = layout_options.copy()
- updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}"
-
- # Create contour plot with updated layout options
- grad_layout = go.Layout(updated_layout_options)
-
- # Create fig
- grad_fig = go.Figure(
- data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout
- )
- grad_fig.update_layout(**layout_kwargs)
+ # if gradient:
+ # grad_figs = []
+ # for i, grad_costs in enumerate(grad_parameter_costs):
+ # # Update title for gradient plots
+ # updated_layout_options = layout_options.copy()
+ # updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}"
+
+ # # Create contour plot with updated layout options
+ # grad_layout = go.Layout(updated_layout_options)
+
+ # # Create fig
+ # grad_fig = go.Figure(
+ # data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout
+ # )
+ # grad_fig.update_layout(**layout_kwargs)
- if show:
- grad_fig.show()
+ # if show:
+ # grad_fig.show()
- # append grad_fig to list
- grad_figs.append(grad_fig)
+ # # append grad_fig to list
+ # grad_figs.append(grad_fig)
- return fig, grad_figs
+ # return fig, grad_figs
return fig
diff --git a/pybop/plot/convergence.py b/pybop/plot/convergence.py
index deabf66d5..5d3390959 100644
--- a/pybop/plot/convergence.py
+++ b/pybop/plot/convergence.py
@@ -1,12 +1,13 @@
from typing import TYPE_CHECKING
from pybop.plot.standard_plots import StandardPlot
+from pybop.plot.util import update_and_show
if TYPE_CHECKING:
from pybop._result import Result
-def convergence(result: "Result", show=True, **layout_kwargs):
+def convergence(result: "Result", show=True, backend=None):
"""
Plot the convergence of the optimisation algorithm.
@@ -37,18 +38,15 @@ def convergence(result: "Result", show=True, **layout_kwargs):
plot_dict = StandardPlot(
x=iteration_numbers,
y=cost_log,
- layout_options=dict(
- xaxis_title="Evaluation",
- yaxis_title="Cost",
- title="Convergence",
- ),
+ xaxis_title="Evaluation",
+ yaxis_title="Cost",
+ title="Convergence",
trace_names=result.method_name,
+ backend=backend,
)
# Generate and display the figure
fig = plot_dict(show=False)
- fig.update_layout(**layout_kwargs)
if show:
- fig.show()
-
+ update_and_show(fig, backend=backend)
return fig
diff --git a/pybop/plot/dataset.py b/pybop/plot/dataset.py
index 24257a732..b92d2b78f 100644
--- a/pybop/plot/dataset.py
+++ b/pybop/plot/dataset.py
@@ -1,7 +1,10 @@
from pybop.plot.standard_plots import StandardPlot, trajectories
+from pybop.plot.util import update_and_show
-def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs):
+def dataset(
+ dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs
+):
"""
Quickly plot a PyBOP Dataset using Plotly.
@@ -50,9 +53,9 @@ def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs):
show=False,
xaxis_title=StandardPlot.remove_brackets(dataset.domain),
yaxis_title=yaxis_title,
+ backend=backend,
)
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
+
+ fig = update_and_show(fig, backend=backend, show=show, **layout_kwargs)
return fig
diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py
new file mode 100644
index 000000000..2730122df
--- /dev/null
+++ b/pybop/plot/matplotlib/__init__.py
@@ -0,0 +1,3 @@
+from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path
+from .voronoi import surface
+from .util import update_and_show, DEFAULT_PLOT_OPTIONS
diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py
new file mode 100644
index 000000000..bd49cd6d5
--- /dev/null
+++ b/pybop/plot/matplotlib/standard_plots.py
@@ -0,0 +1,385 @@
+import warnings
+
+from matplotlib import pyplot as plt
+
+from pybop.plot import StandardPlot
+
+DEFAULT_TRACE_OPTIONS = dict(linewidth=2.0)
+
+
+class Plotter:
+ """
+ A class for creating and displaying interactive Plotly figures.
+
+ Parameters
+ ----------
+ x : list or np.ndarray, optional
+ X-axis data points.
+ y : list or np.ndarray, optional
+ Primary Y-axis data points for simulated model output.
+ trace_options : dict, optional
+ Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS).
+ trace_names : str, optional
+ Name(s) for the primary trace(s) (default: None).
+ trace_name_width : int, optional
+ Maximum length of the trace names before text wrapping is used (default: 40).
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The generated Plotly figure.
+ """
+
+ def __init__(
+ self,
+ trace_options=None,
+ figsize=(8, 6),
+ title=None,
+ xaxis_title=None,
+ yaxis_title=None,
+ xaxis_range=None,
+ yaxis_range=None,
+ grid=None,
+ axis_bg_color=None,
+ **kwargs,
+ ):
+ self.backend = "matplotlib"
+ self.title = title
+ # Warning if layout arguments ignored
+ if len(kwargs) > 0:
+ warnings.warn(
+ "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n"
+ f"{list(kwargs.keys())}",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ # Set default trace options and update if provided
+ self.trace_options = DEFAULT_TRACE_OPTIONS.copy()
+ if trace_options:
+ self.trace_options.update(trace_options)
+
+ self.fig = plt.figure(figsize=figsize, dpi=100)
+ if title is not None:
+ plt.suptitle(self.title)
+ if xaxis_title is not None:
+ plt.xlabel(xaxis_title)
+ if yaxis_title is not None:
+ plt.ylabel(yaxis_title)
+ if grid is not None:
+ plt.grid(**grid)
+ if axis_bg_color is not None:
+ ax = plt.gca()
+ ax.set_facecolor(axis_bg_color)
+ ax.set_axisbelow(True)
+ if xaxis_range is not None:
+ plt.xlim(xaxis_range)
+ if yaxis_range is not None:
+ plt.ylim(yaxis_range)
+ self.traces = []
+
+ def __call__(self, show=True):
+ """
+ Generate and show the figure.
+
+ Parameters
+ ----------
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ """
+ # Add traces
+ for trace in self.traces:
+ self._plot_trace(**trace)
+
+ plt.tick_params(axis="both", labelsize=12)
+ plt.ticklabel_format(axis="both", style="sci", scilimits=(-4, 4))
+
+ labels_in_fig = True
+ for ax in self.fig.axes:
+ if not ax.get_legend_handles_labels() == ([], []):
+ break
+ else:
+ labels_in_fig = False
+ if labels_in_fig:
+ plt.legend(
+ **dict(loc="best", fontsize=12),
+ )
+
+ if show:
+ plt.show()
+ else:
+ return self.fig
+
+ def add_traces(self, x, y, trace_names=None, **trace_options):
+ """
+ Add a set of traces to the plot dictionary.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Primary Y-axis data points for simulated model output.
+ trace_names : str or list[str], optional
+ Name(s) for the primary trace(s) (default: None).
+ """
+
+ options = self.trace_options.copy()
+ options.update(trace_options)
+
+ # Create a trace for each trajectory
+ xi = x[0]
+ for i in range(0, len(y)):
+ trace_options = options.copy()
+ if len(x) > 1:
+ xi = x[i]
+
+ label = None
+ if trace_names is not None:
+ label = trace_names[i]
+
+ self.traces.append(self.create_trace(xi, y[i], label, **trace_options))
+
+ def create_trace(self, x=None, y=None, label=None, ax=None, **trace_options):
+ """
+ Add line to plot.
+
+ Returns
+ -------
+ plotly.graph_objs.Scatter
+ A trace for a Plotly figure.
+ """
+ if x is not None and y is not None:
+ size = min(len(x), len(y))
+ trace = dict(positional_args=[x[:size], y[:size]], label=label, ax=ax)
+ elif y is not None:
+ trace = dict(positional_args=[y], label=label, ax=ax)
+
+ trace.update(trace_options)
+ return trace
+
+ def create_fill_trace(self, x, y_upper, y_lower, **options):
+ trace = dict(positional_args=(x, y_upper, y_lower), plot_type="fill_between")
+ trace.update(options)
+ return trace
+
+ def create_histogram(self, x, name, **trace_options):
+ trace = dict(positional_args=[x], label=name, plot_type="hist")
+ trace.update(trace_options)
+ return trace
+
+ def create_vline(self, fig, x, **trace_options):
+ fig.gca()
+ plt.axvline(x, **trace_options)
+
+ def create_contour(self, x, y, z, **trace_options):
+ contour = dict(positional_args=[x, y, z], plot_type="contourf")
+ contour.update(**trace_options)
+ self.traces.append(contour)
+ self.traces.append(
+ dict(
+ positional_args=[x, y, z],
+ colors=("k"),
+ linestyles="solid",
+ linewidths=0.2,
+ plot_type="contour",
+ )
+ )
+
+ def create_scatter(self, x, y, **trace_options):
+ scatter = dict(positional_args=[x, y], plot_type="scatter")
+ scatter.update(**trace_options)
+ self.traces.append(scatter)
+
+ def _plot_trace(
+ self, ax=None, plot_type="plot", positional_args=None, **trace_options
+ ):
+ if positional_args is None:
+ positional_args = []
+ if ax is None:
+ ax = plt.gca()
+ try:
+ plot_function = getattr(ax, plot_type)
+ except ValueError:
+ print("Plot type not recognised")
+
+ obj = plot_function(*positional_args, **trace_options)
+ if plot_type == "contourf":
+ plt.colorbar(obj)
+
+ return obj
+
+
+class SubplotPlotter(Plotter):
+ """
+ A class for creating and displaying a set of interactive Plotly figures in a grid layout.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Primary Y-axis data points for simulated model output.
+ num_rows : int, optional
+ Number of rows of subplots, can be set automatically (default: None).
+ num_cols : int, optional
+ Number of columns of subplots, can be set automatically (default: None).
+ layout : Plotly layout, optional
+ A layout for the figure, overrides the layout options (default: None).
+ trace_options : dict, optional
+ Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS).
+ trace_names : str, optional
+ Name(s) for the primary trace(s) (default: None).
+ trace_name_width : int, optional
+ Maximum length of the trace names before text wrapping is used (default: 40).
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The generated Plotly figure.
+ """
+
+ def __init__(
+ self,
+ axis_titles=None,
+ trace_options=DEFAULT_TRACE_OPTIONS,
+ figsize=(8, 6),
+ **kwargs,
+ ):
+ super().__init__(trace_options, figsize, **kwargs)
+ self.axis_titles = axis_titles
+
+ def __call__(self, show=True, num_rows=1, num_cols=1):
+ """
+ Generate and show the set of figures.
+
+ Parameters
+ ----------
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ """
+
+ color_cycle = plt.rcParams["axes.prop_cycle"]()
+
+ lines = []
+ show_legend = False
+ for idx, trace in enumerate(self.traces):
+ ax = self.fig.add_subplot(num_rows, num_cols, idx + 1)
+ trace["ax"] = ax
+ if self.axis_titles and idx < len(self.axis_titles):
+ x_title, y_title = self.axis_titles[idx]
+ ax.set_xlabel(x_title)
+ ax.set_ylabel(y_title)
+ if "label" in trace.keys() and trace["label"] is not None:
+ show_legend = True
+
+ lines.append(self._plot_trace(**trace, **next(color_cycle)))
+
+ lines_labels = [ax.get_legend_handles_labels() for ax in self.fig.axes]
+ lines, labels = [sum(lol, []) for lol in zip(*lines_labels, strict=False)]
+ if show_legend:
+ self.fig.legend(
+ lines,
+ labels,
+ loc="upper right",
+ ncol=len(lines),
+ bbox_to_anchor=(0.99, 0.95),
+ )
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
+
+ if self.title is not None:
+ plt.suptitle(self.title)
+
+ if show:
+ plt.show()
+
+ return self.fig
+
+
+def trajectories(
+ x,
+ y,
+ trace_names=None,
+ show=True,
+ xaxis_title="",
+ yaxis_title="",
+ title="",
+ **layout_kwargs,
+):
+ """
+ Quickly plot one or more trajectories using Plotly.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Y-axis data points for each trajectory.
+ trace_names : list or str, optional
+ Name(s) for the trace(s) (default: None).
+ **layout_kwargs : optional
+ This argument is ignored for the matplotlib backend.
+ Valid Plotly layout keys and their values,
+ e.g. `xaxis_title="Time / s"` or
+ `xaxis={"title": "Time [s]", font={"size":14}}`
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The Plotly figure object for the scatter plot.
+ """
+
+ if len(layout_kwargs) > 0:
+ warnings.warn(
+ "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n"
+ f"{list(layout_kwargs.keys())}",
+ UserWarning,
+ stacklevel=2,
+ )
+ # Create a plot dictionary
+ plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="matplotlib")
+
+ # Generate the figure and update the layout
+ fig = plot_dict(show=False)
+ plt.title(title)
+ plt.xlabel(xaxis_title, fontsize=12)
+ plt.ylabel(yaxis_title, fontsize=12)
+ plt.tight_layout()
+ if show:
+ fig.show()
+
+ return fig
+
+
+def show_table(header, values, title):
+ """
+ Display data in a table.
+ """
+ for i, val in enumerate(values):
+ values[i] = [val[0], ", ".join(val[1].astype(str))]
+
+ fig, ax = plt.subplots(figsize=(6, 2), dpi=100)
+
+ # hide axes
+ ax.axis("off")
+ ax.axis("tight")
+ ax.table(
+ cellText=values,
+ colLabels=header,
+ loc="center",
+ cellLoc="center",
+ colColours=["lightsteelblue", "lightsteelblue"],
+ )
+ ax.set_title(title)
+ fig.tight_layout()
+ plt.show()
+
+
+def plot_optimisation_path(plot_dict: StandardPlot, x, y):
+ plot_dict.plotter.create_scatter(
+ x,
+ y,
+ c=[i / len(x) for i in range(len(x))],
+ cmap="Grays",
+ zorder=1,
+ )
diff --git a/pybop/plot/matplotlib/util.py b/pybop/plot/matplotlib/util.py
new file mode 100644
index 000000000..b6a7311d8
--- /dev/null
+++ b/pybop/plot/matplotlib/util.py
@@ -0,0 +1,92 @@
+import warnings
+
+import matplotlib.pyplot as plt
+
+
+def update_and_show(fig, show=True, **layout_kwargs):
+ if len(layout_kwargs) > 0:
+ warnings.warn(
+ "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n"
+ f"{list(layout_kwargs.keys())}",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ if show:
+ plt.show()
+
+ return fig
+
+
+DEFAULT_PLOT_OPTIONS = {
+ "contour": {
+ "plot_options": dict(title="Cost Landscape"),
+ "trace_options_contour": dict(extend="both", cmap="viridis"),
+ "trace_options_initial": dict(
+ marker="X",
+ markersize=14,
+ markerfacecolor="w",
+ markeredgecolor="k",
+ label="Initial values",
+ linestyle="None",
+ ),
+ "trace_options_optim": dict(
+ marker="P",
+ markersize=14,
+ markerfacecolor="k",
+ markeredgecolor="w",
+ label="Final values",
+ linestyle="None",
+ ),
+ },
+ "nyquist": {
+ "plot_options": dict(
+ xaxis_title=r"$Z_{re} / \Omega$", yaxis_title=r"$-Z_{im} / \Omega$"
+ ),
+ "trace_options_model": dict(
+ label="Model", color="#00CC96", linewidth=2, marker=".", markersize=8
+ ),
+ "trace_options_reference": dict(
+ label="Reference",
+ linestyle="none",
+ marker="o",
+ fillstyle="none",
+ markersize=8,
+ markeredgecolor="#636EFA",
+ ),
+ },
+ "parameters": dict(figsize=(18, 8), title="Parameter Convergence"),
+ "problem": {
+ "default_trace_options": dict(label="Model", marker=None, linestyle="-"),
+ "design_cost_options": dict(label="Optimised"),
+ "meta_problem_options": dict(marker=".", linestyle="none"),
+ "reference_options": dict(label="Reference", marker=".", linestyle="none"),
+ "fill_options": dict(color=[(1.0, 0.898, 0.800, 0.8)]),
+ },
+ "posterior": {
+ "plot_options": {
+ "figsize": (15, 8),
+ "grid": dict(axis="y", zorder=-1, color="w"),
+ "axis_bg_color": (
+ 0.6784313725490196,
+ 0.8470588235294118,
+ 0.9019607843137255,
+ 0.3,
+ ),
+ },
+ "trace_options": dict(alpha=0.75),
+ "trace_options_vline": dict(linewidth=3, linestyle="--", color="k"),
+ },
+ "trace": {
+ "plot_options": {
+ "figsize": (15, 8),
+ "grid": dict(axis="y", zorder=-10, color="w"),
+ "axis_bg_color": (
+ 0.6784313725490196,
+ 0.8470588235294118,
+ 0.9019607843137255,
+ 0.3,
+ ),
+ },
+ },
+}
diff --git a/pybop/plot/matplotlib/voronoi.py b/pybop/plot/matplotlib/voronoi.py
new file mode 100644
index 000000000..e7502b941
--- /dev/null
+++ b/pybop/plot/matplotlib/voronoi.py
@@ -0,0 +1,141 @@
+import warnings
+from typing import TYPE_CHECKING
+
+import matplotlib as mpl
+import numpy as np
+from matplotlib import pyplot as plt
+
+if TYPE_CHECKING:
+ from pybop._result import Result
+
+from pybop.plot.voronoi import voronoi_data
+
+
+def surface(
+ result: "Result",
+ bounds=None,
+ normalise=True,
+ title="Voronoi Cost Landscape",
+ show=True,
+ **layout_kwargs,
+):
+ """
+ Plot a 2D representation of the Voronoi diagram with color-coded regions.
+
+ Parameters:
+ -----------
+ result : pybop.Result
+ Optimisation result containing the history of parameter values and associated cost.
+ bounds : numpy.ndarray, optional
+ A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses
+ `cost.parameters.get_bounds_for_plotly`.
+ normalise : bool, optional
+ If True, the voronoi regions are computed using the Euclidean distance between
+ points normalised with respect to the bounds (default: True).
+ resolution : int, optional
+ Resolution of the plot. Default is 500.
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ """
+
+ if len(layout_kwargs) > 0:
+ warnings.warn(
+ "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n"
+ f"{list(layout_kwargs.keys())}",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ points = result.x_model
+ parameters = result.problem.parameters
+
+ if points[0].shape[0] != 2:
+ raise ValueError("This plot method requires two parameters.")
+
+ x_optim, y_optim = map(list, zip(*points, strict=False))
+ f = result.cost
+
+ # Translate bounds, taking only the first two elements
+ xlim, ylim = (
+ bounds if bounds is not None else [param.bounds for param in parameters]
+ )[:2]
+
+ _, _, f, regions, relative_sizes = voronoi_data(
+ xlim, ylim, x_optim, y_optim, f, normalise
+ )
+
+ # Construct figure
+ plt.figure(figsize=(7, 6), dpi=100)
+
+ # normalise cost
+ f_min = np.nanmin(f[np.isfinite(f)])
+ f_max = np.nanmax(f[np.isfinite(f)])
+ norm = mpl.colors.Normalize(vmin=f_min, vmax=f_max, clip=True)
+ norm_f = norm(f, clip=True)
+
+ # get colours
+ cmap = mpl.colormaps["viridis"]
+ colors = cmap(norm_f)
+
+ # Add Voronoi edges and fill Voronoi regions
+ for j, (region, size) in enumerate(zip(regions, relative_sizes, strict=False)):
+ x_region = region[:, 0].tolist() + [region[0, 0]]
+ y_region = region[:, 1].tolist() + [region[0, 1]]
+
+ plt.fill(x_region, y_region, color=colors[j])
+ plt.plot(x_region, y_region, color="w", linewidth=0.5 + size * 0.1)
+
+ plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca())
+
+ # Add original points
+ plt.scatter(
+ x_optim,
+ y_optim,
+ c=[i / len(x_optim) for i in range(len(x_optim))],
+ cmap="Grays",
+ zorder=2.5,
+ )
+
+ # Plot the initial guess
+ if len(result.x_model) > 0:
+ x0 = result.x_model[0]
+ plt.plot(
+ [x0[0]],
+ [x0[1]],
+ "X",
+ markersize=14,
+ markerfacecolor="w",
+ markeredgecolor="k",
+ label="Initial values",
+ linestyle="None",
+ zorder=2.6,
+ )
+
+ # Plot optimised value
+ if result.x is not None:
+ x_best = result.x
+ plt.plot(
+ [x_best[0]],
+ [x_best[1]],
+ "P",
+ markersize=14,
+ markerfacecolor="k",
+ markeredgecolor="w",
+ label="Final values",
+ linestyle="None",
+ zorder=2.6,
+ )
+
+ # Layout
+ names = result.problem.parameters.names
+ plt.xlabel(names[0], labelpad=15)
+ plt.ticklabel_format(axis="both", **dict(style="sci", scilimits=(-4, 4)))
+ plt.ylabel(names[1], labelpad=15)
+ plt.title(title, pad=40)
+ plt.legend(ncols=2, loc="lower center", bbox_to_anchor=(0.5, 1.0))
+ plt.xlim(xlim[0], xlim[1])
+ plt.ylim(ylim[0], ylim[1])
+ plt.tight_layout()
+
+ if show:
+ plt.show()
diff --git a/pybop/plot/nyquist.py b/pybop/plot/nyquist.py
index 80f7eb77a..5d1452f2e 100644
--- a/pybop/plot/nyquist.py
+++ b/pybop/plot/nyquist.py
@@ -1,8 +1,11 @@
from pybop.parameters.parameter import Inputs
from pybop.plot.standard_plots import StandardPlot
+from pybop.plot.util import get_default_options, update_and_show
-def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs):
+def nyquist(
+ problem, inputs: Inputs = None, show=True, title="Nyquist Plot", backend=None
+):
"""
Generates Nyquist plots for the given problem by evaluating the model's output and target values.
@@ -39,81 +42,41 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs):
>>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)")
>>> # The plots will be displayed and nyquist_figures will contain the list of figure objects.
"""
+
+ options = get_default_options("nyquist", backend)
+ plot_options = options.get("plot_options") or {}
+ trace_options_model = options.get("trace_options_model") or {}
+ trace_options_reference = options.get("trace_options_reference") or {}
+
if not isinstance(inputs, dict):
inputs = problem.parameters.to_dict(inputs)
model_output = problem.simulate(inputs)
domain_data = model_output["Impedance"].data.real
target_output = problem.target_data
-
figure_list = []
for var in problem.target:
- default_layout_options = dict(
- title="Nyquist Plot",
- font=dict(family="Arial", size=14),
- plot_bgcolor="white",
- paper_bgcolor="white",
- xaxis=dict(
- title=dict(text="Zre / Ω", font=dict(size=16), standoff=15),
- showline=True,
- linewidth=2,
- linecolor="black",
- mirror=True,
- ticks="outside",
- tickwidth=2,
- tickcolor="black",
- ticklen=5,
- ),
- yaxis=dict(
- title=dict(text="-Zim / Ω", font=dict(size=16), standoff=15),
- showline=True,
- linewidth=2,
- linecolor="black",
- mirror=True,
- ticks="outside",
- tickwidth=2,
- tickcolor="black",
- ticklen=5,
- scaleanchor="x",
- scaleratio=1,
- ),
- legend=dict(
- orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
- ),
- width=600,
- height=600,
- )
-
plot_dict = StandardPlot(
x=domain_data,
y=-model_output[var].data.imag,
- layout_options=default_layout_options,
trace_names="Model",
+ title=title,
+ **plot_options,
)
- plot_dict.traces[0].update(
- mode="lines+markers",
- line=dict(color="#00CC96", width=2),
- marker=dict(size=8, color="#00CC96", symbol="circle"),
- )
+ plot_dict.traces[0].update(trace_options_model)
target_trace = plot_dict.create_trace(
x=target_output[var].real,
y=-target_output[var].imag,
- name="Reference",
- mode="markers",
- marker=dict(size=8, color="#636EFA", symbol="circle-open"),
- showlegend=True,
+ **trace_options_reference,
)
plot_dict.traces.append(target_trace)
fig = plot_dict(show=False)
-
- # Overwrite with user-kwargs
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
-
figure_list.append(fig)
+ if show:
+ update_and_show(figure_list)
+
return figure_list
diff --git a/pybop/plot/parameters.py b/pybop/plot/parameters.py
index a13d9e3fb..934ec67a3 100644
--- a/pybop/plot/parameters.py
+++ b/pybop/plot/parameters.py
@@ -2,12 +2,13 @@
from pybop.costs.log_likelihoods import GaussianLogLikelihood
from pybop.plot.standard_plots import StandardSubplot
+from pybop.plot.util import get_default_options, update_and_show
if TYPE_CHECKING:
from pybop._result import Result
-def parameters(result: "Result", show=True, **layout_kwargs):
+def parameters(result: "Result", show=True, backend=None, **layout_kwargs):
"""
Plot the evolution of parameters during the optimisation process using Plotly.
@@ -44,27 +45,21 @@ def parameters(result: "Result", show=True, **layout_kwargs):
trace_names.append("Sigma")
# Set subplot layout options
- layout_options = dict(
- title="Parameter Convergence",
- width=1024,
- height=576,
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
- )
+ plot_options = get_default_options("paramters", backend)
# Create a plot dictionary
plot_dict = StandardSubplot(
x=x,
y=y,
axis_titles=axis_titles,
- layout_options=layout_options,
trace_names=trace_names,
trace_name_width=50,
+ backend=backend,
+ **plot_options,
)
# Generate the figure and update the layout
fig = plot_dict(show=False)
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
+ fig = update_and_show(fig, backend, **layout_kwargs)
return fig
diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py
new file mode 100644
index 000000000..4de178480
--- /dev/null
+++ b/pybop/plot/plotly/__init__.py
@@ -0,0 +1,5 @@
+from .plotly_manager import PlotlyManager
+from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path
+from .voronoi import surface
+
+from .util import update_and_show, DEFAULT_PLOT_OPTIONS
diff --git a/pybop/plot/plotly_manager.py b/pybop/plot/plotly/plotly_manager.py
similarity index 100%
rename from pybop/plot/plotly_manager.py
rename to pybop/plot/plotly/plotly_manager.py
diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py
new file mode 100644
index 000000000..80429f0c4
--- /dev/null
+++ b/pybop/plot/plotly/standard_plots.py
@@ -0,0 +1,354 @@
+from pybop.plot import StandardPlot
+from pybop.plot.plotly.plotly_manager import PlotlyManager
+
+DEFAULT_LAYOUT_OPTIONS = dict(
+ title=None,
+ title_x=0.5,
+ xaxis=dict(
+ title=dict(font={"size": 14}),
+ showexponent="last",
+ exponentformat="e",
+ tickfont=dict(size=12),
+ ),
+ yaxis=dict(
+ title=dict(font={"size": 14}),
+ showexponent="last",
+ exponentformat="e",
+ tickfont=dict(size=12),
+ ),
+ legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12),
+ showlegend=True,
+ autosize=False,
+ width=600,
+ height=600,
+ margin=dict(l=10, r=10, b=10, t=75, pad=4),
+ plot_bgcolor="white",
+)
+DEFAULT_SUBPLOT_OPTIONS = dict(
+ start_cell="bottom-left",
+)
+DEFAULT_TRACE_OPTIONS = dict(line=dict(width=4), mode="lines")
+DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines")
+
+
+class Plotter:
+ """
+ A class for creating and displaying interactive Plotly figures.
+
+ Parameters
+ ----------
+ x : list or np.ndarray, optional
+ X-axis data points.
+ y : list or np.ndarray, optional
+ Primary Y-axis data points for simulated model output.
+ layout : Plotly layout, optional
+ A layout for the figure, overrides the layout options (default: None).
+ layout_options : dict, optional
+ Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS).
+ trace_options : dict, optional
+ Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS).
+ trace_names : str, optional
+ Name(s) for the primary trace(s) (default: None).
+ trace_name_width : int, optional
+ Maximum length of the trace names before text wrapping is used (default: 40).
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The generated Plotly figure.
+ """
+
+ def __init__(
+ self,
+ title=None,
+ xaxis_title=None,
+ yaxis_title=None,
+ xaxis_range=None,
+ yaxis_range=None,
+ layout=None,
+ layout_options=None,
+ trace_options=None,
+ ):
+ self.backend = "plotly"
+
+ self.traces = []
+ self.layout = layout
+
+ # Set default layout options and update if provided
+ if self.layout is None:
+ self.layout_options = DEFAULT_LAYOUT_OPTIONS.copy()
+ if layout_options:
+ self.layout_options.update(layout_options)
+ if title is not None:
+ self.layout_options.update({"title": title})
+ if xaxis_title is not None:
+ self.layout_options.update({"xaxis_title": xaxis_title})
+ if yaxis_title is not None:
+ self.layout_options.update({"yaxis_title": yaxis_title})
+ if xaxis_range is not None:
+ self.layout_options["xaxis"].update({"range": xaxis_range})
+ if yaxis_range is not None:
+ self.layout_options["yaxis"].update({"range": yaxis_range})
+
+ # Set default trace options and update if provided
+ self.trace_options = DEFAULT_TRACE_OPTIONS.copy()
+ if trace_options:
+ self.trace_options.update(trace_options)
+
+ # Attempt to import plotly when an instance is created
+ self.go = PlotlyManager().go
+
+ # Create layout
+ if self.layout is None:
+ self.layout = self.go.Layout(**self.layout_options)
+
+ def __call__(self, show=True):
+ """
+ Generate and show the figure.
+
+ Parameters
+ ----------
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ """
+ fig = self.go.Figure(data=self.traces, layout=self.layout)
+ if show:
+ fig.show()
+
+ return fig
+
+ def add_traces(self, x, y, trace_names=None, **trace_options):
+ """
+ Add a set of traces to the plot dictionary.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Primary Y-axis data points for simulated model output.
+ trace_names : str or list[str], optional
+ Name(s) for the primary trace(s) (default: None).
+ """
+ options = self.trace_options.copy()
+ options.update(trace_options)
+
+ # Create a trace for each trajectory
+ xi = x[0]
+ for i in range(0, len(y)):
+ trace_options = options.copy()
+ if len(x) > 1:
+ xi = x[i]
+ if trace_names is not None:
+ trace_options["name"] = trace_names[i]
+ else:
+ trace_options["showlegend"] = False
+ trace = self.create_trace(xi, y[i], **trace_options)
+ self.traces.append(trace)
+
+ def create_trace(self, x=None, y=None, label=None, **trace_options):
+ """
+ Create a trace for the Plotly figure.
+
+ Returns
+ -------
+ plotly.graph_objs.Scatter
+ A trace for a Plotly figure.
+ """
+ if label is not None:
+ trace_options.update({"name": label})
+ if x is not None and y is not None:
+ return self.go.Scatter(
+ x=x,
+ y=y,
+ **trace_options,
+ )
+ if x is None and y is not None:
+ return self.go.Scatter(
+ y=y,
+ **trace_options,
+ )
+
+ def create_fill_trace(self, x, y_upper, y_lower, **options):
+ return self.create_trace(
+ x=x + x[::-1],
+ y=y_upper + y_lower[::-1],
+ fill="toself",
+ line=dict(color="rgba(255,255,255,0)"),
+ hoverinfo="skip",
+ showlegend=False,
+ **options,
+ )
+
+ def create_histogram(self, x, name, **trace_options):
+ return self.go.Histogram(x=x, name=name, **trace_options)
+
+ def create_vline(self, fig, x, **trace_options):
+ fig.add_vline(x=x, **trace_options)
+
+ def create_contour(self, x, y, z, **trace_options):
+ self.traces.append(self.go.Contour(x=x, y=y, z=z, **trace_options))
+
+
+class SubplotPlotter(Plotter):
+ """
+ A class for creating and displaying a set of interactive Plotly figures in a grid layout.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Primary Y-axis data points for simulated model output.
+ num_rows : int, optional
+ Number of rows of subplots, can be set automatically (default: None).
+ num_cols : int, optional
+ Number of columns of subplots, can be set automatically (default: None).
+ layout : Plotly layout, optional
+ A layout for the figure, overrides the layout options (default: None).
+ layout_options : dict, optional
+ Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS).
+ trace_options : dict, optional
+ Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS).
+ trace_names : str, optional
+ Name(s) for the primary trace(s) (default: None).
+ trace_name_width : int, optional
+ Maximum length of the trace names before text wrapping is used (default: 40).
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The generated Plotly figure.
+ """
+
+ def __init__(
+ self,
+ axis_titles=None,
+ layout=None,
+ layout_options=DEFAULT_LAYOUT_OPTIONS,
+ subplot_options=DEFAULT_SUBPLOT_OPTIONS,
+ trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS,
+ ):
+ super().__init__(
+ layout, layout_options=layout_options, trace_options=trace_options
+ )
+ self.subplot_options = subplot_options.copy()
+ if subplot_options is not None:
+ for arg, value in subplot_options.items():
+ self.subplot_options[arg] = value
+
+ # Attempt to import plotly when an instance is created
+ self.make_subplots = PlotlyManager().make_subplots
+
+ self.axis_titles = axis_titles
+
+ def __call__(self, show=True, num_rows=1, num_cols=1):
+ """
+ Generate and show the set of figures.
+
+ Parameters
+ ----------
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ """
+ fig = self.make_subplots(
+ rows=num_rows,
+ cols=num_cols,
+ horizontal_spacing=0.1,
+ vertical_spacing=0.15,
+ **self.subplot_options,
+ )
+ fig.update_layout(self.layout_options)
+
+ for idx, trace in enumerate(self.traces):
+ row = (idx // num_cols) + 1
+ col = (idx % num_cols) + 1
+ fig.add_trace(trace, row=row, col=col)
+
+ if self.axis_titles and idx < len(self.axis_titles):
+ x_title, y_title = self.axis_titles[idx]
+ fig.update_xaxes(title_text=x_title, row=row, col=col)
+ fig.update_yaxes(
+ title_text=y_title,
+ row=row,
+ col=col,
+ showexponent="last",
+ exponentformat="e",
+ )
+
+ if show:
+ fig.show()
+
+ return fig
+
+
+def trajectories(x, y, trace_names=None, show=True, **layout_kwargs):
+ """
+ Quickly plot one or more trajectories using Plotly.
+
+ Parameters
+ ----------
+ x : list or np.ndarray
+ X-axis data points.
+ y : list or np.ndarray
+ Y-axis data points for each trajectory.
+ trace_names : list or str, optional
+ Name(s) for the trace(s) (default: None).
+ **layout_kwargs : optional
+ Valid Plotly layout keys and their values,
+ e.g. `xaxis_title="Time / s"` or
+ `xaxis={"title": "Time [s]", font={"size":14}}`
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The Plotly figure object for the scatter plot.
+ """
+ # Create a plot dictionary
+ plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="plotly")
+
+ # Generate the figure and update the layout
+ fig = plot_dict(show=False)
+ fig.update_layout(**layout_kwargs)
+ if show:
+ fig.show()
+
+ return fig
+
+
+def show_table(header, values, title):
+ """
+ Display data in a table.
+ """
+ # Import plotly only when needed
+ go = PlotlyManager().go
+ fig = go.Figure(
+ data=[
+ go.Table(
+ header=dict(values=header),
+ cells=dict(
+ values=[[row[0] for row in values], [row[1] for row in values]]
+ ),
+ )
+ ]
+ )
+
+ fig.update_layout(title=title)
+ fig.show()
+
+
+def plot_optimisation_path(plot_dict: StandardPlot, x, y):
+ plot_dict.traces.append(
+ plot_dict.create_trace(
+ x,
+ y,
+ mode="markers",
+ marker=dict(
+ color=[i / len(x) for i in range(len(x))],
+ colorscale="Greys",
+ size=8,
+ showscale=False,
+ ),
+ showlegend=False,
+ )
+ )
diff --git a/pybop/plot/plotly/util.py b/pybop/plot/plotly/util.py
new file mode 100644
index 000000000..9a913b3aa
--- /dev/null
+++ b/pybop/plot/plotly/util.py
@@ -0,0 +1,150 @@
+def update_and_show(fig, show=True, **layout_kwargs):
+ if hasattr(fig, "__len__") and len(fig) > 0:
+ for f in fig:
+ f.update_layout(**layout_kwargs)
+ if show:
+ f.show()
+ else:
+ fig.update_layout(**layout_kwargs)
+ if show:
+ fig.show()
+ return fig
+
+
+DEFAULT_PLOT_OPTIONS = {
+ "contour": {
+ "plot_options": dict(
+ layout_options=dict(
+ title="Cost Landscape",
+ title_x=0.5,
+ title_y=0.905,
+ width=600,
+ height=600,
+ xaxis=dict(showexponent="last", exponentformat="e"),
+ yaxis=dict(showexponent="last", exponentformat="e"),
+ legend=dict(
+ orientation="h", yanchor="bottom", y=1, xanchor="right", x=1
+ ),
+ autosize=None,
+ showlegend=None,
+ margin=None,
+ )
+ ),
+ "trace_options_contour": dict(colorscale="Viridis", connectgaps=True),
+ "trace_options_initial": dict(
+ mode="markers",
+ marker_symbol="x",
+ marker=dict(
+ color="white",
+ line_color="black",
+ line_width=1,
+ size=14,
+ showscale=False,
+ ),
+ name="Initial values",
+ ),
+ "trace_options_optim": dict(
+ mode="markers",
+ marker_symbol="cross",
+ marker=dict(
+ color="black",
+ line_color="white",
+ line_width=1,
+ size=14,
+ showscale=False,
+ ),
+ name="Final values",
+ ),
+ },
+ "nyquist": {
+ "plot_options": dict(
+ layout_options=dict(
+ title="Nyquist Plot",
+ font=dict(family="Arial", size=14),
+ plot_bgcolor="white",
+ paper_bgcolor="white",
+ xaxis=dict(
+ title=dict(font=dict(size=16), standoff=15),
+ showline=True,
+ linewidth=2,
+ linecolor="black",
+ mirror=True,
+ ticks="outside",
+ tickwidth=2,
+ tickcolor="black",
+ ticklen=5,
+ ),
+ yaxis=dict(
+ title=dict(font=dict(size=16), standoff=15),
+ showline=True,
+ linewidth=2,
+ linecolor="black",
+ mirror=True,
+ ticks="outside",
+ tickwidth=2,
+ tickcolor="black",
+ ticklen=5,
+ scaleanchor="x",
+ scaleratio=1,
+ ),
+ legend=dict(
+ orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
+ ),
+ width=600,
+ height=600,
+ ),
+ xaxis_title="Zre / Ω",
+ yaxis_title="-Zim / Ω",
+ ),
+ "trace_options_model": dict(
+ mode="lines+markers",
+ line=dict(color="#00CC96", width=2),
+ marker=dict(size=8, color="#00CC96", symbol="circle"),
+ ),
+ "trace_options_reference": dict(
+ name="Reference",
+ mode="markers",
+ marker=dict(size=8, color="#636EFA", symbol="circle-open"),
+ showlegend=True,
+ ),
+ },
+ "parameters": dict(
+ layout_options=dict(
+ title="Parameter Convergence",
+ width=1024,
+ height=576,
+ legend=dict(
+ orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
+ ),
+ )
+ ),
+ "posterior": {
+ "plot_options": {
+ "layout_options": dict(
+ barmode="overlay",
+ width=None,
+ height=None,
+ plot_bgcolor=None,
+ autosize=None,
+ legend=None,
+ )
+ },
+ "trace_options": dict(opacity=0.75),
+ "trace_options_vline": dict(line_width=3, line_dash="dash", line_color="black"),
+ },
+ "problem": {
+ "default_trace_options": dict(name="Model", mode="lines", showlegend=True),
+ "design_cost_options": dict(name="Optimised"),
+ "meta_problem_options": dict(mode="lines"),
+ "reference_options": dict(name="Reference", mode="markers", showlegend=True),
+ "fill_options": dict(fillcolor="rgba(255,229,204,0.8)"),
+ },
+ "trace": {
+ "plot_options": {
+ "layout_options": dict(
+ width=None, height=None, plot_bgcolor=None, autosize=None, legend=None
+ )
+ },
+ "trace_options": dict(mode="lines"),
+ },
+}
diff --git a/pybop/plot/plotly/voronoi.py b/pybop/plot/plotly/voronoi.py
new file mode 100644
index 000000000..d60b9597e
--- /dev/null
+++ b/pybop/plot/plotly/voronoi.py
@@ -0,0 +1,216 @@
+from typing import TYPE_CHECKING
+
+import numpy as np
+from scipy.spatial import cKDTree
+
+if TYPE_CHECKING:
+ from pybop._result import Result
+from pybop.plot import voronoi_data
+from pybop.plot.plotly.plotly_manager import PlotlyManager
+
+
+def assign_nearest_value(x, y, f, xi, yi):
+ """
+ Computes an array of values given by the score of the nearest point.
+
+ Parameters
+ ----------
+ x : array-like
+ The x coordinates of points with known scores.
+ y : array-like
+ The y coordinates of points with known scores.
+ f : array-like
+ The score function at the given x and y coordinates.
+ xi : array-like
+ The x coordinates of grid points.
+ yi : array-like
+ The y coordinates of grid points.
+
+ Returns
+ -------
+ A numpy array containing the scores corresponding to the grid points.
+ """
+ # Create a KD-tree for efficient nearest neighbor search
+ tree = cKDTree(np.column_stack((x, y)))
+
+ # Find the nearest point for each grid point
+ _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel())))
+ zi = f[indices].reshape(xi.shape)
+
+ return zi
+
+
+def surface(
+ result: "Result",
+ bounds=None,
+ normalise=True,
+ resolution=250,
+ show=True,
+ **layout_kwargs,
+):
+ """
+ Plot a 2D representation of the Voronoi diagram with color-coded regions.
+
+ Parameters:
+ -----------
+ result : pybop.Result
+ Optimisation result containing the history of parameter values and associated cost.
+ bounds : numpy.ndarray, optional
+ A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses
+ `cost.parameters.get_bounds_for_plotly`.
+ normalise : bool, optional
+ If True, the voronoi regions are computed using the Euclidean distance between
+ points normalised with respect to the bounds (default: True).
+ resolution : int, optional
+ Resolution of the plot. Default is 500.
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ **layout_kwargs : optional
+ Valid Plotly layout keys and their values,
+ e.g. `xaxis_title="Time [s]"` or
+ `xaxis={"title": "Time [s]", font={"size":14}}`
+ """
+ points = result.x_model
+ parameters = result.problem.parameters
+
+ if points[0].shape[0] != 2:
+ raise ValueError("This plot method requires two parameters.")
+
+ x_optim, y_optim = map(list, zip(*points, strict=False))
+ f = result.cost
+
+ # Translate bounds, taking only the first two elements
+ xlim, ylim = (
+ bounds if bounds is not None else [param.bounds for param in parameters]
+ )[:2]
+
+ x, y, f, regions, relative_sizes = voronoi_data(
+ xlim, ylim, x_optim, y_optim, f, normalise
+ )
+
+ # Create a grid for plot
+ xi = np.linspace(xlim[0], xlim[1], resolution)
+ yi = np.linspace(ylim[0], ylim[1], resolution)
+ xi, yi = np.meshgrid(xi, yi)
+
+ if normalise:
+ # Create a normalised grid
+ norm_xi = np.linspace(0, 1, resolution)
+ norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi)
+
+ # Assign a value to each point in the grid
+ zi = assign_nearest_value(x, y, f, norm_xi, norm_yi)
+ else:
+ # Assign a value to each point in the grid
+ zi = assign_nearest_value(x, y, f, xi, yi)
+
+ # Calculate the size of each Voronoi region
+ region_sizes = np.array([len(region) for region in regions])
+ relative_sizes = (region_sizes - region_sizes.min()) / (
+ region_sizes.max() - region_sizes.min()
+ )
+
+ # Construct figure
+ go = PlotlyManager().go
+ fig = go.Figure()
+
+ # Heatmap
+ fig.add_trace(
+ go.Heatmap(
+ x=xi[0],
+ y=yi[:, 0],
+ z=zi,
+ colorscale="Viridis",
+ zsmooth="best",
+ )
+ )
+
+ # Add Voronoi edges
+ for region, size in zip(regions, relative_sizes, strict=False):
+ x_region = region[:, 0].tolist() + [region[0, 0]]
+ y_region = region[:, 1].tolist() + [region[0, 1]]
+
+ fig.add_trace(
+ go.Scatter(
+ x=x_region,
+ y=y_region,
+ mode="lines",
+ line=dict(color="white", width=0.5 + size * 0.1),
+ showlegend=False,
+ )
+ )
+
+ # Add original points
+ fig.add_trace(
+ go.Scatter(
+ x=x_optim,
+ y=y_optim,
+ mode="markers",
+ marker=dict(
+ color=[i / len(x_optim) for i in range(len(x_optim))],
+ colorscale="Greys",
+ size=8,
+ showscale=False,
+ ),
+ text=[f"f={val:.2f}" for val in f],
+ hoverinfo="text",
+ showlegend=False,
+ )
+ )
+
+ # Plot the initial guess
+ if len(result.x_model) > 0:
+ x0 = result.x_model[0]
+ fig.add_trace(
+ go.Scatter(
+ x=[x0[0]],
+ y=[x0[1]],
+ mode="markers",
+ marker_symbol="x",
+ marker=dict(
+ color="white",
+ line_color="black",
+ line_width=1,
+ size=14,
+ showscale=False,
+ ),
+ name="Initial values",
+ )
+ )
+
+ # Plot optimised value
+ if result.x is not None:
+ x_best = result.x
+ fig.add_trace(
+ go.Scatter(
+ x=[x_best[0]],
+ y=[x_best[1]],
+ mode="markers",
+ marker_symbol="cross",
+ marker=dict(
+ color="black",
+ line_color="white",
+ line_width=1,
+ size=14,
+ showscale=False,
+ ),
+ name="Final values",
+ )
+ )
+
+ names = parameters.names
+ fig.update_layout(
+ title="Voronoi Cost Landscape",
+ title_x=0.5,
+ title_y=0.905,
+ xaxis_title=names[0],
+ yaxis_title=names[1],
+ width=600,
+ height=600,
+ xaxis=dict(range=xlim, showexponent="last", exponentformat="e"),
+ yaxis=dict(range=ylim, showexponent="last", exponentformat="e"),
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1),
+ )
+ fig.update_layout(**layout_kwargs)
+ if show:
+ fig.show()
diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py
new file mode 100644
index 000000000..873af49fe
--- /dev/null
+++ b/pybop/plot/plots.py
@@ -0,0 +1,113 @@
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+if TYPE_CHECKING:
+ from pybop._result import Result
+
+from pybop.plot.util import call_plotting_function
+from pybop.problems.problem import Problem
+
+
+def contour(
+ call_object: "Problem | Result",
+ gradient: bool = False,
+ bounds: np.ndarray | None = None,
+ transformed: bool = False,
+ steps: int = 10,
+ show: bool = True,
+ backend=None,
+ **layout_kwargs,
+):
+ """
+ Plot a 2D visualisation of a cost landscape using Plotly.
+
+ This function generates a contour plot representing the cost landscape for a provided
+ callable cost function over a grid of parameter values within the specified bounds.
+
+ Parameters
+ ----------
+ call_object : pybop.Problem | pybop.Result
+ Either:
+ - the cost function to be evaluated. Must accept a list of parameter values and return a cost value.
+ - an optimiser result which provides a specific optimisation trace overlaid on the cost landscape.
+ gradient : bool, optional
+ If True, the gradient is shown (default: False).
+ bounds : numpy.ndarray | list[list[float]], optional
+ A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses
+ `parameters.get_bounds_for_plotly`.
+ transformed : bool, optional
+ Uses the transformed parameter values (as seen by the optimiser) for plotting.
+ steps : int, optional
+ The number of grid points to divide the parameter space into along each dimension (default: 10).
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ **layout_kwargs : optional
+ Valid Plotly layout keys and their values,
+ e.g. `xaxis_title="Time [s]"` or
+ `xaxis={"title": "Time [s]", font={"size":14}}`
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ The Plotly figure object containing the cost landscape plot.
+
+ Raises
+ ------
+ ValueError
+ If the cost function does not return a valid cost when called with a parameter list.
+ """
+ return call_plotting_function(
+ "contour",
+ backend,
+ call_object=call_object,
+ gradient=gradient,
+ bounds=bounds,
+ transformed=transformed,
+ steps=steps,
+ show=show,
+ **layout_kwargs,
+ )
+
+
+def surface(
+ result: "Result",
+ bounds=None,
+ normalise=True,
+ resolution=250,
+ show=True,
+ backend=None,
+ **layout_kwargs,
+):
+ """
+ Plot a 2D representation of the Voronoi diagram with color-coded regions.
+
+ Parameters:
+ -----------
+ result : pybop.Result
+ Optimisation result containing the history of parameter values and associated cost.
+ bounds : numpy.ndarray, optional
+ A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses
+ `cost.parameters.get_bounds_for_plotly`.
+ normalise : bool, optional
+ If True, the voronoi regions are computed using the Euclidean distance between
+ points normalised with respect to the bounds (default: True).
+ resolution : int, optional
+ Resolution of the plot. Default is 500.
+ show : bool, optional
+ If True, the figure is shown upon creation (default: True).
+ **layout_kwargs : optional
+ Valid Plotly layout keys and their values,
+ e.g. `xaxis_title="Time [s]"` or
+ `xaxis={"title": "Time [s]", font={"size":14}}`
+ """
+ return call_plotting_function(
+ "surface",
+ backend,
+ result=result,
+ bounds=bounds,
+ normalise=normalise,
+ resolution=resolution,
+ show=show,
+ **layout_kwargs,
+ )
diff --git a/pybop/plot/problem.py b/pybop/plot/problem.py
index 0372601e6..27b573641 100644
--- a/pybop/plot/problem.py
+++ b/pybop/plot/problem.py
@@ -4,6 +4,7 @@
from pybop.costs.error_measures import ErrorMeasure
from pybop.parameters.parameter import Inputs
from pybop.plot.standard_plots import StandardPlot
+from pybop.plot.util import get_default_options, update_and_show
from pybop.problems.meta_problem import MetaProblem
from pybop.problems.problem import Problem
from pybop.simulators.solution import Solution
@@ -13,7 +14,8 @@ def problem(
problem: Problem,
inputs: Inputs = None,
show: bool = True,
- **layout_kwargs,
+ title="Scatter Plot",
+ backend=None,
):
"""
Produce a quick plot of the target dataset against optimised model output.
@@ -29,10 +31,6 @@ def problem(
Optimised (or example) parameter values.
show : bool, optional
If True, the figure is shown upon creation (default: True).
- **layout_kwargs : optional
- Valid Plotly layout keys and their values,
- e.g. `xaxis_title="Time / s"` or
- `xaxis={"title": "Time [s]", font={"size":14}}`
Returns
-------
@@ -66,33 +64,38 @@ def problem(
model_output = problem.simulate(inputs)
model_domain = target_domain[: len(model_output[target].data)]
+ # Retrieve default layout options
+ plot_options = get_default_options("problem", backend)
+ trace_options = plot_options.get("default_trace_options") or {}
+ design_cost_options = plot_options.get("design_cost_options") or {}
+ meta_problem_options = plot_options.get("meta_problem_options") or {}
+ reference_options = plot_options.get("reference_options") or {}
+ fill_options = plot_options.get("fill_options") or {}
+
# Create a plot for each output
figure_list = []
for var in problem.target:
+ options = trace_options.copy()
+ if isinstance(problem, MetaProblem):
+ options.update(meta_problem_options)
+ if isinstance(problem.cost, DesignCost):
+ options.update(design_cost_options)
+
# Create a plot dictionary
plot_dict = StandardPlot(
- layout_options=dict(
- title="Scatter Plot",
- xaxis_title=StandardPlot.remove_brackets(domain),
- yaxis_title=StandardPlot.remove_brackets(var),
- )
+ title=title,
+ xaxis_title=StandardPlot.remove_brackets(domain),
+ yaxis_title=StandardPlot.remove_brackets(var),
+ backend=backend,
)
model_trace = plot_dict.create_trace(
- x=model_domain,
- y=model_output[var].data,
- name="Optimised" if isinstance(problem.cost, DesignCost) else "Model",
- mode="markers" if isinstance(problem, MetaProblem) else "lines",
- showlegend=True,
+ x=model_domain, y=model_output[var].data, **options
)
plot_dict.traces.append(model_trace)
target_trace = plot_dict.create_trace(
- x=target_domain,
- y=target_output[var].data,
- name="Reference",
- mode="markers",
- showlegend=True,
+ x=target_domain, y=target_output[var].data, **reference_options
)
plot_dict.traces.append(target_trace)
@@ -107,14 +110,8 @@ def problem(
y_upper = (model_output[var].data + plot_dict.sigma).tolist()
y_lower = (model_output[var].data - plot_dict.sigma).tolist()
- fill_trace = plot_dict.create_trace(
- x=x + x[::-1],
- y=y_upper + y_lower[::-1],
- fill="toself",
- fillcolor="rgba(255,229,204,0.8)",
- line=dict(color="rgba(255,255,255,0)"),
- hoverinfo="skip",
- showlegend=False,
+ fill_trace = plot_dict.create_fill_trace(
+ x, y_upper, y_lower, **fill_options
)
plot_dict.traces.append(fill_trace)
@@ -123,9 +120,7 @@ def problem(
# Generate the figure and update the layout
fig = plot_dict(show=False)
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
+ fig = update_and_show(fig, backend=backend)
figure_list.append(fig)
diff --git a/pybop/plot/samples.py b/pybop/plot/samples.py
index 55ee77cd0..3ea022cce 100644
--- a/pybop/plot/samples.py
+++ b/pybop/plot/samples.py
@@ -1,109 +1,107 @@
from typing import TYPE_CHECKING
-from pybop.plot import PlotlyManager
+from pybop.plot import StandardPlot
+from pybop.plot.util import call_plotting_function, get_default_options, update_and_show
if TYPE_CHECKING:
from pybop.samplers.base_pints_sampler import SamplingResult
-def trace(result: "SamplingResult", **kwargs):
+def chains(result: "SamplingResult", show=True, backend=None):
"""
- Plot trace plots for the posterior samples.
+ Plot posterior distributions for each chain.
"""
- # Import plotly only when needed
- go = PlotlyManager().go
+ options = get_default_options("posterior", backend)
+ plot_options = options.get("plot_options") or {}
+ trace_options = options.get("trace_options") or {}
+ trace_options_vline = options.get("trace_options_vline") or {}
- for i in range(result.n_parameters):
- fig = go.Figure()
+ plot_dict = StandardPlot(
+ backend=backend,
+ title="Posterior Distribution",
+ xaxis_title="Value",
+ yaxis_title="Density",
+ **plot_options,
+ )
- for j, chain in enumerate(result.chains):
- fig.add_trace(go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}"))
+ for i, chain in enumerate(result.chains):
+ for j in range(chain.shape[1]):
+ hist = plot_dict.create_histogram(
+ x=chain[:, j], name=f"Chain {i} - Parameter {j}", **trace_options
+ )
+ plot_dict.traces.append(hist)
- fig.update_layout(
- title=f"Parameter {i} Trace Plot",
- xaxis_title="Sample Index",
- yaxis_title="Value",
- )
- fig.update_layout(**kwargs)
- fig.show()
+ fig = plot_dict(show=False)
+ for j in range(chain.shape[1]):
+ plot_dict.create_vline(fig, result.mean[j], **trace_options_vline)
+ update_and_show(fig, backend=backend)
-def chains(result: "SamplingResult", **kwargs):
+
+def trace(result: "SamplingResult", show=True, backend=None):
"""
- Plot posterior distributions for each chain.
+ Plot trace plots for the posterior samples.
"""
- # Import plotly only when needed
- go = PlotlyManager().go
-
- fig = go.Figure()
+ figlist = []
+ options = get_default_options("trace", backend)
+ plot_options = options.get("plot_options") or {}
+ trace_options = options.get("trace_options") or {}
+ for i in range(result.n_parameters):
+ plots = StandardPlot(
+ title=f"Parameter {i} Trace Plot",
+ xaxis_title="Sample Index",
+ yaxis_title="Value",
+ backend=backend,
+ **plot_options,
+ )
- for i, chain in enumerate(result.chains):
- for j in range(chain.shape[1]):
- fig.add_trace(
- go.Histogram(
- x=chain[:, j],
- name=f"Chain {i} - Parameter {j}",
- opacity=0.75,
- )
+ for j, chain in enumerate(result.chains):
+ plots.traces.append(
+ plots.create_trace(y=chain[:, i], label=f"Chain {j}", **trace_options)
)
+ fig = plots(show=False)
+ figlist.append(fig)
- fig.add_shape(
- type="line",
- x0=result.mean[j],
- y0=0,
- x1=result.mean[j],
- y1=result.max[j],
- name=f"Mean - Parameter {j}",
- line=dict(color="Black", width=1.5, dash="dash"),
- )
+ update_and_show(figlist, show=show, backend=backend)
- fig.update_layout(
- barmode="overlay",
- title="Posterior Distribution",
- xaxis_title="Value",
- yaxis_title="Density",
- )
- fig.update_layout(**kwargs)
- fig.show()
+ return figlist
-def posterior(result: "SamplingResult", **kwargs):
+def posterior(result: "SamplingResult", backend=None):
"""
Plot the summed posterior distribution across chains.
"""
+ options = get_default_options("posterior", backend)
+ plot_options = options.get("plot_options") or {}
+ trace_options = options.get("trace_options") or {}
+ trace_options_vline = options.get("trace_options_vline") or {}
# Import plotly only when needed
- go = PlotlyManager().go
-
- fig = go.Figure()
-
- for j in range(result.all_samples.shape[1]):
- histogram = go.Histogram(
- x=result.all_samples[:, j],
- name=f"Parameter {j}",
- opacity=0.75,
- )
- fig.add_trace(histogram)
- fig.add_vline(
- x=result.mean[j], line_width=3, line_dash="dash", line_color="black"
- )
-
- fig.update_layout(
- barmode="overlay",
+ plot_dict = StandardPlot(
+ backend=backend,
title="Posterior Distribution",
xaxis_title="Value",
yaxis_title="Density",
+ **plot_options,
)
- fig.update_layout(**kwargs)
- fig.show()
- return fig
+
+ for j in range(result.all_samples.shape[1]):
+ hist = plot_dict.create_histogram(
+ x=result.all_samples[:, j], name=f"Parameter {j}", **trace_options
+ )
+
+ plot_dict.traces.append(hist)
+
+ fig = plot_dict(show=False)
+ for j in range(result.all_samples.shape[1]):
+ plot_dict.create_vline(fig, result.mean[j], **trace_options_vline)
+
+ update_and_show(fig, backend=backend)
-def summary_table(result: "SamplingResult"):
+def summary_table(result: "SamplingResult", backend=None):
"""
Display summary statistics in a table.
"""
- # Import plotly only when needed
- go = PlotlyManager().go
summary_stats = result.get_summary_statistics()
@@ -116,16 +114,10 @@ def summary_table(result: "SamplingResult"):
["95% CI Upper", summary_stats["ci_upper"]],
]
- fig = go.Figure(
- data=[
- go.Table(
- header=dict(values=header),
- cells=dict(
- values=[[row[0] for row in values], [row[1] for row in values]]
- ),
- )
- ]
+ call_plotting_function(
+ "show_table",
+ backend=backend,
+ header=header,
+ values=values,
+ title="Summary Statistics",
)
-
- fig.update_layout(title="Summary Statistics")
- fig.show()
diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py
index 4422516b8..3287789c5 100644
--- a/pybop/plot/standard_plots.py
+++ b/pybop/plot/standard_plots.py
@@ -3,154 +3,66 @@
import numpy as np
-from pybop.plot.plotly_manager import PlotlyManager
-
-DEFAULT_LAYOUT_OPTIONS = dict(
- title=None,
- title_x=0.5,
- xaxis=dict(
- title=dict(font={"size": 14}),
- showexponent="last",
- exponentformat="e",
- tickfont=dict(size=12),
- ),
- yaxis=dict(
- title=dict(font={"size": 14}),
- showexponent="last",
- exponentformat="e",
- tickfont=dict(size=12),
- ),
- legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12),
- showlegend=True,
- autosize=False,
- width=600,
- height=600,
- margin=dict(l=10, r=10, b=10, t=75, pad=4),
- plot_bgcolor="white",
-)
-DEFAULT_SUBPLOT_OPTIONS = dict(
- start_cell="bottom-left",
-)
-DEFAULT_TRACE_OPTIONS = dict(line=dict(width=4), mode="lines")
-DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines")
+from pybop.plot.util import call_plotting_function
class StandardPlot:
- """
- A class for creating and displaying interactive Plotly figures.
-
- Parameters
- ----------
- x : list or np.ndarray, optional
- X-axis data points.
- y : list or np.ndarray, optional
- Primary Y-axis data points for simulated model output.
- layout : Plotly layout, optional
- A layout for the figure, overrides the layout options (default: None).
- layout_options : dict, optional
- Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS).
- trace_options : dict, optional
- Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS).
- trace_names : str, optional
- Name(s) for the primary trace(s) (default: None).
- trace_name_width : int, optional
- Maximum length of the trace names before text wrapping is used (default: 40).
-
- Returns
- -------
- plotly.graph_objs.Figure
- The generated Plotly figure.
- """
-
def __init__(
self,
x=None,
y=None,
- layout=None,
- layout_options=None,
+ title=None,
+ xaxis_title=None,
+ yaxis_title=None,
trace_options=None,
trace_names=None,
- trace_name_width=40,
+ trace_name_width=20,
+ backend=None,
+ **kwargs,
):
- self.traces = []
- self.layout = layout
- self.trace_name_width = trace_name_width
-
- # Set default layout options and update if provided
- if self.layout is None:
- self.layout_options = DEFAULT_LAYOUT_OPTIONS.copy()
- if layout_options:
- self.layout_options.update(layout_options)
- # Set default trace options and update if provided
- self.trace_options = DEFAULT_TRACE_OPTIONS.copy()
- if trace_options:
- self.trace_options.update(trace_options)
-
- # Attempt to import plotly when an instance is created
- self.go = PlotlyManager().go
+ self.plotter = call_plotting_function(
+ "Plotter",
+ backend,
+ title=title,
+ xaxis_title=xaxis_title,
+ yaxis_title=yaxis_title,
+ trace_options=trace_options,
+ **kwargs,
+ )
- # Create layout
- if self.layout is None:
- self.layout = self.go.Layout(**self.layout_options)
+ self.trace_name_width = trace_name_width
# Add traces
if x is not None and y is not None:
self.add_traces(x, y, trace_names)
def __call__(self, show=True):
- """
- Generate and show the figure.
+ return self.plotter(show=show)
- Parameters
- ----------
- show : bool, optional
- If True, the figure is shown upon creation (default: True).
- """
- fig = self.go.Figure(data=self.traces, layout=self.layout)
- if show:
- fig.show()
+ @property
+ def traces(self):
+ return self.plotter.traces
- return fig
-
- def add_traces(self, x, y, trace_names=None, **trace_options):
- """
- Add a set of traces to the plot dictionary.
-
- Parameters
- ----------
- x : list or np.ndarray
- X-axis data points.
- y : list or np.ndarray
- Primary Y-axis data points for simulated model output.
- trace_names : str or list[str], optional
- Name(s) for the primary trace(s) (default: None).
- """
- options = self.trace_options.copy()
- options.update(trace_options)
+ @traces.setter
+ def traces(self, value):
+ self.plotter.traces = value
+ def add_traces(self, x, y, trace_names):
# Check and wrap trace names
if trace_names is not None:
if isinstance(trace_names, str):
trace_names = [trace_names]
for i, name in enumerate(trace_names):
- trace_names[i] = self.wrap_text(name, width=self.trace_name_width)
+ trace_names[i] = self.wrap_text(
+ name, width=self.trace_name_width, backend=self.plotter.backend
+ )
# Parse the data
x, y = self.parse_data(x, y)
- # Create a trace for each trajectory
- xi = x[0]
- for i in range(0, len(y)):
- trace_options = options.copy()
- if len(x) > 1:
- xi = x[i]
- if trace_names is not None:
- trace_options["name"] = trace_names[i]
- else:
- trace_options["showlegend"] = False
- trace = self.create_trace(xi, y[i], **trace_options)
- self.traces.append(trace)
+ # Add traces
+ self.plotter.add_traces(x, y, trace_names)
def parse_data(self, x, y):
"""
@@ -191,24 +103,23 @@ def parse_data(self, x, y):
)
return x, y
- def create_trace(self, x, y, **trace_options):
- """
- Create a trace for the Plotly figure.
+ def create_trace(self, x=None, y=None, label=None, **trace_options):
+ return self.plotter.create_trace(x, y, label, **trace_options)
- Returns
- -------
- plotly.graph_objs.Scatter
- A trace for a Plotly figure.
- """
+ def create_fill_trace(self, x, y_upper, y_lower, **options):
+ return self.plotter.create_fill_trace(x, y_upper, y_lower, **options)
- return self.go.Scatter(
- x=x,
- y=y,
- **trace_options,
- )
+ def create_histogram(self, x, name, **trace_options):
+ return self.plotter.create_histogram(x, name, **trace_options)
+
+ def create_vline(self, fig, x, **trace_options):
+ return self.plotter.create_vline(fig, x, **trace_options)
+
+ def create_contour(self, x, y, z, **trace_options):
+ return self.plotter.create_contour(x, y, z, **trace_options)
@staticmethod
- def wrap_text(text, width):
+ def wrap_text(text, width, backend="matplotlib"):
"""
Wrap text to a specified width with HTML line breaks.
@@ -225,7 +136,10 @@ def wrap_text(text, width):
The wrapped text.
"""
wrapped_text = textwrap.fill(text, width=width, break_long_words=False)
- return wrapped_text.replace("\n", "
")
+ if backend == "plotly":
+ return wrapped_text.replace("\n", "
")
+ else:
+ return wrapped_text
@staticmethod
def remove_brackets(s):
@@ -280,20 +194,30 @@ def __init__(
self,
x,
y,
+ backend=None,
num_rows=None,
num_cols=None,
axis_titles=None,
- layout=None,
- layout_options=DEFAULT_LAYOUT_OPTIONS,
- subplot_options=DEFAULT_SUBPLOT_OPTIONS,
- trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS,
+ trace_options=None,
trace_names=None,
trace_name_width=40,
+ **kwargs,
):
- super().__init__(
- x, y, layout, layout_options, trace_options, trace_names, trace_name_width
+ self.plotter = call_plotting_function(
+ "SubplotPlotter",
+ backend,
+ axis_titles=axis_titles,
+ trace_options=trace_options,
+ **kwargs,
)
- self.num_traces = len(self.traces)
+
+ self.trace_name_width = trace_name_width
+
+ # Add traces
+ if x is not None and y is not None:
+ self.add_traces(x, y, trace_names)
+
+ self.num_traces = len(self.plotter.traces)
self.num_rows = num_rows
self.num_cols = num_cols
if self.num_rows is None and self.num_cols is None:
@@ -304,56 +228,12 @@ def __init__(
self.num_rows = int(math.ceil(self.num_traces / self.num_cols))
elif self.num_cols is None:
self.num_cols = int(math.ceil(self.num_traces / self.num_rows))
- self.axis_titles = axis_titles
- self.subplot_options = subplot_options.copy()
- if subplot_options is not None:
- for arg, value in subplot_options.items():
- self.subplot_options[arg] = value
-
- # Attempt to import plotly when an instance is created
- self.make_subplots = PlotlyManager().make_subplots
-
- def __call__(self, show):
- """
- Generate and show the set of figures.
-
- Parameters
- ----------
- show : bool, optional
- If True, the figure is shown upon creation (default: True).
- """
- fig = self.make_subplots(
- rows=self.num_rows,
- cols=self.num_cols,
- horizontal_spacing=0.1,
- vertical_spacing=0.15,
- **self.subplot_options,
- )
- fig.update_layout(self.layout_options)
-
- for idx, trace in enumerate(self.traces):
- row = (idx // self.num_cols) + 1
- col = (idx % self.num_cols) + 1
- fig.add_trace(trace, row=row, col=col)
-
- if self.axis_titles and idx < len(self.axis_titles):
- x_title, y_title = self.axis_titles[idx]
- fig.update_xaxes(title_text=x_title, row=row, col=col)
- fig.update_yaxes(
- title_text=y_title,
- row=row,
- col=col,
- showexponent="last",
- exponentformat="e",
- )
-
- if show:
- fig.show()
- return fig
+ def __call__(self, show=True):
+ return self.plotter(show=show, num_rows=self.num_rows, num_cols=self.num_cols)
-def trajectories(x, y, trace_names=None, show=True, **layout_kwargs):
+def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs):
"""
Quickly plot one or more trajectories using Plotly.
@@ -375,17 +255,13 @@ def trajectories(x, y, trace_names=None, show=True, **layout_kwargs):
plotly.graph_objs.Figure
The Plotly figure object for the scatter plot.
"""
- # Create a plot dictionary
- plot_dict = StandardPlot(
+
+ return call_plotting_function(
+ "trajectories",
+ backend,
x=x,
y=y,
trace_names=trace_names,
+ show=show,
+ **layout_kwargs,
)
-
- # Generate the figure and update the layout
- fig = plot_dict(show=False)
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
-
- return fig
diff --git a/pybop/plot/util.py b/pybop/plot/util.py
new file mode 100644
index 000000000..a30b0cee1
--- /dev/null
+++ b/pybop/plot/util.py
@@ -0,0 +1,57 @@
+import importlib.util
+
+import pybop.plot
+
+
+def set_backend(backend):
+ err_msg = (
+ f"Plotting backend {backend} is not available. The default backend has not been updated. \n"
+ f"The default backend is set to {pybop.plot.backend}"
+ )
+ try:
+ importlib.import_module("pybop.plot." + backend)
+ pybop.plot.backend = backend
+
+ except ModuleNotFoundError as error:
+ # Raise an ModuleNotFoundError if the module or attribute is not available
+ raise ModuleNotFoundError(err_msg) from error
+
+
+def call_plotting_function(function_name, backend, **kwargs):
+ if backend is None:
+ backend = pybop.plot.backend
+ err_msg = f"Plotting backend {backend} is not available."
+ try:
+ module = importlib.import_module("pybop.plot." + backend)
+ if hasattr(module, function_name):
+ plotting_function = getattr(module, function_name)
+ # Return the imported attribute
+ return plotting_function(**kwargs)
+ else:
+ err_msg = f"Plotting backend {backend} has no attribute {function_name}."
+ raise ModuleNotFoundError(err_msg)
+
+ except ModuleNotFoundError as error:
+ # Raise an ModuleNotFoundError if the module or attribute is not available
+ raise ModuleNotFoundError(err_msg) from error
+
+
+def update_and_show(fig, backend=None, **kwargs):
+ return call_plotting_function("update_and_show", backend, fig=fig, **kwargs)
+
+
+def get_default_options(plot_type, backend):
+ if backend is None:
+ backend = pybop.plot.backend
+
+ if backend == "plotly":
+ opts = pybop.plot.plotly.DEFAULT_PLOT_OPTIONS
+ elif backend == "matplotlib":
+ opts = pybop.plot.matplotlib.DEFAULT_PLOT_OPTIONS
+ else:
+ opts = {}
+
+ if plot_type in opts.keys():
+ return opts[plot_type]
+ else:
+ return {}
diff --git a/pybop/plot/voronoi.py b/pybop/plot/voronoi.py
index 29be3c096..2ac57fada 100644
--- a/pybop/plot/voronoi.py
+++ b/pybop/plot/voronoi.py
@@ -1,11 +1,10 @@
from typing import TYPE_CHECKING
import numpy as np
-from scipy.spatial import Voronoi, cKDTree
+from scipy.spatial import Voronoi
if TYPE_CHECKING:
- from pybop._result import Result
-from pybop.plot.plotly_manager import PlotlyManager
+ pass
def _voronoi_regions(x, y, f, xlim, ylim):
@@ -195,85 +194,7 @@ def interpolate_point(p, q, axis, boundary_val):
return np.array([boundary_val, s]) if axis == 0 else np.array([s, boundary_val])
-def assign_nearest_value(x, y, f, xi, yi):
- """
- Computes an array of values given by the score of the nearest point.
-
- Parameters
- ----------
- x : array-like
- The x coordinates of points with known scores.
- y : array-like
- The y coordinates of points with known scores.
- f : array-like
- The score function at the given x and y coordinates.
- xi : array-like
- The x coordinates of grid points.
- yi : array-like
- The y coordinates of grid points.
-
- Returns
- -------
- A numpy array containing the scores corresponding to the grid points.
- """
- # Create a KD-tree for efficient nearest neighbor search
- tree = cKDTree(np.column_stack((x, y)))
-
- # Find the nearest point for each grid point
- _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel())))
- zi = f[indices].reshape(xi.shape)
-
- return zi
-
-
-def surface(
- result: "Result",
- bounds=None,
- normalise=True,
- resolution=250,
- show=True,
- **layout_kwargs,
-):
- """
- Plot a 2D representation of the Voronoi diagram with color-coded regions.
-
- Parameters:
- -----------
- result : pybop.Result
- Optimisation result containing the history of parameter values and associated cost.
- bounds : numpy.ndarray, optional
- A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses
- `cost.parameters.get_bounds_for_plotly`.
- normalise : bool, optional
- If True, the voronoi regions are computed using the Euclidean distance between
- points normalised with respect to the bounds (default: True).
- resolution : int, optional
- Resolution of the plot. Default is 500.
- show : bool, optional
- If True, the figure is shown upon creation (default: True).
- **layout_kwargs : optional
- Valid Plotly layout keys and their values,
- e.g. `xaxis_title="Time [s]"` or
- `xaxis={"title": "Time [s]", font={"size":14}}`
- """
- points = result.x_model
- parameters = result.problem.parameters
-
- if points[0].shape[0] != 2:
- raise ValueError("This plot method requires two parameters.")
-
- x_optim, y_optim = map(list, zip(*points, strict=False))
- f = result.cost
-
- # Translate bounds, taking only the first two elements
- xlim, ylim = (
- bounds if bounds is not None else [param.bounds for param in parameters]
- )[:2]
-
- # Create a grid for plot
- xi = np.linspace(xlim[0], xlim[1], resolution)
- yi = np.linspace(ylim[0], ylim[1], resolution)
- xi, yi = np.meshgrid(xi, yi)
+def voronoi_data(xlim, ylim, pts_x, pts_y, f, normalise=True):
if normalise:
if xlim[1] <= xlim[0] or ylim[1] <= ylim[0]:
@@ -282,21 +203,14 @@ def surface(
# Normalise the region
x_range = xlim[1] - xlim[0]
y_range = ylim[1] - ylim[0]
- norm_x_optim = (np.asarray(x_optim) - xlim[0]) / x_range
- norm_y_optim = (np.asarray(y_optim) - ylim[0]) / y_range
+ norm_x_optim = (np.asarray(pts_x) - xlim[0]) / x_range
+ norm_y_optim = (np.asarray(pts_y) - ylim[0]) / y_range
# Compute regions
- norm_x, norm_y, f, norm_regions = _voronoi_regions(
+ x, y, f, norm_regions = _voronoi_regions(
norm_x_optim, norm_y_optim, f, (0, 1), (0, 1)
)
- # Create a normalised grid
- norm_xi = np.linspace(0, 1, resolution)
- norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi)
-
- # Assign a value to each point in the grid
- zi = assign_nearest_value(norm_x, norm_y, f, norm_xi, norm_yi)
-
# Rescale for plotting
regions = []
for norm_region in norm_regions:
@@ -307,10 +221,7 @@ def surface(
else:
# Compute regions
- x, y, f, regions = _voronoi_regions(x_optim, y_optim, f, xlim, ylim)
-
- # Assign a value to each point in the grid
- zi = assign_nearest_value(x, y, f, xi, yi)
+ x, y, f, regions = _voronoi_regions(pts_x, pts_y, f, xlim, ylim)
# Calculate the size of each Voronoi region
region_sizes = np.array([len(region) for region in regions])
@@ -318,107 +229,4 @@ def surface(
region_sizes.max() - region_sizes.min()
)
- # Construct figure
- go = PlotlyManager().go
- fig = go.Figure()
-
- # Heatmap
- fig.add_trace(
- go.Heatmap(
- x=xi[0],
- y=yi[:, 0],
- z=zi,
- colorscale="Viridis",
- zsmooth="best",
- )
- )
-
- # Add Voronoi edges
- for region, size in zip(regions, relative_sizes, strict=False):
- x_region = region[:, 0].tolist() + [region[0, 0]]
- y_region = region[:, 1].tolist() + [region[0, 1]]
-
- fig.add_trace(
- go.Scatter(
- x=x_region,
- y=y_region,
- mode="lines",
- line=dict(color="white", width=0.5 + size * 0.1),
- showlegend=False,
- )
- )
-
- # Add original points
- fig.add_trace(
- go.Scatter(
- x=x_optim,
- y=y_optim,
- mode="markers",
- marker=dict(
- color=[i / len(x_optim) for i in range(len(x_optim))],
- colorscale="Greys",
- size=8,
- showscale=False,
- ),
- text=[f"f={val:.2f}" for val in f],
- hoverinfo="text",
- showlegend=False,
- )
- )
-
- # Plot the initial guess
- if len(result.x_model) > 0:
- x0 = result.x_model[0]
- fig.add_trace(
- go.Scatter(
- x=[x0[0]],
- y=[x0[1]],
- mode="markers",
- marker_symbol="x",
- marker=dict(
- color="white",
- line_color="black",
- line_width=1,
- size=14,
- showscale=False,
- ),
- name="Initial values",
- )
- )
-
- # Plot optimised value
- if result.x is not None:
- x_best = result.x
- fig.add_trace(
- go.Scatter(
- x=[x_best[0]],
- y=[x_best[1]],
- mode="markers",
- marker_symbol="cross",
- marker=dict(
- color="black",
- line_color="white",
- line_width=1,
- size=14,
- showscale=False,
- ),
- name="Final values",
- )
- )
-
- names = parameters.names
- fig.update_layout(
- title="Voronoi Cost Landscape",
- title_x=0.5,
- title_y=0.905,
- xaxis_title=names[0],
- yaxis_title=names[1],
- width=600,
- height=600,
- xaxis=dict(range=xlim, showexponent="last", exponentformat="e"),
- yaxis=dict(range=ylim, showexponent="last", exponentformat="e"),
- legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1),
- )
- fig.update_layout(**layout_kwargs)
- if show:
- fig.show()
+ return x, y, f, regions, relative_sizes
diff --git a/tests/plotting/test_plotly_manager.py b/tests/plotting/test_plotly_manager.py
index f64e608f1..b39df37cd 100644
--- a/tests/plotting/test_plotly_manager.py
+++ b/tests/plotting/test_plotly_manager.py
@@ -8,7 +8,7 @@
import pytest
import pybop
-from pybop.plot import PlotlyManager
+from pybop.plot.plotly import PlotlyManager
# Find the Python executable
python_executable = which("python")
@@ -128,6 +128,9 @@ def dataset(plotly_installed):
@pytest.mark.unit
def test_standard_plot(dataset, plotly_installed):
+ # Set plotting backend
+ pybop.plot.set_backend("plotly")
+
# Check the StandardPlot class
pybop.plot.StandardPlot(dataset["Time [s]"], dataset["Voltage [V]"])
@@ -167,6 +170,9 @@ def test_standard_plot(dataset, plotly_installed):
@pytest.mark.unit
def test_plot_dataset(dataset, plotly_installed):
+ # Set plotting backend
+ pybop.plot.set_backend("plotly")
+
# Test plot of a dataset
pybop.plot.dataset(dataset, signal=["Voltage [V]"])
pybop.plot.dataset(dataset, signal=["Voltage [V]", "Current [A]"])
diff --git a/tests/unit/test_plots.py b/tests/unit/test_plots.py
index 58b707900..db9e0c6c1 100644
--- a/tests/unit/test_plots.py
+++ b/tests/unit/test_plots.py
@@ -6,6 +6,7 @@
import pybop
+@pytest.mark.parametrize("backend", ["plotly", "matplotlib"])
class TestPlots:
"""
A class to test the plot classes.
@@ -13,7 +14,8 @@ class TestPlots:
pytestmark = pytest.mark.unit
- def test_standard_plot(self):
+ def test_standard_plot(self, backend):
+ pybop.plot.set_backend(backend)
# Test standard plot
trace_names = pybop.plot.StandardPlot.remove_brackets(
["Trace [1]", "Trace [2]"]
@@ -69,7 +71,8 @@ def dataset(self, model):
solution = pybamm.Simulation(model).solve(t_eval=t_eval, t_interp=t_eval)
return pybop.import_pybamm_solution(solution)
- def test_dataset_plots(self, dataset):
+ def test_dataset_plots(self, dataset, backend):
+ pybop.plot.set_backend(backend)
# Test plot of Dataset objects
pybop.plot.trajectories(
dataset["Time [s]"],
@@ -112,7 +115,8 @@ def design_problem(self, model, parameters, experiment):
)
return pybop.Problem(simulator)
- def test_problem_plots(self, fitting_problem, design_problem):
+ def test_problem_plots(self, fitting_problem, design_problem, backend):
+ pybop.plot.set_backend(backend)
# Test plot of Problem objects
pybop.plot.problem(fitting_problem, title="Optimised Comparison")
pybop.plot.problem(design_problem)
@@ -122,7 +126,8 @@ def test_problem_plots(self, fitting_problem, design_problem):
fitting_problem, inputs=fitting_problem.parameters.to_dict([0.6, 0.6])
)
- def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds):
+ def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds, backend):
+ pybop.plot.set_backend(backend)
# Test plot of Cost objects
pybop.plot.contour(fitting_problem, gradient=True, steps=5)
@@ -143,7 +148,8 @@ def result(self, fitting_problem):
optim = pybop.XNES(fitting_problem)
return optim.run()
- def test_optim_plots(self, result):
+ def test_optim_plots(self, result, backend):
+ pybop.plot.set_backend(backend)
bounds = np.asarray([[0.5, 0.8], [0.4, 0.7]])
# Plot convergence
@@ -185,7 +191,8 @@ def sampling_result(self, model, parameters, dataset):
sampler = pybop.SliceStepoutMCMC(log_pdf, options=options)
return sampler.run()
- def test_posterior_plots(self, sampling_result):
+ def test_posterior_plots(self, sampling_result, backend):
+ pybop.plot.set_backend(backend)
sampling_result.get_summary_statistics()
# Plot trace
@@ -200,9 +207,11 @@ def test_posterior_plots(self, sampling_result):
# Plot summary table
sampling_result.summary_table()
- def test_with_ipykernel(self, dataset, fitting_problem, result):
+ def test_with_ipykernel(self, dataset, fitting_problem, result, backend):
import ipykernel
+ pybop.plot.set_backend(backend)
+
assert version.parse(ipykernel.__version__) >= version.parse("0.6")
pybop.plot.dataset(dataset, signal=["Voltage [V]"])
pybop.plot.contour(fitting_problem, gradient=True, steps=5)
@@ -210,7 +219,8 @@ def test_with_ipykernel(self, dataset, fitting_problem, result):
result.plot_parameters()
result.plot_contour(steps=5)
- def test_contour_incorrect_number_of_parameters(self, model, dataset):
+ def test_contour_incorrect_number_of_parameters(self, model, dataset, backend):
+ pybop.plot.set_backend(backend)
parameter_values = model.default_parameter_values
# Test with less than two paramters
@@ -251,7 +261,8 @@ def test_contour_incorrect_number_of_parameters(self, model, dataset):
fitting_problem = pybop.Problem(simulator, cost)
pybop.plot.contour(fitting_problem)
- def test_nyquist(self):
+ def test_nyquist(self, backend):
+ pybop.plot.set_backend(backend)
# Define model
model = pybamm.lithium_ion.SPM(options={"surface form": "differential"})
parameter_values = model.default_parameter_values