From 12074de551361e3651fd21c1d2c8cc673a86d494 Mon Sep 17 00:00:00 2001 From: JasonOA888 Date: Fri, 24 Apr 2026 01:17:50 +0800 Subject: [PATCH 1/2] fix: validate URL origin in fetch_hf_docs to prevent SSRF The fetch_hf_docs tool accepted arbitrary URLs from LLM-generated arguments without origin validation. Since the request carries the user's HF Bearer token and follows redirects, a crafted prompt could trick the LLM into exfiltrating the token to an attacker-controlled server. Add _is_allowed_doc_url() which restricts fetch requests to huggingface.co, hf.co, and gradio.app origins (matching the tool's intended purpose of fetching HF/Gradio documentation). --- agent/tools/docs_tools.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/agent/tools/docs_tools.py b/agent/tools/docs_tools.py index a1782107..c3ae5af4 100644 --- a/agent/tools/docs_tools.py +++ b/agent/tools/docs_tools.py @@ -12,6 +12,7 @@ from whoosh.fields import ID, TEXT, Schema from whoosh.filedb.filestore import RamStorage from whoosh.qparser import MultifieldParser, OrGroup +from urllib.parse import urlparse # --------------------------------------------------------------------------- # Configuration @@ -379,6 +380,28 @@ async def explore_hf_docs_handler( return f"Unexpected error: {str(e)}", False +# Allowed origins for fetch_hf_docs -- prevents SSRF via LLM-generated URLs. +_ALLOWED_DOC_ORIGINS = { + "huggingface.co", + "hf.co", + "gradio.app", +} + + +def _is_allowed_doc_url(url: str) -> bool: + """Return True if *url* points to an allowed documentation origin.""" + try: + parsed = urlparse(url) + except Exception: + return False + if parsed.scheme != "https": + return False + host = parsed.hostname or "" + return host in _ALLOWED_DOC_ORIGINS or any( + host.endswith(f".{d}") for d in _ALLOWED_DOC_ORIGINS + ) + + async def hf_docs_fetch_handler( arguments: dict[str, Any], session=None ) -> tuple[str, bool]: @@ -387,6 +410,13 @@ async def hf_docs_fetch_handler( if not url: return "Error: No URL provided", False + if not _is_allowed_doc_url(url): + return ( + f"Error: URL not allowed. Only huggingface.co, hf.co, and gradio.app " + f"documentation URLs are accepted. Got: {url}", + False, + ) + hf_token = session.hf_token if session else None if not hf_token: return "Error: No HF token available (not logged in)", False From c058b74dbc2e31abbadf79057812778e978520a0 Mon Sep 17 00:00:00 2001 From: JasonOA888 Date: Fri, 24 Apr 2026 12:58:40 +0800 Subject: [PATCH 2/2] test: add unit tests for _is_allowed_doc_url SSRF guard 23 test cases covering: - Allowed origins (exact + subdomain) - Blocked schemes (HTTP, FTP) - Disallowed hosts (evil.com, metadata endpoint, prefix attacks) - SSRF payloads (127.0.0.1, 0.0.0.0, ::1, localhost) - Edge cases (empty, garbage, port numbers) --- tests/unit/test_docs_tools_ssrf.py | 133 +++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 tests/unit/test_docs_tools_ssrf.py diff --git a/tests/unit/test_docs_tools_ssrf.py b/tests/unit/test_docs_tools_ssrf.py new file mode 100644 index 00000000..a91e2811 --- /dev/null +++ b/tests/unit/test_docs_tools_ssrf.py @@ -0,0 +1,133 @@ +"""Tests for _is_allowed_doc_url SSRF guard in agent/tools/docs_tools.py.""" + +import importlib.util +import sys +from pathlib import Path + +import pytest + +# Stub heavy dependencies BEFORE any import chain triggers +from unittest.mock import MagicMock + +_STUBS = [ + "litellm", "datasets", "fastmcp", "huggingface_hub", + "sentence_transformers", "nbconvert", "torch", + "agent", "agent.tools", "agent.core", +] +for mod in _STUBS: + if mod not in sys.modules: + sys.modules[mod] = MagicMock() + +# Import just the source file directly, bypassing __init__.py chains +_spec = importlib.util.spec_from_file_location( + "docs_tools", + Path(__file__).resolve().parent.parent.parent / "agent" / "tools" / "docs_tools.py", +) +_docs_tools = importlib.util.module_from_spec(_spec) +# Provide the deps that docs_tools actually uses at module level +_deps = { + "httpx": __import__("httpx"), + "bs4": __import__("bs4"), + "whoosh": __import__("whoosh"), +} +for name, mod in _deps.items(): + sys.modules[name] = mod +_spec.loader.exec_module(_docs_tools) + +_is_allowed_doc_url = _docs_tools._is_allowed_doc_url + + +# ── Allowed origins ────────────────────────────────────────────────────── + +class TestAllowedOrigins: + + @pytest.mark.parametrize( + "url", + [ + "https://huggingface.co/docs/transformers", + "https://hf.co/docs/trl", + "https://gradio.app/docs", + "https://huggingface.co/docs/trl/dpo_trainer", + "https://hf.co/docs/some-deep/path/page.md", + ], + ) + def test_exact_allowed_hosts(self, url: str): + assert _is_allowed_doc_url(url) is True + + @pytest.mark.parametrize( + "url", + [ + "https://sub.huggingface.co/anything", + "https://cdn.gradio.app/assets/foo", + "https://mirror.hf.co/docs/x", + ], + ) + def test_subdomain_allowed(self, url: str): + assert _is_allowed_doc_url(url) is True + + +# ── Blocked: wrong scheme ──────────────────────────────────────────────── + +class TestBlockedScheme: + + @pytest.mark.parametrize( + "url", + [ + "http://huggingface.co/docs/transformers", + "http://hf.co/docs/x", + "ftp://huggingface.co/etc/passwd", + ], + ) + def test_non_https_rejected(self, url: str): + assert _is_allowed_doc_url(url) is False + + +# ── Blocked: disallowed hosts ──────────────────────────────────────────── + +class TestBlockedHosts: + + @pytest.mark.parametrize( + "url", + [ + "https://evil.com/docs", + "https://169.254.169.254/latest/meta-data/", + "https://evil-huggingface.co/docs", + "https://huggingface.co.evil.com/docs", + ], + ) + def test_disallowed_hosts_rejected(self, url: str): + assert _is_allowed_doc_url(url) is False + + +# ── Blocked: SSRF payloads ─────────────────────────────────────────────── + +class TestSSRFPayloads: + + @pytest.mark.parametrize( + "url", + [ + "https://127.0.0.1/api/internal", + "https://0.0.0.0/", + "https://[::1]/admin", + "https://localhost/etc/passwd", + ], + ) + def test_internal_addresses_rejected(self, url: str): + assert _is_allowed_doc_url(url) is False + + +# ── Edge cases ─────────────────────────────────────────────────────────── + +class TestEdgeCases: + + def test_empty_string(self): + assert _is_allowed_doc_url("") is False + + def test_bare_host_no_path(self): + assert _is_allowed_doc_url("https://huggingface.co") is True + + def test_garbage_input(self): + assert _is_allowed_doc_url("not-a-url") is False + + def test_port_number(self): + assert _is_allowed_doc_url("https://huggingface.co:443/docs/x") is True