From 98e949d2fd9522a91b430ac239c5d945304601b1 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 24 Dec 2025 15:05:35 +0000 Subject: [PATCH] Patches (#2218) * feat: implement URL validation to prevent SSRF * feat: add zip extraction security * ruff fixes --- application/agents/tools/api_tool.py | 24 ++ application/agents/tools/read_webpage.py | 13 +- application/core/url_validation.py | 181 +++++++++++ application/parser/remote/crawler_loader.py | 19 +- application/parser/remote/crawler_markdown.py | 15 +- application/parser/remote/sitemap_loader.py | 17 +- application/worker.py | 120 ++++++- tests/core/test_url_validation.py | 197 ++++++++++++ tests/parser/remote/test_crawler_loader.py | 41 ++- tests/parser/remote/test_crawler_markdown.py | 36 ++- tests/test_zip_extraction_security.py | 293 ++++++++++++++++++ 11 files changed, 929 insertions(+), 27 deletions(-) create mode 100644 application/core/url_validation.py create mode 100644 tests/core/test_url_validation.py create mode 100644 tests/test_zip_extraction_security.py diff --git a/application/agents/tools/api_tool.py b/application/agents/tools/api_tool.py index 6bd2eb8d..e010b51b 100644 --- a/application/agents/tools/api_tool.py +++ b/application/agents/tools/api_tool.py @@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import ( RequestBodySerializer, ) from application.agents.tools.base import Tool +from application.core.url_validation import validate_url, SSRFError logger = logging.getLogger(__name__) @@ -73,6 +74,17 @@ class APITool(Tool): request_headers = headers.copy() if headers else {} response = None + # Validate URL to prevent SSRF attacks + try: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + try: path_params_used = set() if query_params: @@ -90,6 +102,18 @@ class APITool(Tool): query_string = urlencode(remaining_params) separator = "&" if "?" in request_url else "?" request_url = f"{request_url}{separator}{query_string}" + + # Re-validate URL after parameter substitution to prevent SSRF via path params + try: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed after parameter substitution: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + # Serialize body based on content type if body and body != {}: diff --git a/application/agents/tools/read_webpage.py b/application/agents/tools/read_webpage.py index e87c79e3..f0321a5a 100644 --- a/application/agents/tools/read_webpage.py +++ b/application/agents/tools/read_webpage.py @@ -1,7 +1,7 @@ import requests from markdownify import markdownify from application.agents.tools.base import Tool -from urllib.parse import urlparse +from application.core.url_validation import validate_url, SSRFError class ReadWebpageTool(Tool): """ @@ -31,11 +31,12 @@ class ReadWebpageTool(Tool): if not url: return "Error: URL parameter is missing." - # Ensure the URL has a scheme (if not, default to http) - parsed_url = urlparse(url) - if not parsed_url.scheme: - url = "http://" + url - + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + return f"Error: URL validation failed - {e}" + try: response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'}) response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) diff --git a/application/core/url_validation.py b/application/core/url_validation.py new file mode 100644 index 00000000..acd8a523 --- /dev/null +++ b/application/core/url_validation.py @@ -0,0 +1,181 @@ +""" +URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks. + +This module provides functions to validate URLs before making HTTP requests, +blocking access to internal networks, cloud metadata services, and other +potentially dangerous endpoints. +""" + +import ipaddress +import socket +from urllib.parse import urlparse +from typing import Optional, Set + + +class SSRFError(Exception): + """Raised when a URL fails SSRF validation.""" + pass + + +# Blocked hostnames that should never be accessed +BLOCKED_HOSTNAMES: Set[str] = { + "localhost", + "localhost.localdomain", + "metadata.google.internal", + "metadata", +} + +# Cloud metadata IP addresses (AWS, GCP, Azure, etc.) +METADATA_IPS: Set[str] = { + "169.254.169.254", # AWS, GCP, Azure metadata + "169.254.170.2", # AWS ECS task metadata + "fd00:ec2::254", # AWS IPv6 metadata +} + +# Allowed schemes for external requests +ALLOWED_SCHEMES: Set[str] = {"http", "https"} + + +def is_private_ip(ip_str: str) -> bool: + """ + Check if an IP address is private, loopback, or link-local. + + Args: + ip_str: IP address as a string + + Returns: + True if the IP is private/internal, False otherwise + """ + try: + ip = ipaddress.ip_address(ip_str) + return ( + ip.is_private or + ip.is_loopback or + ip.is_link_local or + ip.is_reserved or + ip.is_multicast or + ip.is_unspecified + ) + except ValueError: + # If we can't parse it as an IP, return False + return False + + +def is_metadata_ip(ip_str: str) -> bool: + """ + Check if an IP address is a cloud metadata service IP. + + Args: + ip_str: IP address as a string + + Returns: + True if the IP is a metadata service, False otherwise + """ + return ip_str in METADATA_IPS + + +def resolve_hostname(hostname: str) -> Optional[str]: + """ + Resolve a hostname to an IP address. + + Args: + hostname: The hostname to resolve + + Returns: + The resolved IP address, or None if resolution fails + """ + try: + return socket.gethostbyname(hostname) + except socket.gaierror: + return None + + +def validate_url(url: str, allow_localhost: bool = False) -> str: + """ + Validate a URL to prevent SSRF attacks. + + This function checks that: + 1. The URL has an allowed scheme (http or https) + 2. The hostname is not a blocked hostname + 3. The resolved IP is not a private/internal IP + 4. The resolved IP is not a cloud metadata service + + Args: + url: The URL to validate + allow_localhost: If True, allow localhost connections (for testing only) + + Returns: + The validated URL (with scheme added if missing) + + Raises: + SSRFError: If the URL fails validation + """ + # Ensure URL has a scheme + if not urlparse(url).scheme: + url = "http://" + url + + parsed = urlparse(url) + + # Check scheme + if parsed.scheme not in ALLOWED_SCHEMES: + raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.") + + hostname = parsed.hostname + if not hostname: + raise SSRFError("URL must have a valid hostname.") + + hostname_lower = hostname.lower() + + # Check blocked hostnames + if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost: + raise SSRFError(f"Access to '{hostname}' is not allowed.") + + # Check if hostname is an IP address directly + try: + ip = ipaddress.ip_address(hostname) + ip_str = str(ip) + + if is_metadata_ip(ip_str): + raise SSRFError("Access to cloud metadata services is not allowed.") + + if is_private_ip(ip_str) and not allow_localhost: + raise SSRFError("Access to private/internal IP addresses is not allowed.") + + return url + except ValueError: + # Not an IP address, it's a hostname - resolve it + pass + + # Resolve hostname and check the IP + resolved_ip = resolve_hostname(hostname) + if resolved_ip is None: + raise SSRFError(f"Unable to resolve hostname: {hostname}") + + if is_metadata_ip(resolved_ip): + raise SSRFError("Access to cloud metadata services is not allowed.") + + if is_private_ip(resolved_ip) and not allow_localhost: + raise SSRFError("Access to private/internal networks is not allowed.") + + return url + + +def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]: + """ + Validate a URL and return a tuple with validation result. + + This is a non-throwing version of validate_url for cases where + you want to handle validation failures gracefully. + + Args: + url: The URL to validate + allow_localhost: If True, allow localhost connections (for testing only) + + Returns: + Tuple of (is_valid, validated_url_or_original, error_message_or_none) + """ + try: + validated = validate_url(url, allow_localhost) + return (True, validated, None) + except SSRFError as e: + return (False, url, str(e)) diff --git a/application/parser/remote/crawler_loader.py b/application/parser/remote/crawler_loader.py index 2ff6cf6f..1bfd2276 100644 --- a/application/parser/remote/crawler_loader.py +++ b/application/parser/remote/crawler_loader.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin from bs4 import BeautifulSoup from application.parser.remote.base import BaseRemote from application.parser.schema.base import Document +from application.core.url_validation import validate_url, SSRFError from langchain_community.document_loaders import WebBaseLoader class CrawlerLoader(BaseRemote): @@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote): if isinstance(url, list) and url: url = url[0] - # Check if the URL scheme is provided, if not, assume http - if not urlparse(url).scheme: - url = "http://" + url + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + logging.error(f"URL validation failed: {e}") + return [] visited_urls = set() base_url = urlparse(url).scheme + "://" + urlparse(url).hostname @@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote): visited_urls.add(current_url) try: - response = requests.get(current_url) + # Validate each URL before making requests + try: + validate_url(current_url) + except SSRFError as e: + logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}") + continue + + response = requests.get(current_url, timeout=30) response.raise_for_status() loader = self.loader([current_url]) docs = loader.load() diff --git a/application/parser/remote/crawler_markdown.py b/application/parser/remote/crawler_markdown.py index 3d199332..8fc4c92c 100644 --- a/application/parser/remote/crawler_markdown.py +++ b/application/parser/remote/crawler_markdown.py @@ -2,6 +2,7 @@ import requests from urllib.parse import urlparse, urljoin from bs4 import BeautifulSoup from application.parser.remote.base import BaseRemote +from application.core.url_validation import validate_url, SSRFError import re from markdownify import markdownify from application.parser.schema.base import Document @@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote): if isinstance(url, list) and url: url = url[0] - # Ensure the URL has a scheme (if not, default to http) - if not urlparse(url).scheme: - url = "http://" + url + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + print(f"URL validation failed: {e}") + return [] # Keep track of visited URLs to avoid revisiting the same page visited_urls = set() @@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote): def _fetch_page(self, url): try: + # Validate URL before fetching to prevent SSRF + validate_url(url) response = self.session.get(url, timeout=10) response.raise_for_status() return response.text + except SSRFError as e: + print(f"URL validation failed for {url}: {e}") + return None except requests.exceptions.RequestException as e: print(f"Error fetching URL {url}: {e}") return None diff --git a/application/parser/remote/sitemap_loader.py b/application/parser/remote/sitemap_loader.py index 6d54ea9b..ff7c1ede 100644 --- a/application/parser/remote/sitemap_loader.py +++ b/application/parser/remote/sitemap_loader.py @@ -3,6 +3,7 @@ import requests import re # Import regular expression library import xml.etree.ElementTree as ET from application.parser.remote.base import BaseRemote +from application.core.url_validation import validate_url, SSRFError class SitemapLoader(BaseRemote): def __init__(self, limit=20): @@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote): sitemap_url= inputs # Check if the input is a list and if it is, use the first element if isinstance(sitemap_url, list) and sitemap_url: - url = sitemap_url[0] + sitemap_url = sitemap_url[0] + + # Validate URL to prevent SSRF attacks + try: + sitemap_url = validate_url(sitemap_url) + except SSRFError as e: + logging.error(f"URL validation failed: {e}") + return [] urls = self._extract_urls(sitemap_url) if not urls: @@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote): def _extract_urls(self, sitemap_url): try: - response = requests.get(sitemap_url) + # Validate URL before fetching to prevent SSRF + validate_url(sitemap_url) + response = requests.get(sitemap_url, timeout=30) response.raise_for_status() # Raise an exception for HTTP errors + except SSRFError as e: + print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}") + return [] except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e: print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}") return [] diff --git a/application/worker.py b/application/worker.py index f45e94a5..fa2b6cd7 100755 --- a/application/worker.py +++ b/application/worker.py @@ -63,10 +63,111 @@ current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +# Zip extraction security limits +MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size +MAX_FILE_COUNT = 10000 # Maximum number of files to extract +MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs) + + +class ZipExtractionError(Exception): + """Raised when zip extraction fails due to security constraints.""" + pass + + +def _is_path_safe(base_path: str, target_path: str) -> bool: + """ + Check if target_path is safely within base_path (prevents zip slip attacks). + + Args: + base_path: The base directory where extraction should occur. + target_path: The full path where a file would be extracted. + + Returns: + True if the path is safe, False otherwise. + """ + # Resolve to absolute paths and check containment + base_resolved = os.path.realpath(base_path) + target_resolved = os.path.realpath(target_path) + return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved + + +def _validate_zip_safety(zip_path: str, extract_to: str) -> None: + """ + Validate a zip file for security issues before extraction. + + Checks for: + - Zip bombs (excessive compression ratio or uncompressed size) + - Too many files + - Path traversal attacks (zip slip) + + Args: + zip_path: Path to the zip file. + extract_to: Destination directory. + + Raises: + ZipExtractionError: If the zip file fails security validation. + """ + try: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + # Get compressed size + compressed_size = os.path.getsize(zip_path) + + # Calculate total uncompressed size and file count + total_uncompressed = 0 + file_count = 0 + + for info in zip_ref.infolist(): + file_count += 1 + + # Check file count limit + if file_count > MAX_FILE_COUNT: + raise ZipExtractionError( + f"Zip file contains too many files (>{MAX_FILE_COUNT}). " + "This may be a zip bomb attack." + ) + + # Accumulate uncompressed size + total_uncompressed += info.file_size + + # Check total uncompressed size + if total_uncompressed > MAX_UNCOMPRESSED_SIZE: + raise ZipExtractionError( + f"Zip file uncompressed size exceeds limit " + f"({total_uncompressed / (1024*1024):.1f} MB > " + f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). " + "This may be a zip bomb attack." + ) + + # Check for path traversal (zip slip) + target_path = os.path.join(extract_to, info.filename) + if not _is_path_safe(extract_to, target_path): + raise ZipExtractionError( + f"Zip file contains path traversal attempt: {info.filename}" + ) + + # Check compression ratio (only if compressed size is meaningful) + if compressed_size > 0 and total_uncompressed > 0: + compression_ratio = total_uncompressed / compressed_size + if compression_ratio > MAX_COMPRESSION_RATIO: + raise ZipExtractionError( + f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > " + f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack." + ) + + except zipfile.BadZipFile as e: + raise ZipExtractionError(f"Invalid or corrupted zip file: {e}") + def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): """ - Recursively extract zip files with a limit on recursion depth. + Recursively extract zip files with security protections. + + Security measures: + - Limits recursion depth to prevent infinite loops + - Validates uncompressed size to prevent zip bombs + - Limits number of files to prevent resource exhaustion + - Checks compression ratio to detect zip bombs + - Validates paths to prevent zip slip attacks Args: zip_path (str): Path to the zip file to be extracted. @@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): if current_depth > max_depth: logging.warning(f"Reached maximum recursion depth of {max_depth}") return + try: + # Validate zip file safety before extraction + _validate_zip_safety(zip_path, extract_to) + + # Safe to extract with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) os.remove(zip_path) # Remove the zip file after extracting + + except ZipExtractionError as e: + logging.error(f"Zip security validation failed for {zip_path}: {e}") + # Remove the potentially malicious zip file + try: + os.remove(zip_path) + except OSError: + pass + return except Exception as e: logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True) return - # Check for nested zip files and extract them + # Check for nested zip files and extract them for root, dirs, files in os.walk(extract_to): for file in files: if file.endswith(".zip"): # If a nested zip file is found, extract it recursively - file_path = os.path.join(root, file) extract_zip_recursive(file_path, root, current_depth + 1, max_depth) diff --git a/tests/core/test_url_validation.py b/tests/core/test_url_validation.py new file mode 100644 index 00000000..924e5cde --- /dev/null +++ b/tests/core/test_url_validation.py @@ -0,0 +1,197 @@ +"""Tests for SSRF URL validation module.""" + +import pytest +from unittest.mock import patch + +from application.core.url_validation import ( + SSRFError, + validate_url, + validate_url_safe, + is_private_ip, + is_metadata_ip, +) + + +class TestIsPrivateIP: + """Tests for is_private_ip function.""" + + def test_loopback_ipv4(self): + assert is_private_ip("127.0.0.1") is True + assert is_private_ip("127.255.255.255") is True + + def test_private_class_a(self): + assert is_private_ip("10.0.0.1") is True + assert is_private_ip("10.255.255.255") is True + + def test_private_class_b(self): + assert is_private_ip("172.16.0.1") is True + assert is_private_ip("172.31.255.255") is True + + def test_private_class_c(self): + assert is_private_ip("192.168.0.1") is True + assert is_private_ip("192.168.255.255") is True + + def test_link_local(self): + assert is_private_ip("169.254.0.1") is True + + def test_public_ip(self): + assert is_private_ip("8.8.8.8") is False + assert is_private_ip("1.1.1.1") is False + assert is_private_ip("93.184.216.34") is False + + def test_invalid_ip(self): + assert is_private_ip("not-an-ip") is False + assert is_private_ip("") is False + + +class TestIsMetadataIP: + """Tests for is_metadata_ip function.""" + + def test_aws_metadata_ip(self): + assert is_metadata_ip("169.254.169.254") is True + + def test_aws_ecs_metadata_ip(self): + assert is_metadata_ip("169.254.170.2") is True + + def test_non_metadata_ip(self): + assert is_metadata_ip("8.8.8.8") is False + assert is_metadata_ip("10.0.0.1") is False + + +class TestValidateUrl: + """Tests for validate_url function.""" + + def test_adds_scheme_if_missing(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" # Public IP + result = validate_url("example.com") + assert result == "http://example.com" + + def test_preserves_https_scheme(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + result = validate_url("https://example.com") + assert result == "https://example.com" + + def test_blocks_localhost(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://localhost") + assert "localhost" in str(exc_info.value).lower() + + def test_blocks_localhost_localdomain(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://localhost.localdomain") + assert "not allowed" in str(exc_info.value).lower() + + def test_blocks_loopback_ip(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://127.0.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_a(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://10.0.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_b(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://172.16.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_c(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://192.168.1.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_aws_metadata_ip(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://169.254.169.254") + assert "metadata" in str(exc_info.value).lower() + + def test_blocks_aws_metadata_with_path(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://169.254.169.254/latest/meta-data/") + assert "metadata" in str(exc_info.value).lower() + + def test_blocks_gcp_metadata_hostname(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://metadata.google.internal") + assert "not allowed" in str(exc_info.value).lower() + + def test_blocks_ftp_scheme(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("ftp://example.com") + assert "scheme" in str(exc_info.value).lower() + + def test_blocks_file_scheme(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("file:///etc/passwd") + assert "scheme" in str(exc_info.value).lower() + + def test_blocks_hostname_resolving_to_private_ip(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "192.168.1.1" + with pytest.raises(SSRFError) as exc_info: + validate_url("http://internal.example.com") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_hostname_resolving_to_metadata_ip(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "169.254.169.254" + with pytest.raises(SSRFError) as exc_info: + validate_url("http://evil.example.com") + assert "metadata" in str(exc_info.value).lower() + + def test_allows_public_ip(self): + result = validate_url("http://8.8.8.8") + assert result == "http://8.8.8.8" + + def test_allows_public_hostname(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + result = validate_url("https://example.com") + assert result == "https://example.com" + + def test_raises_on_unresolvable_hostname(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = None + with pytest.raises(SSRFError) as exc_info: + validate_url("http://nonexistent.invalid") + assert "resolve" in str(exc_info.value).lower() + + def test_raises_on_empty_hostname(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://") + assert "hostname" in str(exc_info.value).lower() + + def test_allow_localhost_flag(self): + # Should work with allow_localhost=True + result = validate_url("http://localhost", allow_localhost=True) + assert result == "http://localhost" + + result = validate_url("http://127.0.0.1", allow_localhost=True) + assert result == "http://127.0.0.1" + + +class TestValidateUrlSafe: + """Tests for validate_url_safe non-throwing function.""" + + def test_returns_tuple_on_success(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + is_valid, url, error = validate_url_safe("https://example.com") + assert is_valid is True + assert url == "https://example.com" + assert error is None + + def test_returns_tuple_on_failure(self): + is_valid, url, error = validate_url_safe("http://localhost") + assert is_valid is False + assert url == "http://localhost" + assert error is not None + assert "localhost" in error.lower() + + def test_returns_error_message_for_private_ip(self): + is_valid, url, error = validate_url_safe("http://192.168.1.1") + assert is_valid is False + assert "private" in error.lower() or "internal" in error.lower() diff --git a/tests/parser/remote/test_crawler_loader.py b/tests/parser/remote/test_crawler_loader.py index 92ffdc84..06d27517 100644 --- a/tests/parser/remote/test_crawler_loader.py +++ b/tests/parser/remote/test_crawler_loader.py @@ -13,8 +13,17 @@ class DummyResponse: return None +def _mock_validate_url(url): + """Mock validate_url that allows test URLs through.""" + from urllib.parse import urlparse + if not urlparse(url).scheme: + url = "http://" + url + return url + + +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_crawls_same_domain_links(mock_requests_get): +def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url): responses = { "http://example.com": DummyResponse( """ @@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get): "http://example.com/about": DummyResponse("About page"), } - def response_side_effect(url: str): + def response_side_effect(url: str, timeout=30): if url not in responses: raise AssertionError(f"Unexpected request for URL: {url}") return responses[url] @@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get): assert loader_call_order == ["http://example.com", "http://example.com/about"] +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): +def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url): mock_requests_get.return_value = DummyResponse("No links here") doc = MagicMock(spec=LCDocument) @@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): result = crawler.load_data(["example.com", "unused.com"]) - mock_requests_get.assert_called_once_with("http://example.com") + mock_requests_get.assert_called_once_with("http://example.com", timeout=30) crawler.loader.assert_called_once_with(["http://example.com"]) assert len(result) == 1 @@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): assert result[0].extra_info == {"source": "http://example.com"} +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_respects_limit(mock_requests_get): +def test_load_data_respects_limit(mock_requests_get, mock_validate_url): responses = { "http://example.com": DummyResponse( """ @@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get): "http://example.com/about": DummyResponse("About"), } - mock_requests_get.side_effect = lambda url: responses[url] + mock_requests_get.side_effect = lambda url, timeout=30: responses[url] root_doc = MagicMock(spec=LCDocument) root_doc.page_content = "Root content" @@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get): assert crawler.loader.call_count == 1 +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.logging") @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging): +def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url): mock_requests_get.return_value = DummyResponse("Error route") failing_loader_instance = MagicMock() @@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin result = crawler.load_data("http://example.com") assert result == [] - mock_requests_get.assert_called_once_with("http://example.com") + mock_requests_get.assert_called_once_with("http://example.com", timeout=30) failing_loader_instance.load.assert_called_once() mock_logging.error.assert_called_once() @@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin assert "Error processing URL http://example.com" in message assert mock_logging.error.call_args.kwargs.get("exc_info") is True + +@patch("application.parser.remote.crawler_loader.validate_url") +def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url): + """Test that SSRF validation failure returns empty list.""" + from application.core.url_validation import SSRFError + mock_validate_url.side_effect = SSRFError("Access to private IP not allowed") + + crawler = CrawlerLoader() + result = crawler.load_data("http://192.168.1.1") + + assert result == [] + mock_validate_url.assert_called_once() + diff --git a/tests/parser/remote/test_crawler_markdown.py b/tests/parser/remote/test_crawler_markdown.py index ac27b3d0..b2b3f21c 100644 --- a/tests/parser/remote/test_crawler_markdown.py +++ b/tests/parser/remote/test_crawler_markdown.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock +from urllib.parse import urlparse import pytest import requests @@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace: return SimpleNamespace(domain=domain, suffix=suffix) +def _mock_validate_url(url): + """Mock validate_url that allows test URLs through.""" + if not urlparse(url).scheme: + url = "http://" + url + return url + + +@pytest.fixture(autouse=True) +def _patch_validate_url(monkeypatch): + monkeypatch.setattr( + "application.parser.remote.crawler_markdown.validate_url", + _mock_validate_url, + ) + + @pytest.fixture(autouse=True) def _patch_tldextract(monkeypatch): monkeypatch.setattr( @@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify): assert len(docs) == 2 -def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify): +def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url): root_html = """ Home About @@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify): assert docs[0].text == "Home Markdown" assert mock_print.called + +def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch): + """Test that SSRF validation failure returns empty list.""" + from application.core.url_validation import SSRFError + + def raise_ssrf_error(url): + raise SSRFError("Access to private IP not allowed") + + monkeypatch.setattr( + "application.parser.remote.crawler_markdown.validate_url", + raise_ssrf_error, + ) + + loader = CrawlerLoader() + result = loader.load_data("http://192.168.1.1") + + assert result == [] + diff --git a/tests/test_zip_extraction_security.py b/tests/test_zip_extraction_security.py new file mode 100644 index 00000000..c53452f7 --- /dev/null +++ b/tests/test_zip_extraction_security.py @@ -0,0 +1,293 @@ +"""Tests for zip extraction security measures.""" + +import os +import tempfile +import zipfile + +import pytest + +from application.worker import ( + ZipExtractionError, + _is_path_safe, + _validate_zip_safety, + extract_zip_recursive, + MAX_FILE_COUNT, +) + + +class TestIsPathSafe: + """Tests for _is_path_safe function.""" + + def test_safe_path_in_directory(self): + """Normal file within directory should be safe.""" + assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True + + def test_safe_path_in_subdirectory(self): + """File in subdirectory should be safe.""" + assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True + + def test_unsafe_path_parent_traversal(self): + """Path traversal to parent directory should be unsafe.""" + assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False + + def test_unsafe_path_absolute(self): + """Absolute path outside base should be unsafe.""" + assert _is_path_safe("/tmp/extract", "/etc/passwd") is False + + def test_unsafe_path_sibling(self): + """Sibling directory should be unsafe.""" + assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False + + def test_base_path_itself(self): + """Base path itself should be safe.""" + assert _is_path_safe("/tmp/extract", "/tmp/extract") is True + + +class TestValidateZipSafety: + """Tests for _validate_zip_safety function.""" + + def test_valid_small_zip(self): + """Small valid zip file should pass validation.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a small valid zip + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "Hello, World!") + + # Should not raise + _validate_zip_safety(zip_path, extract_to) + + def test_zip_with_too_many_files(self): + """Zip with too many files should be rejected.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a zip with many files (just over limit) + with zipfile.ZipFile(zip_path, "w") as zf: + for i in range(MAX_FILE_COUNT + 1): + zf.writestr(f"file_{i}.txt", "x") + + with pytest.raises(ZipExtractionError) as exc_info: + _validate_zip_safety(zip_path, extract_to) + assert "too many files" in str(exc_info.value).lower() + + def test_zip_with_path_traversal(self): + """Zip with path traversal attempt should be rejected.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a zip with path traversal + with zipfile.ZipFile(zip_path, "w") as zf: + # Add a normal file first + zf.writestr("normal.txt", "normal content") + # Add a file with path traversal + zf.writestr("../../../etc/passwd", "malicious content") + + with pytest.raises(ZipExtractionError) as exc_info: + _validate_zip_safety(zip_path, extract_to) + assert "path traversal" in str(exc_info.value).lower() + + def test_corrupted_zip(self): + """Corrupted zip file should be rejected.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a corrupted "zip" file + with open(zip_path, "wb") as f: + f.write(b"not a zip file content") + + with pytest.raises(ZipExtractionError) as exc_info: + _validate_zip_safety(zip_path, extract_to) + assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower() + + +class TestExtractZipRecursive: + """Tests for extract_zip_recursive function.""" + + def test_extract_valid_zip(self): + """Valid zip file should be extracted successfully.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a valid zip + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "Hello, World!") + zf.writestr("subdir/nested.txt", "Nested content") + + extract_zip_recursive(zip_path, extract_to) + + # Check files were extracted + assert os.path.exists(os.path.join(extract_to, "test.txt")) + assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt")) + + # Check zip was removed + assert not os.path.exists(zip_path) + + def test_extract_nested_zip(self): + """Nested zip files should be extracted recursively.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create inner zip + inner_zip_content = b"" + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp: + with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf: + inner_zf.writestr("inner.txt", "Inner content") + with open(inner_tmp.name, "rb") as f: + inner_zip_content = f.read() + os.unlink(inner_tmp.name) + + # Create outer zip containing inner zip + zip_path = os.path.join(temp_dir, "outer.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("outer.txt", "Outer content") + zf.writestr("inner.zip", inner_zip_content) + + extract_zip_recursive(zip_path, extract_to) + + # Check outer file was extracted + assert os.path.exists(os.path.join(extract_to, "outer.txt")) + + # Check inner zip was extracted + assert os.path.exists(os.path.join(extract_to, "inner.txt")) + + # Check both zips were removed + assert not os.path.exists(zip_path) + assert not os.path.exists(os.path.join(extract_to, "inner.zip")) + + def test_respects_max_depth(self): + """Extraction should stop at max recursion depth.""" + with tempfile.TemporaryDirectory() as temp_dir: + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a chain of nested zips + current_content = b"Final content" + for i in range(7): # More than default max_depth of 5 + inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) + with zipfile.ZipFile(inner_tmp.name, "w") as zf: + if i == 0: + zf.writestr("content.txt", current_content.decode()) + else: + zf.writestr("nested.zip", current_content) + with open(inner_tmp.name, "rb") as f: + current_content = f.read() + os.unlink(inner_tmp.name) + + # Write the final outermost zip + zip_path = os.path.join(temp_dir, "outer.zip") + with open(zip_path, "wb") as f: + f.write(current_content) + + # Extract with max_depth=2 + extract_zip_recursive(zip_path, extract_to, max_depth=2) + + # The deepest nested zips should remain unextracted + # (we can't easily verify the exact behavior, but the function should not crash) + + def test_rejects_path_traversal(self): + """Zip with path traversal should be rejected and removed.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "malicious.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a malicious zip + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("../../../tmp/malicious.txt", "malicious") + + extract_zip_recursive(zip_path, extract_to) + + # Zip should be removed + assert not os.path.exists(zip_path) + + # Malicious file should NOT exist outside extract_to + assert not os.path.exists("/tmp/malicious.txt") + + def test_handles_corrupted_zip_gracefully(self): + """Corrupted zip should be handled gracefully without crashing.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "corrupted.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a corrupted file + with open(zip_path, "wb") as f: + f.write(b"This is not a valid zip file") + + # Should not raise, just log error + extract_zip_recursive(zip_path, extract_to) + + # Function should complete without exception + + +class TestZipBombProtection: + """Tests specifically for zip bomb protection.""" + + def test_detects_high_compression_ratio(self): + """Highly compressed data should trigger compression ratio check.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "bomb.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a file with highly compressible content (all zeros) + # This triggers the compression ratio check + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + # Create a large file with repetitive content - compresses extremely well + repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's + zf.writestr("repetitive.txt", repetitive_content) + + # This should be rejected due to high compression ratio + with pytest.raises(ZipExtractionError) as exc_info: + _validate_zip_safety(zip_path, extract_to) + assert "compression ratio" in str(exc_info.value).lower() + + def test_normal_compression_passes(self): + """Normal compression ratio should pass validation.""" + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "normal.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a zip with random-ish content that doesn't compress well + import random + random.seed(42) + random_content = "".join( + random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240) + ) + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("random.txt", random_content) + + # Should pass - random content doesn't compress well + _validate_zip_safety(zip_path, extract_to) + + def test_size_limit_check(self): + """Files exceeding size limit should be rejected.""" + # Note: We can't easily create a real zip bomb in tests + # This test verifies the validation logic works + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + extract_to = os.path.join(temp_dir, "extract") + os.makedirs(extract_to) + + # Create a zip with a reasonable size (no compression to avoid ratio issues) + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf: + # 10 KB file + zf.writestr("normal.txt", "x" * 10240) + + # Should pass + _validate_zip_safety(zip_path, extract_to)