Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions notebooks/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,27 @@ def _():


@app.function
def func(x, y, **kwargs):
print(kwargs)
def func(x, y, label, **kwargs):
print(label)
ax = plt.gca()
ax.scatter(x, y)
ax.scatter(x, y, label=f"{kwargs['color']}")
ax.legend(fontsize=4)


@app.cell
def _(data):
chart = (
Chart(data, figsize=(4, 2))
.encode("x", "y", color="x", shape="y")
.mapping(color={3: "pink"}, shape={5: "+"})
.mark_point()
.map(func)
.encode("x", "y", color="x", shape=("y", "x"))
.mapping(color={1: "blue", 3: "pink"}, shape={5: "+"})
# .map(func)
.mark_point(s=5)
.facet("a", "x")
# .to_facet()
# .delaxes()
.select(row=0, col=0)
.legend(fontsize=6)
# .set_titles()
)
chart
return (chart,)


@app.cell
def _(chart):
type(chart)
return


Expand Down
15 changes: 12 additions & 3 deletions notebooks/facet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import marimo

__generated_with = "0.19.7"
__generated_with = "0.19.9"
app = marimo.App(width="medium")

with app.setup:
Expand Down Expand Up @@ -28,8 +28,17 @@ def _():

@app.cell
def _(data):
grid = FacetGrid(data, row="a", col="b", figsize=(4, 2))
grid.set_titles({"b": "{:.1f}"}, a="A(m)", margin_titles=True)
FacetGrid(data, row="a", col="b", figsize=(4, 2)).set_titles(
{"b": "{:.1f}"}, a="A(m)", margin_titles=True
)
return


@app.cell
def _(data):
FacetGrid(
data.with_columns(pl.col("a") * 1e-6), col="a", wrap=2, figsize=(4, 2)
).set_titles(a="A(m)")
return


Expand Down
5 changes: 5 additions & 0 deletions src/plotaris/core/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,11 @@ def set_title(

return self

def legend(self, *args: Any, **kwargs: Any) -> Self:
facet_axes = self.facet_axes.filter(has_data=True)
facet_axes.map_axes(lambda ax: ax.legend(*args, **kwargs)) # pyright: ignore[reportUnknownMemberType]
return self

def _display_(self) -> Figure:
"""Return the figure for display in IPython environments."""
return self.figure
80 changes: 32 additions & 48 deletions src/plotaris/core/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ class Chart:
color: tuple[str, ...] = ()
size: tuple[str, ...] = ()
shape: tuple[str, ...] = ()
row: tuple[str, ...] = ()
col: tuple[str, ...] = ()
wrap: int | None = None
palette: Palette | None = None
mark: Mark | None = None
_plot: Callable[..., Any] | None = None
plot: Callable[..., Any] | None = None
axes: Axes | None = None
_kwargs: dict[str, Any]

def __init__(
Expand All @@ -56,63 +54,43 @@ def encoding(self) -> dict[str, tuple[str, ...]]:
names = ["color", "size", "shape"]
return {name: value for name in names if (value := getattr(self, name))}

@property
def has_facet(self) -> bool:
return bool(self.row or self.col)

def encode(
self,
x: str | pl.Expr | None = None,
y: str | pl.Expr | None = None,
*,
color: str | Iterable[str] | None = None,
size: str | Iterable[str] | None = None,
shape: str | Iterable[str] | None = None,
row: str | Iterable[str] | None = None,
col: str | Iterable[str] | None = None,
wrap: int | None = None,
size: str | Iterable[str] | None = None,
) -> Self:
if x is not None:
self.x = x
if y is not None:
self.y = y
if color is not None:
self.color = to_tuple(color)
if size is not None:
self.size = to_tuple(size)
if shape is not None:
self.shape = to_tuple(shape)

if row or col:
self.facet(row, col, wrap)
if size is not None:
self.size = to_tuple(size)

self.palette = Palette(**self.encoding).set(self.data)

return self

def mapping(
self,
/,
**mapping: Mapping[Any | tuple[Any, ...], VisualValue],
color: Mapping[Any, VisualValue] | None = None,
shape: Mapping[Any, VisualValue] | None = None,
size: Mapping[Any, VisualValue] | None = None,
) -> Self:
it = [("color", color), ("shape", shape), ("size", size)]
mapping = {k: v for k, v in it if v}
if self.palette is not None:
self.palette.mapping(**mapping).set(self.data)
return self

def facet(
self,
row: str | Iterable[str] | None = None,
col: str | Iterable[str] | None = None,
wrap: int | None = None,
) -> Self:
self.row = to_tuple(row)
self.col = to_tuple(col)
self.wrap = wrap

return self

def map(self, plot: Callable[..., Any], /) -> Self:
self._plot = plot
self.plot = plot
return self

def mark_point(self, **kwargs: Any) -> Self:
Expand All @@ -127,9 +105,7 @@ def mark_bar(self, **kwargs: Any) -> Self:
self.mark = BarMark(**kwargs)
return self

def _get_series(self, data: pl.DataFrame) -> dict[str, Any]:
kwargs: dict[str, Any] = {}

def _get_series(self, data: pl.DataFrame, **kwargs: Any) -> dict[str, Any]:
if self.palette is not None:
kwargs.update(self.palette.get(data))

Expand All @@ -143,27 +119,35 @@ def _get_series(self, data: pl.DataFrame) -> dict[str, Any]:

def _iter_series(self, data: pl.DataFrame) -> Iterator[dict[str, Any]]:
group = Group(data, **self.encoding)
return map(self._get_series, group)
for df, label in zip(group, group.labels(merge=True), strict=True):
yield self._get_series(df, label=label)

def _plot_series(self, data: pl.DataFrame) -> None:
for series in self._iter_series(data):
if self._plot:
self._plot(**series)
if self.plot:
self.plot(**series)
if self.mark:
self.mark.plot(**series)

def to_facet(self) -> FacetGrid:
grid = FacetGrid(self.data, self.row, self.col, self.wrap, **self._kwargs)
def facet(
self,
row: str | Iterable[str] | None = None,
col: str | Iterable[str] | None = None,
wrap: int | None = None,
) -> FacetGrid:
grid = FacetGrid(self.data, row, col, wrap, **self._kwargs)
grid.map_dataframe(self._plot_series)
return grid

def display(self) -> Axes | FacetGrid:
if self.has_facet:
return self.to_facet()
def display(self) -> Axes:
if self.axes is None:
self.axes = plt.figure(**self._kwargs).add_subplot() # pyright: ignore[reportUnknownMemberType]
self._plot_series(self.data)
return self.axes

ax = plt.figure(**self._kwargs).add_subplot() # pyright: ignore[reportUnknownMemberType]
self._plot_series(self.data)
return ax
def legend(self, *args: Any, **kwargs: Any) -> Self:
self.display().legend(*args, **kwargs) # pyright: ignore[reportUnknownMemberType]
return self

def _display_(self) -> Axes | FacetGrid:
def _display_(self) -> Axes:
return self.display()
85 changes: 70 additions & 15 deletions src/plotaris/core/group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any, Literal, overload

import polars as pl

Expand Down Expand Up @@ -222,37 +223,91 @@ def dimension_keys(self) -> dict[str, pl.DataFrame]:
return {dim: self.keys(dim) for dim in self.mapping}

@overload
def labels(self, index: int) -> dict[str, dict[str, Any]]: ...
def labels(
self,
index: int,
*,
merge: Literal[False] = False,
) -> dict[str, dict[str, Any]]: ...

@overload
def labels(
self,
index: int,
*,
merge: Literal[True],
) -> dict[str, Any]: ...

@overload
def labels(self, index: None = None) -> list[dict[str, dict[str, Any]]]: ...
def labels(
self,
index: None = None,
*,
merge: Literal[False] = False,
) -> list[dict[str, dict[str, Any]]]: ...

@overload
def labels(
self,
index: None = None,
*,
merge: Literal[True],
) -> list[dict[str, Any]]: ...

def labels(
self,
index: int | None = None,
) -> dict[str, dict[str, Any]] | list[dict[str, dict[str, Any]]]:
*,
merge: bool = False,
):
"""Gets the labels for one or all data groups.

Each group is defined by a unique combination of values from the columns
specified in the dimensions. This method retrieves these values.
specified in the dimensions. This method retrieves these values as labels.

Args:
index: The integer index of a specific group. If None (default),
labels for all groups are returned.
merge: If True, the dictionaries for each dimension are merged into
a single dictionary. Defaults to False.

Returns:
If `index` is an integer, returns a dictionary mapping dimension names
to the key-value pairs for that group.
Example: `{"row": {"col_a": 1}, "col": {"col_b": "x"}}`
A dictionary or list of dictionaries representing the group labels.
The structure depends on the `index` and `merge` arguments:

- `index` is `int`, `merge=False` (default):
Returns a dictionary mapping dimension names to the key-value pairs
for that group.
`{"row": {"var1": "a"}, "col": {"var2": 1}}`

- `index` is `int`, `merge=True`:
Returns a single dictionary with all key-value pairs merged.
`{"var1": "a", "var2": 1}`

- `index` is `None`, `merge=False` (default):
Returns a list of dictionaries, one for each group, structured
as in the first case.

If `index` is None, returns a list of these dictionaries, with one
entry for each group.
- `index` is `None`, `merge=True`:
Returns a list of merged dictionaries, one for each group.
"""
if index is not None:
return {dim: self.keys(dim).row(index, named=True) for dim in self.mapping}
labels = {d: self.keys(d).row(index, named=True) for d in self.mapping}
return _merge(labels) if merge else labels

dim_keys = self.dimension_keys()
return [
{dim: keys.row(i, named=True) for dim, keys in dim_keys.items()}
for i in range(len(self))
]
return [_label(dim_keys, i, merge=merge) for i in range(len(self))]


def _merge(labels: dict[str, dict[str, Any]], /) -> dict[str, Any]:
return reduce(lambda x, y: {**x, **y}, labels.values())


def _label(
dim_keys: dict[str, pl.DataFrame],
index: int,
*,
merge: bool,
) -> dict[str, dict[str, Any]] | dict[str, Any]:
labels = {dim: keys.row(index, named=True) for dim, keys in dim_keys.items()}
return _merge(labels) if merge else labels
9 changes: 2 additions & 7 deletions src/plotaris/core/palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,9 @@ def default(self, /, **default: Iterable[VisualValue] | None) -> Self:
self._default = {k: list(v) for k, v in default.items() if v is not None}
return self

def mapping(
self,
/,
**mapping: Mapping[Any | tuple[Any, ...], VisualValue],
) -> Self:
def mapping(self, **mapping: Mapping[Any, VisualValue]) -> Self:
def to_tuple_dict(
x: Mapping[Any | tuple[Any, ...], VisualValue],
/,
x: Mapping[Any, VisualValue],
) -> dict[tuple[Any, ...], VisualValue]:
return {k if isinstance(k, tuple) else (k,): v for k, v in x.items()}

Expand Down
Loading
Loading