Skip to content

Commit 9003d07

Browse files
authored
Improve plotly coverage (#46)
* extended the plotly coverage * updated tutorials * formatting
1 parent f2a2952 commit 9003d07

17 files changed

Lines changed: 1348 additions & 311 deletions

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ ax.set_legend(True)
7575
canvas.show(backend="plotext")
7676
```
7777

78+
## Examples
79+
80+
Runnable example scripts live in `examples/`:
81+
82+
``` bash
83+
python examples/plotly_backend_basic.py
84+
python examples/plotly_backend_parity.py
85+
```
86+
7887
### Layers
7988

8089
``` python

examples/plotly_backend_basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
3+
from maxplotlib import Canvas
4+
5+
6+
def main() -> None:
7+
x = np.linspace(0, 2 * np.pi, 200)
8+
9+
canvas = Canvas(width="12cm", ratio=0.5)
10+
canvas.add_line(x, np.sin(x), color="royalblue", label="sin(x)")
11+
canvas.scatter(x[::12], np.sin(x[::12]), color="tomato", label="samples")
12+
canvas.axhline(0, color="black", linestyle="dotted")
13+
canvas.set_title("Plotly backend (basic)")
14+
canvas.set_xlabel("x")
15+
canvas.set_ylabel("y")
16+
canvas.set_grid(True)
17+
canvas.set_legend(True)
18+
19+
canvas.savefig("plotly_basic.html", backend="plotly")
20+
21+
22+
if __name__ == "__main__":
23+
main()

examples/plotly_backend_parity.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import matplotlib.patches as mpatches
2+
import numpy as np
3+
4+
from maxplotlib import Canvas
5+
6+
7+
def main() -> None:
8+
x = np.linspace(0.5, 10, 60)
9+
y = np.sqrt(x)
10+
11+
canvas = Canvas(width="12cm", ratio=0.55)
12+
13+
canvas.add_line(x, y, color="steelblue", label="sqrt(x)")
14+
canvas.errorbar(
15+
x[::10],
16+
y[::10],
17+
yerr=0.15,
18+
color="tomato",
19+
marker="o",
20+
label="samples ± err",
21+
)
22+
canvas.fill_between(x, y - 0.1, y + 0.1, color="steelblue", alpha=0.2, label="band")
23+
canvas.vlines([2, 5, 8], ymin=0, ymax=3.5, color="gray", linestyle="dashed")
24+
canvas.text(7.2, 2.8, "note", color="purple")
25+
canvas.annotate(
26+
"peak-ish", xy=(9.5, np.sqrt(9.5)), xytext=(6.0, 3.1), color="purple"
27+
)
28+
29+
canvas.add_patch(
30+
mpatches.Rectangle((1.2, 0.0), 2.5, 1.2, fill=True),
31+
facecolor="rgba(255,0,0,0.1)",
32+
edgecolor="crimson",
33+
alpha=0.3,
34+
)
35+
36+
canvas.set_title("Plotly backend (parity features)")
37+
canvas.set_xlabel("x")
38+
canvas.set_ylabel("y")
39+
canvas.set_xscale("log")
40+
canvas.set_grid(True)
41+
canvas.set_legend(True)
42+
43+
canvas.savefig("plotly_parity.html", backend="plotly")
44+
45+
46+
if __name__ == "__main__":
47+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"pint",
2020
"plotly",
2121
"plotext",
22-
"tikzfigure[vis]>=0.2.1",
22+
"tikzfigure[vis]>=0.3.0",
2323
]
2424
[project.optional-dependencies]
2525
test = [

src/maxplotlib/canvas/canvas.py

Lines changed: 208 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import matplotlib.patches as patches
77
import matplotlib.pyplot as plt
8+
import numpy as np
89
from plotly.subplots import make_subplots
910
from tikzfigure import TikzFigure
1011

@@ -579,6 +580,41 @@ def text(
579580
"""Add a text label at (x, y) on a subplot."""
580581
self._get_or_create_subplot(row, col).text(x, y, s, layer=layer, **kwargs)
581582

583+
def imshow(
584+
self,
585+
data,
586+
layer=0,
587+
row: int | None = None,
588+
col: int | None = None,
589+
**kwargs,
590+
):
591+
"""Add an image/matrix plot to a subplot."""
592+
self._get_or_create_subplot(row, col).add_imshow(data, layer=layer, **kwargs)
593+
594+
def add_patch(
595+
self,
596+
patch,
597+
layer=0,
598+
row: int | None = None,
599+
col: int | None = None,
600+
**kwargs,
601+
):
602+
"""Add a Matplotlib patch to a subplot."""
603+
self._get_or_create_subplot(row, col).add_patch(patch, layer=layer, **kwargs)
604+
605+
def colorbar(
606+
self,
607+
label: str = "",
608+
layer=0,
609+
row: int | None = None,
610+
col: int | None = None,
611+
**kwargs,
612+
):
613+
"""Add a colorbar to the most recent imshow() on a subplot (matplotlib backend)."""
614+
self._get_or_create_subplot(row, col).add_colorbar(
615+
label=label, layer=layer, **kwargs
616+
)
617+
582618
# ------------------------------------------------------------------
583619
# Multi-subplot helpers
584620
# ------------------------------------------------------------------
@@ -773,6 +809,34 @@ def savefig(
773809
figure.savefig(full_filepath)
774810
if verbose:
775811
print(f"Saved {full_filepath}")
812+
elif backend == "plotly":
813+
if layer_by_layer:
814+
layers = []
815+
for layer in self.layers:
816+
layers.append(layer)
817+
full_filepath = f"{filename_no_extension}_{layers}{extension}"
818+
fig = self.plot(
819+
backend="plotly",
820+
savefig=False,
821+
layers=layers,
822+
)
823+
self._save_plotly(fig, full_filepath)
824+
if verbose:
825+
print(f"Saved {full_filepath}")
826+
else:
827+
if layers is None:
828+
layers = self.layers
829+
full_filepath = filename
830+
else:
831+
full_filepath = f"{filename_no_extension}_{layers}{extension}"
832+
fig = self.plot(
833+
backend="plotly",
834+
savefig=False,
835+
layers=layers,
836+
)
837+
self._save_plotly(fig, full_filepath)
838+
if verbose:
839+
print(f"Saved {full_filepath}")
776840

777841
def plot(
778842
self,
@@ -797,6 +861,7 @@ def plot(
797861
elif backend == "plotly":
798862
return self.plot_plotly(
799863
savefig=savefig,
864+
layers=layers,
800865
usetex=resolved_usetex,
801866
verbose=verbose,
802867
)
@@ -832,7 +897,11 @@ def show(
832897
# self._matplotlib_fig.show()
833898
elif backend == "plotly":
834899
resolved_usetex = self._usetex if usetex is None else usetex
835-
self.plot_plotly(savefig=False, usetex=resolved_usetex)
900+
fig = self.plot_plotly(
901+
savefig=False, layers=layers, usetex=resolved_usetex, verbose=verbose
902+
)
903+
fig.show()
904+
return fig
836905
elif backend == "plotext":
837906
figure = self.plot_plotext(
838907
savefig=False,
@@ -1034,6 +1103,7 @@ def plot_plotly(
10341103
self,
10351104
show=True,
10361105
savefig=None,
1106+
layers: list | None = None,
10371107
usetex: bool | None = None,
10381108
verbose: bool = False,
10391109
):
@@ -1063,38 +1133,134 @@ def plot_plotly(
10631133
ratio=self._ratio,
10641134
)
10651135
# print(self._width, fig_width, fig_height)
1066-
# Create subplots
1136+
# Create subplot titles in row-major order (Plotly expects rows*cols entries)
1137+
subplot_titles = [""] * (self.nrows * self.ncols)
1138+
for (row, col), sp in self._subplot_dict.items():
1139+
index = row * self.ncols + col
1140+
subplot_titles[index] = sp._title or f"({row}, {col})"
1141+
10671142
fig = make_subplots(
10681143
rows=self.nrows,
10691144
cols=self.ncols,
1070-
subplot_titles=[
1071-
sp._title or f"({row}, {col})"
1072-
for (row, col), sp in self._subplot_dict.items()
1073-
],
1145+
subplot_titles=subplot_titles,
10741146
)
10751147

10761148
# Plot each subplot and propagate axis labels/scale
1077-
axis_index = 1
10781149
for (row, col), line_plot in self._subplot_dict.items():
1079-
traces = line_plot.plot_plotly()
1150+
traces, shapes, annotations = line_plot.plot_plotly(layers=layers)
10801151
for trace in traces:
10811152
fig.add_trace(trace, row=row + 1, col=col + 1)
10821153

1083-
# Axis label keys are "xaxis", "xaxis2", "xaxis3", ...
1084-
xkey = "xaxis" if axis_index == 1 else f"xaxis{axis_index}"
1085-
ykey = "yaxis" if axis_index == 1 else f"yaxis{axis_index}"
1086-
layout_patch = {}
1087-
if line_plot._xlabel:
1088-
layout_patch[xkey] = {"title": {"text": line_plot._xlabel}}
1089-
if line_plot._ylabel:
1090-
layout_patch[ykey] = {"title": {"text": line_plot._ylabel}}
1154+
# Axis indices are row-major: (row*ncols + col + 1)
1155+
axis_index = row * self.ncols + col + 1
1156+
xref = "x" if axis_index == 1 else f"x{axis_index}"
1157+
yref = "y" if axis_index == 1 else f"y{axis_index}"
1158+
1159+
for shape in shapes:
1160+
shape = dict(shape)
1161+
if shape.get("xref") not in {"paper"}:
1162+
shape["xref"] = xref
1163+
if shape.get("yref") not in {"paper"}:
1164+
shape["yref"] = yref
1165+
fig.add_shape(shape)
1166+
1167+
for annotation in annotations:
1168+
annotation = dict(annotation)
1169+
annotation.setdefault("xref", xref)
1170+
annotation.setdefault("yref", yref)
1171+
fig.add_annotation(annotation)
1172+
1173+
# Apply per-axis config in a row/col-safe way
1174+
xaxis_kwargs = dict(
1175+
title_text=line_plot._xlabel or None,
1176+
showgrid=bool(line_plot._grid),
1177+
row=row + 1,
1178+
col=col + 1,
1179+
)
10911180
if line_plot._xaxis_scale == "log":
1092-
layout_patch.setdefault(xkey, {})["type"] = "log"
1181+
xaxis_kwargs["type"] = "log"
1182+
fig.update_xaxes(**xaxis_kwargs)
1183+
1184+
yaxis_kwargs = dict(
1185+
title_text=line_plot._ylabel or None,
1186+
showgrid=bool(line_plot._grid),
1187+
row=row + 1,
1188+
col=col + 1,
1189+
)
10931190
if line_plot._yaxis_scale == "log":
1094-
layout_patch.setdefault(ykey, {})["type"] = "log"
1095-
if layout_patch:
1096-
fig.update_layout(**layout_patch)
1097-
axis_index += 1
1191+
yaxis_kwargs["type"] = "log"
1192+
fig.update_yaxes(**yaxis_kwargs)
1193+
1194+
# Axis limits
1195+
if line_plot._xmin is not None or line_plot._xmax is not None:
1196+
x_range = [line_plot._xmin, line_plot._xmax]
1197+
if x_range[0] is not None:
1198+
x_range[0] = line_plot._transform_scalar_x(x_range[0])
1199+
if x_range[1] is not None:
1200+
x_range[1] = line_plot._transform_scalar_x(x_range[1])
1201+
if (
1202+
line_plot._xaxis_scale == "log"
1203+
and x_range[0] is not None
1204+
and x_range[1] is not None
1205+
and x_range[0] > 0
1206+
and x_range[1] > 0
1207+
):
1208+
x_range = [np.log10(x_range[0]), np.log10(x_range[1])]
1209+
fig.update_xaxes(
1210+
range=x_range,
1211+
row=row + 1,
1212+
col=col + 1,
1213+
)
1214+
if line_plot._ymin is not None or line_plot._ymax is not None:
1215+
y_range = [line_plot._ymin, line_plot._ymax]
1216+
if y_range[0] is not None:
1217+
y_range[0] = line_plot._transform_scalar_y(y_range[0])
1218+
if y_range[1] is not None:
1219+
y_range[1] = line_plot._transform_scalar_y(y_range[1])
1220+
if (
1221+
line_plot._yaxis_scale == "log"
1222+
and y_range[0] is not None
1223+
and y_range[1] is not None
1224+
and y_range[0] > 0
1225+
and y_range[1] > 0
1226+
):
1227+
y_range = [np.log10(y_range[0]), np.log10(y_range[1])]
1228+
fig.update_yaxes(
1229+
range=y_range,
1230+
row=row + 1,
1231+
col=col + 1,
1232+
)
1233+
1234+
# Custom ticks (positions + optional labels)
1235+
if line_plot._xticks is not None:
1236+
tickvals = [line_plot._transform_scalar_x(v) for v in line_plot._xticks]
1237+
fig.update_xaxes(
1238+
tickmode="array",
1239+
tickvals=tickvals,
1240+
ticktext=line_plot._xticklabels,
1241+
row=row + 1,
1242+
col=col + 1,
1243+
)
1244+
if line_plot._yticks is not None:
1245+
tickvals = [line_plot._transform_scalar_y(v) for v in line_plot._yticks]
1246+
fig.update_yaxes(
1247+
tickmode="array",
1248+
tickvals=tickvals,
1249+
ticktext=line_plot._yticklabels,
1250+
row=row + 1,
1251+
col=col + 1,
1252+
)
1253+
1254+
# Aspect ratio
1255+
if line_plot._aspect == "equal":
1256+
fig.update_yaxes(scaleanchor=xref, row=row + 1, col=col + 1)
1257+
elif isinstance(line_plot._aspect, (int, float)):
1258+
fig.update_yaxes(
1259+
scaleanchor=xref,
1260+
scaleratio=float(line_plot._aspect),
1261+
row=row + 1,
1262+
col=col + 1,
1263+
)
10981264

10991265
# Update layout settings
11001266
fig.update_layout(
@@ -1105,10 +1271,30 @@ def plot_plotly(
11051271
fig.update_layout(title=dict(text=self._suptitle, x=0.5))
11061272

11071273
if savefig:
1108-
fig.write_image(savefig)
1274+
try:
1275+
fig.write_image(savefig)
1276+
except Exception as exc:
1277+
raise RuntimeError(
1278+
"Plotly image export failed. If you are exporting to PNG/PDF/SVG, "
1279+
"install kaleido (e.g., `pip install -U kaleido`)."
1280+
) from exc
11091281

11101282
return fig
11111283

1284+
def _save_plotly(self, fig, filename: str) -> None:
1285+
_, extension = os.path.splitext(filename)
1286+
extension = extension.lower()
1287+
if extension in {".html", ".htm"}:
1288+
fig.write_html(filename)
1289+
return
1290+
try:
1291+
fig.write_image(filename)
1292+
except Exception as exc:
1293+
raise RuntimeError(
1294+
"Plotly image export failed. For PNG/PDF/SVG export, install kaleido "
1295+
"(e.g., `pip install -U kaleido`), or export to HTML instead."
1296+
) from exc
1297+
11121298
# Property getters
11131299

11141300
@property

0 commit comments

Comments
 (0)