diff --git a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb index 1eb2a59c5..069d68660 100644 --- a/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb +++ b/examples/notebooks/battery_parameterisation/echem_identification_pitfalls.ipynb @@ -32,8 +32,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb index dce5640a7..23433be03 100644 --- a/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_monte_carlo_sampling.ipynb @@ -46,7 +46,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"matplotlib\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] @@ -1103,7 +1104,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "env-py-3-13", "language": "python", "name": "python3" }, @@ -1117,7 +1118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.11" } }, "nbformat": 4, diff --git a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb index 65f4623b6..8140c73c1 100644 --- a/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_multipulse_identification.ipynb @@ -48,7 +48,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb index ce327fe6c..70915b6e9 100644 --- a/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb +++ b/examples/notebooks/battery_parameterisation/ecm_scipy_constraints.ipynb @@ -31,7 +31,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb index adc764f25..94b981c61 100644 --- a/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb +++ b/examples/notebooks/battery_parameterisation/electrode_balancing.ipynb @@ -29,7 +29,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb index 7cd0c65b0..d722d52af 100644 --- a/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb +++ b/examples/notebooks/battery_parameterisation/lgm50_pulse_validation.ipynb @@ -32,8 +32,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb index ce32e337b..1a73079e4 100644 --- a/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb +++ b/examples/notebooks/battery_parameterisation/pouch_cell_identification.ipynb @@ -30,8 +30,9 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb index 3cab87a0d..d5a7ecd4f 100644 --- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb +++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_hessian.ipynb @@ -29,6 +29,7 @@ "\n", "import pybop\n", "\n", + "pybop.plot.set_backend(\"matplotlib\")\n", "np.random.seed(8) # users can remove this line" ] }, @@ -210,708 +211,708 @@ "showlegend": false, "type": "scatter", "x": [ - 0.0, - 10.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - 70.0, - 80.0, - 90.0, - 100.0, - 110.0, - 120.0, - 130.0, - 140.0, - 150.0, - 160.0, - 170.0, - 180.0, - 190.0, - 200.0, - 210.0, - 220.0, - 230.0, - 240.0, - 250.0, - 260.0, - 270.0, - 280.0, - 290.0, - 300.0, - 310.0, - 320.0, - 330.0, - 340.0, - 350.0, - 360.0, - 370.0, - 380.0, - 390.0, - 400.0, - 410.0, - 420.0, - 430.0, - 440.0, - 450.0, - 460.0, - 470.0, - 480.0, - 490.0, - 500.0, - 510.0, - 520.0, - 530.0, - 540.0, - 550.0, - 560.0, - 570.0, - 580.0, - 590.0, - 600.0, - 610.0, - 620.0, - 630.0, - 640.0, - 650.0, - 660.0, - 670.0, - 680.0, - 690.0, - 700.0, - 710.0, - 720.0, - 730.0, - 740.0, - 750.0, - 760.0, - 770.0, - 780.0, - 790.0, - 800.0, - 810.0, - 820.0, - 830.0, - 840.0, - 850.0, - 860.0, - 870.0, - 880.0, - 890.0, - 900.0, - 910.0, - 920.0, - 930.0, - 940.0, - 950.0, - 960.0, - 970.0, - 980.0, - 990.0, - 1000.0, - 1010.0, - 1020.0, - 1030.0, - 1040.0, - 1050.0, - 1060.0, - 1070.0, - 1080.0, - 1090.0, - 1100.0, - 1110.0, - 1120.0, - 1130.0, - 1140.0, - 1150.0, - 1160.0, - 1170.0, - 1180.0, - 1190.0, - 1200.0, - 1210.0, - 1220.0, - 1230.0, - 1240.0, - 1250.0, - 1260.0, - 1270.0, - 1280.0, - 1290.0, - 1300.0, - 1310.0, - 1320.0, - 1330.0, - 1340.0, - 1350.0, - 1360.0, - 1370.0, - 1380.0, - 1390.0, - 1400.0, - 1410.0, - 1420.0, - 1430.0, - 1440.0, - 1450.0, - 1460.0, - 1470.0, - 1480.0, - 1490.0, - 1500.0, - 1510.0, - 1520.0, - 1530.0, - 1540.0, - 1550.0, - 1560.0, - 1570.0, - 1580.0, - 1590.0, - 1600.0, - 1610.0, - 1620.0, - 1630.0, - 1640.0, - 1650.0, - 1660.0, - 1670.0, - 1680.0, - 1690.0, - 1700.0, - 1710.0, - 1720.0, - 1730.0, - 1740.0, - 1750.0, - 1760.0, - 1770.0, - 1780.0, - 1790.0, - 1800.0, - 1810.0, - 1820.0, - 1830.0, - 1840.0, - 1850.0, - 1860.0, - 1870.0, - 1880.0, - 1890.0, - 1900.0, - 1910.0, - 1920.0, - 1930.0, - 1940.0, - 1950.0, - 1960.0, - 1970.0, - 1980.0, - 1990.0, - 2000.0, - 2010.0, - 2020.0, - 2030.0, - 2040.0, - 2050.0, - 2060.0, - 2070.0, - 2080.0, - 2090.0, - 2100.0, - 2110.0, - 2120.0, - 2130.0, - 2140.0, - 2150.0, - 2160.0, - 2170.0, - 2180.0, - 2190.0, - 2200.0, - 2210.0, - 2220.0, - 2230.0, - 2240.0, - 2250.0, - 2260.0, - 2270.0, - 2280.0, - 2290.0, - 2300.0, - 2310.0, - 2320.0, - 2330.0, - 2340.0, - 2350.0, - 2360.0, - 2370.0, - 2380.0, - 2390.0, - 2400.0, - 2410.0, - 2420.0, - 2430.0, - 2440.0, - 2450.0, - 2460.0, - 2470.0, - 2480.0, - 2490.0, - 2500.0, - 2510.0, - 2520.0, - 2530.0, - 2540.0, - 2550.0, - 2560.0, - 2570.0, - 2580.0, - 2590.0, - 2600.0, - 2610.0, - 2620.0, - 2630.0, - 2640.0, - 2650.0, - 2660.0, - 2670.0, - 2680.0, - 2690.0, - 2700.0, - 2710.0, - 2720.0, - 2730.0, - 2740.0, - 2750.0, - 2760.0, - 2770.0, - 2780.0, - 2790.0, - 2800.0, - 2810.0, - 2820.0, - 2830.0, - 2840.0, - 2850.0, - 2860.0, - 2870.0, - 2880.0, - 2890.0, - 2900.0, - 2910.0, - 2920.0, - 2930.0, - 2940.0, - 2950.0, - 2960.0, - 2970.0, - 2980.0, - 2990.0, - 3000.0, - 3010.0, - 3020.0, - 3030.0, - 3040.0, - 3050.0, - 3060.0, - 3070.0, - 3080.0, - 3090.0, - 3100.0, - 3110.0, - 3120.0, - 3130.0, - 3140.0, - 3150.0, - 3160.0, - 3170.0, - 3180.0, - 3190.0, - 3200.0, - 3210.0, - 3220.0, - 3230.0, - 3240.0, - 3250.0, - 3260.0, - 3270.0, - 3280.0, - 3290.0, - 3300.0, - 3310.0, - 3320.0, - 3330.0, - 3340.0, - 3350.0, - 3360.0, - 3370.0, - 3380.0, - 3390.0, - 3400.0, - 3410.0, - 3420.0, - 3430.0, - 3440.0, - 3450.0, - 3460.0, - 3470.0, - 3480.0, - 3490.0, - 3500.0, - 3500.0, - 3490.0, - 3480.0, - 3470.0, - 3460.0, - 3450.0, - 3440.0, - 3430.0, - 3420.0, - 3410.0, - 3400.0, - 3390.0, - 3380.0, - 3370.0, - 3360.0, - 3350.0, - 3340.0, - 3330.0, - 3320.0, - 3310.0, - 3300.0, - 3290.0, - 3280.0, - 3270.0, - 3260.0, - 3250.0, - 3240.0, - 3230.0, - 3220.0, - 3210.0, - 3200.0, - 3190.0, - 3180.0, - 3170.0, - 3160.0, - 3150.0, - 3140.0, - 3130.0, - 3120.0, - 3110.0, - 3100.0, - 3090.0, - 3080.0, - 3070.0, - 3060.0, - 3050.0, - 3040.0, - 3030.0, - 3020.0, - 3010.0, - 3000.0, - 2990.0, - 2980.0, - 2970.0, - 2960.0, - 2950.0, - 2940.0, - 2930.0, - 2920.0, - 2910.0, - 2900.0, - 2890.0, - 2880.0, - 2870.0, - 2860.0, - 2850.0, - 2840.0, - 2830.0, - 2820.0, - 2810.0, - 2800.0, - 2790.0, - 2780.0, - 2770.0, - 2760.0, - 2750.0, - 2740.0, - 2730.0, - 2720.0, - 2710.0, - 2700.0, - 2690.0, - 2680.0, - 2670.0, - 2660.0, - 2650.0, - 2640.0, - 2630.0, - 2620.0, - 2610.0, - 2600.0, - 2590.0, - 2580.0, - 2570.0, - 2560.0, - 2550.0, - 2540.0, - 2530.0, - 2520.0, - 2510.0, - 2500.0, - 2490.0, - 2480.0, - 2470.0, - 2460.0, - 2450.0, - 2440.0, - 2430.0, - 2420.0, - 2410.0, - 2400.0, - 2390.0, - 2380.0, - 2370.0, - 2360.0, - 2350.0, - 2340.0, - 2330.0, - 2320.0, - 2310.0, - 2300.0, - 2290.0, - 2280.0, - 2270.0, - 2260.0, - 2250.0, - 2240.0, - 2230.0, - 2220.0, - 2210.0, - 2200.0, - 2190.0, - 2180.0, - 2170.0, - 2160.0, - 2150.0, - 2140.0, - 2130.0, - 2120.0, - 2110.0, - 2100.0, - 2090.0, - 2080.0, - 2070.0, - 2060.0, - 2050.0, - 2040.0, - 2030.0, - 2020.0, - 2010.0, - 2000.0, - 1990.0, - 1980.0, - 1970.0, - 1960.0, - 1950.0, - 1940.0, - 1930.0, - 1920.0, - 1910.0, - 1900.0, - 1890.0, - 1880.0, - 1870.0, - 1860.0, - 1850.0, - 1840.0, - 1830.0, - 1820.0, - 1810.0, - 1800.0, - 1790.0, - 1780.0, - 1770.0, - 1760.0, - 1750.0, - 1740.0, - 1730.0, - 1720.0, - 1710.0, - 1700.0, - 1690.0, - 1680.0, - 1670.0, - 1660.0, - 1650.0, - 1640.0, - 1630.0, - 1620.0, - 1610.0, - 1600.0, - 1590.0, - 1580.0, - 1570.0, - 1560.0, - 1550.0, - 1540.0, - 1530.0, - 1520.0, - 1510.0, - 1500.0, - 1490.0, - 1480.0, - 1470.0, - 1460.0, - 1450.0, - 1440.0, - 1430.0, - 1420.0, - 1410.0, - 1400.0, - 1390.0, - 1380.0, - 1370.0, - 1360.0, - 1350.0, - 1340.0, - 1330.0, - 1320.0, - 1310.0, - 1300.0, - 1290.0, - 1280.0, - 1270.0, - 1260.0, - 1250.0, - 1240.0, - 1230.0, - 1220.0, - 1210.0, - 1200.0, - 1190.0, - 1180.0, - 1170.0, - 1160.0, - 1150.0, - 1140.0, - 1130.0, - 1120.0, - 1110.0, - 1100.0, - 1090.0, - 1080.0, - 1070.0, - 1060.0, - 1050.0, - 1040.0, - 1030.0, - 1020.0, - 1010.0, - 1000.0, - 990.0, - 980.0, - 970.0, - 960.0, - 950.0, - 940.0, - 930.0, - 920.0, - 910.0, - 900.0, - 890.0, - 880.0, - 870.0, - 860.0, - 850.0, - 840.0, - 830.0, - 820.0, - 810.0, - 800.0, - 790.0, - 780.0, - 770.0, - 760.0, - 750.0, - 740.0, - 730.0, - 720.0, - 710.0, - 700.0, - 690.0, - 680.0, - 670.0, - 660.0, - 650.0, - 640.0, - 630.0, - 620.0, - 610.0, - 600.0, - 590.0, - 580.0, - 570.0, - 560.0, - 550.0, - 540.0, - 530.0, - 520.0, - 510.0, - 500.0, - 490.0, - 480.0, - 470.0, - 460.0, - 450.0, - 440.0, - 430.0, - 420.0, - 410.0, - 400.0, - 390.0, - 380.0, - 370.0, - 360.0, - 350.0, - 340.0, - 330.0, - 320.0, - 310.0, - 300.0, - 290.0, - 280.0, - 270.0, - 260.0, - 250.0, - 240.0, - 230.0, - 220.0, - 210.0, - 200.0, - 190.0, - 180.0, - 170.0, - 160.0, - 150.0, - 140.0, - 130.0, - 120.0, - 110.0, - 100.0, - 90.0, - 80.0, - 70.0, - 60.0, - 50.0, - 40.0, - 30.0, - 20.0, - 10.0, - 0.0 + 0, + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500, + 510, + 520, + 530, + 540, + 550, + 560, + 570, + 580, + 590, + 600, + 610, + 620, + 630, + 640, + 650, + 660, + 670, + 680, + 690, + 700, + 710, + 720, + 730, + 740, + 750, + 760, + 770, + 780, + 790, + 800, + 810, + 820, + 830, + 840, + 850, + 860, + 870, + 880, + 890, + 900, + 910, + 920, + 930, + 940, + 950, + 960, + 970, + 980, + 990, + 1000, + 1010, + 1020, + 1030, + 1040, + 1050, + 1060, + 1070, + 1080, + 1090, + 1100, + 1110, + 1120, + 1130, + 1140, + 1150, + 1160, + 1170, + 1180, + 1190, + 1200, + 1210, + 1220, + 1230, + 1240, + 1250, + 1260, + 1270, + 1280, + 1290, + 1300, + 1310, + 1320, + 1330, + 1340, + 1350, + 1360, + 1370, + 1380, + 1390, + 1400, + 1410, + 1420, + 1430, + 1440, + 1450, + 1460, + 1470, + 1480, + 1490, + 1500, + 1510, + 1520, + 1530, + 1540, + 1550, + 1560, + 1570, + 1580, + 1590, + 1600, + 1610, + 1620, + 1630, + 1640, + 1650, + 1660, + 1670, + 1680, + 1690, + 1700, + 1710, + 1720, + 1730, + 1740, + 1750, + 1760, + 1770, + 1780, + 1790, + 1800, + 1810, + 1820, + 1830, + 1840, + 1850, + 1860, + 1870, + 1880, + 1890, + 1900, + 1910, + 1920, + 1930, + 1940, + 1950, + 1960, + 1970, + 1980, + 1990, + 2000, + 2010, + 2020, + 2030, + 2040, + 2050, + 2060, + 2070, + 2080, + 2090, + 2100, + 2110, + 2120, + 2130, + 2140, + 2150, + 2160, + 2170, + 2180, + 2190, + 2200, + 2210, + 2220, + 2230, + 2240, + 2250, + 2260, + 2270, + 2280, + 2290, + 2300, + 2310, + 2320, + 2330, + 2340, + 2350, + 2360, + 2370, + 2380, + 2390, + 2400, + 2410, + 2420, + 2430, + 2440, + 2450, + 2460, + 2470, + 2480, + 2490, + 2500, + 2510, + 2520, + 2530, + 2540, + 2550, + 2560, + 2570, + 2580, + 2590, + 2600, + 2610, + 2620, + 2630, + 2640, + 2650, + 2660, + 2670, + 2680, + 2690, + 2700, + 2710, + 2720, + 2730, + 2740, + 2750, + 2760, + 2770, + 2780, + 2790, + 2800, + 2810, + 2820, + 2830, + 2840, + 2850, + 2860, + 2870, + 2880, + 2890, + 2900, + 2910, + 2920, + 2930, + 2940, + 2950, + 2960, + 2970, + 2980, + 2990, + 3000, + 3010, + 3020, + 3030, + 3040, + 3050, + 3060, + 3070, + 3080, + 3090, + 3100, + 3110, + 3120, + 3130, + 3140, + 3150, + 3160, + 3170, + 3180, + 3190, + 3200, + 3210, + 3220, + 3230, + 3240, + 3250, + 3260, + 3270, + 3280, + 3290, + 3300, + 3310, + 3320, + 3330, + 3340, + 3350, + 3360, + 3370, + 3380, + 3390, + 3400, + 3410, + 3420, + 3430, + 3440, + 3450, + 3460, + 3470, + 3480, + 3490, + 3500, + 3500, + 3490, + 3480, + 3470, + 3460, + 3450, + 3440, + 3430, + 3420, + 3410, + 3400, + 3390, + 3380, + 3370, + 3360, + 3350, + 3340, + 3330, + 3320, + 3310, + 3300, + 3290, + 3280, + 3270, + 3260, + 3250, + 3240, + 3230, + 3220, + 3210, + 3200, + 3190, + 3180, + 3170, + 3160, + 3150, + 3140, + 3130, + 3120, + 3110, + 3100, + 3090, + 3080, + 3070, + 3060, + 3050, + 3040, + 3030, + 3020, + 3010, + 3000, + 2990, + 2980, + 2970, + 2960, + 2950, + 2940, + 2930, + 2920, + 2910, + 2900, + 2890, + 2880, + 2870, + 2860, + 2850, + 2840, + 2830, + 2820, + 2810, + 2800, + 2790, + 2780, + 2770, + 2760, + 2750, + 2740, + 2730, + 2720, + 2710, + 2700, + 2690, + 2680, + 2670, + 2660, + 2650, + 2640, + 2630, + 2620, + 2610, + 2600, + 2590, + 2580, + 2570, + 2560, + 2550, + 2540, + 2530, + 2520, + 2510, + 2500, + 2490, + 2480, + 2470, + 2460, + 2450, + 2440, + 2430, + 2420, + 2410, + 2400, + 2390, + 2380, + 2370, + 2360, + 2350, + 2340, + 2330, + 2320, + 2310, + 2300, + 2290, + 2280, + 2270, + 2260, + 2250, + 2240, + 2230, + 2220, + 2210, + 2200, + 2190, + 2180, + 2170, + 2160, + 2150, + 2140, + 2130, + 2120, + 2110, + 2100, + 2090, + 2080, + 2070, + 2060, + 2050, + 2040, + 2030, + 2020, + 2010, + 2000, + 1990, + 1980, + 1970, + 1960, + 1950, + 1940, + 1930, + 1920, + 1910, + 1900, + 1890, + 1880, + 1870, + 1860, + 1850, + 1840, + 1830, + 1820, + 1810, + 1800, + 1790, + 1780, + 1770, + 1760, + 1750, + 1740, + 1730, + 1720, + 1710, + 1700, + 1690, + 1680, + 1670, + 1660, + 1650, + 1640, + 1630, + 1620, + 1610, + 1600, + 1590, + 1580, + 1570, + 1560, + 1550, + 1540, + 1530, + 1520, + 1510, + 1500, + 1490, + 1480, + 1470, + 1460, + 1450, + 1440, + 1430, + 1420, + 1410, + 1400, + 1390, + 1380, + 1370, + 1360, + 1350, + 1340, + 1330, + 1320, + 1310, + 1300, + 1290, + 1280, + 1270, + 1260, + 1250, + 1240, + 1230, + 1220, + 1210, + 1200, + 1190, + 1180, + 1170, + 1160, + 1150, + 1140, + 1130, + 1120, + 1110, + 1100, + 1090, + 1080, + 1070, + 1060, + 1050, + 1040, + 1030, + 1020, + 1010, + 1000, + 990, + 980, + 970, + 960, + 950, + 940, + 930, + 920, + 910, + 900, + 890, + 880, + 870, + 860, + 850, + 840, + 830, + 820, + 810, + 800, + 790, + 780, + 770, + 760, + 750, + 740, + 730, + 720, + 710, + 700, + 690, + 680, + 670, + 660, + 650, + 640, + 630, + 620, + 610, + 600, + 590, + 580, + 570, + 560, + 550, + 540, + 530, + 520, + 510, + 500, + 490, + 480, + 470, + 460, + 450, + 440, + 430, + 420, + 410, + 400, + 390, + 380, + 370, + 360, + 350, + 340, + 330, + 320, + 310, + 300, + 290, + 280, + 270, + 260, + 250, + 240, + 230, + 220, + 210, + 200, + 190, + 180, + 170, + 160, + 150, + 140, + 130, + 120, + 110, + 100, + 90, + 80, + 70, + 60, + 50, + 40, + 30, + 20, + 10, + 0 ], "y": [ 4.0531258862154065, @@ -1744,7 +1745,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1780,7 +1781,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1804,7 +1805,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1840,7 +1841,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1867,7 +1868,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1903,7 +1904,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -1918,7 +1919,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -1954,7 +1955,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -2110,7 +2111,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2146,7 +2147,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -2237,7 +2238,7 @@ ], "sequential": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2273,13 +2274,13 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], "sequentialminus": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -2315,7 +2316,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ] @@ -2503,7 +2504,7 @@ { "colorscale": [ [ - 0.0, + 0, "#440154" ], [ @@ -2539,7 +2540,7 @@ "#b5de2b" ], [ - 1.0, + 1, "#fde725" ] ], @@ -15963,7 +15964,7 @@ "hoverinfo": "text", "marker": { "color": [ - 0.0, + 0, 0.001968503937007874, 0.003937007874015748, 0.005905511811023622, @@ -16474,7 +16475,7 @@ ], "colorscale": [ [ - 0.0, + 0, "rgb(255,255,255)" ], [ @@ -16506,7 +16507,7 @@ "rgb(37,37,37)" ], [ - 1.0, + 1, "rgb(0,0,0)" ] ], @@ -18165,7 +18166,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18201,7 +18202,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18225,7 +18226,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18261,7 +18262,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18288,7 +18289,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18324,7 +18325,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18339,7 +18340,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18375,7 +18376,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18531,7 +18532,7 @@ }, "colorscale": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18567,7 +18568,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], @@ -18658,7 +18659,7 @@ ], "sequential": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18694,13 +18695,13 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ], "sequentialminus": [ [ - 0.0, + 0, "#0d0887" ], [ @@ -18736,7 +18737,7 @@ "#fdca26" ], [ - 1.0, + 1, "#f0f921" ] ] diff --git a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb index e7a1a308c..4324ce855 100644 --- a/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb +++ b/examples/notebooks/battery_parameterisation/sensitivity_analysis_salib.ipynb @@ -58,7 +58,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"matplotlib\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb index 71dc80155..b6a4251bd 100644 --- a/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb +++ b/examples/notebooks/comparison_examples/comparing_cost_functions.ipynb @@ -26,8 +26,8 @@ "\n", "import pybop\n", "\n", - "go = pybop.plot.PlotlyManager().go\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "go = pybop.plot.plotly.PlotlyManager().go\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb index 9ee9e69d5..c23fdf369 100644 --- a/examples/notebooks/comparison_examples/optimiser_calibration.ipynb +++ b/examples/notebooks/comparison_examples/optimiser_calibration.ipynb @@ -32,7 +32,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb index 6250b85e9..e4541c8d8 100644 --- a/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb +++ b/examples/notebooks/design_optimisation/energy_based_electrode_design.ipynb @@ -35,7 +35,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/cost_compute_methods.ipynb b/examples/notebooks/getting_started/cost_compute_methods.ipynb index 3a424f38c..909bcddd5 100644 --- a/examples/notebooks/getting_started/cost_compute_methods.ipynb +++ b/examples/notebooks/getting_started/cost_compute_methods.ipynb @@ -28,7 +28,7 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb index ccd95b2ea..7da706f7f 100644 --- a/examples/notebooks/getting_started/maximum_a_posteriori.ipynb +++ b/examples/notebooks/getting_started/maximum_a_posteriori.ipynb @@ -51,7 +51,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/optimising_with_adamw.ipynb b/examples/notebooks/getting_started/optimising_with_adamw.ipynb index d6fc110af..ab708410e 100644 --- a/examples/notebooks/getting_started/optimising_with_adamw.ipynb +++ b/examples/notebooks/getting_started/optimising_with_adamw.ipynb @@ -34,7 +34,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/setting_optimiser_options.ipynb b/examples/notebooks/getting_started/setting_optimiser_options.ipynb index a790b26b2..65760ef97 100644 --- a/examples/notebooks/getting_started/setting_optimiser_options.ipynb +++ b/examples/notebooks/getting_started/setting_optimiser_options.ipynb @@ -32,7 +32,7 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/notebooks/getting_started/using_transformations.ipynb b/examples/notebooks/getting_started/using_transformations.ipynb index 236dc8bba..5d5d9a47f 100644 --- a/examples/notebooks/getting_started/using_transformations.ipynb +++ b/examples/notebooks/getting_started/using_transformations.ipynb @@ -29,7 +29,8 @@ "\n", "import pybop\n", "\n", - "pybop.plot.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", + "pybop.plot.set_backend(\"plotly\")\n", + "pybop.plot.plotly.PlotlyManager().pio.renderers.default = \"notebook_connected\"\n", "\n", "np.random.seed(8) # users can remove this line" ] diff --git a/examples/scripts/comparison_examples/grouped_SPMe.py b/examples/scripts/comparison_examples/grouped_SPMe.py index de4c897ab..cb1d060e2 100644 --- a/examples/scripts/comparison_examples/grouped_SPMe.py +++ b/examples/scripts/comparison_examples/grouped_SPMe.py @@ -11,11 +11,10 @@ """ # Prepare figure -layout_options = dict( - xaxis_title="Time / s", - yaxis_title="Voltage / V", -) -plot_dict = pybop.plot.StandardPlot(layout_options=layout_options) +pybop.plot.set_backend("matplotlib") +plot_dict = pybop.plot.StandardPlot() +plt.xlabel("Time / s") +plt.ylabel("Voltage / V") # Use the Chen2020 parameters parameter_values = pybamm.ParameterValues("Chen2020") @@ -46,18 +45,18 @@ ) SPMe_model = pybamm.lithium_ion.SPMe(options=model_options) grouped_SPMe_model = pybop.lithium_ion.GroupedSPMe(options=model_options) -for model, param, line_style in zip( +for model, param, linestyle in zip( [SPMe_model, grouped_SPMe_model], [parameter_values, grouped_parameter_values], - ["solid", "dash"], + ["-", "--"], strict=False, ): solution = pybamm.Simulation( model, parameter_values=param, experiment=experiment, cache_esoh=False ).solve(initial_soc=init_soc) dataset = pybop.import_pybamm_solution(solution) - plot_dict.add_traces( - dataset["Time [s]"], dataset["Voltage [V]"], line_dash=line_style + plot_dict.create_trace( + dataset["Time [s]"], dataset["Voltage [V]"], label=None, linestyle=linestyle ) plot_dict() diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index f58db2016..da5e84d44 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -1,13 +1,24 @@ +# Plotting backend default +DEFAULT_BACKEND = 'matplotlib' +backend=DEFAULT_BACKEND + +from .util import set_backend, call_plotting_function, get_default_options + # # Import plots # -from .plotly_manager import PlotlyManager +from .plots import ( + surface + ) + from .standard_plots import StandardPlot, StandardSubplot, trajectories from .contour import contour -from .dataset import dataset from .convergence import convergence +from .dataset import dataset +from .nyquist import nyquist from .parameters import parameters from .problem import problem -from .nyquist import nyquist -from .voronoi import surface -from .samples import trace, chains, posterior, summary_table +from .samples import chains, posterior, summary_table, trace +from .voronoi import voronoi_data, _voronoi_regions +from . import matplotlib +from . import plotly diff --git a/pybop/plot/contour.py b/pybop/plot/contour.py index da2d6a1a6..cdd5b8b59 100644 --- a/pybop/plot/contour.py +++ b/pybop/plot/contour.py @@ -4,7 +4,8 @@ import numpy as np -from pybop.plot.plotly_manager import PlotlyManager +from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import call_plotting_function, get_default_options, update_and_show from pybop.problems.problem import Problem if TYPE_CHECKING: @@ -18,7 +19,7 @@ def contour( transformed: bool = False, steps: int = 10, show: bool = True, - **layout_kwargs, + backend=None, ): """ Plot a 2D visualisation of a cost landscape using Plotly. @@ -143,121 +144,86 @@ def transform_array_of_values(list_of_values, parameter): bounds[0] = transform_array_of_values(bounds[0], parameters[names[0]]) bounds[1] = transform_array_of_values(bounds[1], parameters[names[1]]) - # Import plotly only when needed - go = PlotlyManager().go - - # Set default layout properties - layout_options = dict( - title="Cost Landscape", - title_x=0.5, - title_y=0.905, - width=600, - height=600, - xaxis=dict(range=bounds[0], showexponent="last", exponentformat="e"), - yaxis=dict(range=bounds[1], showexponent="last", exponentformat="e"), - legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), + # Get options + options = get_default_options("contour", backend) + plot_options = options.get("plot_options") or {} + trace_options_initial = options.get("trace_options_initial") or {} + trace_options_optim = options.get("trace_options_optim") or {} + trace_options_contour = options.get("trace_options_contour") or {} + + plot_dict = StandardPlot( + xaxis_title="Transformed " + names[0] if transformed else names[0], + yaxis_title="Transformed " + names[1] if transformed else names[1], + xaxis_range=bounds[0], + yaxis_range=bounds[1], + backend=backend, + **plot_options, ) - layout_options["xaxis_title"] = ( - "Transformed " + names[0] if transformed else names[0] - ) - layout_options["yaxis_title"] = ( - "Transformed " + names[1] if transformed else names[1] - ) - layout = go.Layout(layout_options) # Create contour plot and update the layout - fig = go.Figure( - data=[go.Contour(x=x, y=y, z=costs, colorscale="Viridis", connectgaps=True)], - layout=layout, - ) + plot_dict.create_contour(x=x, y=y, z=costs, **trace_options_contour) if plot_optim: # Plot the optimisation trace optim_trace = np.asarray([item[:2] for item in result.x_model]) optim_trace = optim_trace.reshape(-1, 2) - - fig.add_trace( - go.Scatter( - x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), - y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), - mode="markers", - marker=dict( - color=[i / len(optim_trace) for i in range(len(optim_trace))], - colorscale="Greys", - size=8, - showscale=False, - ), - showlegend=False, - ) + call_plotting_function( + "plot_optimisation_path", + backend=backend, + plot_dict=plot_dict, + x=transform_array_of_values(optim_trace[:, 0], parameters[names[0]]), + y=transform_array_of_values(optim_trace[:, 1], parameters[names[1]]), ) # Plot the initial guess if len(result.x_model) > 0: x0 = result.x_model[0] - fig.add_trace( - go.Scatter( + plot_dict.traces.append( + plot_dict.create_trace( x=transform_array_of_values([x0[0]], parameters[names[0]]), y=transform_array_of_values([x0[1]], parameters[names[1]]), - mode="markers", - marker_symbol="x", - marker=dict( - color="white", - line_color="black", - line_width=1, - size=14, - showscale=False, - ), - name="Initial values", + **trace_options_initial, ) ) # Plot optimised value if result.x is not None: x_best = result.x - fig.add_trace( - go.Scatter( + plot_dict.traces.append( + plot_dict.create_trace( x=transform_array_of_values([x_best[0]], parameters[names[0]]), y=transform_array_of_values([x_best[1]], parameters[names[1]]), - mode="markers", - marker_symbol="cross", - marker=dict( - color="black", - line_color="white", - line_width=1, - size=14, - showscale=False, - ), - name="Final values", + **trace_options_optim, ) ) # Update the layout and display the figure - fig.update_layout(**layout_kwargs) + fig = plot_dict(show=False) if show: - fig.show() + update_and_show(fig) - if gradient: - grad_figs = [] - for i, grad_costs in enumerate(grad_parameter_costs): - # Update title for gradient plots - updated_layout_options = layout_options.copy() - updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}" - - # Create contour plot with updated layout options - grad_layout = go.Layout(updated_layout_options) - - # Create fig - grad_fig = go.Figure( - data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout - ) - grad_fig.update_layout(**layout_kwargs) + # if gradient: + # grad_figs = [] + # for i, grad_costs in enumerate(grad_parameter_costs): + # # Update title for gradient plots + # updated_layout_options = layout_options.copy() + # updated_layout_options["title"] = f"Gradient for Parameter: {i + 1}" + + # # Create contour plot with updated layout options + # grad_layout = go.Layout(updated_layout_options) + + # # Create fig + # grad_fig = go.Figure( + # data=[go.Contour(x=x, y=y, z=grad_costs)], layout=grad_layout + # ) + # grad_fig.update_layout(**layout_kwargs) - if show: - grad_fig.show() + # if show: + # grad_fig.show() - # append grad_fig to list - grad_figs.append(grad_fig) + # # append grad_fig to list + # grad_figs.append(grad_fig) - return fig, grad_figs + # return fig, grad_figs return fig diff --git a/pybop/plot/convergence.py b/pybop/plot/convergence.py index deabf66d5..5d3390959 100644 --- a/pybop/plot/convergence.py +++ b/pybop/plot/convergence.py @@ -1,12 +1,13 @@ from typing import TYPE_CHECKING from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import update_and_show if TYPE_CHECKING: from pybop._result import Result -def convergence(result: "Result", show=True, **layout_kwargs): +def convergence(result: "Result", show=True, backend=None): """ Plot the convergence of the optimisation algorithm. @@ -37,18 +38,15 @@ def convergence(result: "Result", show=True, **layout_kwargs): plot_dict = StandardPlot( x=iteration_numbers, y=cost_log, - layout_options=dict( - xaxis_title="Evaluation", - yaxis_title="Cost", - title="Convergence", - ), + xaxis_title="Evaluation", + yaxis_title="Cost", + title="Convergence", trace_names=result.method_name, + backend=backend, ) # Generate and display the figure fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) if show: - fig.show() - + update_and_show(fig, backend=backend) return fig diff --git a/pybop/plot/dataset.py b/pybop/plot/dataset.py index 24257a732..b92d2b78f 100644 --- a/pybop/plot/dataset.py +++ b/pybop/plot/dataset.py @@ -1,7 +1,10 @@ from pybop.plot.standard_plots import StandardPlot, trajectories +from pybop.plot.util import update_and_show -def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs): +def dataset( + dataset, signal=None, trace_names=None, show=True, backend=None, **layout_kwargs +): """ Quickly plot a PyBOP Dataset using Plotly. @@ -50,9 +53,9 @@ def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs): show=False, xaxis_title=StandardPlot.remove_brackets(dataset.domain), yaxis_title=yaxis_title, + backend=backend, ) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + + fig = update_and_show(fig, backend=backend, show=show, **layout_kwargs) return fig diff --git a/pybop/plot/matplotlib/__init__.py b/pybop/plot/matplotlib/__init__.py new file mode 100644 index 000000000..2730122df --- /dev/null +++ b/pybop/plot/matplotlib/__init__.py @@ -0,0 +1,3 @@ +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path +from .voronoi import surface +from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/matplotlib/standard_plots.py b/pybop/plot/matplotlib/standard_plots.py new file mode 100644 index 000000000..bd49cd6d5 --- /dev/null +++ b/pybop/plot/matplotlib/standard_plots.py @@ -0,0 +1,385 @@ +import warnings + +from matplotlib import pyplot as plt + +from pybop.plot import StandardPlot + +DEFAULT_TRACE_OPTIONS = dict(linewidth=2.0) + + +class Plotter: + """ + A class for creating and displaying interactive Plotly figures. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + trace_options=None, + figsize=(8, 6), + title=None, + xaxis_title=None, + yaxis_title=None, + xaxis_range=None, + yaxis_range=None, + grid=None, + axis_bg_color=None, + **kwargs, + ): + self.backend = "matplotlib" + self.title = title + # Warning if layout arguments ignored + if len(kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + # Set default trace options and update if provided + self.trace_options = DEFAULT_TRACE_OPTIONS.copy() + if trace_options: + self.trace_options.update(trace_options) + + self.fig = plt.figure(figsize=figsize, dpi=100) + if title is not None: + plt.suptitle(self.title) + if xaxis_title is not None: + plt.xlabel(xaxis_title) + if yaxis_title is not None: + plt.ylabel(yaxis_title) + if grid is not None: + plt.grid(**grid) + if axis_bg_color is not None: + ax = plt.gca() + ax.set_facecolor(axis_bg_color) + ax.set_axisbelow(True) + if xaxis_range is not None: + plt.xlim(xaxis_range) + if yaxis_range is not None: + plt.ylim(yaxis_range) + self.traces = [] + + def __call__(self, show=True): + """ + Generate and show the figure. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + # Add traces + for trace in self.traces: + self._plot_trace(**trace) + + plt.tick_params(axis="both", labelsize=12) + plt.ticklabel_format(axis="both", style="sci", scilimits=(-4, 4)) + + labels_in_fig = True + for ax in self.fig.axes: + if not ax.get_legend_handles_labels() == ([], []): + break + else: + labels_in_fig = False + if labels_in_fig: + plt.legend( + **dict(loc="best", fontsize=12), + ) + + if show: + plt.show() + else: + return self.fig + + def add_traces(self, x, y, trace_names=None, **trace_options): + """ + Add a set of traces to the plot dictionary. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + trace_names : str or list[str], optional + Name(s) for the primary trace(s) (default: None). + """ + + options = self.trace_options.copy() + options.update(trace_options) + + # Create a trace for each trajectory + xi = x[0] + for i in range(0, len(y)): + trace_options = options.copy() + if len(x) > 1: + xi = x[i] + + label = None + if trace_names is not None: + label = trace_names[i] + + self.traces.append(self.create_trace(xi, y[i], label, **trace_options)) + + def create_trace(self, x=None, y=None, label=None, ax=None, **trace_options): + """ + Add line to plot. + + Returns + ------- + plotly.graph_objs.Scatter + A trace for a Plotly figure. + """ + if x is not None and y is not None: + size = min(len(x), len(y)) + trace = dict(positional_args=[x[:size], y[:size]], label=label, ax=ax) + elif y is not None: + trace = dict(positional_args=[y], label=label, ax=ax) + + trace.update(trace_options) + return trace + + def create_fill_trace(self, x, y_upper, y_lower, **options): + trace = dict(positional_args=(x, y_upper, y_lower), plot_type="fill_between") + trace.update(options) + return trace + + def create_histogram(self, x, name, **trace_options): + trace = dict(positional_args=[x], label=name, plot_type="hist") + trace.update(trace_options) + return trace + + def create_vline(self, fig, x, **trace_options): + fig.gca() + plt.axvline(x, **trace_options) + + def create_contour(self, x, y, z, **trace_options): + contour = dict(positional_args=[x, y, z], plot_type="contourf") + contour.update(**trace_options) + self.traces.append(contour) + self.traces.append( + dict( + positional_args=[x, y, z], + colors=("k"), + linestyles="solid", + linewidths=0.2, + plot_type="contour", + ) + ) + + def create_scatter(self, x, y, **trace_options): + scatter = dict(positional_args=[x, y], plot_type="scatter") + scatter.update(**trace_options) + self.traces.append(scatter) + + def _plot_trace( + self, ax=None, plot_type="plot", positional_args=None, **trace_options + ): + if positional_args is None: + positional_args = [] + if ax is None: + ax = plt.gca() + try: + plot_function = getattr(ax, plot_type) + except ValueError: + print("Plot type not recognised") + + obj = plot_function(*positional_args, **trace_options) + if plot_type == "contourf": + plt.colorbar(obj) + + return obj + + +class SubplotPlotter(Plotter): + """ + A class for creating and displaying a set of interactive Plotly figures in a grid layout. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + num_rows : int, optional + Number of rows of subplots, can be set automatically (default: None). + num_cols : int, optional + Number of columns of subplots, can be set automatically (default: None). + layout : Plotly layout, optional + A layout for the figure, overrides the layout options (default: None). + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + axis_titles=None, + trace_options=DEFAULT_TRACE_OPTIONS, + figsize=(8, 6), + **kwargs, + ): + super().__init__(trace_options, figsize, **kwargs) + self.axis_titles = axis_titles + + def __call__(self, show=True, num_rows=1, num_cols=1): + """ + Generate and show the set of figures. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + + color_cycle = plt.rcParams["axes.prop_cycle"]() + + lines = [] + show_legend = False + for idx, trace in enumerate(self.traces): + ax = self.fig.add_subplot(num_rows, num_cols, idx + 1) + trace["ax"] = ax + if self.axis_titles and idx < len(self.axis_titles): + x_title, y_title = self.axis_titles[idx] + ax.set_xlabel(x_title) + ax.set_ylabel(y_title) + if "label" in trace.keys() and trace["label"] is not None: + show_legend = True + + lines.append(self._plot_trace(**trace, **next(color_cycle))) + + lines_labels = [ax.get_legend_handles_labels() for ax in self.fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels, strict=False)] + if show_legend: + self.fig.legend( + lines, + labels, + loc="upper right", + ncol=len(lines), + bbox_to_anchor=(0.99, 0.95), + ) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + + if self.title is not None: + plt.suptitle(self.title) + + if show: + plt.show() + + return self.fig + + +def trajectories( + x, + y, + trace_names=None, + show=True, + xaxis_title="", + yaxis_title="", + title="", + **layout_kwargs, +): + """ + Quickly plot one or more trajectories using Plotly. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Y-axis data points for each trajectory. + trace_names : list or str, optional + Name(s) for the trace(s) (default: None). + **layout_kwargs : optional + This argument is ignored for the matplotlib backend. + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + # Create a plot dictionary + plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="matplotlib") + + # Generate the figure and update the layout + fig = plot_dict(show=False) + plt.title(title) + plt.xlabel(xaxis_title, fontsize=12) + plt.ylabel(yaxis_title, fontsize=12) + plt.tight_layout() + if show: + fig.show() + + return fig + + +def show_table(header, values, title): + """ + Display data in a table. + """ + for i, val in enumerate(values): + values[i] = [val[0], ", ".join(val[1].astype(str))] + + fig, ax = plt.subplots(figsize=(6, 2), dpi=100) + + # hide axes + ax.axis("off") + ax.axis("tight") + ax.table( + cellText=values, + colLabels=header, + loc="center", + cellLoc="center", + colColours=["lightsteelblue", "lightsteelblue"], + ) + ax.set_title(title) + fig.tight_layout() + plt.show() + + +def plot_optimisation_path(plot_dict: StandardPlot, x, y): + plot_dict.plotter.create_scatter( + x, + y, + c=[i / len(x) for i in range(len(x))], + cmap="Grays", + zorder=1, + ) diff --git a/pybop/plot/matplotlib/util.py b/pybop/plot/matplotlib/util.py new file mode 100644 index 000000000..b6a7311d8 --- /dev/null +++ b/pybop/plot/matplotlib/util.py @@ -0,0 +1,92 @@ +import warnings + +import matplotlib.pyplot as plt + + +def update_and_show(fig, show=True, **layout_kwargs): + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + if show: + plt.show() + + return fig + + +DEFAULT_PLOT_OPTIONS = { + "contour": { + "plot_options": dict(title="Cost Landscape"), + "trace_options_contour": dict(extend="both", cmap="viridis"), + "trace_options_initial": dict( + marker="X", + markersize=14, + markerfacecolor="w", + markeredgecolor="k", + label="Initial values", + linestyle="None", + ), + "trace_options_optim": dict( + marker="P", + markersize=14, + markerfacecolor="k", + markeredgecolor="w", + label="Final values", + linestyle="None", + ), + }, + "nyquist": { + "plot_options": dict( + xaxis_title=r"$Z_{re} / \Omega$", yaxis_title=r"$-Z_{im} / \Omega$" + ), + "trace_options_model": dict( + label="Model", color="#00CC96", linewidth=2, marker=".", markersize=8 + ), + "trace_options_reference": dict( + label="Reference", + linestyle="none", + marker="o", + fillstyle="none", + markersize=8, + markeredgecolor="#636EFA", + ), + }, + "parameters": dict(figsize=(18, 8), title="Parameter Convergence"), + "problem": { + "default_trace_options": dict(label="Model", marker=None, linestyle="-"), + "design_cost_options": dict(label="Optimised"), + "meta_problem_options": dict(marker=".", linestyle="none"), + "reference_options": dict(label="Reference", marker=".", linestyle="none"), + "fill_options": dict(color=[(1.0, 0.898, 0.800, 0.8)]), + }, + "posterior": { + "plot_options": { + "figsize": (15, 8), + "grid": dict(axis="y", zorder=-1, color="w"), + "axis_bg_color": ( + 0.6784313725490196, + 0.8470588235294118, + 0.9019607843137255, + 0.3, + ), + }, + "trace_options": dict(alpha=0.75), + "trace_options_vline": dict(linewidth=3, linestyle="--", color="k"), + }, + "trace": { + "plot_options": { + "figsize": (15, 8), + "grid": dict(axis="y", zorder=-10, color="w"), + "axis_bg_color": ( + 0.6784313725490196, + 0.8470588235294118, + 0.9019607843137255, + 0.3, + ), + }, + }, +} diff --git a/pybop/plot/matplotlib/voronoi.py b/pybop/plot/matplotlib/voronoi.py new file mode 100644 index 000000000..e7502b941 --- /dev/null +++ b/pybop/plot/matplotlib/voronoi.py @@ -0,0 +1,141 @@ +import warnings +from typing import TYPE_CHECKING + +import matplotlib as mpl +import numpy as np +from matplotlib import pyplot as plt + +if TYPE_CHECKING: + from pybop._result import Result + +from pybop.plot.voronoi import voronoi_data + + +def surface( + result: "Result", + bounds=None, + normalise=True, + title="Voronoi Cost Landscape", + show=True, + **layout_kwargs, +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + + if len(layout_kwargs) > 0: + warnings.warn( + "The following layout argument keys are ignored for the current plotting backend (matplotlib): \n" + f"{list(layout_kwargs.keys())}", + UserWarning, + stacklevel=2, + ) + + points = result.x_model + parameters = result.problem.parameters + + if points[0].shape[0] != 2: + raise ValueError("This plot method requires two parameters.") + + x_optim, y_optim = map(list, zip(*points, strict=False)) + f = result.cost + + # Translate bounds, taking only the first two elements + xlim, ylim = ( + bounds if bounds is not None else [param.bounds for param in parameters] + )[:2] + + _, _, f, regions, relative_sizes = voronoi_data( + xlim, ylim, x_optim, y_optim, f, normalise + ) + + # Construct figure + plt.figure(figsize=(7, 6), dpi=100) + + # normalise cost + f_min = np.nanmin(f[np.isfinite(f)]) + f_max = np.nanmax(f[np.isfinite(f)]) + norm = mpl.colors.Normalize(vmin=f_min, vmax=f_max, clip=True) + norm_f = norm(f, clip=True) + + # get colours + cmap = mpl.colormaps["viridis"] + colors = cmap(norm_f) + + # Add Voronoi edges and fill Voronoi regions + for j, (region, size) in enumerate(zip(regions, relative_sizes, strict=False)): + x_region = region[:, 0].tolist() + [region[0, 0]] + y_region = region[:, 1].tolist() + [region[0, 1]] + + plt.fill(x_region, y_region, color=colors[j]) + plt.plot(x_region, y_region, color="w", linewidth=0.5 + size * 0.1) + + plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca()) + + # Add original points + plt.scatter( + x_optim, + y_optim, + c=[i / len(x_optim) for i in range(len(x_optim))], + cmap="Grays", + zorder=2.5, + ) + + # Plot the initial guess + if len(result.x_model) > 0: + x0 = result.x_model[0] + plt.plot( + [x0[0]], + [x0[1]], + "X", + markersize=14, + markerfacecolor="w", + markeredgecolor="k", + label="Initial values", + linestyle="None", + zorder=2.6, + ) + + # Plot optimised value + if result.x is not None: + x_best = result.x + plt.plot( + [x_best[0]], + [x_best[1]], + "P", + markersize=14, + markerfacecolor="k", + markeredgecolor="w", + label="Final values", + linestyle="None", + zorder=2.6, + ) + + # Layout + names = result.problem.parameters.names + plt.xlabel(names[0], labelpad=15) + plt.ticklabel_format(axis="both", **dict(style="sci", scilimits=(-4, 4))) + plt.ylabel(names[1], labelpad=15) + plt.title(title, pad=40) + plt.legend(ncols=2, loc="lower center", bbox_to_anchor=(0.5, 1.0)) + plt.xlim(xlim[0], xlim[1]) + plt.ylim(ylim[0], ylim[1]) + plt.tight_layout() + + if show: + plt.show() diff --git a/pybop/plot/nyquist.py b/pybop/plot/nyquist.py index 80f7eb77a..5d1452f2e 100644 --- a/pybop/plot/nyquist.py +++ b/pybop/plot/nyquist.py @@ -1,8 +1,11 @@ from pybop.parameters.parameter import Inputs from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import get_default_options, update_and_show -def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): +def nyquist( + problem, inputs: Inputs = None, show=True, title="Nyquist Plot", backend=None +): """ Generates Nyquist plots for the given problem by evaluating the model's output and target values. @@ -39,81 +42,41 @@ def nyquist(problem, inputs: Inputs = None, show=True, **layout_kwargs): >>> nyquist_figures = nyquist(problem, show=True, title="Nyquist Plot", xaxis_title="Real(Z)", yaxis_title="Imag(Z)") >>> # The plots will be displayed and nyquist_figures will contain the list of figure objects. """ + + options = get_default_options("nyquist", backend) + plot_options = options.get("plot_options") or {} + trace_options_model = options.get("trace_options_model") or {} + trace_options_reference = options.get("trace_options_reference") or {} + if not isinstance(inputs, dict): inputs = problem.parameters.to_dict(inputs) model_output = problem.simulate(inputs) domain_data = model_output["Impedance"].data.real target_output = problem.target_data - figure_list = [] for var in problem.target: - default_layout_options = dict( - title="Nyquist Plot", - font=dict(family="Arial", size=14), - plot_bgcolor="white", - paper_bgcolor="white", - xaxis=dict( - title=dict(text="Zre / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - ), - yaxis=dict( - title=dict(text="-Zim / Ω", font=dict(size=16), standoff=15), - showline=True, - linewidth=2, - linecolor="black", - mirror=True, - ticks="outside", - tickwidth=2, - tickcolor="black", - ticklen=5, - scaleanchor="x", - scaleratio=1, - ), - legend=dict( - orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 - ), - width=600, - height=600, - ) - plot_dict = StandardPlot( x=domain_data, y=-model_output[var].data.imag, - layout_options=default_layout_options, trace_names="Model", + title=title, + **plot_options, ) - plot_dict.traces[0].update( - mode="lines+markers", - line=dict(color="#00CC96", width=2), - marker=dict(size=8, color="#00CC96", symbol="circle"), - ) + plot_dict.traces[0].update(trace_options_model) target_trace = plot_dict.create_trace( x=target_output[var].real, y=-target_output[var].imag, - name="Reference", - mode="markers", - marker=dict(size=8, color="#636EFA", symbol="circle-open"), - showlegend=True, + **trace_options_reference, ) plot_dict.traces.append(target_trace) fig = plot_dict(show=False) - - # Overwrite with user-kwargs - fig.update_layout(**layout_kwargs) - if show: - fig.show() - figure_list.append(fig) + if show: + update_and_show(figure_list) + return figure_list diff --git a/pybop/plot/parameters.py b/pybop/plot/parameters.py index a13d9e3fb..934ec67a3 100644 --- a/pybop/plot/parameters.py +++ b/pybop/plot/parameters.py @@ -2,12 +2,13 @@ from pybop.costs.log_likelihoods import GaussianLogLikelihood from pybop.plot.standard_plots import StandardSubplot +from pybop.plot.util import get_default_options, update_and_show if TYPE_CHECKING: from pybop._result import Result -def parameters(result: "Result", show=True, **layout_kwargs): +def parameters(result: "Result", show=True, backend=None, **layout_kwargs): """ Plot the evolution of parameters during the optimisation process using Plotly. @@ -44,27 +45,21 @@ def parameters(result: "Result", show=True, **layout_kwargs): trace_names.append("Sigma") # Set subplot layout options - layout_options = dict( - title="Parameter Convergence", - width=1024, - height=576, - legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), - ) + plot_options = get_default_options("paramters", backend) # Create a plot dictionary plot_dict = StandardSubplot( x=x, y=y, axis_titles=axis_titles, - layout_options=layout_options, trace_names=trace_names, trace_name_width=50, + backend=backend, + **plot_options, ) # Generate the figure and update the layout fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + fig = update_and_show(fig, backend, **layout_kwargs) return fig diff --git a/pybop/plot/plotly/__init__.py b/pybop/plot/plotly/__init__.py new file mode 100644 index 000000000..4de178480 --- /dev/null +++ b/pybop/plot/plotly/__init__.py @@ -0,0 +1,5 @@ +from .plotly_manager import PlotlyManager +from .standard_plots import Plotter, SubplotPlotter, show_table, trajectories, plot_optimisation_path +from .voronoi import surface + +from .util import update_and_show, DEFAULT_PLOT_OPTIONS diff --git a/pybop/plot/plotly_manager.py b/pybop/plot/plotly/plotly_manager.py similarity index 100% rename from pybop/plot/plotly_manager.py rename to pybop/plot/plotly/plotly_manager.py diff --git a/pybop/plot/plotly/standard_plots.py b/pybop/plot/plotly/standard_plots.py new file mode 100644 index 000000000..80429f0c4 --- /dev/null +++ b/pybop/plot/plotly/standard_plots.py @@ -0,0 +1,354 @@ +from pybop.plot import StandardPlot +from pybop.plot.plotly.plotly_manager import PlotlyManager + +DEFAULT_LAYOUT_OPTIONS = dict( + title=None, + title_x=0.5, + xaxis=dict( + title=dict(font={"size": 14}), + showexponent="last", + exponentformat="e", + tickfont=dict(size=12), + ), + yaxis=dict( + title=dict(font={"size": 14}), + showexponent="last", + exponentformat="e", + tickfont=dict(size=12), + ), + legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12), + showlegend=True, + autosize=False, + width=600, + height=600, + margin=dict(l=10, r=10, b=10, t=75, pad=4), + plot_bgcolor="white", +) +DEFAULT_SUBPLOT_OPTIONS = dict( + start_cell="bottom-left", +) +DEFAULT_TRACE_OPTIONS = dict(line=dict(width=4), mode="lines") +DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines") + + +class Plotter: + """ + A class for creating and displaying interactive Plotly figures. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + layout : Plotly layout, optional + A layout for the figure, overrides the layout options (default: None). + layout_options : dict, optional + Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + title=None, + xaxis_title=None, + yaxis_title=None, + xaxis_range=None, + yaxis_range=None, + layout=None, + layout_options=None, + trace_options=None, + ): + self.backend = "plotly" + + self.traces = [] + self.layout = layout + + # Set default layout options and update if provided + if self.layout is None: + self.layout_options = DEFAULT_LAYOUT_OPTIONS.copy() + if layout_options: + self.layout_options.update(layout_options) + if title is not None: + self.layout_options.update({"title": title}) + if xaxis_title is not None: + self.layout_options.update({"xaxis_title": xaxis_title}) + if yaxis_title is not None: + self.layout_options.update({"yaxis_title": yaxis_title}) + if xaxis_range is not None: + self.layout_options["xaxis"].update({"range": xaxis_range}) + if yaxis_range is not None: + self.layout_options["yaxis"].update({"range": yaxis_range}) + + # Set default trace options and update if provided + self.trace_options = DEFAULT_TRACE_OPTIONS.copy() + if trace_options: + self.trace_options.update(trace_options) + + # Attempt to import plotly when an instance is created + self.go = PlotlyManager().go + + # Create layout + if self.layout is None: + self.layout = self.go.Layout(**self.layout_options) + + def __call__(self, show=True): + """ + Generate and show the figure. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + fig = self.go.Figure(data=self.traces, layout=self.layout) + if show: + fig.show() + + return fig + + def add_traces(self, x, y, trace_names=None, **trace_options): + """ + Add a set of traces to the plot dictionary. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + trace_names : str or list[str], optional + Name(s) for the primary trace(s) (default: None). + """ + options = self.trace_options.copy() + options.update(trace_options) + + # Create a trace for each trajectory + xi = x[0] + for i in range(0, len(y)): + trace_options = options.copy() + if len(x) > 1: + xi = x[i] + if trace_names is not None: + trace_options["name"] = trace_names[i] + else: + trace_options["showlegend"] = False + trace = self.create_trace(xi, y[i], **trace_options) + self.traces.append(trace) + + def create_trace(self, x=None, y=None, label=None, **trace_options): + """ + Create a trace for the Plotly figure. + + Returns + ------- + plotly.graph_objs.Scatter + A trace for a Plotly figure. + """ + if label is not None: + trace_options.update({"name": label}) + if x is not None and y is not None: + return self.go.Scatter( + x=x, + y=y, + **trace_options, + ) + if x is None and y is not None: + return self.go.Scatter( + y=y, + **trace_options, + ) + + def create_fill_trace(self, x, y_upper, y_lower, **options): + return self.create_trace( + x=x + x[::-1], + y=y_upper + y_lower[::-1], + fill="toself", + line=dict(color="rgba(255,255,255,0)"), + hoverinfo="skip", + showlegend=False, + **options, + ) + + def create_histogram(self, x, name, **trace_options): + return self.go.Histogram(x=x, name=name, **trace_options) + + def create_vline(self, fig, x, **trace_options): + fig.add_vline(x=x, **trace_options) + + def create_contour(self, x, y, z, **trace_options): + self.traces.append(self.go.Contour(x=x, y=y, z=z, **trace_options)) + + +class SubplotPlotter(Plotter): + """ + A class for creating and displaying a set of interactive Plotly figures in a grid layout. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + num_rows : int, optional + Number of rows of subplots, can be set automatically (default: None). + num_cols : int, optional + Number of columns of subplots, can be set automatically (default: None). + layout : Plotly layout, optional + A layout for the figure, overrides the layout options (default: None). + layout_options : dict, optional + Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). + trace_options : dict, optional + Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). + trace_names : str, optional + Name(s) for the primary trace(s) (default: None). + trace_name_width : int, optional + Maximum length of the trace names before text wrapping is used (default: 40). + + Returns + ------- + plotly.graph_objs.Figure + The generated Plotly figure. + """ + + def __init__( + self, + axis_titles=None, + layout=None, + layout_options=DEFAULT_LAYOUT_OPTIONS, + subplot_options=DEFAULT_SUBPLOT_OPTIONS, + trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS, + ): + super().__init__( + layout, layout_options=layout_options, trace_options=trace_options + ) + self.subplot_options = subplot_options.copy() + if subplot_options is not None: + for arg, value in subplot_options.items(): + self.subplot_options[arg] = value + + # Attempt to import plotly when an instance is created + self.make_subplots = PlotlyManager().make_subplots + + self.axis_titles = axis_titles + + def __call__(self, show=True, num_rows=1, num_cols=1): + """ + Generate and show the set of figures. + + Parameters + ---------- + show : bool, optional + If True, the figure is shown upon creation (default: True). + """ + fig = self.make_subplots( + rows=num_rows, + cols=num_cols, + horizontal_spacing=0.1, + vertical_spacing=0.15, + **self.subplot_options, + ) + fig.update_layout(self.layout_options) + + for idx, trace in enumerate(self.traces): + row = (idx // num_cols) + 1 + col = (idx % num_cols) + 1 + fig.add_trace(trace, row=row, col=col) + + if self.axis_titles and idx < len(self.axis_titles): + x_title, y_title = self.axis_titles[idx] + fig.update_xaxes(title_text=x_title, row=row, col=col) + fig.update_yaxes( + title_text=y_title, + row=row, + col=col, + showexponent="last", + exponentformat="e", + ) + + if show: + fig.show() + + return fig + + +def trajectories(x, y, trace_names=None, show=True, **layout_kwargs): + """ + Quickly plot one or more trajectories using Plotly. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Y-axis data points for each trajectory. + trace_names : list or str, optional + Name(s) for the trace(s) (default: None). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time / s"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object for the scatter plot. + """ + # Create a plot dictionary + plot_dict = StandardPlot(x=x, y=y, trace_names=trace_names, backend="plotly") + + # Generate the figure and update the layout + fig = plot_dict(show=False) + fig.update_layout(**layout_kwargs) + if show: + fig.show() + + return fig + + +def show_table(header, values, title): + """ + Display data in a table. + """ + # Import plotly only when needed + go = PlotlyManager().go + fig = go.Figure( + data=[ + go.Table( + header=dict(values=header), + cells=dict( + values=[[row[0] for row in values], [row[1] for row in values]] + ), + ) + ] + ) + + fig.update_layout(title=title) + fig.show() + + +def plot_optimisation_path(plot_dict: StandardPlot, x, y): + plot_dict.traces.append( + plot_dict.create_trace( + x, + y, + mode="markers", + marker=dict( + color=[i / len(x) for i in range(len(x))], + colorscale="Greys", + size=8, + showscale=False, + ), + showlegend=False, + ) + ) diff --git a/pybop/plot/plotly/util.py b/pybop/plot/plotly/util.py new file mode 100644 index 000000000..9a913b3aa --- /dev/null +++ b/pybop/plot/plotly/util.py @@ -0,0 +1,150 @@ +def update_and_show(fig, show=True, **layout_kwargs): + if hasattr(fig, "__len__") and len(fig) > 0: + for f in fig: + f.update_layout(**layout_kwargs) + if show: + f.show() + else: + fig.update_layout(**layout_kwargs) + if show: + fig.show() + return fig + + +DEFAULT_PLOT_OPTIONS = { + "contour": { + "plot_options": dict( + layout_options=dict( + title="Cost Landscape", + title_x=0.5, + title_y=0.905, + width=600, + height=600, + xaxis=dict(showexponent="last", exponentformat="e"), + yaxis=dict(showexponent="last", exponentformat="e"), + legend=dict( + orientation="h", yanchor="bottom", y=1, xanchor="right", x=1 + ), + autosize=None, + showlegend=None, + margin=None, + ) + ), + "trace_options_contour": dict(colorscale="Viridis", connectgaps=True), + "trace_options_initial": dict( + mode="markers", + marker_symbol="x", + marker=dict( + color="white", + line_color="black", + line_width=1, + size=14, + showscale=False, + ), + name="Initial values", + ), + "trace_options_optim": dict( + mode="markers", + marker_symbol="cross", + marker=dict( + color="black", + line_color="white", + line_width=1, + size=14, + showscale=False, + ), + name="Final values", + ), + }, + "nyquist": { + "plot_options": dict( + layout_options=dict( + title="Nyquist Plot", + font=dict(family="Arial", size=14), + plot_bgcolor="white", + paper_bgcolor="white", + xaxis=dict( + title=dict(font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + ), + yaxis=dict( + title=dict(font=dict(size=16), standoff=15), + showline=True, + linewidth=2, + linecolor="black", + mirror=True, + ticks="outside", + tickwidth=2, + tickcolor="black", + ticklen=5, + scaleanchor="x", + scaleratio=1, + ), + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + width=600, + height=600, + ), + xaxis_title="Zre / Ω", + yaxis_title="-Zim / Ω", + ), + "trace_options_model": dict( + mode="lines+markers", + line=dict(color="#00CC96", width=2), + marker=dict(size=8, color="#00CC96", symbol="circle"), + ), + "trace_options_reference": dict( + name="Reference", + mode="markers", + marker=dict(size=8, color="#636EFA", symbol="circle-open"), + showlegend=True, + ), + }, + "parameters": dict( + layout_options=dict( + title="Parameter Convergence", + width=1024, + height=576, + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + ) + ), + "posterior": { + "plot_options": { + "layout_options": dict( + barmode="overlay", + width=None, + height=None, + plot_bgcolor=None, + autosize=None, + legend=None, + ) + }, + "trace_options": dict(opacity=0.75), + "trace_options_vline": dict(line_width=3, line_dash="dash", line_color="black"), + }, + "problem": { + "default_trace_options": dict(name="Model", mode="lines", showlegend=True), + "design_cost_options": dict(name="Optimised"), + "meta_problem_options": dict(mode="lines"), + "reference_options": dict(name="Reference", mode="markers", showlegend=True), + "fill_options": dict(fillcolor="rgba(255,229,204,0.8)"), + }, + "trace": { + "plot_options": { + "layout_options": dict( + width=None, height=None, plot_bgcolor=None, autosize=None, legend=None + ) + }, + "trace_options": dict(mode="lines"), + }, +} diff --git a/pybop/plot/plotly/voronoi.py b/pybop/plot/plotly/voronoi.py new file mode 100644 index 000000000..d60b9597e --- /dev/null +++ b/pybop/plot/plotly/voronoi.py @@ -0,0 +1,216 @@ +from typing import TYPE_CHECKING + +import numpy as np +from scipy.spatial import cKDTree + +if TYPE_CHECKING: + from pybop._result import Result +from pybop.plot import voronoi_data +from pybop.plot.plotly.plotly_manager import PlotlyManager + + +def assign_nearest_value(x, y, f, xi, yi): + """ + Computes an array of values given by the score of the nearest point. + + Parameters + ---------- + x : array-like + The x coordinates of points with known scores. + y : array-like + The y coordinates of points with known scores. + f : array-like + The score function at the given x and y coordinates. + xi : array-like + The x coordinates of grid points. + yi : array-like + The y coordinates of grid points. + + Returns + ------- + A numpy array containing the scores corresponding to the grid points. + """ + # Create a KD-tree for efficient nearest neighbor search + tree = cKDTree(np.column_stack((x, y))) + + # Find the nearest point for each grid point + _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel()))) + zi = f[indices].reshape(xi.shape) + + return zi + + +def surface( + result: "Result", + bounds=None, + normalise=True, + resolution=250, + show=True, + **layout_kwargs, +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + """ + points = result.x_model + parameters = result.problem.parameters + + if points[0].shape[0] != 2: + raise ValueError("This plot method requires two parameters.") + + x_optim, y_optim = map(list, zip(*points, strict=False)) + f = result.cost + + # Translate bounds, taking only the first two elements + xlim, ylim = ( + bounds if bounds is not None else [param.bounds for param in parameters] + )[:2] + + x, y, f, regions, relative_sizes = voronoi_data( + xlim, ylim, x_optim, y_optim, f, normalise + ) + + # Create a grid for plot + xi = np.linspace(xlim[0], xlim[1], resolution) + yi = np.linspace(ylim[0], ylim[1], resolution) + xi, yi = np.meshgrid(xi, yi) + + if normalise: + # Create a normalised grid + norm_xi = np.linspace(0, 1, resolution) + norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi) + + # Assign a value to each point in the grid + zi = assign_nearest_value(x, y, f, norm_xi, norm_yi) + else: + # Assign a value to each point in the grid + zi = assign_nearest_value(x, y, f, xi, yi) + + # Calculate the size of each Voronoi region + region_sizes = np.array([len(region) for region in regions]) + relative_sizes = (region_sizes - region_sizes.min()) / ( + region_sizes.max() - region_sizes.min() + ) + + # Construct figure + go = PlotlyManager().go + fig = go.Figure() + + # Heatmap + fig.add_trace( + go.Heatmap( + x=xi[0], + y=yi[:, 0], + z=zi, + colorscale="Viridis", + zsmooth="best", + ) + ) + + # Add Voronoi edges + for region, size in zip(regions, relative_sizes, strict=False): + x_region = region[:, 0].tolist() + [region[0, 0]] + y_region = region[:, 1].tolist() + [region[0, 1]] + + fig.add_trace( + go.Scatter( + x=x_region, + y=y_region, + mode="lines", + line=dict(color="white", width=0.5 + size * 0.1), + showlegend=False, + ) + ) + + # Add original points + fig.add_trace( + go.Scatter( + x=x_optim, + y=y_optim, + mode="markers", + marker=dict( + color=[i / len(x_optim) for i in range(len(x_optim))], + colorscale="Greys", + size=8, + showscale=False, + ), + text=[f"f={val:.2f}" for val in f], + hoverinfo="text", + showlegend=False, + ) + ) + + # Plot the initial guess + if len(result.x_model) > 0: + x0 = result.x_model[0] + fig.add_trace( + go.Scatter( + x=[x0[0]], + y=[x0[1]], + mode="markers", + marker_symbol="x", + marker=dict( + color="white", + line_color="black", + line_width=1, + size=14, + showscale=False, + ), + name="Initial values", + ) + ) + + # Plot optimised value + if result.x is not None: + x_best = result.x + fig.add_trace( + go.Scatter( + x=[x_best[0]], + y=[x_best[1]], + mode="markers", + marker_symbol="cross", + marker=dict( + color="black", + line_color="white", + line_width=1, + size=14, + showscale=False, + ), + name="Final values", + ) + ) + + names = parameters.names + fig.update_layout( + title="Voronoi Cost Landscape", + title_x=0.5, + title_y=0.905, + xaxis_title=names[0], + yaxis_title=names[1], + width=600, + height=600, + xaxis=dict(range=xlim, showexponent="last", exponentformat="e"), + yaxis=dict(range=ylim, showexponent="last", exponentformat="e"), + legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), + ) + fig.update_layout(**layout_kwargs) + if show: + fig.show() diff --git a/pybop/plot/plots.py b/pybop/plot/plots.py new file mode 100644 index 000000000..873af49fe --- /dev/null +++ b/pybop/plot/plots.py @@ -0,0 +1,113 @@ +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from pybop._result import Result + +from pybop.plot.util import call_plotting_function +from pybop.problems.problem import Problem + + +def contour( + call_object: "Problem | Result", + gradient: bool = False, + bounds: np.ndarray | None = None, + transformed: bool = False, + steps: int = 10, + show: bool = True, + backend=None, + **layout_kwargs, +): + """ + Plot a 2D visualisation of a cost landscape using Plotly. + + This function generates a contour plot representing the cost landscape for a provided + callable cost function over a grid of parameter values within the specified bounds. + + Parameters + ---------- + call_object : pybop.Problem | pybop.Result + Either: + - the cost function to be evaluated. Must accept a list of parameter values and return a cost value. + - an optimiser result which provides a specific optimisation trace overlaid on the cost landscape. + gradient : bool, optional + If True, the gradient is shown (default: False). + bounds : numpy.ndarray | list[list[float]], optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `parameters.get_bounds_for_plotly`. + transformed : bool, optional + Uses the transformed parameter values (as seen by the optimiser) for plotting. + steps : int, optional + The number of grid points to divide the parameter space into along each dimension (default: 10). + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + + Returns + ------- + plotly.graph_objs.Figure + The Plotly figure object containing the cost landscape plot. + + Raises + ------ + ValueError + If the cost function does not return a valid cost when called with a parameter list. + """ + return call_plotting_function( + "contour", + backend, + call_object=call_object, + gradient=gradient, + bounds=bounds, + transformed=transformed, + steps=steps, + show=show, + **layout_kwargs, + ) + + +def surface( + result: "Result", + bounds=None, + normalise=True, + resolution=250, + show=True, + backend=None, + **layout_kwargs, +): + """ + Plot a 2D representation of the Voronoi diagram with color-coded regions. + + Parameters: + ----------- + result : pybop.Result + Optimisation result containing the history of parameter values and associated cost. + bounds : numpy.ndarray, optional + A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses + `cost.parameters.get_bounds_for_plotly`. + normalise : bool, optional + If True, the voronoi regions are computed using the Euclidean distance between + points normalised with respect to the bounds (default: True). + resolution : int, optional + Resolution of the plot. Default is 500. + show : bool, optional + If True, the figure is shown upon creation (default: True). + **layout_kwargs : optional + Valid Plotly layout keys and their values, + e.g. `xaxis_title="Time [s]"` or + `xaxis={"title": "Time [s]", font={"size":14}}` + """ + return call_plotting_function( + "surface", + backend, + result=result, + bounds=bounds, + normalise=normalise, + resolution=resolution, + show=show, + **layout_kwargs, + ) diff --git a/pybop/plot/problem.py b/pybop/plot/problem.py index 0372601e6..27b573641 100644 --- a/pybop/plot/problem.py +++ b/pybop/plot/problem.py @@ -4,6 +4,7 @@ from pybop.costs.error_measures import ErrorMeasure from pybop.parameters.parameter import Inputs from pybop.plot.standard_plots import StandardPlot +from pybop.plot.util import get_default_options, update_and_show from pybop.problems.meta_problem import MetaProblem from pybop.problems.problem import Problem from pybop.simulators.solution import Solution @@ -13,7 +14,8 @@ def problem( problem: Problem, inputs: Inputs = None, show: bool = True, - **layout_kwargs, + title="Scatter Plot", + backend=None, ): """ Produce a quick plot of the target dataset against optimised model output. @@ -29,10 +31,6 @@ def problem( Optimised (or example) parameter values. show : bool, optional If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time / s"` or - `xaxis={"title": "Time [s]", font={"size":14}}` Returns ------- @@ -66,33 +64,38 @@ def problem( model_output = problem.simulate(inputs) model_domain = target_domain[: len(model_output[target].data)] + # Retrieve default layout options + plot_options = get_default_options("problem", backend) + trace_options = plot_options.get("default_trace_options") or {} + design_cost_options = plot_options.get("design_cost_options") or {} + meta_problem_options = plot_options.get("meta_problem_options") or {} + reference_options = plot_options.get("reference_options") or {} + fill_options = plot_options.get("fill_options") or {} + # Create a plot for each output figure_list = [] for var in problem.target: + options = trace_options.copy() + if isinstance(problem, MetaProblem): + options.update(meta_problem_options) + if isinstance(problem.cost, DesignCost): + options.update(design_cost_options) + # Create a plot dictionary plot_dict = StandardPlot( - layout_options=dict( - title="Scatter Plot", - xaxis_title=StandardPlot.remove_brackets(domain), - yaxis_title=StandardPlot.remove_brackets(var), - ) + title=title, + xaxis_title=StandardPlot.remove_brackets(domain), + yaxis_title=StandardPlot.remove_brackets(var), + backend=backend, ) model_trace = plot_dict.create_trace( - x=model_domain, - y=model_output[var].data, - name="Optimised" if isinstance(problem.cost, DesignCost) else "Model", - mode="markers" if isinstance(problem, MetaProblem) else "lines", - showlegend=True, + x=model_domain, y=model_output[var].data, **options ) plot_dict.traces.append(model_trace) target_trace = plot_dict.create_trace( - x=target_domain, - y=target_output[var].data, - name="Reference", - mode="markers", - showlegend=True, + x=target_domain, y=target_output[var].data, **reference_options ) plot_dict.traces.append(target_trace) @@ -107,14 +110,8 @@ def problem( y_upper = (model_output[var].data + plot_dict.sigma).tolist() y_lower = (model_output[var].data - plot_dict.sigma).tolist() - fill_trace = plot_dict.create_trace( - x=x + x[::-1], - y=y_upper + y_lower[::-1], - fill="toself", - fillcolor="rgba(255,229,204,0.8)", - line=dict(color="rgba(255,255,255,0)"), - hoverinfo="skip", - showlegend=False, + fill_trace = plot_dict.create_fill_trace( + x, y_upper, y_lower, **fill_options ) plot_dict.traces.append(fill_trace) @@ -123,9 +120,7 @@ def problem( # Generate the figure and update the layout fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + fig = update_and_show(fig, backend=backend) figure_list.append(fig) diff --git a/pybop/plot/samples.py b/pybop/plot/samples.py index 55ee77cd0..3ea022cce 100644 --- a/pybop/plot/samples.py +++ b/pybop/plot/samples.py @@ -1,109 +1,107 @@ from typing import TYPE_CHECKING -from pybop.plot import PlotlyManager +from pybop.plot import StandardPlot +from pybop.plot.util import call_plotting_function, get_default_options, update_and_show if TYPE_CHECKING: from pybop.samplers.base_pints_sampler import SamplingResult -def trace(result: "SamplingResult", **kwargs): +def chains(result: "SamplingResult", show=True, backend=None): """ - Plot trace plots for the posterior samples. + Plot posterior distributions for each chain. """ - # Import plotly only when needed - go = PlotlyManager().go + options = get_default_options("posterior", backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + trace_options_vline = options.get("trace_options_vline") or {} - for i in range(result.n_parameters): - fig = go.Figure() + plot_dict = StandardPlot( + backend=backend, + title="Posterior Distribution", + xaxis_title="Value", + yaxis_title="Density", + **plot_options, + ) - for j, chain in enumerate(result.chains): - fig.add_trace(go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}")) + for i, chain in enumerate(result.chains): + for j in range(chain.shape[1]): + hist = plot_dict.create_histogram( + x=chain[:, j], name=f"Chain {i} - Parameter {j}", **trace_options + ) + plot_dict.traces.append(hist) - fig.update_layout( - title=f"Parameter {i} Trace Plot", - xaxis_title="Sample Index", - yaxis_title="Value", - ) - fig.update_layout(**kwargs) - fig.show() + fig = plot_dict(show=False) + for j in range(chain.shape[1]): + plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) + update_and_show(fig, backend=backend) -def chains(result: "SamplingResult", **kwargs): + +def trace(result: "SamplingResult", show=True, backend=None): """ - Plot posterior distributions for each chain. + Plot trace plots for the posterior samples. """ - # Import plotly only when needed - go = PlotlyManager().go - - fig = go.Figure() + figlist = [] + options = get_default_options("trace", backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + for i in range(result.n_parameters): + plots = StandardPlot( + title=f"Parameter {i} Trace Plot", + xaxis_title="Sample Index", + yaxis_title="Value", + backend=backend, + **plot_options, + ) - for i, chain in enumerate(result.chains): - for j in range(chain.shape[1]): - fig.add_trace( - go.Histogram( - x=chain[:, j], - name=f"Chain {i} - Parameter {j}", - opacity=0.75, - ) + for j, chain in enumerate(result.chains): + plots.traces.append( + plots.create_trace(y=chain[:, i], label=f"Chain {j}", **trace_options) ) + fig = plots(show=False) + figlist.append(fig) - fig.add_shape( - type="line", - x0=result.mean[j], - y0=0, - x1=result.mean[j], - y1=result.max[j], - name=f"Mean - Parameter {j}", - line=dict(color="Black", width=1.5, dash="dash"), - ) + update_and_show(figlist, show=show, backend=backend) - fig.update_layout( - barmode="overlay", - title="Posterior Distribution", - xaxis_title="Value", - yaxis_title="Density", - ) - fig.update_layout(**kwargs) - fig.show() + return figlist -def posterior(result: "SamplingResult", **kwargs): +def posterior(result: "SamplingResult", backend=None): """ Plot the summed posterior distribution across chains. """ + options = get_default_options("posterior", backend) + plot_options = options.get("plot_options") or {} + trace_options = options.get("trace_options") or {} + trace_options_vline = options.get("trace_options_vline") or {} # Import plotly only when needed - go = PlotlyManager().go - - fig = go.Figure() - - for j in range(result.all_samples.shape[1]): - histogram = go.Histogram( - x=result.all_samples[:, j], - name=f"Parameter {j}", - opacity=0.75, - ) - fig.add_trace(histogram) - fig.add_vline( - x=result.mean[j], line_width=3, line_dash="dash", line_color="black" - ) - - fig.update_layout( - barmode="overlay", + plot_dict = StandardPlot( + backend=backend, title="Posterior Distribution", xaxis_title="Value", yaxis_title="Density", + **plot_options, ) - fig.update_layout(**kwargs) - fig.show() - return fig + + for j in range(result.all_samples.shape[1]): + hist = plot_dict.create_histogram( + x=result.all_samples[:, j], name=f"Parameter {j}", **trace_options + ) + + plot_dict.traces.append(hist) + + fig = plot_dict(show=False) + for j in range(result.all_samples.shape[1]): + plot_dict.create_vline(fig, result.mean[j], **trace_options_vline) + + update_and_show(fig, backend=backend) -def summary_table(result: "SamplingResult"): +def summary_table(result: "SamplingResult", backend=None): """ Display summary statistics in a table. """ - # Import plotly only when needed - go = PlotlyManager().go summary_stats = result.get_summary_statistics() @@ -116,16 +114,10 @@ def summary_table(result: "SamplingResult"): ["95% CI Upper", summary_stats["ci_upper"]], ] - fig = go.Figure( - data=[ - go.Table( - header=dict(values=header), - cells=dict( - values=[[row[0] for row in values], [row[1] for row in values]] - ), - ) - ] + call_plotting_function( + "show_table", + backend=backend, + header=header, + values=values, + title="Summary Statistics", ) - - fig.update_layout(title="Summary Statistics") - fig.show() diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index 4422516b8..3287789c5 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -3,154 +3,66 @@ import numpy as np -from pybop.plot.plotly_manager import PlotlyManager - -DEFAULT_LAYOUT_OPTIONS = dict( - title=None, - title_x=0.5, - xaxis=dict( - title=dict(font={"size": 14}), - showexponent="last", - exponentformat="e", - tickfont=dict(size=12), - ), - yaxis=dict( - title=dict(font={"size": 14}), - showexponent="last", - exponentformat="e", - tickfont=dict(size=12), - ), - legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12), - showlegend=True, - autosize=False, - width=600, - height=600, - margin=dict(l=10, r=10, b=10, t=75, pad=4), - plot_bgcolor="white", -) -DEFAULT_SUBPLOT_OPTIONS = dict( - start_cell="bottom-left", -) -DEFAULT_TRACE_OPTIONS = dict(line=dict(width=4), mode="lines") -DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines") +from pybop.plot.util import call_plotting_function class StandardPlot: - """ - A class for creating and displaying interactive Plotly figures. - - Parameters - ---------- - x : list or np.ndarray, optional - X-axis data points. - y : list or np.ndarray, optional - Primary Y-axis data points for simulated model output. - layout : Plotly layout, optional - A layout for the figure, overrides the layout options (default: None). - layout_options : dict, optional - Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). - trace_options : dict, optional - Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). - trace_names : str, optional - Name(s) for the primary trace(s) (default: None). - trace_name_width : int, optional - Maximum length of the trace names before text wrapping is used (default: 40). - - Returns - ------- - plotly.graph_objs.Figure - The generated Plotly figure. - """ - def __init__( self, x=None, y=None, - layout=None, - layout_options=None, + title=None, + xaxis_title=None, + yaxis_title=None, trace_options=None, trace_names=None, - trace_name_width=40, + trace_name_width=20, + backend=None, + **kwargs, ): - self.traces = [] - self.layout = layout - self.trace_name_width = trace_name_width - - # Set default layout options and update if provided - if self.layout is None: - self.layout_options = DEFAULT_LAYOUT_OPTIONS.copy() - if layout_options: - self.layout_options.update(layout_options) - # Set default trace options and update if provided - self.trace_options = DEFAULT_TRACE_OPTIONS.copy() - if trace_options: - self.trace_options.update(trace_options) - - # Attempt to import plotly when an instance is created - self.go = PlotlyManager().go + self.plotter = call_plotting_function( + "Plotter", + backend, + title=title, + xaxis_title=xaxis_title, + yaxis_title=yaxis_title, + trace_options=trace_options, + **kwargs, + ) - # Create layout - if self.layout is None: - self.layout = self.go.Layout(**self.layout_options) + self.trace_name_width = trace_name_width # Add traces if x is not None and y is not None: self.add_traces(x, y, trace_names) def __call__(self, show=True): - """ - Generate and show the figure. + return self.plotter(show=show) - Parameters - ---------- - show : bool, optional - If True, the figure is shown upon creation (default: True). - """ - fig = self.go.Figure(data=self.traces, layout=self.layout) - if show: - fig.show() + @property + def traces(self): + return self.plotter.traces - return fig - - def add_traces(self, x, y, trace_names=None, **trace_options): - """ - Add a set of traces to the plot dictionary. - - Parameters - ---------- - x : list or np.ndarray - X-axis data points. - y : list or np.ndarray - Primary Y-axis data points for simulated model output. - trace_names : str or list[str], optional - Name(s) for the primary trace(s) (default: None). - """ - options = self.trace_options.copy() - options.update(trace_options) + @traces.setter + def traces(self, value): + self.plotter.traces = value + def add_traces(self, x, y, trace_names): # Check and wrap trace names if trace_names is not None: if isinstance(trace_names, str): trace_names = [trace_names] for i, name in enumerate(trace_names): - trace_names[i] = self.wrap_text(name, width=self.trace_name_width) + trace_names[i] = self.wrap_text( + name, width=self.trace_name_width, backend=self.plotter.backend + ) # Parse the data x, y = self.parse_data(x, y) - # Create a trace for each trajectory - xi = x[0] - for i in range(0, len(y)): - trace_options = options.copy() - if len(x) > 1: - xi = x[i] - if trace_names is not None: - trace_options["name"] = trace_names[i] - else: - trace_options["showlegend"] = False - trace = self.create_trace(xi, y[i], **trace_options) - self.traces.append(trace) + # Add traces + self.plotter.add_traces(x, y, trace_names) def parse_data(self, x, y): """ @@ -191,24 +103,23 @@ def parse_data(self, x, y): ) return x, y - def create_trace(self, x, y, **trace_options): - """ - Create a trace for the Plotly figure. + def create_trace(self, x=None, y=None, label=None, **trace_options): + return self.plotter.create_trace(x, y, label, **trace_options) - Returns - ------- - plotly.graph_objs.Scatter - A trace for a Plotly figure. - """ + def create_fill_trace(self, x, y_upper, y_lower, **options): + return self.plotter.create_fill_trace(x, y_upper, y_lower, **options) - return self.go.Scatter( - x=x, - y=y, - **trace_options, - ) + def create_histogram(self, x, name, **trace_options): + return self.plotter.create_histogram(x, name, **trace_options) + + def create_vline(self, fig, x, **trace_options): + return self.plotter.create_vline(fig, x, **trace_options) + + def create_contour(self, x, y, z, **trace_options): + return self.plotter.create_contour(x, y, z, **trace_options) @staticmethod - def wrap_text(text, width): + def wrap_text(text, width, backend="matplotlib"): """ Wrap text to a specified width with HTML line breaks. @@ -225,7 +136,10 @@ def wrap_text(text, width): The wrapped text. """ wrapped_text = textwrap.fill(text, width=width, break_long_words=False) - return wrapped_text.replace("\n", "
") + if backend == "plotly": + return wrapped_text.replace("\n", "
") + else: + return wrapped_text @staticmethod def remove_brackets(s): @@ -280,20 +194,30 @@ def __init__( self, x, y, + backend=None, num_rows=None, num_cols=None, axis_titles=None, - layout=None, - layout_options=DEFAULT_LAYOUT_OPTIONS, - subplot_options=DEFAULT_SUBPLOT_OPTIONS, - trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS, + trace_options=None, trace_names=None, trace_name_width=40, + **kwargs, ): - super().__init__( - x, y, layout, layout_options, trace_options, trace_names, trace_name_width + self.plotter = call_plotting_function( + "SubplotPlotter", + backend, + axis_titles=axis_titles, + trace_options=trace_options, + **kwargs, ) - self.num_traces = len(self.traces) + + self.trace_name_width = trace_name_width + + # Add traces + if x is not None and y is not None: + self.add_traces(x, y, trace_names) + + self.num_traces = len(self.plotter.traces) self.num_rows = num_rows self.num_cols = num_cols if self.num_rows is None and self.num_cols is None: @@ -304,56 +228,12 @@ def __init__( self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) elif self.num_cols is None: self.num_cols = int(math.ceil(self.num_traces / self.num_rows)) - self.axis_titles = axis_titles - self.subplot_options = subplot_options.copy() - if subplot_options is not None: - for arg, value in subplot_options.items(): - self.subplot_options[arg] = value - - # Attempt to import plotly when an instance is created - self.make_subplots = PlotlyManager().make_subplots - - def __call__(self, show): - """ - Generate and show the set of figures. - - Parameters - ---------- - show : bool, optional - If True, the figure is shown upon creation (default: True). - """ - fig = self.make_subplots( - rows=self.num_rows, - cols=self.num_cols, - horizontal_spacing=0.1, - vertical_spacing=0.15, - **self.subplot_options, - ) - fig.update_layout(self.layout_options) - - for idx, trace in enumerate(self.traces): - row = (idx // self.num_cols) + 1 - col = (idx % self.num_cols) + 1 - fig.add_trace(trace, row=row, col=col) - - if self.axis_titles and idx < len(self.axis_titles): - x_title, y_title = self.axis_titles[idx] - fig.update_xaxes(title_text=x_title, row=row, col=col) - fig.update_yaxes( - title_text=y_title, - row=row, - col=col, - showexponent="last", - exponentformat="e", - ) - - if show: - fig.show() - return fig + def __call__(self, show=True): + return self.plotter(show=show, num_rows=self.num_rows, num_cols=self.num_cols) -def trajectories(x, y, trace_names=None, show=True, **layout_kwargs): +def trajectories(x, y, trace_names=None, show=True, backend=None, **layout_kwargs): """ Quickly plot one or more trajectories using Plotly. @@ -375,17 +255,13 @@ def trajectories(x, y, trace_names=None, show=True, **layout_kwargs): plotly.graph_objs.Figure The Plotly figure object for the scatter plot. """ - # Create a plot dictionary - plot_dict = StandardPlot( + + return call_plotting_function( + "trajectories", + backend, x=x, y=y, trace_names=trace_names, + show=show, + **layout_kwargs, ) - - # Generate the figure and update the layout - fig = plot_dict(show=False) - fig.update_layout(**layout_kwargs) - if show: - fig.show() - - return fig diff --git a/pybop/plot/util.py b/pybop/plot/util.py new file mode 100644 index 000000000..a30b0cee1 --- /dev/null +++ b/pybop/plot/util.py @@ -0,0 +1,57 @@ +import importlib.util + +import pybop.plot + + +def set_backend(backend): + err_msg = ( + f"Plotting backend {backend} is not available. The default backend has not been updated. \n" + f"The default backend is set to {pybop.plot.backend}" + ) + try: + importlib.import_module("pybop.plot." + backend) + pybop.plot.backend = backend + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + + +def call_plotting_function(function_name, backend, **kwargs): + if backend is None: + backend = pybop.plot.backend + err_msg = f"Plotting backend {backend} is not available." + try: + module = importlib.import_module("pybop.plot." + backend) + if hasattr(module, function_name): + plotting_function = getattr(module, function_name) + # Return the imported attribute + return plotting_function(**kwargs) + else: + err_msg = f"Plotting backend {backend} has no attribute {function_name}." + raise ModuleNotFoundError(err_msg) + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error + + +def update_and_show(fig, backend=None, **kwargs): + return call_plotting_function("update_and_show", backend, fig=fig, **kwargs) + + +def get_default_options(plot_type, backend): + if backend is None: + backend = pybop.plot.backend + + if backend == "plotly": + opts = pybop.plot.plotly.DEFAULT_PLOT_OPTIONS + elif backend == "matplotlib": + opts = pybop.plot.matplotlib.DEFAULT_PLOT_OPTIONS + else: + opts = {} + + if plot_type in opts.keys(): + return opts[plot_type] + else: + return {} diff --git a/pybop/plot/voronoi.py b/pybop/plot/voronoi.py index 29be3c096..2ac57fada 100644 --- a/pybop/plot/voronoi.py +++ b/pybop/plot/voronoi.py @@ -1,11 +1,10 @@ from typing import TYPE_CHECKING import numpy as np -from scipy.spatial import Voronoi, cKDTree +from scipy.spatial import Voronoi if TYPE_CHECKING: - from pybop._result import Result -from pybop.plot.plotly_manager import PlotlyManager + pass def _voronoi_regions(x, y, f, xlim, ylim): @@ -195,85 +194,7 @@ def interpolate_point(p, q, axis, boundary_val): return np.array([boundary_val, s]) if axis == 0 else np.array([s, boundary_val]) -def assign_nearest_value(x, y, f, xi, yi): - """ - Computes an array of values given by the score of the nearest point. - - Parameters - ---------- - x : array-like - The x coordinates of points with known scores. - y : array-like - The y coordinates of points with known scores. - f : array-like - The score function at the given x and y coordinates. - xi : array-like - The x coordinates of grid points. - yi : array-like - The y coordinates of grid points. - - Returns - ------- - A numpy array containing the scores corresponding to the grid points. - """ - # Create a KD-tree for efficient nearest neighbor search - tree = cKDTree(np.column_stack((x, y))) - - # Find the nearest point for each grid point - _, indices = tree.query(np.column_stack((xi.ravel(), yi.ravel()))) - zi = f[indices].reshape(xi.shape) - - return zi - - -def surface( - result: "Result", - bounds=None, - normalise=True, - resolution=250, - show=True, - **layout_kwargs, -): - """ - Plot a 2D representation of the Voronoi diagram with color-coded regions. - - Parameters: - ----------- - result : pybop.Result - Optimisation result containing the history of parameter values and associated cost. - bounds : numpy.ndarray, optional - A 2x2 array specifying the [min, max] bounds for each parameter. If None, uses - `cost.parameters.get_bounds_for_plotly`. - normalise : bool, optional - If True, the voronoi regions are computed using the Euclidean distance between - points normalised with respect to the bounds (default: True). - resolution : int, optional - Resolution of the plot. Default is 500. - show : bool, optional - If True, the figure is shown upon creation (default: True). - **layout_kwargs : optional - Valid Plotly layout keys and their values, - e.g. `xaxis_title="Time [s]"` or - `xaxis={"title": "Time [s]", font={"size":14}}` - """ - points = result.x_model - parameters = result.problem.parameters - - if points[0].shape[0] != 2: - raise ValueError("This plot method requires two parameters.") - - x_optim, y_optim = map(list, zip(*points, strict=False)) - f = result.cost - - # Translate bounds, taking only the first two elements - xlim, ylim = ( - bounds if bounds is not None else [param.bounds for param in parameters] - )[:2] - - # Create a grid for plot - xi = np.linspace(xlim[0], xlim[1], resolution) - yi = np.linspace(ylim[0], ylim[1], resolution) - xi, yi = np.meshgrid(xi, yi) +def voronoi_data(xlim, ylim, pts_x, pts_y, f, normalise=True): if normalise: if xlim[1] <= xlim[0] or ylim[1] <= ylim[0]: @@ -282,21 +203,14 @@ def surface( # Normalise the region x_range = xlim[1] - xlim[0] y_range = ylim[1] - ylim[0] - norm_x_optim = (np.asarray(x_optim) - xlim[0]) / x_range - norm_y_optim = (np.asarray(y_optim) - ylim[0]) / y_range + norm_x_optim = (np.asarray(pts_x) - xlim[0]) / x_range + norm_y_optim = (np.asarray(pts_y) - ylim[0]) / y_range # Compute regions - norm_x, norm_y, f, norm_regions = _voronoi_regions( + x, y, f, norm_regions = _voronoi_regions( norm_x_optim, norm_y_optim, f, (0, 1), (0, 1) ) - # Create a normalised grid - norm_xi = np.linspace(0, 1, resolution) - norm_xi, norm_yi = np.meshgrid(norm_xi, norm_xi) - - # Assign a value to each point in the grid - zi = assign_nearest_value(norm_x, norm_y, f, norm_xi, norm_yi) - # Rescale for plotting regions = [] for norm_region in norm_regions: @@ -307,10 +221,7 @@ def surface( else: # Compute regions - x, y, f, regions = _voronoi_regions(x_optim, y_optim, f, xlim, ylim) - - # Assign a value to each point in the grid - zi = assign_nearest_value(x, y, f, xi, yi) + x, y, f, regions = _voronoi_regions(pts_x, pts_y, f, xlim, ylim) # Calculate the size of each Voronoi region region_sizes = np.array([len(region) for region in regions]) @@ -318,107 +229,4 @@ def surface( region_sizes.max() - region_sizes.min() ) - # Construct figure - go = PlotlyManager().go - fig = go.Figure() - - # Heatmap - fig.add_trace( - go.Heatmap( - x=xi[0], - y=yi[:, 0], - z=zi, - colorscale="Viridis", - zsmooth="best", - ) - ) - - # Add Voronoi edges - for region, size in zip(regions, relative_sizes, strict=False): - x_region = region[:, 0].tolist() + [region[0, 0]] - y_region = region[:, 1].tolist() + [region[0, 1]] - - fig.add_trace( - go.Scatter( - x=x_region, - y=y_region, - mode="lines", - line=dict(color="white", width=0.5 + size * 0.1), - showlegend=False, - ) - ) - - # Add original points - fig.add_trace( - go.Scatter( - x=x_optim, - y=y_optim, - mode="markers", - marker=dict( - color=[i / len(x_optim) for i in range(len(x_optim))], - colorscale="Greys", - size=8, - showscale=False, - ), - text=[f"f={val:.2f}" for val in f], - hoverinfo="text", - showlegend=False, - ) - ) - - # Plot the initial guess - if len(result.x_model) > 0: - x0 = result.x_model[0] - fig.add_trace( - go.Scatter( - x=[x0[0]], - y=[x0[1]], - mode="markers", - marker_symbol="x", - marker=dict( - color="white", - line_color="black", - line_width=1, - size=14, - showscale=False, - ), - name="Initial values", - ) - ) - - # Plot optimised value - if result.x is not None: - x_best = result.x - fig.add_trace( - go.Scatter( - x=[x_best[0]], - y=[x_best[1]], - mode="markers", - marker_symbol="cross", - marker=dict( - color="black", - line_color="white", - line_width=1, - size=14, - showscale=False, - ), - name="Final values", - ) - ) - - names = parameters.names - fig.update_layout( - title="Voronoi Cost Landscape", - title_x=0.5, - title_y=0.905, - xaxis_title=names[0], - yaxis_title=names[1], - width=600, - height=600, - xaxis=dict(range=xlim, showexponent="last", exponentformat="e"), - yaxis=dict(range=ylim, showexponent="last", exponentformat="e"), - legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1), - ) - fig.update_layout(**layout_kwargs) - if show: - fig.show() + return x, y, f, regions, relative_sizes diff --git a/tests/plotting/test_plotly_manager.py b/tests/plotting/test_plotly_manager.py index f64e608f1..b39df37cd 100644 --- a/tests/plotting/test_plotly_manager.py +++ b/tests/plotting/test_plotly_manager.py @@ -8,7 +8,7 @@ import pytest import pybop -from pybop.plot import PlotlyManager +from pybop.plot.plotly import PlotlyManager # Find the Python executable python_executable = which("python") @@ -128,6 +128,9 @@ def dataset(plotly_installed): @pytest.mark.unit def test_standard_plot(dataset, plotly_installed): + # Set plotting backend + pybop.plot.set_backend("plotly") + # Check the StandardPlot class pybop.plot.StandardPlot(dataset["Time [s]"], dataset["Voltage [V]"]) @@ -167,6 +170,9 @@ def test_standard_plot(dataset, plotly_installed): @pytest.mark.unit def test_plot_dataset(dataset, plotly_installed): + # Set plotting backend + pybop.plot.set_backend("plotly") + # Test plot of a dataset pybop.plot.dataset(dataset, signal=["Voltage [V]"]) pybop.plot.dataset(dataset, signal=["Voltage [V]", "Current [A]"]) diff --git a/tests/unit/test_plots.py b/tests/unit/test_plots.py index 58b707900..db9e0c6c1 100644 --- a/tests/unit/test_plots.py +++ b/tests/unit/test_plots.py @@ -6,6 +6,7 @@ import pybop +@pytest.mark.parametrize("backend", ["plotly", "matplotlib"]) class TestPlots: """ A class to test the plot classes. @@ -13,7 +14,8 @@ class TestPlots: pytestmark = pytest.mark.unit - def test_standard_plot(self): + def test_standard_plot(self, backend): + pybop.plot.set_backend(backend) # Test standard plot trace_names = pybop.plot.StandardPlot.remove_brackets( ["Trace [1]", "Trace [2]"] @@ -69,7 +71,8 @@ def dataset(self, model): solution = pybamm.Simulation(model).solve(t_eval=t_eval, t_interp=t_eval) return pybop.import_pybamm_solution(solution) - def test_dataset_plots(self, dataset): + def test_dataset_plots(self, dataset, backend): + pybop.plot.set_backend(backend) # Test plot of Dataset objects pybop.plot.trajectories( dataset["Time [s]"], @@ -112,7 +115,8 @@ def design_problem(self, model, parameters, experiment): ) return pybop.Problem(simulator) - def test_problem_plots(self, fitting_problem, design_problem): + def test_problem_plots(self, fitting_problem, design_problem, backend): + pybop.plot.set_backend(backend) # Test plot of Problem objects pybop.plot.problem(fitting_problem, title="Optimised Comparison") pybop.plot.problem(design_problem) @@ -122,7 +126,8 @@ def test_problem_plots(self, fitting_problem, design_problem): fitting_problem, inputs=fitting_problem.parameters.to_dict([0.6, 0.6]) ) - def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds): + def test_cost_plots(self, fitting_problem, fitting_problem_no_bounds, backend): + pybop.plot.set_backend(backend) # Test plot of Cost objects pybop.plot.contour(fitting_problem, gradient=True, steps=5) @@ -143,7 +148,8 @@ def result(self, fitting_problem): optim = pybop.XNES(fitting_problem) return optim.run() - def test_optim_plots(self, result): + def test_optim_plots(self, result, backend): + pybop.plot.set_backend(backend) bounds = np.asarray([[0.5, 0.8], [0.4, 0.7]]) # Plot convergence @@ -185,7 +191,8 @@ def sampling_result(self, model, parameters, dataset): sampler = pybop.SliceStepoutMCMC(log_pdf, options=options) return sampler.run() - def test_posterior_plots(self, sampling_result): + def test_posterior_plots(self, sampling_result, backend): + pybop.plot.set_backend(backend) sampling_result.get_summary_statistics() # Plot trace @@ -200,9 +207,11 @@ def test_posterior_plots(self, sampling_result): # Plot summary table sampling_result.summary_table() - def test_with_ipykernel(self, dataset, fitting_problem, result): + def test_with_ipykernel(self, dataset, fitting_problem, result, backend): import ipykernel + pybop.plot.set_backend(backend) + assert version.parse(ipykernel.__version__) >= version.parse("0.6") pybop.plot.dataset(dataset, signal=["Voltage [V]"]) pybop.plot.contour(fitting_problem, gradient=True, steps=5) @@ -210,7 +219,8 @@ def test_with_ipykernel(self, dataset, fitting_problem, result): result.plot_parameters() result.plot_contour(steps=5) - def test_contour_incorrect_number_of_parameters(self, model, dataset): + def test_contour_incorrect_number_of_parameters(self, model, dataset, backend): + pybop.plot.set_backend(backend) parameter_values = model.default_parameter_values # Test with less than two paramters @@ -251,7 +261,8 @@ def test_contour_incorrect_number_of_parameters(self, model, dataset): fitting_problem = pybop.Problem(simulator, cost) pybop.plot.contour(fitting_problem) - def test_nyquist(self): + def test_nyquist(self, backend): + pybop.plot.set_backend(backend) # Define model model = pybamm.lithium_ion.SPM(options={"surface form": "differential"}) parameter_values = model.default_parameter_values