From 27553d86fef590273e7d694fc37b2fd56764d68b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:53:19 +0000 Subject: [PATCH] Harden SSRF protection and add security tests Co-authored-by: abhimehro <84992105+abhimehro@users.noreply.github.com> --- main.py | 12 ++++---- tests/test_security.py | 64 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 86792da4..0e3ce7a8 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,6 @@ import sys import threading import time -from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Sequence, Set from urllib.parse import urlparse @@ -321,12 +320,12 @@ def _api_client() -> httpx.Client: _cache_lock = threading.RLock() -@lru_cache(maxsize=128) +# SECURITY: Do not cache validation results. Caching introduces a TOCTOU (Time-of-Check Time-of-Use) +# vulnerability where a DNS record could change from a public IP to a private IP (DNS Rebinding) +# in the time between validation and fetch. We must re-validate immediately before fetching. def validate_folder_url(url: str) -> bool: """ Validates a folder URL. - Cached to avoid repeated DNS lookups (socket.getaddrinfo) for the same URL - during warm-up and sync phases. """ if not url.startswith("https://"): log.warning( @@ -1043,9 +1042,8 @@ def sync_profile( no_delete: bool = False, plan_accumulator: Optional[List[Dict[str, Any]]] = None, ) -> bool: - # SECURITY: Clear cached DNS validations at the start of each sync run. - # This prevents TOCTOU issues where a domain's IP could change between runs. - validate_folder_url.cache_clear() + # SECURITY: Removed cache clearing because we removed the cache on validate_folder_url + # to mitigate TOCTOU (DNS Rebinding) risks. Validation now happens per-fetch. try: # Fetch all folder data first diff --git a/tests/test_security.py b/tests/test_security.py index 536852c1..09e5fba2 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,7 +1,8 @@ import os import stat import sys -from unittest.mock import MagicMock +import socket +from unittest.mock import MagicMock, patch import pytest @@ -285,3 +286,64 @@ def test_is_valid_rule_strict(rule, expected_validity): Tests the is_valid_rule function against a strict whitelist of inputs. """ assert is_valid_rule(rule) == expected_validity, f"Failed for rule: {rule}" + + +# Mock helpers for TestValidateFolderUrl +def mock_getaddrinfo_ipv4(ip): + return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (ip, 443))] + +def mock_getaddrinfo_ipv6(ip): + return [(socket.AF_INET6, socket.SOCK_STREAM, 6, '', (ip, 443, 0, 0))] + +class TestValidateFolderUrl: + """Security tests for SSRF protection in validate_folder_url.""" + + def test_rejects_non_https(self): + assert main.validate_folder_url("http://example.com/foo") is False + assert main.validate_folder_url("ftp://example.com/foo") is False + assert main.validate_folder_url("file:///etc/passwd") is False + + @patch('socket.getaddrinfo') + def test_accepts_valid_public_ip(self, mock_gai): + # Mock domain resolving to public IP (8.8.8.8) + mock_gai.side_effect = lambda host, *args, **kwargs: mock_getaddrinfo_ipv4("8.8.8.8") + assert main.validate_folder_url("https://example.com/foo") is True + + def test_rejects_localhost_literal(self): + # Test explicit localhost hostname check + assert main.validate_folder_url("https://localhost/foo") is False + + @patch('socket.getaddrinfo') + def test_rejects_private_ip_resolution(self, mock_gai): + # Mock domain resolving to private IP (192.168.1.1) + mock_gai.side_effect = lambda host, *args, **kwargs: mock_getaddrinfo_ipv4("192.168.1.1") + assert main.validate_folder_url("https://internal.corp/foo") is False + + @patch('socket.getaddrinfo') + def test_rejects_loopback_ip_resolution(self, mock_gai): + # Mock domain resolving to loopback IP (127.0.0.1) + mock_gai.side_effect = lambda host, *args, **kwargs: mock_getaddrinfo_ipv4("127.0.0.1") + assert main.validate_folder_url("https://evil.com/foo") is False + + @patch('socket.getaddrinfo') + def test_rejects_ipv6_loopback_resolution(self, mock_gai): + # Mock domain resolving to IPv6 loopback (::1) + mock_gai.side_effect = lambda host, *args, **kwargs: mock_getaddrinfo_ipv6("::1") + assert main.validate_folder_url("https://ipv6.local/foo") is False + + def test_rejects_ip_literal_private(self): + # IP literals are checked directly via ipaddress module, bypassing DNS + assert main.validate_folder_url("https://192.168.1.1/foo") is False + assert main.validate_folder_url("https://10.0.0.1/foo") is False + assert main.validate_folder_url("https://127.0.0.1/foo") is False + assert main.validate_folder_url("https://[::1]/foo") is False + + def test_accepts_ip_literal_public(self): + assert main.validate_folder_url("https://8.8.8.8/foo") is True + assert main.validate_folder_url("https://[2001:4860:4860::8888]/foo") is True + + @patch('socket.getaddrinfo') + def test_dns_resolution_failure_is_safe(self, mock_gai): + # Ensure that if DNS fails, we default to False (fail closed) + mock_gai.side_effect = socket.gaierror("Name or service not known") + assert main.validate_folder_url("https://nonexistent.com/foo") is False