Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions manager_for_ynab/sankey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from datetime import date
from decimal import Decimal
from enum import Enum
from importlib.resources import files
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -68,6 +69,11 @@ class SankeyNode:
label: str


class SortBy(Enum):
ALPHABETICAL = "alphabetical"
AMOUNT = "amount"


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog=_PACKAGE,
Expand All @@ -91,8 +97,9 @@ def build_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--sort-by",
choices=("alphabetical", "amount"),
default="alphabetical",
type=SortBy,
choices=list(SortBy),
default=SortBy.ALPHABETICAL,
help="How to sort Sankey nodes within each stage.",
)
parser.add_argument(
Expand Down Expand Up @@ -150,7 +157,7 @@ async def sankey(
start: date,
end: date,
out: Path | None,
sort_by: str,
sort_by: SortBy,
quiet: bool,
token_override: str | None,
) -> int:
Expand Down Expand Up @@ -213,9 +220,7 @@ async def fetch_sankey_rows(
]


def build_sankey_data(
rows: Sequence[SankeyRow], *, sort_by: str = "alphabetical"
) -> SankeyData:
def build_sankey_data(rows: Sequence[SankeyRow], *, sort_by: SortBy) -> SankeyData:
labels: list[str] = []
indexes: dict[SankeyNode, int] = {}
links: defaultdict[tuple[SankeyNode, SankeyNode], Decimal] = defaultdict(Decimal)
Expand Down Expand Up @@ -254,18 +259,14 @@ def add_node(node: SankeyNode) -> None:
spending[(category_group, category)] += row.amount
categories_by_group[category_group].add(category)

if sort_by not in {"alphabetical", "amount"}:
msg = "sort_by must be 'alphabetical' or 'amount'"
raise ValueError(msg)

group_totals = {
group: sum(
(spending[(group, category)] for category in categories_by_group[group]),
Decimal(0),
)
for group in categories_by_group
}
if sort_by == "amount":
if sort_by == SortBy.AMOUNT:
income_nodes = sorted(
income, key=lambda node: (-income[node], node.label.casefold())
)
Expand All @@ -287,7 +288,7 @@ def add_node(node: SankeyNode) -> None:
)

def sorted_categories(group: SankeyNode) -> list[SankeyNode]:
if sort_by == "amount":
if sort_by == SortBy.AMOUNT:
return sorted(
categories_by_group[group],
key=lambda node: (-spending[(group, node)], node.label.casefold()),
Expand Down
44 changes: 25 additions & 19 deletions tests/sankey/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from manager_for_ynab.sankey import run
from manager_for_ynab.sankey import sankey
from manager_for_ynab.sankey import SankeyRow
from manager_for_ynab.sankey import SortBy


_SEED_SQL = Path(__file__).with_name("seed.sql")
Expand Down Expand Up @@ -86,7 +87,8 @@ def test_build_sankey_data_links_income_to_groups_to_categories():
"Groceries",
Decimal("45.5"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == [
Expand Down Expand Up @@ -117,10 +119,11 @@ def test_build_sankey_data_skips_zero_rows():
SankeyRow(
"Cafe", "food-group", "Food", "restaurants", "Restaurants", Decimal("0")
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data == build_sankey_data(())
assert data == build_sankey_data((), sort_by=SortBy.ALPHABETICAL)


def test_build_sankey_data_uses_sql_netted_category_outflows():
Expand All @@ -134,7 +137,8 @@ def test_build_sankey_data_uses_sql_netted_category_outflows():
"Gifts",
Decimal("50"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == ["Income", "Gifts", "Gifts"]
Expand All @@ -154,7 +158,8 @@ def test_build_sankey_data_treats_sql_netted_category_income_by_category():
"Gifts",
Decimal("-60"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == ["Gifts", "Net Category Income", "Income"]
Expand Down Expand Up @@ -182,7 +187,8 @@ def test_build_sankey_data_keeps_payee_income_separate_from_net_category_income(
"Gifts",
Decimal("-60"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == [
Expand Down Expand Up @@ -222,7 +228,8 @@ def test_build_sankey_data_groups_links_over_whole_range():
SankeyRow(
"Landlord", "bills-group", "Bills", "rent", "Rent", Decimal("80")
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == ["Employer", "Ready to Assign", "Income", "Bills", "Rent"]
Expand Down Expand Up @@ -259,7 +266,8 @@ def test_build_sankey_data_sorts_categories_within_groups_on_right_side():
"Amazon",
Decimal("40"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == [
Expand Down Expand Up @@ -307,7 +315,7 @@ def test_build_sankey_data_sorts_by_amount_with_label_tiebreaks():
),
SankeyRow("Gym", "health-group", "Health", "gym", "Gym", Decimal("70")),
],
sort_by="amount",
sort_by=SortBy.AMOUNT,
)

assert data.labels == [
Expand All @@ -323,11 +331,6 @@ def test_build_sankey_data_sorts_by_amount_with_label_tiebreaks():
]


def test_build_sankey_data_rejects_unknown_sort_by():
with pytest.raises(ValueError, match="sort_by must be 'alphabetical' or 'amount'"):
build_sankey_data((), sort_by="unknown")


def test_build_sankey_data_keeps_same_named_nodes_separate_by_stage():
data = build_sankey_data(
[
Expand All @@ -347,7 +350,8 @@ def test_build_sankey_data_keeps_same_named_nodes_separate_by_stage():
"Taxes",
Decimal("120"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

assert data.labels == ["Employer", "Ready to Assign", "Income", "Taxes", "Taxes"]
Expand Down Expand Up @@ -380,7 +384,8 @@ def test_build_echarts_html_uses_node_keys_and_labels():
"Taxes",
Decimal("120"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

html = build_echarts_html(data, start=date(2026, 4, 1), end=date(2026, 4, 30))
Expand Down Expand Up @@ -418,7 +423,8 @@ def test_build_echarts_html_floors_rendered_link_value_without_changing_tooltip_
"Taxes",
Decimal("0.5"),
),
]
],
sort_by=SortBy.ALPHABETICAL,
)

html = build_echarts_html(data, start=date(2026, 4, 1), end=date(2026, 4, 30))
Expand Down Expand Up @@ -578,7 +584,7 @@ async def test_sankey_skips_empty_data(sync, db, capsys):
start=date(2026, 6, 1),
end=date(2026, 6, 30),
out=None,
sort_by="alphabetical",
sort_by=SortBy.ALPHABETICAL,
quiet=False,
token_override=None,
)
Expand All @@ -600,7 +606,7 @@ async def test_sankey_quiet_suppresses_empty_output(sync, db, capsys):
start=date(2026, 6, 1),
end=date(2026, 6, 30),
out=None,
sort_by="alphabetical",
sort_by=SortBy.ALPHABETICAL,
quiet=True,
token_override=None,
)
Expand Down
Loading