diff --git a/doc/changelog.qmd b/doc/changelog.qmd index 346d84ef2e..372019406f 100644 --- a/doc/changelog.qmd +++ b/doc/changelog.qmd @@ -38,6 +38,28 @@ title: Changelog theme(plot_footer_line=element_line(color="black")) ``` +### API Changes + +- Removed `geom.to_layer()`, `stat.to_layer()`, `annotate.to_layer()`, + `layer.from_geom()` and `layer.from_stat()`. Use + `layer(geom=...)` and `layer(stat=...)` directly. + + The [](:class:`~plotnine.layer.layer`) constructor now handles both geom-first + and stat-first creation. When only a `stat` is provided, the geom is + automatically derived from the stat's default. + + ```python + # Before + layer.from_geom(geom_point()) + layer.from_stat(stat_bin()) + geom_point().to_layer() + + # After + layer(geom=geom_point()) + layer(stat=stat_bin()) + layer(geom=geom_point()) + ``` + ### Bug Fixes - Fixed [](:class:`~plotnine.geom_smooth`) / [](:class:`~plotnine.stat_smooth`) when using a linear model via "lm" with weights for the model to do a weighted regression. This bug did not affect the formula API of the linear model. ({{< issue 1005 >}}) diff --git a/plotnine/geoms/annotate.py b/plotnine/geoms/annotate.py index d54443e3b0..72cf275c57 100644 --- a/plotnine/geoms/annotate.py +++ b/plotnine/geoms/annotate.py @@ -15,7 +15,6 @@ from typing import Any from plotnine import ggplot - from plotnine.layer import layer class annotate: @@ -135,16 +134,7 @@ def __radd__(self, other: ggplot) -> ggplot: """ Add to ggplot """ - other += self.to_layer() # Add layer - return other - - def to_layer(self) -> layer: - """ - Make a layer that represents this annotation + from ..layer import layer - Returns - ------- - out : layer - Layer - """ - return self._annotation_geom.to_layer() + other += layer(geom=self._annotation_geom) + return other diff --git a/plotnine/geoms/geom.py b/plotnine/geoms/geom.py index 86822034c3..38ff079b6b 100644 --- a/plotnine/geoms/geom.py +++ b/plotnine/geoms/geom.py @@ -12,13 +12,11 @@ data_mapping_as_kwargs, remove_missing, ) -from .._utils.registry import Register, Registry +from .._utils.registry import Register from ..exceptions import PlotnineError from ..layer import layer from ..mapping.aes import rename_aesthetics from ..mapping.evaluation import evaluate -from ..positions.position import position -from ..stats.stat import stat if typing.TYPE_CHECKING: from typing import Any @@ -81,7 +79,7 @@ def __init__( ): kwargs = rename_aesthetics(kwargs) kwargs = data_mapping_as_kwargs((data, mapping), kwargs) - self._kwargs = kwargs # Will be used to create stat & layer + self._raw_kwargs = kwargs # Will be used to create stat & layer # separate aesthetics and parameters self.aes_params = { @@ -92,47 +90,6 @@ def __init__( } self.mapping = kwargs["mapping"] self.data = kwargs["data"] - self._stat = stat.from_geom(self) - self._position = position.from_geom(self) - self._verify_arguments(kwargs) # geom, stat, layer - - @staticmethod - def from_stat(stat: stat) -> geom: - """ - Return an instantiated geom object - - geoms should not override this method. - - Parameters - ---------- - stat : - `stat` - - Returns - ------- - : - A geom object - - Raises - ------ - PlotnineError - If unable to create a `geom`. - """ - name = stat.params["geom"] - - if isinstance(name, geom): - return name - - if isinstance(name, type) and issubclass(name, geom): - klass = name - elif isinstance(name, str): - if not name.startswith("geom_"): - name = f"geom_{name}" - klass = Registry[name] - else: - raise PlotnineError(f"Unknown geom of type {type(name)}") - - return klass(stat=stat, **stat._kwargs) @classmethod def aesthetics(cls: type[geom]) -> set[str]: @@ -163,7 +120,7 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> geom: new = result.__dict__ # don't make a deepcopy of data, or environment - shallow = {"data", "_kwargs", "environment"} + shallow = {"data", "_raw_kwargs", "environment"} for key, item in old.items(): if key in shallow: new[key] = item # pyright: ignore[reportIndexIssue] @@ -473,46 +430,9 @@ def __radd__(self, other: ggplot) -> ggplot: : ggplot object with added layer. """ - other += self.to_layer() # Add layer + other += layer(geom=self) return other - def to_layer(self) -> layer: - """ - Make a layer that represents this geom - - Returns - ------- - : - Layer - """ - return layer.from_geom(self) - - def _verify_arguments(self, kwargs: dict[str, Any]): - """ - Verify arguments passed to the geom - """ - geom_stat_args = kwargs.keys() | self._stat._kwargs.keys() - unknown = ( - geom_stat_args - - self.aesthetics() - - self.DEFAULT_PARAMS.keys() # geom aesthetics - - self._stat.aesthetics() # geom parameters - - self._stat.DEFAULT_PARAMS.keys() # stat aesthetics - - { # stat parameters - "data", - "mapping", - "show_legend", # layer parameters - "inherit_aes", - "raster", - } - ) # layer parameters - if unknown: - msg = ( - "Parameters {}, are not understood by " - "either the geom, stat or layer." - ) - raise PlotnineError(msg.format(unknown)) - def handle_na(self, data: pd.DataFrame) -> pd.DataFrame: """ Remove rows with NaN values diff --git a/plotnine/geoms/geom_dotplot.py b/plotnine/geoms/geom_dotplot.py index ef0aea1dd6..2fb36ec6d7 100644 --- a/plotnine/geoms/geom_dotplot.py +++ b/plotnine/geoms/geom_dotplot.py @@ -66,7 +66,7 @@ class geom_dotplot(geom): def setup_data(self, data: pd.DataFrame) -> pd.DataFrame: gp = self.params - sp = self._stat.params + sp = self.params["stat_params"] # Issue warnings when parameters don't make sense if gp["position"] == "stack": @@ -207,16 +207,15 @@ def draw_group( size = data["binwidth"].iloc[0] * params["dotsize"] offsets = data["stackpos"] * params["stackratio"] - if params["binaxis"] == "x": + binaxis = params["stat_params"]["binaxis"] + if binaxis == "x": width, height = size, size * factor xpos, ypos = data["x"], data["y"] + height * offsets - elif params["binaxis"] == "y": + elif binaxis == "y": width, height = size / factor, size xpos, ypos = data["x"] + width * offsets, data["y"] else: - raise ValueError( - f"Invalid valid value binaxis={params['binaxis']}" - ) + raise ValueError(f"Invalid valid value binaxis={binaxis}") circles = [] for xy in zip(xpos, ypos): diff --git a/plotnine/layer.py b/plotnine/layer.py index 42dc058503..aada49e4af 100644 --- a/plotnine/layer.py +++ b/plotnine/layer.py @@ -7,12 +7,13 @@ import pandas as pd from ._utils import array_kind, check_required_aesthetics, ninteraction +from ._utils.registry import Registry from .exceptions import PlotnineError from .mapping.aes import NO_GROUP, SCALED_AESTHETICS, aes, make_labels from .mapping.evaluation import evaluate, stage if typing.TYPE_CHECKING: - from typing import Any, Optional, Sequence, SupportsIndex + from typing import Any, Sequence, SupportsIndex from plotnine import ggplot from plotnine.coords.coord import coord @@ -34,38 +35,42 @@ class layer: """ Layer - When a `geom` or `stat` is added to a [](`~plotnine.ggplot`) object, - it creates a single layer. This class is a representation of that layer. + When a `geom` or `stat` is added to a + [](`~plotnine.ggplot`) object, it creates a single layer. + This class is a representation of that layer. Parameters ---------- geom : - geom to used to draw this layer. + Geom used to draw this layer. Accepts an instance, + a class, or a string name (e.g. ``"point"``). stat : - stat used for the statistical transformation of - data in this layer + Stat used for the statistical transformation of data + in this layer. Accepts an instance, a class, or a + string name. If ``None``, the geom's default stat is + used. mapping : Aesthetic mappings. data : Data plotted in this layer. If `None`, the data from the [](`~plotnine.ggplot`) object will be used. position : - Position object to adjust the geometries in this layer. + Position adjustment for geometries in this layer. + Accepts an instance, a class, or a string name. If + ``None``, the geom's default position is used. inherit_aes : If `True` inherit from the aesthetic mappings of the [](`~plotnine.ggplot`) object. show_legend : Whether to make up and show a legend for the mappings of this layer. If `None` then an automatic/good choice - is made + is made. raster : - If `True`, draw onto this layer a raster (bitmap) object - even if the final image format is vector. - - Notes - ----- - There is no benefit to manually creating a layer. You should - always use a `geom` or `stat`. + If `True`, draw onto this layer a raster (bitmap) + object even if the final image format is vector. + **kwargs : + Keyword arguments passed to the geom constructor when + *geom* is a class or string. """ # Data for this layer @@ -73,57 +78,92 @@ class layer: def __init__( self, - geom: geom, - stat: stat, + geom: geom | type[geom] | str | None = None, + stat: stat | type[stat] | str | None = None, *, - mapping: aes, - data: Optional[LayerDataLike], - position: position, + mapping: aes | None = None, + data: LayerDataLike | None = None, + position: position | type[position] | str | None = None, inherit_aes: bool = True, show_legend: bool | dict[str, bool] | None = None, raster: bool = False, + **kwargs: Any, ): - self.geom = geom - self.stat = stat - self._data = data - self.mapping = mapping - self.position = position - self.inherit_aes = inherit_aes - self.show_legend = show_legend - self.raster = raster + # Stat-first: derive geom from stat's default + if stat is not None: + stat_ref = _lookup_stat(stat) + if isinstance(stat_ref, type): + geom = stat_ref.DEFAULT_PARAMS.get("geom", "blank") + else: + geom = stat_ref.params.get("geom", "blank") + # Forward stat instance's kwargs to the geom + if mapping is None and data is None and not kwargs: + mapping = stat_ref._raw_kwargs.get("mapping") + data = stat_ref._raw_kwargs.get("data") + kwargs = { + k: v + for k, v in stat_ref._raw_kwargs.items() + if k not in ("mapping", "data") + } + + if geom is None: + geom = "blank" + + _geom = _resolve_geom(geom, mapping, data, kwargs) + _stat = _resolve_stat(stat, _geom) + _pos = _resolve_position(position, _geom) + self._verify_arguments(_geom, _stat) + + # Layer params: prefer explicit kwargs, fall back to + # geom._raw_kwargs, then geom.DEFAULT_PARAMS + raw = _geom._raw_kwargs + self.inherit_aes = raw.get( + "inherit_aes", + _geom.DEFAULT_PARAMS.get("inherit_aes", inherit_aes), + ) + self.show_legend = raw.get( + "show_legend", + _geom.DEFAULT_PARAMS.get("show_legend", show_legend), + ) + self.raster = raw.get( + "raster", + _geom.DEFAULT_PARAMS.get("raster", raster), + ) + + self.geom = _geom + self.stat = _stat + self._data = _geom.data + self.mapping = _geom.mapping + self.position = _pos self.zorder = 0 @staticmethod - def from_geom(geom: geom) -> layer: - """ - Create a layer given a [](`~plotnine.geoms.geom`) - - Parameters - ---------- - geom : - `geom` from which a layer will be created - - Returns - ------- - out : layer - Layer that represents the specific `geom`. - """ - kwargs = geom._kwargs - lkwargs = { - "geom": geom, - "mapping": geom.mapping, - "data": geom.data, - "stat": geom._stat, - "position": geom._position, - } - - layer_params = ("inherit_aes", "show_legend", "raster") - for param in layer_params: - if param in kwargs: - lkwargs[param] = kwargs[param] - elif param in geom.DEFAULT_PARAMS: - lkwargs[param] = geom.DEFAULT_PARAMS[param] - return layer(**lkwargs) + def _verify_arguments(geom: geom, stat: stat) -> None: + """ + Verify arguments for the geom, stat and layer + """ + geom_stat_args = geom._raw_kwargs.keys() | stat._raw_kwargs.keys() + unknown = ( + geom_stat_args + - geom.aesthetics() + - geom.DEFAULT_PARAMS.keys() + - stat.aesthetics() + - stat.DEFAULT_PARAMS.keys() + - { + "data", + "mapping", + "geom", + "show_legend", + "inherit_aes", + "raster", + } + ) + if unknown: + msg = ( + "Parameters {}, are not understood by " + "either the geom, stat or layer." + ) + raise PlotnineError(msg.format(unknown)) def __radd__(self, other: ggplot) -> ggplot: """ @@ -331,7 +371,7 @@ def setup_data(self): if len(data) == 0: return - self.geom.params.update(self.stat.params) + self.geom.params["stat_params"] = self.stat.params self.geom.setup_params(data) self.geom.setup_aes_params(data) data = self.geom.setup_data(data) @@ -568,3 +608,152 @@ def discrete_columns( continue lst.append(str(col)) return lst + + +def _resolve_geom( + geom_spec: geom | type[geom] | str, + mapping: aes | None, + data: LayerDataLike | None, + kwargs: dict[str, Any], +) -> geom: + """ + Resolve a geom specification to an instantiated geom + + Parameters + ---------- + geom_spec : + A geom instance, class, or string name. + mapping : + Aesthetic mappings. + data : + Layer data. + kwargs : + Additional keyword arguments forwarded to the geom + constructor. + """ + from .geoms.geom import geom as geom_cls + + if isinstance(geom_spec, geom_cls): + return geom_spec + + if isinstance(geom_spec, type) and issubclass(geom_spec, geom_cls): + klass = geom_spec + elif isinstance(geom_spec, str): + name = geom_spec + if not name.startswith("geom_"): + name = f"geom_{name}" + klass = Registry[name] + else: + raise PlotnineError(f"Unknown geom of type {type(geom_spec)}") + + return klass(mapping, data, **kwargs) + + +def _lookup_stat( + stat_spec: stat | type[stat] | str, +) -> stat | type[stat]: + """ + Look up a stat specification without instantiation + + Parameters + ---------- + stat_spec : + A stat instance, class, or string name. + + Returns + ------- + : + The stat instance or class. + """ + from .stats.stat import stat as stat_cls + + # Duck-type guard for module reloads + if not isinstance(stat_spec, type) and hasattr(stat_spec, "compute_layer"): + return stat_spec # type: ignore[return-value] + + if isinstance(stat_spec, stat_cls): + return stat_spec + + if isinstance(stat_spec, type) and issubclass(stat_spec, stat_cls): + return stat_spec + + if isinstance(stat_spec, str): + name = stat_spec + if not name.startswith("stat_"): + name = f"stat_{name}" + return Registry[name] + + raise PlotnineError(f"Unknown stat of type {type(stat_spec)}") + + +def _resolve_stat( + stat_spec: stat | type[stat] | str | None, + geom_obj: geom, +) -> stat: + """ + Resolve a stat specification to an instantiated stat + + Parameters + ---------- + stat_spec : + A stat instance, class, string name, or None to use + the geom's default. + geom_obj : + The resolved geom (used to derive defaults). + """ + from .stats.stat import stat as stat_cls + + if stat_spec is None: + stat_spec = geom_obj.params["stat"] + + result = _lookup_stat(stat_spec) # type: ignore[arg-type] + + if isinstance(result, stat_cls): + return result + + # It's a class — instantiate with filtered geom kwargs + klass = result + kwargs = geom_obj._raw_kwargs + valid_kwargs = ( + klass.aesthetics() | klass.DEFAULT_PARAMS.keys() + ) & kwargs.keys() + params = {k: kwargs[k] for k in valid_kwargs} + return klass(**params) + + +def _resolve_position( + position_spec: position | type[position] | str | None, + geom_obj: geom, +) -> position: + """ + Resolve a position specification to an instantiated position + + Parameters + ---------- + position_spec : + A position instance, class, string name, or None to use + the geom's default. + geom_obj : + The resolved geom (used to derive defaults). + """ + from .positions.position import position as position_cls + + if position_spec is None: + position_spec = geom_obj.params["position"] + + if isinstance(position_spec, position_cls): + return position_spec + + if isinstance(position_spec, type) and issubclass( + position_spec, position_cls + ): + klass = position_spec + elif isinstance(position_spec, str): + name = position_spec + if not name.startswith("position_"): + name = f"position_{name}" + klass = Registry[name] + else: + raise PlotnineError(f"Unknown position of type {type(position_spec)}") + + return klass() diff --git a/plotnine/positions/position.py b/plotnine/positions/position.py index 792946a9b3..ec30d223ab 100644 --- a/plotnine/positions/position.py +++ b/plotnine/positions/position.py @@ -8,7 +8,7 @@ import numpy as np from .._utils import check_required_aesthetics, groupby_apply -from .._utils.registry import Register, Registry +from .._utils.registry import Register from ..exceptions import PlotnineError, PlotnineWarning from ..mapping.aes import X_AESTHETICS, Y_AESTHETICS @@ -18,7 +18,6 @@ import pandas as pd from plotnine.facets.layout import Layout - from plotnine.geoms.geom import geom from plotnine.iapi import pos_scales from plotnine.typing import TransformCol @@ -132,41 +131,6 @@ def transform_position( return data - @staticmethod - def from_geom(geom: geom) -> position: - """ - Create and return a position object for the geom - - Parameters - ---------- - geom : geom - An instantiated geom object. - - Returns - ------- - out : position - A position object - - Raises - ------ - PlotnineError - If unable to create a `position`. - """ - name = geom.params["position"] - if issubclass(type(name), position): - return name - - if isinstance(name, type) and issubclass(name, position): - klass = name - elif isinstance(name, str): - if not name.startswith("position_"): - name = f"position_{name}" - klass = Registry[name] - else: - raise PlotnineError(f"Unknown position of type {type(name)}") - - return klass() - @staticmethod def strategy(data: pd.DataFrame, params: dict[str, Any]) -> pd.DataFrame: """ diff --git a/plotnine/stats/stat.py b/plotnine/stats/stat.py index 43e49ecbaa..cc2537665c 100644 --- a/plotnine/stats/stat.py +++ b/plotnine/stats/stat.py @@ -12,8 +12,7 @@ remove_missing, uniquecols, ) -from .._utils.registry import Register, Registry -from ..exceptions import PlotnineError +from .._utils.registry import Register from ..layer import layer from ..mapping import aes @@ -22,7 +21,6 @@ from plotnine import ggplot from plotnine.facets.layout import Layout - from plotnine.geoms.geom import geom from plotnine.iapi import pos_scales from plotnine.mapping import Environment from plotnine.typing import DataLike @@ -42,7 +40,7 @@ class stat(ABC, metaclass=Register): NON_MISSING_AES: set[str] = set() """Required aesthetics for the stat""" - DEFAULT_PARAMS: dict[str, Any] = {} + DEFAULT_PARAMS: dict[str, Any] = {"geom": "blank"} """Required parameters for the stat""" CREATES: set[str] = set() @@ -73,7 +71,7 @@ def __init__( **kwargs: Any, ): kwargs = data_mapping_as_kwargs((data, mapping), kwargs) - self._kwargs = kwargs # Will be used to create the geom + self._raw_kwargs = kwargs # Will be used to create the geom self.params = self.DEFAULT_PARAMS | { k: v for k, v in kwargs.items() if k in self.DEFAULT_PARAMS } @@ -82,52 +80,6 @@ def __init__( ae: kwargs[ae] for ae in self.aesthetics() & set(kwargs) } - @staticmethod - def from_geom(geom: geom) -> stat: - """ - Return an instantiated stat object - - stats should not override this method. - - Parameters - ---------- - geom : - A geom object - - Returns - ------- - stat - A stat object - - Raises - ------ - [](`~plotnine.exceptions.PlotnineError`) if unable to create a `stat`. - """ - name = geom.params["stat"] - kwargs = geom._kwargs - # More stable when reloading modules than - # using issubclass - if not isinstance(name, type) and hasattr(name, "compute_layer"): - return name - - if isinstance(name, stat): - return name - elif isinstance(name, type) and issubclass(name, stat): - klass = name - elif isinstance(name, str): - if not name.startswith("stat_"): - name = f"stat_{name}" - klass = Registry[name] - else: - raise PlotnineError(f"Unknown stat of type {type(name)}") - - valid_kwargs = ( - klass.aesthetics() | klass.DEFAULT_PARAMS.keys() - ) & kwargs.keys() - - params = {k: kwargs[k] for k in valid_kwargs} - return klass(geom=geom, **params) - def __deepcopy__(self, memo: dict[Any, Any]) -> stat: """ Deep copy without copying the self.data dataframe @@ -140,8 +92,8 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> stat: old = self.__dict__ new = result.__dict__ - # don't make a _kwargs - shallow = {"_kwargs"} + # don't make a _raw_kwargs + shallow = {"_raw_kwargs"} for key, item in old.items(): if key in shallow: new[key] = item # pyright: ignore[reportIndexIssue] @@ -394,19 +346,5 @@ def __radd__(self, other: ggplot) -> ggplot: out : ggplot object with added layer """ - other += self.to_layer() # Add layer + other += layer(stat=self) return other - - def to_layer(self) -> layer: - """ - Make a layer that represents this stat - - Returns - ------- - out : - Layer - """ - # Create, geom from stat, then layer from geom - from ..geoms.geom import geom - - return layer.from_geom(geom.from_stat(self)) diff --git a/tests/test_geom.py b/tests/test_geom.py index bfc3bec72d..4e614e48d7 100644 --- a/tests/test_geom.py +++ b/tests/test_geom.py @@ -4,6 +4,7 @@ from plotnine import aes, geom_point, ggplot, stat_identity from plotnine.exceptions import PlotnineError from plotnine.geoms.geom import geom +from plotnine.layer import layer data = pd.DataFrame({"col1": [1, 2, 3, 4], "col2": 2, "col3": list("abcd")}) @@ -46,18 +47,18 @@ class geom_abc(geom): DEFAULT_PARAMS = {"stat": "identity", "position": "identity"} with pytest.raises(PlotnineError): - geom_abc(do_the_impossible=True) + layer(geom=geom_abc(do_the_impossible=True)) def test_geom_from_stat(): stat = stat_identity(geom="point") - assert isinstance(geom.from_stat(stat), geom_point) + assert isinstance(layer(stat=stat).geom, geom_point) stat = stat_identity(geom="geom_point") - assert isinstance(geom.from_stat(stat), geom_point) + assert isinstance(layer(stat=stat).geom, geom_point) stat = stat_identity(geom=geom_point()) - assert isinstance(geom.from_stat(stat), geom_point) + assert isinstance(layer(stat=stat).geom, geom_point) stat = stat_identity(geom=geom_point) - assert isinstance(geom.from_stat(stat), geom_point) + assert isinstance(layer(stat=stat).geom, geom_point) diff --git a/tests/test_layers.py b/tests/test_layers.py index 5c60d04fcb..d8c343b491 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -42,7 +42,7 @@ def test_addition(self): assert _get_colors(p2) == colors # Real layers - lyrs = Layers(layer.from_geom(obj) for obj in self.lyrs) + lyrs = Layers(layer(geom=obj) for obj in self.lyrs) p3 = p + lyrs assert _get_colors(p3) == colors @@ -50,7 +50,7 @@ def test_addition(self): assert _get_colors(p) == colors with pytest.raises(PlotnineError): - geom_point() + layer.from_geom(geom_point()) + geom_point() + layer(geom=geom_point()) with pytest.raises(PlotnineError): geom_point() + self.lyrs @@ -77,7 +77,7 @@ def __init__(self, obj): self.obj = obj def __radd__(self, other): - other.layers.insert(0, self.obj.to_layer()) + other.layers.insert(0, layer(geom=self.obj)) return other p = ( diff --git a/tests/test_position.py b/tests/test_position.py index 6972de9f9b..f99197a7cf 100644 --- a/tests/test_position.py +++ b/tests/test_position.py @@ -26,7 +26,7 @@ stage, ) from plotnine.exceptions import PlotnineError -from plotnine.positions.position import position +from plotnine.layer import layer n = 6 m = 10 @@ -241,17 +241,17 @@ def test_jitterdodge(): def test_position_from_geom(): - geom = geom_point(position="jitter") - assert isinstance(position.from_geom(geom), position_jitter) + lyr = layer(geom=geom_point(position="jitter")) + assert isinstance(lyr.position, position_jitter) - geom = geom_point(position="position_jitter") - assert isinstance(position.from_geom(geom), position_jitter) + lyr = layer(geom=geom_point(position="position_jitter")) + assert isinstance(lyr.position, position_jitter) - geom = geom_point(position=position_jitter()) - assert isinstance(position.from_geom(geom), position_jitter) + lyr = layer(geom=geom_point(position=position_jitter())) + assert isinstance(lyr.position, position_jitter) - geom = geom_point(position=position_jitter) - assert isinstance(position.from_geom(geom), position_jitter) + lyr = layer(geom=geom_point(position=position_jitter)) + assert isinstance(lyr.position, position_jitter) def test_dodge_empty_data(): diff --git a/tests/test_stat.py b/tests/test_stat.py index d60988fc9f..e597fac8c0 100644 --- a/tests/test_stat.py +++ b/tests/test_stat.py @@ -5,6 +5,7 @@ from plotnine.data import mtcars from plotnine.exceptions import PlotnineError, PlotnineWarning from plotnine.geoms.geom import geom +from plotnine.layer import layer from plotnine.stats.stat import stat @@ -55,11 +56,25 @@ def draw(pinfo, panel_params, coord, ax, **kwargs): # not a geom manual setting g = geom_abc(weight=4) assert "weight" in g.aes_params - assert "weight" in g._stat.params + lyr = layer(geom=g) + assert "weight" in lyr.stat.params g = geom_abc(aes(weight="mpg")) assert "weight" in g.mapping - assert "weight" in g._stat.params + lyr = layer(geom=g) + assert "weight" in lyr.stat.params + + +def test_stat_extending(): + class stat_xyz(stat): + REQUIRED_AES = {"x", "y"} + + def compute_group(self, data, scales): + return data + + p = ggplot(mtcars, aes("wt", "mpg")) + stat_xyz(geom="point", size=1) + + p.draw_test() # pyright: ignore[reportAttributeAccessIssue] def test_calculated_expressions():