diff --git a/manager_for_ynab/sankey/__init__.py b/manager_for_ynab/sankey/__init__.py index a269992..78dfc5d 100644 --- a/manager_for_ynab/sankey/__init__.py +++ b/manager_for_ynab/sankey/__init__.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from datetime import date from decimal import Decimal +from enum import Enum from importlib.resources import files from pathlib import Path from typing import TYPE_CHECKING @@ -68,6 +69,11 @@ class SankeyNode: label: str +class SortBy(Enum): + ALPHABETICAL = "alphabetical" + AMOUNT = "amount" + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog=_PACKAGE, @@ -91,8 +97,9 @@ def build_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--sort-by", - choices=("alphabetical", "amount"), - default="alphabetical", + type=SortBy, + choices=list(SortBy), + default=SortBy.ALPHABETICAL, help="How to sort Sankey nodes within each stage.", ) parser.add_argument( @@ -150,7 +157,7 @@ async def sankey( start: date, end: date, out: Path | None, - sort_by: str, + sort_by: SortBy, quiet: bool, token_override: str | None, ) -> int: @@ -213,9 +220,7 @@ async def fetch_sankey_rows( ] -def build_sankey_data( - rows: Sequence[SankeyRow], *, sort_by: str = "alphabetical" -) -> SankeyData: +def build_sankey_data(rows: Sequence[SankeyRow], *, sort_by: SortBy) -> SankeyData: labels: list[str] = [] indexes: dict[SankeyNode, int] = {} links: defaultdict[tuple[SankeyNode, SankeyNode], Decimal] = defaultdict(Decimal) @@ -254,10 +259,6 @@ def add_node(node: SankeyNode) -> None: spending[(category_group, category)] += row.amount categories_by_group[category_group].add(category) - if sort_by not in {"alphabetical", "amount"}: - msg = "sort_by must be 'alphabetical' or 'amount'" - raise ValueError(msg) - group_totals = { group: sum( (spending[(group, category)] for category in categories_by_group[group]), @@ -265,7 +266,7 @@ def add_node(node: SankeyNode) -> None: ) for group in categories_by_group } - if sort_by == "amount": + if sort_by == SortBy.AMOUNT: income_nodes = sorted( income, key=lambda node: (-income[node], node.label.casefold()) ) @@ -287,7 +288,7 @@ def add_node(node: SankeyNode) -> None: ) def sorted_categories(group: SankeyNode) -> list[SankeyNode]: - if sort_by == "amount": + if sort_by == SortBy.AMOUNT: return sorted( categories_by_group[group], key=lambda node: (-spending[(group, node)], node.label.casefold()), diff --git a/tests/sankey/test.py b/tests/sankey/test.py index a3ea989..f04945f 100644 --- a/tests/sankey/test.py +++ b/tests/sankey/test.py @@ -15,6 +15,7 @@ from manager_for_ynab.sankey import run from manager_for_ynab.sankey import sankey from manager_for_ynab.sankey import SankeyRow +from manager_for_ynab.sankey import SortBy _SEED_SQL = Path(__file__).with_name("seed.sql") @@ -86,7 +87,8 @@ def test_build_sankey_data_links_income_to_groups_to_categories(): "Groceries", Decimal("45.5"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == [ @@ -117,10 +119,11 @@ def test_build_sankey_data_skips_zero_rows(): SankeyRow( "Cafe", "food-group", "Food", "restaurants", "Restaurants", Decimal("0") ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) - assert data == build_sankey_data(()) + assert data == build_sankey_data((), sort_by=SortBy.ALPHABETICAL) def test_build_sankey_data_uses_sql_netted_category_outflows(): @@ -134,7 +137,8 @@ def test_build_sankey_data_uses_sql_netted_category_outflows(): "Gifts", Decimal("50"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == ["Income", "Gifts", "Gifts"] @@ -154,7 +158,8 @@ def test_build_sankey_data_treats_sql_netted_category_income_by_category(): "Gifts", Decimal("-60"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == ["Gifts", "Net Category Income", "Income"] @@ -182,7 +187,8 @@ def test_build_sankey_data_keeps_payee_income_separate_from_net_category_income( "Gifts", Decimal("-60"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == [ @@ -222,7 +228,8 @@ def test_build_sankey_data_groups_links_over_whole_range(): SankeyRow( "Landlord", "bills-group", "Bills", "rent", "Rent", Decimal("80") ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == ["Employer", "Ready to Assign", "Income", "Bills", "Rent"] @@ -259,7 +266,8 @@ def test_build_sankey_data_sorts_categories_within_groups_on_right_side(): "Amazon", Decimal("40"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == [ @@ -307,7 +315,7 @@ def test_build_sankey_data_sorts_by_amount_with_label_tiebreaks(): ), SankeyRow("Gym", "health-group", "Health", "gym", "Gym", Decimal("70")), ], - sort_by="amount", + sort_by=SortBy.AMOUNT, ) assert data.labels == [ @@ -323,11 +331,6 @@ def test_build_sankey_data_sorts_by_amount_with_label_tiebreaks(): ] -def test_build_sankey_data_rejects_unknown_sort_by(): - with pytest.raises(ValueError, match="sort_by must be 'alphabetical' or 'amount'"): - build_sankey_data((), sort_by="unknown") - - def test_build_sankey_data_keeps_same_named_nodes_separate_by_stage(): data = build_sankey_data( [ @@ -347,7 +350,8 @@ def test_build_sankey_data_keeps_same_named_nodes_separate_by_stage(): "Taxes", Decimal("120"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) assert data.labels == ["Employer", "Ready to Assign", "Income", "Taxes", "Taxes"] @@ -380,7 +384,8 @@ def test_build_echarts_html_uses_node_keys_and_labels(): "Taxes", Decimal("120"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) html = build_echarts_html(data, start=date(2026, 4, 1), end=date(2026, 4, 30)) @@ -418,7 +423,8 @@ def test_build_echarts_html_floors_rendered_link_value_without_changing_tooltip_ "Taxes", Decimal("0.5"), ), - ] + ], + sort_by=SortBy.ALPHABETICAL, ) html = build_echarts_html(data, start=date(2026, 4, 1), end=date(2026, 4, 30)) @@ -578,7 +584,7 @@ async def test_sankey_skips_empty_data(sync, db, capsys): start=date(2026, 6, 1), end=date(2026, 6, 30), out=None, - sort_by="alphabetical", + sort_by=SortBy.ALPHABETICAL, quiet=False, token_override=None, ) @@ -600,7 +606,7 @@ async def test_sankey_quiet_suppresses_empty_output(sync, db, capsys): start=date(2026, 6, 1), end=date(2026, 6, 30), out=None, - sort_by="alphabetical", + sort_by=SortBy.ALPHABETICAL, quiet=True, token_override=None, )