55
66import matplotlib .patches as patches
77import matplotlib .pyplot as plt
8+ import numpy as np
89from plotly .subplots import make_subplots
910from tikzfigure import TikzFigure
1011
@@ -579,6 +580,41 @@ def text(
579580 """Add a text label at (x, y) on a subplot."""
580581 self ._get_or_create_subplot (row , col ).text (x , y , s , layer = layer , ** kwargs )
581582
583+ def imshow (
584+ self ,
585+ data ,
586+ layer = 0 ,
587+ row : int | None = None ,
588+ col : int | None = None ,
589+ ** kwargs ,
590+ ):
591+ """Add an image/matrix plot to a subplot."""
592+ self ._get_or_create_subplot (row , col ).add_imshow (data , layer = layer , ** kwargs )
593+
594+ def add_patch (
595+ self ,
596+ patch ,
597+ layer = 0 ,
598+ row : int | None = None ,
599+ col : int | None = None ,
600+ ** kwargs ,
601+ ):
602+ """Add a Matplotlib patch to a subplot."""
603+ self ._get_or_create_subplot (row , col ).add_patch (patch , layer = layer , ** kwargs )
604+
605+ def colorbar (
606+ self ,
607+ label : str = "" ,
608+ layer = 0 ,
609+ row : int | None = None ,
610+ col : int | None = None ,
611+ ** kwargs ,
612+ ):
613+ """Add a colorbar to the most recent imshow() on a subplot (matplotlib backend)."""
614+ self ._get_or_create_subplot (row , col ).add_colorbar (
615+ label = label , layer = layer , ** kwargs
616+ )
617+
582618 # ------------------------------------------------------------------
583619 # Multi-subplot helpers
584620 # ------------------------------------------------------------------
@@ -773,6 +809,34 @@ def savefig(
773809 figure .savefig (full_filepath )
774810 if verbose :
775811 print (f"Saved { full_filepath } " )
812+ elif backend == "plotly" :
813+ if layer_by_layer :
814+ layers = []
815+ for layer in self .layers :
816+ layers .append (layer )
817+ full_filepath = f"{ filename_no_extension } _{ layers } { extension } "
818+ fig = self .plot (
819+ backend = "plotly" ,
820+ savefig = False ,
821+ layers = layers ,
822+ )
823+ self ._save_plotly (fig , full_filepath )
824+ if verbose :
825+ print (f"Saved { full_filepath } " )
826+ else :
827+ if layers is None :
828+ layers = self .layers
829+ full_filepath = filename
830+ else :
831+ full_filepath = f"{ filename_no_extension } _{ layers } { extension } "
832+ fig = self .plot (
833+ backend = "plotly" ,
834+ savefig = False ,
835+ layers = layers ,
836+ )
837+ self ._save_plotly (fig , full_filepath )
838+ if verbose :
839+ print (f"Saved { full_filepath } " )
776840
777841 def plot (
778842 self ,
@@ -797,6 +861,7 @@ def plot(
797861 elif backend == "plotly" :
798862 return self .plot_plotly (
799863 savefig = savefig ,
864+ layers = layers ,
800865 usetex = resolved_usetex ,
801866 verbose = verbose ,
802867 )
@@ -832,7 +897,11 @@ def show(
832897 # self._matplotlib_fig.show()
833898 elif backend == "plotly" :
834899 resolved_usetex = self ._usetex if usetex is None else usetex
835- self .plot_plotly (savefig = False , usetex = resolved_usetex )
900+ fig = self .plot_plotly (
901+ savefig = False , layers = layers , usetex = resolved_usetex , verbose = verbose
902+ )
903+ fig .show ()
904+ return fig
836905 elif backend == "plotext" :
837906 figure = self .plot_plotext (
838907 savefig = False ,
@@ -1034,6 +1103,7 @@ def plot_plotly(
10341103 self ,
10351104 show = True ,
10361105 savefig = None ,
1106+ layers : list | None = None ,
10371107 usetex : bool | None = None ,
10381108 verbose : bool = False ,
10391109 ):
@@ -1063,38 +1133,134 @@ def plot_plotly(
10631133 ratio = self ._ratio ,
10641134 )
10651135 # print(self._width, fig_width, fig_height)
1066- # Create subplots
1136+ # Create subplot titles in row-major order (Plotly expects rows*cols entries)
1137+ subplot_titles = ["" ] * (self .nrows * self .ncols )
1138+ for (row , col ), sp in self ._subplot_dict .items ():
1139+ index = row * self .ncols + col
1140+ subplot_titles [index ] = sp ._title or f"({ row } , { col } )"
1141+
10671142 fig = make_subplots (
10681143 rows = self .nrows ,
10691144 cols = self .ncols ,
1070- subplot_titles = [
1071- sp ._title or f"({ row } , { col } )"
1072- for (row , col ), sp in self ._subplot_dict .items ()
1073- ],
1145+ subplot_titles = subplot_titles ,
10741146 )
10751147
10761148 # Plot each subplot and propagate axis labels/scale
1077- axis_index = 1
10781149 for (row , col ), line_plot in self ._subplot_dict .items ():
1079- traces = line_plot .plot_plotly ()
1150+ traces , shapes , annotations = line_plot .plot_plotly (layers = layers )
10801151 for trace in traces :
10811152 fig .add_trace (trace , row = row + 1 , col = col + 1 )
10821153
1083- # Axis label keys are "xaxis", "xaxis2", "xaxis3", ...
1084- xkey = "xaxis" if axis_index == 1 else f"xaxis{ axis_index } "
1085- ykey = "yaxis" if axis_index == 1 else f"yaxis{ axis_index } "
1086- layout_patch = {}
1087- if line_plot ._xlabel :
1088- layout_patch [xkey ] = {"title" : {"text" : line_plot ._xlabel }}
1089- if line_plot ._ylabel :
1090- layout_patch [ykey ] = {"title" : {"text" : line_plot ._ylabel }}
1154+ # Axis indices are row-major: (row*ncols + col + 1)
1155+ axis_index = row * self .ncols + col + 1
1156+ xref = "x" if axis_index == 1 else f"x{ axis_index } "
1157+ yref = "y" if axis_index == 1 else f"y{ axis_index } "
1158+
1159+ for shape in shapes :
1160+ shape = dict (shape )
1161+ if shape .get ("xref" ) not in {"paper" }:
1162+ shape ["xref" ] = xref
1163+ if shape .get ("yref" ) not in {"paper" }:
1164+ shape ["yref" ] = yref
1165+ fig .add_shape (shape )
1166+
1167+ for annotation in annotations :
1168+ annotation = dict (annotation )
1169+ annotation .setdefault ("xref" , xref )
1170+ annotation .setdefault ("yref" , yref )
1171+ fig .add_annotation (annotation )
1172+
1173+ # Apply per-axis config in a row/col-safe way
1174+ xaxis_kwargs = dict (
1175+ title_text = line_plot ._xlabel or None ,
1176+ showgrid = bool (line_plot ._grid ),
1177+ row = row + 1 ,
1178+ col = col + 1 ,
1179+ )
10911180 if line_plot ._xaxis_scale == "log" :
1092- layout_patch .setdefault (xkey , {})["type" ] = "log"
1181+ xaxis_kwargs ["type" ] = "log"
1182+ fig .update_xaxes (** xaxis_kwargs )
1183+
1184+ yaxis_kwargs = dict (
1185+ title_text = line_plot ._ylabel or None ,
1186+ showgrid = bool (line_plot ._grid ),
1187+ row = row + 1 ,
1188+ col = col + 1 ,
1189+ )
10931190 if line_plot ._yaxis_scale == "log" :
1094- layout_patch .setdefault (ykey , {})["type" ] = "log"
1095- if layout_patch :
1096- fig .update_layout (** layout_patch )
1097- axis_index += 1
1191+ yaxis_kwargs ["type" ] = "log"
1192+ fig .update_yaxes (** yaxis_kwargs )
1193+
1194+ # Axis limits
1195+ if line_plot ._xmin is not None or line_plot ._xmax is not None :
1196+ x_range = [line_plot ._xmin , line_plot ._xmax ]
1197+ if x_range [0 ] is not None :
1198+ x_range [0 ] = line_plot ._transform_scalar_x (x_range [0 ])
1199+ if x_range [1 ] is not None :
1200+ x_range [1 ] = line_plot ._transform_scalar_x (x_range [1 ])
1201+ if (
1202+ line_plot ._xaxis_scale == "log"
1203+ and x_range [0 ] is not None
1204+ and x_range [1 ] is not None
1205+ and x_range [0 ] > 0
1206+ and x_range [1 ] > 0
1207+ ):
1208+ x_range = [np .log10 (x_range [0 ]), np .log10 (x_range [1 ])]
1209+ fig .update_xaxes (
1210+ range = x_range ,
1211+ row = row + 1 ,
1212+ col = col + 1 ,
1213+ )
1214+ if line_plot ._ymin is not None or line_plot ._ymax is not None :
1215+ y_range = [line_plot ._ymin , line_plot ._ymax ]
1216+ if y_range [0 ] is not None :
1217+ y_range [0 ] = line_plot ._transform_scalar_y (y_range [0 ])
1218+ if y_range [1 ] is not None :
1219+ y_range [1 ] = line_plot ._transform_scalar_y (y_range [1 ])
1220+ if (
1221+ line_plot ._yaxis_scale == "log"
1222+ and y_range [0 ] is not None
1223+ and y_range [1 ] is not None
1224+ and y_range [0 ] > 0
1225+ and y_range [1 ] > 0
1226+ ):
1227+ y_range = [np .log10 (y_range [0 ]), np .log10 (y_range [1 ])]
1228+ fig .update_yaxes (
1229+ range = y_range ,
1230+ row = row + 1 ,
1231+ col = col + 1 ,
1232+ )
1233+
1234+ # Custom ticks (positions + optional labels)
1235+ if line_plot ._xticks is not None :
1236+ tickvals = [line_plot ._transform_scalar_x (v ) for v in line_plot ._xticks ]
1237+ fig .update_xaxes (
1238+ tickmode = "array" ,
1239+ tickvals = tickvals ,
1240+ ticktext = line_plot ._xticklabels ,
1241+ row = row + 1 ,
1242+ col = col + 1 ,
1243+ )
1244+ if line_plot ._yticks is not None :
1245+ tickvals = [line_plot ._transform_scalar_y (v ) for v in line_plot ._yticks ]
1246+ fig .update_yaxes (
1247+ tickmode = "array" ,
1248+ tickvals = tickvals ,
1249+ ticktext = line_plot ._yticklabels ,
1250+ row = row + 1 ,
1251+ col = col + 1 ,
1252+ )
1253+
1254+ # Aspect ratio
1255+ if line_plot ._aspect == "equal" :
1256+ fig .update_yaxes (scaleanchor = xref , row = row + 1 , col = col + 1 )
1257+ elif isinstance (line_plot ._aspect , (int , float )):
1258+ fig .update_yaxes (
1259+ scaleanchor = xref ,
1260+ scaleratio = float (line_plot ._aspect ),
1261+ row = row + 1 ,
1262+ col = col + 1 ,
1263+ )
10981264
10991265 # Update layout settings
11001266 fig .update_layout (
@@ -1105,10 +1271,30 @@ def plot_plotly(
11051271 fig .update_layout (title = dict (text = self ._suptitle , x = 0.5 ))
11061272
11071273 if savefig :
1108- fig .write_image (savefig )
1274+ try :
1275+ fig .write_image (savefig )
1276+ except Exception as exc :
1277+ raise RuntimeError (
1278+ "Plotly image export failed. If you are exporting to PNG/PDF/SVG, "
1279+ "install kaleido (e.g., `pip install -U kaleido`)."
1280+ ) from exc
11091281
11101282 return fig
11111283
1284+ def _save_plotly (self , fig , filename : str ) -> None :
1285+ _ , extension = os .path .splitext (filename )
1286+ extension = extension .lower ()
1287+ if extension in {".html" , ".htm" }:
1288+ fig .write_html (filename )
1289+ return
1290+ try :
1291+ fig .write_image (filename )
1292+ except Exception as exc :
1293+ raise RuntimeError (
1294+ "Plotly image export failed. For PNG/PDF/SVG export, install kaleido "
1295+ "(e.g., `pip install -U kaleido`), or export to HTML instead."
1296+ ) from exc
1297+
11121298 # Property getters
11131299
11141300 @property
0 commit comments