diff --git a/main.py b/main.py index e6aabc57..98778559 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,7 @@ import concurrent.futures import threading import ipaddress +import socket from urllib.parse import urlparse from typing import Dict, List, Optional, Any, Set, Sequence @@ -209,8 +210,21 @@ def validate_folder_url(url: str) -> bool: log.warning(f"Skipping unsafe URL (private IP): {sanitize_for_log(url)}") return False except ValueError: - # Not an IP literal, it's a domain. - pass + # Not an IP literal, it's a domain. Resolve it to prevent SSRF against internal services. + try: + # Use getaddrinfo to support both IPv4 and IPv6 and check all resolved addresses + addr_info = socket.getaddrinfo(hostname, None) + for res in addr_info: + # res is (family, type, proto, canonname, sockaddr) + # sockaddr is (address, port) for IPv4 and (address, port, flow info, scope id) for IPv6 + ip_str = res[4][0] + ip_obj = ipaddress.ip_address(ip_str) + if ip_obj.is_private or ip_obj.is_loopback: + log.warning(f"Skipping unsafe URL (domain resolves to private IP {ip_str}): {sanitize_for_log(url)}") + return False + except Exception as e: + log.warning(f"Skipping unsafe URL (DNS resolution failed): {sanitize_for_log(url)} ({e})") + return False except Exception as e: log.warning(f"Failed to validate URL {sanitize_for_log(url)}: {e}") diff --git a/tests/test_ssrf.py b/tests/test_ssrf.py new file mode 100644 index 00000000..e67d4d58 --- /dev/null +++ b/tests/test_ssrf.py @@ -0,0 +1,70 @@ +import unittest +from unittest.mock import patch +import sys +import os +import socket + +# Add parent directory to path to import main +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from main import validate_folder_url + +class TestSSRF(unittest.TestCase): + def test_localhost_literal(self): + """Test that explicit localhost strings are rejected.""" + self.assertFalse(validate_folder_url("https://localhost/config.json")) + self.assertFalse(validate_folder_url("https://127.0.0.1/config.json")) + self.assertFalse(validate_folder_url("https://[::1]/config.json")) + + @patch('socket.getaddrinfo') + def test_private_ipv4_resolution(self, mock_getaddrinfo): + """Test that domains resolving to private IPv4 are rejected.""" + # mock returns list of (family, type, proto, canonname, sockaddr) + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 0)) + ] + url = "https://internal.private/config.json" + + self.assertFalse(validate_folder_url(url), "Should reject domain resolving to private IPv4") + + @patch('socket.getaddrinfo') + def test_private_ipv6_resolution(self, mock_getaddrinfo): + """Test that domains resolving to private IPv6 are rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('fd00::1', 0, 0, 0)) + ] + url = "https://internal6.private/config.json" + + self.assertFalse(validate_folder_url(url), "Should reject domain resolving to private IPv6") + + @patch('socket.getaddrinfo') + def test_mixed_resolution_unsafe(self, mock_getaddrinfo): + """Test that if ANY resolved IP is private, it is rejected.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('8.8.8.8', 0)), + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 0)) + ] + url = "https://mixed.private/config.json" + + self.assertFalse(validate_folder_url(url), "Should reject if any IP is private") + + @patch('socket.getaddrinfo') + def test_public_resolution(self, mock_getaddrinfo): + """Test that domains resolving to only public IPs are accepted.""" + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('8.8.8.8', 0)) + ] + url = "https://google.com/config.json" + + self.assertTrue(validate_folder_url(url), "Should accept domain resolving to public IP") + + @patch('socket.getaddrinfo') + def test_dns_resolution_failure(self, mock_getaddrinfo): + """Test that domains failing resolution are rejected.""" + mock_getaddrinfo.side_effect = Exception("DNS lookup failed") + url = "https://nonexistent.domain/config.json" + + self.assertFalse(validate_folder_url(url), "Should reject domain that fails resolution") + +if __name__ == '__main__': + unittest.main()