From 28295452337b31c021bcfaa8b298eb71966a49dd Mon Sep 17 00:00:00 2001 From: Daizu Date: Wed, 11 Feb 2026 07:32:37 +0000 Subject: [PATCH] Delete facet --- notebooks/chart.py | 25 +++++------ notebooks/facet.py | 15 +++++-- src/plotaris/core/axisgrid.py | 5 +++ src/plotaris/core/chart.py | 80 +++++++++++++-------------------- src/plotaris/core/group.py | 85 ++++++++++++++++++++++++++++------- src/plotaris/core/palette.py | 9 +--- tests/core/test_group.py | 46 +++++++++++++++++++ 7 files changed, 177 insertions(+), 88 deletions(-) diff --git a/notebooks/chart.py b/notebooks/chart.py index 1c2d6fb..180a338 100644 --- a/notebooks/chart.py +++ b/notebooks/chart.py @@ -19,32 +19,27 @@ def _(): @app.function -def func(x, y, **kwargs): - print(kwargs) +def func(x, y, label, **kwargs): + print(label) ax = plt.gca() - ax.scatter(x, y) + ax.scatter(x, y, label=f"{kwargs['color']}") + ax.legend(fontsize=4) @app.cell def _(data): chart = ( Chart(data, figsize=(4, 2)) - .encode("x", "y", color="x", shape="y") - .mapping(color={3: "pink"}, shape={5: "+"}) - .mark_point() - .map(func) + .encode("x", "y", color="x", shape=("y", "x")) + .mapping(color={1: "blue", 3: "pink"}, shape={5: "+"}) + # .map(func) + .mark_point(s=5) .facet("a", "x") - # .to_facet() - # .delaxes() + .select(row=0, col=0) + .legend(fontsize=6) # .set_titles() ) chart - return (chart,) - - -@app.cell -def _(chart): - type(chart) return diff --git a/notebooks/facet.py b/notebooks/facet.py index d508e11..311e792 100644 --- a/notebooks/facet.py +++ b/notebooks/facet.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.19.7" +__generated_with = "0.19.9" app = marimo.App(width="medium") with app.setup: @@ -28,8 +28,17 @@ def _(): @app.cell def _(data): - grid = FacetGrid(data, row="a", col="b", figsize=(4, 2)) - grid.set_titles({"b": "{:.1f}"}, a="A(m)", margin_titles=True) + FacetGrid(data, row="a", col="b", figsize=(4, 2)).set_titles( + {"b": "{:.1f}"}, a="A(m)", margin_titles=True + ) + return + + +@app.cell +def _(data): + FacetGrid( + data.with_columns(pl.col("a") * 1e-6), col="a", wrap=2, figsize=(4, 2) + ).set_titles(a="A(m)") return diff --git a/src/plotaris/core/axisgrid.py b/src/plotaris/core/axisgrid.py index bf2be08..9f66657 100644 --- a/src/plotaris/core/axisgrid.py +++ b/src/plotaris/core/axisgrid.py @@ -458,6 +458,11 @@ def set_title( return self + def legend(self, *args: Any, **kwargs: Any) -> Self: + facet_axes = self.facet_axes.filter(has_data=True) + facet_axes.map_axes(lambda ax: ax.legend(*args, **kwargs)) # pyright: ignore[reportUnknownMemberType] + return self + def _display_(self) -> Figure: """Return the figure for display in IPython environments.""" return self.figure diff --git a/src/plotaris/core/chart.py b/src/plotaris/core/chart.py index baa13f4..d80c1c7 100644 --- a/src/plotaris/core/chart.py +++ b/src/plotaris/core/chart.py @@ -31,12 +31,10 @@ class Chart: color: tuple[str, ...] = () size: tuple[str, ...] = () shape: tuple[str, ...] = () - row: tuple[str, ...] = () - col: tuple[str, ...] = () - wrap: int | None = None palette: Palette | None = None mark: Mark | None = None - _plot: Callable[..., Any] | None = None + plot: Callable[..., Any] | None = None + axes: Axes | None = None _kwargs: dict[str, Any] def __init__( @@ -56,21 +54,14 @@ def encoding(self) -> dict[str, tuple[str, ...]]: names = ["color", "size", "shape"] return {name: value for name in names if (value := getattr(self, name))} - @property - def has_facet(self) -> bool: - return bool(self.row or self.col) - def encode( self, x: str | pl.Expr | None = None, y: str | pl.Expr | None = None, *, color: str | Iterable[str] | None = None, - size: str | Iterable[str] | None = None, shape: str | Iterable[str] | None = None, - row: str | Iterable[str] | None = None, - col: str | Iterable[str] | None = None, - wrap: int | None = None, + size: str | Iterable[str] | None = None, ) -> Self: if x is not None: self.x = x @@ -78,41 +69,28 @@ def encode( self.y = y if color is not None: self.color = to_tuple(color) - if size is not None: - self.size = to_tuple(size) if shape is not None: self.shape = to_tuple(shape) - - if row or col: - self.facet(row, col, wrap) + if size is not None: + self.size = to_tuple(size) self.palette = Palette(**self.encoding).set(self.data) - return self def mapping( self, - /, - **mapping: Mapping[Any | tuple[Any, ...], VisualValue], + color: Mapping[Any, VisualValue] | None = None, + shape: Mapping[Any, VisualValue] | None = None, + size: Mapping[Any, VisualValue] | None = None, ) -> Self: + it = [("color", color), ("shape", shape), ("size", size)] + mapping = {k: v for k, v in it if v} if self.palette is not None: self.palette.mapping(**mapping).set(self.data) return self - def facet( - self, - row: str | Iterable[str] | None = None, - col: str | Iterable[str] | None = None, - wrap: int | None = None, - ) -> Self: - self.row = to_tuple(row) - self.col = to_tuple(col) - self.wrap = wrap - - return self - def map(self, plot: Callable[..., Any], /) -> Self: - self._plot = plot + self.plot = plot return self def mark_point(self, **kwargs: Any) -> Self: @@ -127,9 +105,7 @@ def mark_bar(self, **kwargs: Any) -> Self: self.mark = BarMark(**kwargs) return self - def _get_series(self, data: pl.DataFrame) -> dict[str, Any]: - kwargs: dict[str, Any] = {} - + def _get_series(self, data: pl.DataFrame, **kwargs: Any) -> dict[str, Any]: if self.palette is not None: kwargs.update(self.palette.get(data)) @@ -143,27 +119,35 @@ def _get_series(self, data: pl.DataFrame) -> dict[str, Any]: def _iter_series(self, data: pl.DataFrame) -> Iterator[dict[str, Any]]: group = Group(data, **self.encoding) - return map(self._get_series, group) + for df, label in zip(group, group.labels(merge=True), strict=True): + yield self._get_series(df, label=label) def _plot_series(self, data: pl.DataFrame) -> None: for series in self._iter_series(data): - if self._plot: - self._plot(**series) + if self.plot: + self.plot(**series) if self.mark: self.mark.plot(**series) - def to_facet(self) -> FacetGrid: - grid = FacetGrid(self.data, self.row, self.col, self.wrap, **self._kwargs) + def facet( + self, + row: str | Iterable[str] | None = None, + col: str | Iterable[str] | None = None, + wrap: int | None = None, + ) -> FacetGrid: + grid = FacetGrid(self.data, row, col, wrap, **self._kwargs) grid.map_dataframe(self._plot_series) return grid - def display(self) -> Axes | FacetGrid: - if self.has_facet: - return self.to_facet() + def display(self) -> Axes: + if self.axes is None: + self.axes = plt.figure(**self._kwargs).add_subplot() # pyright: ignore[reportUnknownMemberType] + self._plot_series(self.data) + return self.axes - ax = plt.figure(**self._kwargs).add_subplot() # pyright: ignore[reportUnknownMemberType] - self._plot_series(self.data) - return ax + def legend(self, *args: Any, **kwargs: Any) -> Self: + self.display().legend(*args, **kwargs) # pyright: ignore[reportUnknownMemberType] + return self - def _display_(self) -> Axes | FacetGrid: + def _display_(self) -> Axes: return self.display() diff --git a/src/plotaris/core/group.py b/src/plotaris/core/group.py index 535a381..85f6d5e 100644 --- a/src/plotaris/core/group.py +++ b/src/plotaris/core/group.py @@ -1,7 +1,8 @@ from __future__ import annotations +from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Literal, overload import polars as pl @@ -222,37 +223,91 @@ def dimension_keys(self) -> dict[str, pl.DataFrame]: return {dim: self.keys(dim) for dim in self.mapping} @overload - def labels(self, index: int) -> dict[str, dict[str, Any]]: ... + def labels( + self, + index: int, + *, + merge: Literal[False] = False, + ) -> dict[str, dict[str, Any]]: ... + + @overload + def labels( + self, + index: int, + *, + merge: Literal[True], + ) -> dict[str, Any]: ... @overload - def labels(self, index: None = None) -> list[dict[str, dict[str, Any]]]: ... + def labels( + self, + index: None = None, + *, + merge: Literal[False] = False, + ) -> list[dict[str, dict[str, Any]]]: ... + + @overload + def labels( + self, + index: None = None, + *, + merge: Literal[True], + ) -> list[dict[str, Any]]: ... def labels( self, index: int | None = None, - ) -> dict[str, dict[str, Any]] | list[dict[str, dict[str, Any]]]: + *, + merge: bool = False, + ): """Gets the labels for one or all data groups. Each group is defined by a unique combination of values from the columns - specified in the dimensions. This method retrieves these values. + specified in the dimensions. This method retrieves these values as labels. Args: index: The integer index of a specific group. If None (default), labels for all groups are returned. + merge: If True, the dictionaries for each dimension are merged into + a single dictionary. Defaults to False. Returns: - If `index` is an integer, returns a dictionary mapping dimension names - to the key-value pairs for that group. - Example: `{"row": {"col_a": 1}, "col": {"col_b": "x"}}` + A dictionary or list of dictionaries representing the group labels. + The structure depends on the `index` and `merge` arguments: + + - `index` is `int`, `merge=False` (default): + Returns a dictionary mapping dimension names to the key-value pairs + for that group. + `{"row": {"var1": "a"}, "col": {"var2": 1}}` + + - `index` is `int`, `merge=True`: + Returns a single dictionary with all key-value pairs merged. + `{"var1": "a", "var2": 1}` + + - `index` is `None`, `merge=False` (default): + Returns a list of dictionaries, one for each group, structured + as in the first case. - If `index` is None, returns a list of these dictionaries, with one - entry for each group. + - `index` is `None`, `merge=True`: + Returns a list of merged dictionaries, one for each group. """ if index is not None: - return {dim: self.keys(dim).row(index, named=True) for dim in self.mapping} + labels = {d: self.keys(d).row(index, named=True) for d in self.mapping} + return _merge(labels) if merge else labels dim_keys = self.dimension_keys() - return [ - {dim: keys.row(i, named=True) for dim, keys in dim_keys.items()} - for i in range(len(self)) - ] + return [_label(dim_keys, i, merge=merge) for i in range(len(self))] + + +def _merge(labels: dict[str, dict[str, Any]], /) -> dict[str, Any]: + return reduce(lambda x, y: {**x, **y}, labels.values()) + + +def _label( + dim_keys: dict[str, pl.DataFrame], + index: int, + *, + merge: bool, +) -> dict[str, dict[str, Any]] | dict[str, Any]: + labels = {dim: keys.row(index, named=True) for dim, keys in dim_keys.items()} + return _merge(labels) if merge else labels diff --git a/src/plotaris/core/palette.py b/src/plotaris/core/palette.py index 8aaf477..0e1ed75 100644 --- a/src/plotaris/core/palette.py +++ b/src/plotaris/core/palette.py @@ -33,14 +33,9 @@ def default(self, /, **default: Iterable[VisualValue] | None) -> Self: self._default = {k: list(v) for k, v in default.items() if v is not None} return self - def mapping( - self, - /, - **mapping: Mapping[Any | tuple[Any, ...], VisualValue], - ) -> Self: + def mapping(self, **mapping: Mapping[Any, VisualValue]) -> Self: def to_tuple_dict( - x: Mapping[Any | tuple[Any, ...], VisualValue], - /, + x: Mapping[Any, VisualValue], ) -> dict[tuple[Any, ...], VisualValue]: return {k if isinstance(k, tuple) else (k,): v for k, v in x.items()} diff --git a/tests/core/test_group.py b/tests/core/test_group.py index 0b46eaf..2a86842 100644 --- a/tests/core/test_group.py +++ b/tests/core/test_group.py @@ -193,12 +193,19 @@ def test_group_columns_str_str(data: pl.DataFrame) -> None: assert_frame_equal(dim["col"], expected.select("b")) assert gr.labels(1) == {"row": {"a": 1}, "col": {"b": 4}} + assert gr.labels(2, merge=True) == {"a": 2, "b": 4} assert gr.labels() == [ {"row": {"a": 1}, "col": {"b": 3}}, {"row": {"a": 1}, "col": {"b": 4}}, {"row": {"a": 2}, "col": {"b": 4}}, {"row": {"a": 2}, "col": {"b": 5}}, ] + assert gr.labels(merge=True) == [ + {"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 4}, + {"a": 2, "b": 5}, + ] def test_group_columns_str_str_duplicated(data: pl.DataFrame) -> None: @@ -216,6 +223,15 @@ def test_group_columns_str_str_duplicated(data: pl.DataFrame) -> None: assert_frame_equal(dim["row"], expected) assert_frame_equal(dim["col"], expected) + assert gr.labels(1) == {"row": {"b": 4}, "col": {"b": 4}} + assert gr.labels(2, merge=True) == {"b": 5} + assert gr.labels() == [ + {"row": {"b": 3}, "col": {"b": 3}}, + {"row": {"b": 4}, "col": {"b": 4}}, + {"row": {"b": 5}, "col": {"b": 5}}, + ] + assert gr.labels(merge=True) == [{"b": 3}, {"b": 4}, {"b": 5}] + def test_group_columns_tuple(data: pl.DataFrame) -> None: gr = Group(data, row=("a", "b")) @@ -231,6 +247,21 @@ def test_group_columns_tuple(data: pl.DataFrame) -> None: dim = gr.dimension_keys() assert_frame_equal(dim["row"], expected) + assert gr.labels(1) == {"row": {"a": 1, "b": 4}} + assert gr.labels(2, merge=True) == {"a": 2, "b": 4} + assert gr.labels() == [ + {"row": {"a": 1, "b": 3}}, + {"row": {"a": 1, "b": 4}}, + {"row": {"a": 2, "b": 4}}, + {"row": {"a": 2, "b": 5}}, + ] + assert gr.labels(merge=True) == [ + {"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 4}, + {"a": 2, "b": 5}, + ] + def test_group_columns_tuple_str(data: pl.DataFrame) -> None: gr = Group(data, row=("b", "a"), col="a") @@ -247,6 +278,21 @@ def test_group_columns_tuple_str(data: pl.DataFrame) -> None: assert_frame_equal(dim["row"], expected) assert_frame_equal(dim["col"], expected.select("a")) + assert gr.labels(1) == {"row": {"b": 4, "a": 1}, "col": {"a": 1}} + assert gr.labels(2, merge=True) == {"b": 4, "a": 2} + assert gr.labels() == [ + {"row": {"b": 3, "a": 1}, "col": {"a": 1}}, + {"row": {"b": 4, "a": 1}, "col": {"a": 1}}, + {"row": {"b": 4, "a": 2}, "col": {"a": 2}}, + {"row": {"b": 5, "a": 2}, "col": {"a": 2}}, + ] + assert gr.labels(merge=True) == [ + {"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 4}, + {"a": 2, "b": 5}, + ] + def test_group_columns_str_empty(data: pl.DataFrame) -> None: gr = Group(data, row="a", col=())