Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions plotnine/geoms/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def from_stat(stat: stat) -> geom:
PlotnineError
If unable to create a `geom`.
"""
name = stat.params["geom"]
name = stat.params.get("geom", "blank")

if isinstance(name, geom):
return name
Expand Down Expand Up @@ -494,18 +494,21 @@ def _verify_arguments(self, kwargs: dict[str, Any]):
geom_stat_args = kwargs.keys() | self._stat._kwargs.keys()
unknown = (
geom_stat_args
- self.aesthetics()
- self.DEFAULT_PARAMS.keys() # geom aesthetics
- self._stat.aesthetics() # geom parameters
- self._stat.DEFAULT_PARAMS.keys() # stat aesthetics
- { # stat parameters
- self.aesthetics() # geom aesthetics
- self.DEFAULT_PARAMS.keys() # geom parameters
- self._stat.aesthetics() # stat aesthetics
- self._stat.DEFAULT_PARAMS.keys() # stat parameters
- {
# stat parameters
"data",
"mapping",
"show_legend", # layer parameters
"geom",
# layer parameters
"show_legend",
"inherit_aes",
"raster",
}
) # layer parameters
)
if unknown:
msg = (
"Parameters {}, are not understood by "
Expand Down
2 changes: 1 addition & 1 deletion plotnine/stats/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class stat(ABC, metaclass=Register):
NON_MISSING_AES: set[str] = set()
"""Required aesthetics for the stat"""

DEFAULT_PARAMS: dict[str, Any] = {}
DEFAULT_PARAMS: dict[str, Any] = {"geom": "blank"}
"""Required parameters for the stat"""

CREATES: set[str] = set()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def draw(pinfo, panel_params, coord, ax, **kwargs):
assert "weight" in g._stat.params


def test_stat_extending():
class stat_xyz(stat):
REQUIRED_AES = {"x", "y"}

def compute_group(self, data, scales):
return data

p = ggplot(mtcars, aes("wt", "mpg")) + stat_xyz(geom="point", size=1)

p.draw_test() # pyright: ignore[reportAttributeAccessIssue]


def test_calculated_expressions():
p = ggplot(mtcars, aes(x="factor(cyl)", y="..count..+1")) + geom_bar()
# No exception
Expand Down