Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/perspicacite/search/domain_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def is_available(self, name: str) -> bool:
return False


def _within_year_bounds(paper: Any, year_min: int | None, year_max: int | None) -> bool:
"""True if ``paper.year`` lies within [year_min, year_max].

Papers with an unknown year are kept (missing publication metadata is common and
dropping them would silently shrink recall). Used to enforce the year window on the
merged results, since the window is not forwarded to the providers (see ``search``).
"""
year = getattr(paper, "year", None)
if year is None:
return True
if year_min is not None and year < year_min:
return False
return not (year_max is not None and year > year_max)


class DomainAwareAggregator:
"""Routes queries to domain-appropriate providers and merges results."""

Expand Down Expand Up @@ -202,8 +217,14 @@ async def search(
p,
query=query,
max_results=self._max_per,
year_min=year_min,
year_max=year_max,
# Year bounds are enforced post-merge on each paper's ``year`` rather
# than forwarded here: a year window makes SciLEx fan out into one
# query per year (26y x N APIs), blowing the per-provider timeout
# budget and returning [] — so an in-range request paradoxically
# yielded zero hits. Fetching on the fast (no-year) path and filtering
# the merged results is provider-independent and correct.
year_min=None,
year_max=None,
extra_kwargs=extra,
)
)
Expand Down Expand Up @@ -243,6 +264,8 @@ async def search(
seen_title_hashes[title_hash] = paper
merged.append(paper)

if year_min is not None or year_max is not None:
merged = [p for p in merged if _within_year_bounds(p, year_min, year_max)]
return merged[:max_results]


Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_domain_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,44 @@ async def search(self, query, max_results=20, year_min=None, year_max=None, **kw
return self._papers


class _YearRecordingProvider:
"""Records the year window it was called with and returns papers of mixed years."""

def __init__(self, papers: list[Paper]):
self.name = "gen"
self.description = "gen"
self.domains = ["general"]
self.tier = "reliable"
self.retry = 0
self._papers = papers
self.seen_year_min: int | None = -1 # sentinel: never called
self.seen_year_max: int | None = -1

async def search(self, query, max_results=20, year_min=None, year_max=None, **kwargs):
self.seen_year_min, self.seen_year_max = year_min, year_max
return self._papers


@pytest.mark.asyncio
async def test_year_window_enforced_post_merge_not_forwarded():
"""A year_max keeps in-range + unknown-year papers, drops post-cutoff ones, and is
NOT forwarded to providers (fetch on the fast no-year path, then filter)."""
papers = [
Paper(id="10.1/old", title="Old", doi="10.1/old", source=PaperSource.PUBMED, year=2018),
Paper(id="10.1/new", title="New", doi="10.1/new", source=PaperSource.PUBMED, year=2024),
Paper(id="10.1/none", title="NoYr", doi="10.1/none", source=PaperSource.PUBMED, year=None),
]
p = _YearRecordingProvider(papers)
agg = DomainAwareAggregator([p], provider_timeout_s=5.0)
results = await agg.search("any query", year_max=2023)

dois = {r.doi for r in results}
assert dois == {"10.1/old", "10.1/none"} # 2024 dropped; unknown-year kept
# The year window is enforced locally, never pushed to the provider (avoids the
# SciLEx per-year fan-out that returned zero results).
assert p.seen_year_min is None and p.seen_year_max is None


@pytest.mark.asyncio
async def test_basic_routing_general_provider():
p = _Provider("gen", [_paper("10.1/a")])
Expand Down
Loading