diff --git a/src/perspicacite/search/domain_aggregator.py b/src/perspicacite/search/domain_aggregator.py index 0cf840a..26c925c 100644 --- a/src/perspicacite/search/domain_aggregator.py +++ b/src/perspicacite/search/domain_aggregator.py @@ -51,6 +51,21 @@ def is_available(self, name: str) -> bool: return False +def _within_year_bounds(paper: Any, year_min: int | None, year_max: int | None) -> bool: + """True if ``paper.year`` lies within [year_min, year_max]. + + Papers with an unknown year are kept (missing publication metadata is common and + dropping them would silently shrink recall). Used to enforce the year window on the + merged results, since the window is not forwarded to the providers (see ``search``). + """ + year = getattr(paper, "year", None) + if year is None: + return True + if year_min is not None and year < year_min: + return False + return not (year_max is not None and year > year_max) + + class DomainAwareAggregator: """Routes queries to domain-appropriate providers and merges results.""" @@ -202,8 +217,14 @@ async def search( p, query=query, max_results=self._max_per, - year_min=year_min, - year_max=year_max, + # Year bounds are enforced post-merge on each paper's ``year`` rather + # than forwarded here: a year window makes SciLEx fan out into one + # query per year (26y x N APIs), blowing the per-provider timeout + # budget and returning [] — so an in-range request paradoxically + # yielded zero hits. Fetching on the fast (no-year) path and filtering + # the merged results is provider-independent and correct. + year_min=None, + year_max=None, extra_kwargs=extra, ) ) @@ -243,6 +264,8 @@ async def search( seen_title_hashes[title_hash] = paper merged.append(paper) + if year_min is not None or year_max is not None: + merged = [p for p in merged if _within_year_bounds(p, year_min, year_max)] return merged[:max_results] diff --git a/tests/unit/test_domain_aggregator.py b/tests/unit/test_domain_aggregator.py index c773b88..01d166d 100644 --- a/tests/unit/test_domain_aggregator.py +++ b/tests/unit/test_domain_aggregator.py @@ -36,6 +36,44 @@ async def search(self, query, max_results=20, year_min=None, year_max=None, **kw return self._papers +class _YearRecordingProvider: + """Records the year window it was called with and returns papers of mixed years.""" + + def __init__(self, papers: list[Paper]): + self.name = "gen" + self.description = "gen" + self.domains = ["general"] + self.tier = "reliable" + self.retry = 0 + self._papers = papers + self.seen_year_min: int | None = -1 # sentinel: never called + self.seen_year_max: int | None = -1 + + async def search(self, query, max_results=20, year_min=None, year_max=None, **kwargs): + self.seen_year_min, self.seen_year_max = year_min, year_max + return self._papers + + +@pytest.mark.asyncio +async def test_year_window_enforced_post_merge_not_forwarded(): + """A year_max keeps in-range + unknown-year papers, drops post-cutoff ones, and is + NOT forwarded to providers (fetch on the fast no-year path, then filter).""" + papers = [ + Paper(id="10.1/old", title="Old", doi="10.1/old", source=PaperSource.PUBMED, year=2018), + Paper(id="10.1/new", title="New", doi="10.1/new", source=PaperSource.PUBMED, year=2024), + Paper(id="10.1/none", title="NoYr", doi="10.1/none", source=PaperSource.PUBMED, year=None), + ] + p = _YearRecordingProvider(papers) + agg = DomainAwareAggregator([p], provider_timeout_s=5.0) + results = await agg.search("any query", year_max=2023) + + dois = {r.doi for r in results} + assert dois == {"10.1/old", "10.1/none"} # 2024 dropped; unknown-year kept + # The year window is enforced locally, never pushed to the provider (avoids the + # SciLEx per-year fan-out that returned zero results). + assert p.seen_year_min is None and p.seen_year_max is None + + @pytest.mark.asyncio async def test_basic_routing_general_provider(): p = _Provider("gen", [_paper("10.1/a")])