Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 186 additions & 2 deletions agent/tools/papers_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

Operations: trending, search, paper_details, read_paper,
find_datasets, find_models, find_collections, find_all_resources,
citation_graph, snippet_search, recommend
citation_graph, snippet_search, recommend,
pubmed_search, fetch_pubmed, fetch_doi
"""

import asyncio
import os
import re
import time
from typing import Any
from xml.etree import ElementTree

import httpx
from bs4 import BeautifulSoup, Tag
Expand All @@ -21,6 +23,11 @@
ARXIV_HTML = "https://arxiv.org/html"
AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html"

PUBMED_ESEARCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
PUBMED_ESUMMARY = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
PUBMED_EFETCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
CROSSREF_API = "https://api.crossref.org/works"

DEFAULT_LIMIT = 10
MAX_LIMIT = 50
MAX_SUMMARY_LEN = 300
Expand Down Expand Up @@ -1139,6 +1146,169 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
}


# ---------------------------------------------------------------------------
# PubMed operations
# ---------------------------------------------------------------------------

async def _op_pubmed_search(args: dict[str, Any], limit: int) -> ToolResult:
"""Search PubMed via NCBI E-utilities (biomedical, clinical, pharmacological)."""
query = args.get("query", "").strip()
if not query:
return _error("'query' is required for pubmed_search.")

params = {
"db": "pubmed",
"term": query,
"retmax": limit,
"retmode": "json",
"tool": "ml-intern",
}
async with httpx.AsyncClient(timeout=15) as client:
try:
resp = await client.get(PUBMED_ESEARCH, params=params)
resp.raise_for_status()
ids = resp.json().get("esearchresult", {}).get("idlist", [])
except Exception as exc:
return _error(f"PubMed search failed: {exc}")

if not ids:
return {"formatted": f"No PubMed results for: {query}", "totalResults": 0, "resultsShared": 0}

try:
sum_resp = await client.get(
PUBMED_ESUMMARY,
params={"db": "pubmed", "id": ",".join(ids), "retmode": "json", "tool": "ml-intern"},
)
sum_resp.raise_for_status()
result_data = sum_resp.json().get("result", {})
except Exception as exc:
return _error(f"PubMed summary failed: {exc}")

lines = [f"# PubMed results for: {query}\n"]
for pmid in ids:
item = result_data.get(pmid, {})
title = item.get("title", "(no title)")
authors = ", ".join(a.get("name", "") for a in item.get("authors", [])[:3])
if len(item.get("authors", [])) > 3:
authors += " et al."
source = item.get("source", "")
pubdate = item.get("pubdate", "")
lines.append(f"**pmid:{pmid}** — {title}")
lines.append(f" {authors} · {source} · {pubdate}")
lines.append(f" → fetch with: fetch_pubmed pmid={pmid}\n")

return {"formatted": "\n".join(lines), "totalResults": len(ids), "resultsShared": len(ids)}


async def _op_fetch_pubmed(args: dict[str, Any], limit: int) -> ToolResult:
"""Fetch a PubMed abstract by PMID."""
pmid = str(args.get("pmid", "")).strip().removeprefix("pmid:")
if not pmid:
return _error("'pmid' is required for fetch_pubmed.")

params = {"db": "pubmed", "id": pmid, "rettype": "abstract", "retmode": "xml", "tool": "ml-intern"}
async with httpx.AsyncClient(timeout=15) as client:
try:
resp = await client.get(PUBMED_EFETCH, params=params)
resp.raise_for_status()
except Exception as exc:
return _error(f"PubMed fetch failed for pmid:{pmid}: {exc}")

try:
root = ElementTree.fromstring(resp.text)
article = root.find(".//PubmedArticle")
if article is None:
return _error(f"No article found for pmid:{pmid}")

title = article.findtext(".//ArticleTitle") or "(no title)"
abstract_parts = article.findall(".//AbstractText")
abstract = " ".join(
(f"**{p.get('Label')}:** " if p.get("Label") else "") + (p.text or "")
for p in abstract_parts
).strip()
authors = []
for author in article.findall(".//Author")[:5]:
last = author.findtext("LastName") or ""
fore = author.findtext("ForeName") or ""
if last:
authors.append(f"{fore} {last}".strip())
journal = article.findtext(".//Journal/Title") or article.findtext(".//MedlineTA") or ""
pub_year = article.findtext(".//PubDate/Year") or ""
doi = next(
(id_el.text for id_el in article.findall(".//ArticleId") if id_el.get("IdType") == "doi"),
None,
)
except ElementTree.ParseError as exc:
return _error(f"Failed to parse PubMed XML for pmid:{pmid}: {exc}")

lines = [f"# {title}"]
lines.append(f"**PMID:** {pmid} | **URL:** https://pubmed.ncbi.nlm.nih.gov/{pmid}/")
if doi:
lines.append(f"**DOI:** https://doi.org/{doi}")
if authors:
suffix = " et al." if len(article.findall(".//Author")) > 5 else ""
lines.append(f"**Authors:** {', '.join(authors)}{suffix}")
lines.append(f"**Journal:** {journal} | **Year:** {pub_year}")
lines.append("")
lines.append("## Abstract")
lines.append(abstract or "(no abstract available)")
return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}


# ---------------------------------------------------------------------------
# DOI fetch via Crossref
# ---------------------------------------------------------------------------

async def _op_fetch_doi(args: dict[str, Any], limit: int) -> ToolResult:
"""Fetch metadata and abstract for any DOI via Crossref (bioRxiv, medRxiv, journals)."""
doi = str(args.get("doi", "")).strip().removeprefix("doi:")
if not doi:
return _error("'doi' is required for fetch_doi.")

url = f"{CROSSREF_API}/{doi}"
headers = {"User-Agent": "ml-intern/1.0 (mailto:ml-intern@huggingface.co)"}
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
try:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
work = resp.json().get("message", {})
except httpx.HTTPStatusError as exc:
return _error(f"Crossref returned {exc.response.status_code} for doi:{doi}")
except Exception as exc:
return _error(f"Crossref fetch failed for doi:{doi}: {exc}")

title_parts = work.get("title") or []
title = title_parts[0] if title_parts else "(no title)"
authors = []
for a in (work.get("author") or [])[:5]:
given = a.get("given", "")
family = a.get("family", "")
if family:
authors.append(f"{given} {family}".strip())
container = (work.get("container-title") or [""])[0]
pub_date_parts = (work.get("published") or work.get("published-print") or work.get("published-online") or {}).get("date-parts", [[]])
pub_date = "-".join(str(p) for p in pub_date_parts[0]) if pub_date_parts else ""
abstract_raw = work.get("abstract", "")
abstract = re.sub(r"<[^>]+>", "", abstract_raw).strip()
full_text_url = f"https://doi.org/{doi}"

lines = [f"# {title}"]
lines.append(f"**DOI:** https://doi.org/{doi}")
lines.append(f"**Source:** {container} | **Published:** {pub_date}")
if authors:
suffix = " et al." if len(work.get("author") or []) > 5 else ""
lines.append(f"**Authors:** {', '.join(authors)}{suffix}")
lines.append(f"**Full text:** {full_text_url}")
lines.append("")
if abstract:
lines.append("## Abstract")
lines.append(abstract)
else:
lines.append("*(Abstract not available via Crossref for this DOI)*")

return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}


# ---------------------------------------------------------------------------
# Operation dispatch
# ---------------------------------------------------------------------------
Expand All @@ -1155,6 +1325,9 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
"find_models": _op_find_models,
"find_collections": _op_find_collections,
"find_all_resources": _op_find_all_resources,
"pubmed_search": _op_pubmed_search,
"fetch_pubmed": _op_fetch_pubmed,
"fetch_doi": _op_fetch_doi,
}


Expand Down Expand Up @@ -1183,7 +1356,10 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
"- find_datasets: Find datasets linked to a paper\n"
"- find_models: Find models linked to a paper\n"
"- find_collections: Find collections that include a paper\n"
"- find_all_resources: Parallel fetch of datasets + models + collections for a paper"
"- find_all_resources: Parallel fetch of datasets + models + collections for a paper\n"
"- pubmed_search: Search PubMed (biomedical, clinical, pharmacological literature)\n"
"- fetch_pubmed: Fetch abstract for a PubMed paper by PMID (e.g. pmid=38903003)\n"
"- fetch_doi: Fetch metadata + abstract for any DOI via Crossref (bioRxiv, medRxiv, journals)"
),
"parameters": {
"type": "object",
Expand Down Expand Up @@ -1265,6 +1441,14 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
"type": "integer",
"description": "Maximum results to return (default: 10, max: 50).",
},
"pmid": {
"type": "string",
"description": "PubMed ID (e.g. '38903003' or 'pmid:38903003'). Required for: fetch_pubmed.",
},
"doi": {
"type": "string",
"description": "DOI string (e.g. '10.1101/2023.12.15.571821' or 'doi:10.1101/...'). Required for: fetch_doi.",
},
},
"required": ["operation"],
},
Expand Down