From 98e949d2fd9522a91b430ac239c5d945304601b1 Mon Sep 17 00:00:00 2001
From: Alex
Date: Wed, 24 Dec 2025 15:05:35 +0000
Subject: [PATCH] Patches (#2218)
* feat: implement URL validation to prevent SSRF
* feat: add zip extraction security
* ruff fixes
---
application/agents/tools/api_tool.py | 24 ++
application/agents/tools/read_webpage.py | 13 +-
application/core/url_validation.py | 181 +++++++++++
application/parser/remote/crawler_loader.py | 19 +-
application/parser/remote/crawler_markdown.py | 15 +-
application/parser/remote/sitemap_loader.py | 17 +-
application/worker.py | 120 ++++++-
tests/core/test_url_validation.py | 197 ++++++++++++
tests/parser/remote/test_crawler_loader.py | 41 ++-
tests/parser/remote/test_crawler_markdown.py | 36 ++-
tests/test_zip_extraction_security.py | 293 ++++++++++++++++++
11 files changed, 929 insertions(+), 27 deletions(-)
create mode 100644 application/core/url_validation.py
create mode 100644 tests/core/test_url_validation.py
create mode 100644 tests/test_zip_extraction_security.py
diff --git a/application/agents/tools/api_tool.py b/application/agents/tools/api_tool.py
index 6bd2eb8d..e010b51b 100644
--- a/application/agents/tools/api_tool.py
+++ b/application/agents/tools/api_tool.py
@@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
+from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
@@ -73,6 +74,17 @@ class APITool(Tool):
request_headers = headers.copy() if headers else {}
response = None
+ # Validate URL to prevent SSRF attacks
+ try:
+ validate_url(request_url)
+ except SSRFError as e:
+ logger.error(f"URL validation failed: {e}")
+ return {
+ "status_code": None,
+ "message": f"URL validation error: {e}",
+ "data": None,
+ }
+
try:
path_params_used = set()
if query_params:
@@ -90,6 +102,18 @@ class APITool(Tool):
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
+
+ # Re-validate URL after parameter substitution to prevent SSRF via path params
+ try:
+ validate_url(request_url)
+ except SSRFError as e:
+ logger.error(f"URL validation failed after parameter substitution: {e}")
+ return {
+ "status_code": None,
+ "message": f"URL validation error: {e}",
+ "data": None,
+ }
+
# Serialize body based on content type
if body and body != {}:
diff --git a/application/agents/tools/read_webpage.py b/application/agents/tools/read_webpage.py
index e87c79e3..f0321a5a 100644
--- a/application/agents/tools/read_webpage.py
+++ b/application/agents/tools/read_webpage.py
@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
-from urllib.parse import urlparse
+from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
- # Ensure the URL has a scheme (if not, default to http)
- parsed_url = urlparse(url)
- if not parsed_url.scheme:
- url = "http://" + url
-
+ # Validate URL to prevent SSRF attacks
+ try:
+ url = validate_url(url)
+ except SSRFError as e:
+ return f"Error: URL validation failed - {e}"
+
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
diff --git a/application/core/url_validation.py b/application/core/url_validation.py
new file mode 100644
index 00000000..acd8a523
--- /dev/null
+++ b/application/core/url_validation.py
@@ -0,0 +1,181 @@
+"""
+URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
+
+This module provides functions to validate URLs before making HTTP requests,
+blocking access to internal networks, cloud metadata services, and other
+potentially dangerous endpoints.
+"""
+
+import ipaddress
+import socket
+from urllib.parse import urlparse
+from typing import Optional, Set
+
+
+class SSRFError(Exception):
+ """Raised when a URL fails SSRF validation."""
+ pass
+
+
+# Blocked hostnames that should never be accessed
+BLOCKED_HOSTNAMES: Set[str] = {
+ "localhost",
+ "localhost.localdomain",
+ "metadata.google.internal",
+ "metadata",
+}
+
+# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
+METADATA_IPS: Set[str] = {
+ "169.254.169.254", # AWS, GCP, Azure metadata
+ "169.254.170.2", # AWS ECS task metadata
+ "fd00:ec2::254", # AWS IPv6 metadata
+}
+
+# Allowed schemes for external requests
+ALLOWED_SCHEMES: Set[str] = {"http", "https"}
+
+
+def is_private_ip(ip_str: str) -> bool:
+ """
+ Check if an IP address is private, loopback, or link-local.
+
+ Args:
+ ip_str: IP address as a string
+
+ Returns:
+ True if the IP is private/internal, False otherwise
+ """
+ try:
+ ip = ipaddress.ip_address(ip_str)
+ return (
+ ip.is_private or
+ ip.is_loopback or
+ ip.is_link_local or
+ ip.is_reserved or
+ ip.is_multicast or
+ ip.is_unspecified
+ )
+ except ValueError:
+ # If we can't parse it as an IP, return False
+ return False
+
+
+def is_metadata_ip(ip_str: str) -> bool:
+ """
+ Check if an IP address is a cloud metadata service IP.
+
+ Args:
+ ip_str: IP address as a string
+
+ Returns:
+ True if the IP is a metadata service, False otherwise
+ """
+ return ip_str in METADATA_IPS
+
+
+def resolve_hostname(hostname: str) -> Optional[str]:
+ """
+ Resolve a hostname to an IP address.
+
+ Args:
+ hostname: The hostname to resolve
+
+ Returns:
+ The resolved IP address, or None if resolution fails
+ """
+ try:
+ return socket.gethostbyname(hostname)
+ except socket.gaierror:
+ return None
+
+
+def validate_url(url: str, allow_localhost: bool = False) -> str:
+ """
+ Validate a URL to prevent SSRF attacks.
+
+ This function checks that:
+ 1. The URL has an allowed scheme (http or https)
+ 2. The hostname is not a blocked hostname
+ 3. The resolved IP is not a private/internal IP
+ 4. The resolved IP is not a cloud metadata service
+
+ Args:
+ url: The URL to validate
+ allow_localhost: If True, allow localhost connections (for testing only)
+
+ Returns:
+ The validated URL (with scheme added if missing)
+
+ Raises:
+ SSRFError: If the URL fails validation
+ """
+ # Ensure URL has a scheme
+ if not urlparse(url).scheme:
+ url = "http://" + url
+
+ parsed = urlparse(url)
+
+ # Check scheme
+ if parsed.scheme not in ALLOWED_SCHEMES:
+ raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
+
+ hostname = parsed.hostname
+ if not hostname:
+ raise SSRFError("URL must have a valid hostname.")
+
+ hostname_lower = hostname.lower()
+
+ # Check blocked hostnames
+ if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
+ raise SSRFError(f"Access to '{hostname}' is not allowed.")
+
+ # Check if hostname is an IP address directly
+ try:
+ ip = ipaddress.ip_address(hostname)
+ ip_str = str(ip)
+
+ if is_metadata_ip(ip_str):
+ raise SSRFError("Access to cloud metadata services is not allowed.")
+
+ if is_private_ip(ip_str) and not allow_localhost:
+ raise SSRFError("Access to private/internal IP addresses is not allowed.")
+
+ return url
+ except ValueError:
+ # Not an IP address, it's a hostname - resolve it
+ pass
+
+ # Resolve hostname and check the IP
+ resolved_ip = resolve_hostname(hostname)
+ if resolved_ip is None:
+ raise SSRFError(f"Unable to resolve hostname: {hostname}")
+
+ if is_metadata_ip(resolved_ip):
+ raise SSRFError("Access to cloud metadata services is not allowed.")
+
+ if is_private_ip(resolved_ip) and not allow_localhost:
+ raise SSRFError("Access to private/internal networks is not allowed.")
+
+ return url
+
+
+def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
+ """
+ Validate a URL and return a tuple with validation result.
+
+ This is a non-throwing version of validate_url for cases where
+ you want to handle validation failures gracefully.
+
+ Args:
+ url: The URL to validate
+ allow_localhost: If True, allow localhost connections (for testing only)
+
+ Returns:
+ Tuple of (is_valid, validated_url_or_original, error_message_or_none)
+ """
+ try:
+ validated = validate_url(url, allow_localhost)
+ return (True, validated, None)
+ except SSRFError as e:
+ return (False, url, str(e))
diff --git a/application/parser/remote/crawler_loader.py b/application/parser/remote/crawler_loader.py
index 2ff6cf6f..1bfd2276 100644
--- a/application/parser/remote/crawler_loader.py
+++ b/application/parser/remote/crawler_loader.py
@@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
+from application.core.url_validation import validate_url, SSRFError
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
- # Check if the URL scheme is provided, if not, assume http
- if not urlparse(url).scheme:
- url = "http://" + url
+ # Validate URL to prevent SSRF attacks
+ try:
+ url = validate_url(url)
+ except SSRFError as e:
+ logging.error(f"URL validation failed: {e}")
+ return []
visited_urls = set()
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
@@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
- response = requests.get(current_url)
+ # Validate each URL before making requests
+ try:
+ validate_url(current_url)
+ except SSRFError as e:
+ logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
+ continue
+
+ response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()
diff --git a/application/parser/remote/crawler_markdown.py b/application/parser/remote/crawler_markdown.py
index 3d199332..8fc4c92c 100644
--- a/application/parser/remote/crawler_markdown.py
+++ b/application/parser/remote/crawler_markdown.py
@@ -2,6 +2,7 @@ import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
+from application.core.url_validation import validate_url, SSRFError
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
- # Ensure the URL has a scheme (if not, default to http)
- if not urlparse(url).scheme:
- url = "http://" + url
+ # Validate URL to prevent SSRF attacks
+ try:
+ url = validate_url(url)
+ except SSRFError as e:
+ print(f"URL validation failed: {e}")
+ return []
# Keep track of visited URLs to avoid revisiting the same page
visited_urls = set()
@@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
+ # Validate URL before fetching to prevent SSRF
+ validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
+ except SSRFError as e:
+ print(f"URL validation failed for {url}: {e}")
+ return None
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None
diff --git a/application/parser/remote/sitemap_loader.py b/application/parser/remote/sitemap_loader.py
index 6d54ea9b..ff7c1ede 100644
--- a/application/parser/remote/sitemap_loader.py
+++ b/application/parser/remote/sitemap_loader.py
@@ -3,6 +3,7 @@ import requests
import re # Import regular expression library
import xml.etree.ElementTree as ET
from application.parser.remote.base import BaseRemote
+from application.core.url_validation import validate_url, SSRFError
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
- url = sitemap_url[0]
+ sitemap_url = sitemap_url[0]
+
+ # Validate URL to prevent SSRF attacks
+ try:
+ sitemap_url = validate_url(sitemap_url)
+ except SSRFError as e:
+ logging.error(f"URL validation failed: {e}")
+ return []
urls = self._extract_urls(sitemap_url)
if not urls:
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
- response = requests.get(sitemap_url)
+ # Validate URL before fetching to prevent SSRF
+ validate_url(sitemap_url)
+ response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
+ except SSRFError as e:
+ print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
+ return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
diff --git a/application/worker.py b/application/worker.py
index f45e94a5..fa2b6cd7 100755
--- a/application/worker.py
+++ b/application/worker.py
@@ -63,10 +63,111 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
+# Zip extraction security limits
+MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size
+MAX_FILE_COUNT = 10000 # Maximum number of files to extract
+MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs)
+
+
+class ZipExtractionError(Exception):
+ """Raised when zip extraction fails due to security constraints."""
+ pass
+
+
+def _is_path_safe(base_path: str, target_path: str) -> bool:
+ """
+ Check if target_path is safely within base_path (prevents zip slip attacks).
+
+ Args:
+ base_path: The base directory where extraction should occur.
+ target_path: The full path where a file would be extracted.
+
+ Returns:
+ True if the path is safe, False otherwise.
+ """
+ # Resolve to absolute paths and check containment
+ base_resolved = os.path.realpath(base_path)
+ target_resolved = os.path.realpath(target_path)
+ return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved
+
+
+def _validate_zip_safety(zip_path: str, extract_to: str) -> None:
+ """
+ Validate a zip file for security issues before extraction.
+
+ Checks for:
+ - Zip bombs (excessive compression ratio or uncompressed size)
+ - Too many files
+ - Path traversal attacks (zip slip)
+
+ Args:
+ zip_path: Path to the zip file.
+ extract_to: Destination directory.
+
+ Raises:
+ ZipExtractionError: If the zip file fails security validation.
+ """
+ try:
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ # Get compressed size
+ compressed_size = os.path.getsize(zip_path)
+
+ # Calculate total uncompressed size and file count
+ total_uncompressed = 0
+ file_count = 0
+
+ for info in zip_ref.infolist():
+ file_count += 1
+
+ # Check file count limit
+ if file_count > MAX_FILE_COUNT:
+ raise ZipExtractionError(
+ f"Zip file contains too many files (>{MAX_FILE_COUNT}). "
+ "This may be a zip bomb attack."
+ )
+
+ # Accumulate uncompressed size
+ total_uncompressed += info.file_size
+
+ # Check total uncompressed size
+ if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
+ raise ZipExtractionError(
+ f"Zip file uncompressed size exceeds limit "
+ f"({total_uncompressed / (1024*1024):.1f} MB > "
+ f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). "
+ "This may be a zip bomb attack."
+ )
+
+ # Check for path traversal (zip slip)
+ target_path = os.path.join(extract_to, info.filename)
+ if not _is_path_safe(extract_to, target_path):
+ raise ZipExtractionError(
+ f"Zip file contains path traversal attempt: {info.filename}"
+ )
+
+ # Check compression ratio (only if compressed size is meaningful)
+ if compressed_size > 0 and total_uncompressed > 0:
+ compression_ratio = total_uncompressed / compressed_size
+ if compression_ratio > MAX_COMPRESSION_RATIO:
+ raise ZipExtractionError(
+ f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > "
+ f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack."
+ )
+
+ except zipfile.BadZipFile as e:
+ raise ZipExtractionError(f"Invalid or corrupted zip file: {e}")
+
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
- Recursively extract zip files with a limit on recursion depth.
+ Recursively extract zip files with security protections.
+
+ Security measures:
+ - Limits recursion depth to prevent infinite loops
+ - Validates uncompressed size to prevent zip bombs
+ - Limits number of files to prevent resource exhaustion
+ - Checks compression ratio to detect zip bombs
+ - Validates paths to prevent zip slip attacks
Args:
zip_path (str): Path to the zip file to be extracted.
@@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
if current_depth > max_depth:
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
+
try:
+ # Validate zip file safety before extraction
+ _validate_zip_safety(zip_path, extract_to)
+
+ # Safe to extract
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
+
+ except ZipExtractionError as e:
+ logging.error(f"Zip security validation failed for {zip_path}: {e}")
+ # Remove the potentially malicious zip file
+ try:
+ os.remove(zip_path)
+ except OSError:
+ pass
+ return
except Exception as e:
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
return
- # Check for nested zip files and extract them
+ # Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
-
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
diff --git a/tests/core/test_url_validation.py b/tests/core/test_url_validation.py
new file mode 100644
index 00000000..924e5cde
--- /dev/null
+++ b/tests/core/test_url_validation.py
@@ -0,0 +1,197 @@
+"""Tests for SSRF URL validation module."""
+
+import pytest
+from unittest.mock import patch
+
+from application.core.url_validation import (
+ SSRFError,
+ validate_url,
+ validate_url_safe,
+ is_private_ip,
+ is_metadata_ip,
+)
+
+
+class TestIsPrivateIP:
+ """Tests for is_private_ip function."""
+
+ def test_loopback_ipv4(self):
+ assert is_private_ip("127.0.0.1") is True
+ assert is_private_ip("127.255.255.255") is True
+
+ def test_private_class_a(self):
+ assert is_private_ip("10.0.0.1") is True
+ assert is_private_ip("10.255.255.255") is True
+
+ def test_private_class_b(self):
+ assert is_private_ip("172.16.0.1") is True
+ assert is_private_ip("172.31.255.255") is True
+
+ def test_private_class_c(self):
+ assert is_private_ip("192.168.0.1") is True
+ assert is_private_ip("192.168.255.255") is True
+
+ def test_link_local(self):
+ assert is_private_ip("169.254.0.1") is True
+
+ def test_public_ip(self):
+ assert is_private_ip("8.8.8.8") is False
+ assert is_private_ip("1.1.1.1") is False
+ assert is_private_ip("93.184.216.34") is False
+
+ def test_invalid_ip(self):
+ assert is_private_ip("not-an-ip") is False
+ assert is_private_ip("") is False
+
+
+class TestIsMetadataIP:
+ """Tests for is_metadata_ip function."""
+
+ def test_aws_metadata_ip(self):
+ assert is_metadata_ip("169.254.169.254") is True
+
+ def test_aws_ecs_metadata_ip(self):
+ assert is_metadata_ip("169.254.170.2") is True
+
+ def test_non_metadata_ip(self):
+ assert is_metadata_ip("8.8.8.8") is False
+ assert is_metadata_ip("10.0.0.1") is False
+
+
+class TestValidateUrl:
+ """Tests for validate_url function."""
+
+ def test_adds_scheme_if_missing(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "93.184.216.34" # Public IP
+ result = validate_url("example.com")
+ assert result == "http://example.com"
+
+ def test_preserves_https_scheme(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "93.184.216.34"
+ result = validate_url("https://example.com")
+ assert result == "https://example.com"
+
+ def test_blocks_localhost(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://localhost")
+ assert "localhost" in str(exc_info.value).lower()
+
+ def test_blocks_localhost_localdomain(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://localhost.localdomain")
+ assert "not allowed" in str(exc_info.value).lower()
+
+ def test_blocks_loopback_ip(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://127.0.0.1")
+ assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
+
+ def test_blocks_private_ip_class_a(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://10.0.0.1")
+ assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
+
+ def test_blocks_private_ip_class_b(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://172.16.0.1")
+ assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
+
+ def test_blocks_private_ip_class_c(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://192.168.1.1")
+ assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
+
+ def test_blocks_aws_metadata_ip(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://169.254.169.254")
+ assert "metadata" in str(exc_info.value).lower()
+
+ def test_blocks_aws_metadata_with_path(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://169.254.169.254/latest/meta-data/")
+ assert "metadata" in str(exc_info.value).lower()
+
+ def test_blocks_gcp_metadata_hostname(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://metadata.google.internal")
+ assert "not allowed" in str(exc_info.value).lower()
+
+ def test_blocks_ftp_scheme(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("ftp://example.com")
+ assert "scheme" in str(exc_info.value).lower()
+
+ def test_blocks_file_scheme(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("file:///etc/passwd")
+ assert "scheme" in str(exc_info.value).lower()
+
+ def test_blocks_hostname_resolving_to_private_ip(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "192.168.1.1"
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://internal.example.com")
+ assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
+
+ def test_blocks_hostname_resolving_to_metadata_ip(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "169.254.169.254"
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://evil.example.com")
+ assert "metadata" in str(exc_info.value).lower()
+
+ def test_allows_public_ip(self):
+ result = validate_url("http://8.8.8.8")
+ assert result == "http://8.8.8.8"
+
+ def test_allows_public_hostname(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "93.184.216.34"
+ result = validate_url("https://example.com")
+ assert result == "https://example.com"
+
+ def test_raises_on_unresolvable_hostname(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = None
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://nonexistent.invalid")
+ assert "resolve" in str(exc_info.value).lower()
+
+ def test_raises_on_empty_hostname(self):
+ with pytest.raises(SSRFError) as exc_info:
+ validate_url("http://")
+ assert "hostname" in str(exc_info.value).lower()
+
+ def test_allow_localhost_flag(self):
+ # Should work with allow_localhost=True
+ result = validate_url("http://localhost", allow_localhost=True)
+ assert result == "http://localhost"
+
+ result = validate_url("http://127.0.0.1", allow_localhost=True)
+ assert result == "http://127.0.0.1"
+
+
+class TestValidateUrlSafe:
+ """Tests for validate_url_safe non-throwing function."""
+
+ def test_returns_tuple_on_success(self):
+ with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
+ mock_resolve.return_value = "93.184.216.34"
+ is_valid, url, error = validate_url_safe("https://example.com")
+ assert is_valid is True
+ assert url == "https://example.com"
+ assert error is None
+
+ def test_returns_tuple_on_failure(self):
+ is_valid, url, error = validate_url_safe("http://localhost")
+ assert is_valid is False
+ assert url == "http://localhost"
+ assert error is not None
+ assert "localhost" in error.lower()
+
+ def test_returns_error_message_for_private_ip(self):
+ is_valid, url, error = validate_url_safe("http://192.168.1.1")
+ assert is_valid is False
+ assert "private" in error.lower() or "internal" in error.lower()
diff --git a/tests/parser/remote/test_crawler_loader.py b/tests/parser/remote/test_crawler_loader.py
index 92ffdc84..06d27517 100644
--- a/tests/parser/remote/test_crawler_loader.py
+++ b/tests/parser/remote/test_crawler_loader.py
@@ -13,8 +13,17 @@ class DummyResponse:
return None
+def _mock_validate_url(url):
+ """Mock validate_url that allows test URLs through."""
+ from urllib.parse import urlparse
+ if not urlparse(url).scheme:
+ url = "http://" + url
+ return url
+
+
+@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
-def test_load_data_crawls_same_domain_links(mock_requests_get):
+def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
"http://example.com/about": DummyResponse("About page"),
}
- def response_side_effect(url: str):
+ def response_side_effect(url: str, timeout=30):
if url not in responses:
raise AssertionError(f"Unexpected request for URL: {url}")
return responses[url]
@@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
assert loader_call_order == ["http://example.com", "http://example.com/about"]
+@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
-def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
+def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url):
mock_requests_get.return_value = DummyResponse("No links here")
doc = MagicMock(spec=LCDocument)
@@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
result = crawler.load_data(["example.com", "unused.com"])
- mock_requests_get.assert_called_once_with("http://example.com")
+ mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
crawler.loader.assert_called_once_with(["http://example.com"])
assert len(result) == 1
@@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
assert result[0].extra_info == {"source": "http://example.com"}
+@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
-def test_load_data_respects_limit(mock_requests_get):
+def test_load_data_respects_limit(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get):
"http://example.com/about": DummyResponse("About"),
}
- mock_requests_get.side_effect = lambda url: responses[url]
+ mock_requests_get.side_effect = lambda url, timeout=30: responses[url]
root_doc = MagicMock(spec=LCDocument)
root_doc.page_content = "Root content"
@@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get):
assert crawler.loader.call_count == 1
+@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.logging")
@patch("application.parser.remote.crawler_loader.requests.get")
-def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging):
+def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url):
mock_requests_get.return_value = DummyResponse("Error route")
failing_loader_instance = MagicMock()
@@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
result = crawler.load_data("http://example.com")
assert result == []
- mock_requests_get.assert_called_once_with("http://example.com")
+ mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
failing_loader_instance.load.assert_called_once()
mock_logging.error.assert_called_once()
@@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
assert "Error processing URL http://example.com" in message
assert mock_logging.error.call_args.kwargs.get("exc_info") is True
+
+@patch("application.parser.remote.crawler_loader.validate_url")
+def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url):
+ """Test that SSRF validation failure returns empty list."""
+ from application.core.url_validation import SSRFError
+ mock_validate_url.side_effect = SSRFError("Access to private IP not allowed")
+
+ crawler = CrawlerLoader()
+ result = crawler.load_data("http://192.168.1.1")
+
+ assert result == []
+ mock_validate_url.assert_called_once()
+
diff --git a/tests/parser/remote/test_crawler_markdown.py b/tests/parser/remote/test_crawler_markdown.py
index ac27b3d0..b2b3f21c 100644
--- a/tests/parser/remote/test_crawler_markdown.py
+++ b/tests/parser/remote/test_crawler_markdown.py
@@ -1,5 +1,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
+from urllib.parse import urlparse
import pytest
import requests
@@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace:
return SimpleNamespace(domain=domain, suffix=suffix)
+def _mock_validate_url(url):
+ """Mock validate_url that allows test URLs through."""
+ if not urlparse(url).scheme:
+ url = "http://" + url
+ return url
+
+
+@pytest.fixture(autouse=True)
+def _patch_validate_url(monkeypatch):
+ monkeypatch.setattr(
+ "application.parser.remote.crawler_markdown.validate_url",
+ _mock_validate_url,
+ )
+
+
@pytest.fixture(autouse=True)
def _patch_tldextract(monkeypatch):
monkeypatch.setattr(
@@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify):
assert len(docs) == 2
-def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
+def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url):
root_html = """
Home
About
@@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
assert docs[0].text == "Home Markdown"
assert mock_print.called
+
+def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch):
+ """Test that SSRF validation failure returns empty list."""
+ from application.core.url_validation import SSRFError
+
+ def raise_ssrf_error(url):
+ raise SSRFError("Access to private IP not allowed")
+
+ monkeypatch.setattr(
+ "application.parser.remote.crawler_markdown.validate_url",
+ raise_ssrf_error,
+ )
+
+ loader = CrawlerLoader()
+ result = loader.load_data("http://192.168.1.1")
+
+ assert result == []
+
diff --git a/tests/test_zip_extraction_security.py b/tests/test_zip_extraction_security.py
new file mode 100644
index 00000000..c53452f7
--- /dev/null
+++ b/tests/test_zip_extraction_security.py
@@ -0,0 +1,293 @@
+"""Tests for zip extraction security measures."""
+
+import os
+import tempfile
+import zipfile
+
+import pytest
+
+from application.worker import (
+ ZipExtractionError,
+ _is_path_safe,
+ _validate_zip_safety,
+ extract_zip_recursive,
+ MAX_FILE_COUNT,
+)
+
+
+class TestIsPathSafe:
+ """Tests for _is_path_safe function."""
+
+ def test_safe_path_in_directory(self):
+ """Normal file within directory should be safe."""
+ assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True
+
+ def test_safe_path_in_subdirectory(self):
+ """File in subdirectory should be safe."""
+ assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True
+
+ def test_unsafe_path_parent_traversal(self):
+ """Path traversal to parent directory should be unsafe."""
+ assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False
+
+ def test_unsafe_path_absolute(self):
+ """Absolute path outside base should be unsafe."""
+ assert _is_path_safe("/tmp/extract", "/etc/passwd") is False
+
+ def test_unsafe_path_sibling(self):
+ """Sibling directory should be unsafe."""
+ assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False
+
+ def test_base_path_itself(self):
+ """Base path itself should be safe."""
+ assert _is_path_safe("/tmp/extract", "/tmp/extract") is True
+
+
+class TestValidateZipSafety:
+ """Tests for _validate_zip_safety function."""
+
+ def test_valid_small_zip(self):
+ """Small valid zip file should pass validation."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a small valid zip
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr("test.txt", "Hello, World!")
+
+ # Should not raise
+ _validate_zip_safety(zip_path, extract_to)
+
+ def test_zip_with_too_many_files(self):
+ """Zip with too many files should be rejected."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a zip with many files (just over limit)
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ for i in range(MAX_FILE_COUNT + 1):
+ zf.writestr(f"file_{i}.txt", "x")
+
+ with pytest.raises(ZipExtractionError) as exc_info:
+ _validate_zip_safety(zip_path, extract_to)
+ assert "too many files" in str(exc_info.value).lower()
+
+ def test_zip_with_path_traversal(self):
+ """Zip with path traversal attempt should be rejected."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a zip with path traversal
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ # Add a normal file first
+ zf.writestr("normal.txt", "normal content")
+ # Add a file with path traversal
+ zf.writestr("../../../etc/passwd", "malicious content")
+
+ with pytest.raises(ZipExtractionError) as exc_info:
+ _validate_zip_safety(zip_path, extract_to)
+ assert "path traversal" in str(exc_info.value).lower()
+
+ def test_corrupted_zip(self):
+ """Corrupted zip file should be rejected."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a corrupted "zip" file
+ with open(zip_path, "wb") as f:
+ f.write(b"not a zip file content")
+
+ with pytest.raises(ZipExtractionError) as exc_info:
+ _validate_zip_safety(zip_path, extract_to)
+ assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower()
+
+
+class TestExtractZipRecursive:
+ """Tests for extract_zip_recursive function."""
+
+ def test_extract_valid_zip(self):
+ """Valid zip file should be extracted successfully."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a valid zip
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr("test.txt", "Hello, World!")
+ zf.writestr("subdir/nested.txt", "Nested content")
+
+ extract_zip_recursive(zip_path, extract_to)
+
+ # Check files were extracted
+ assert os.path.exists(os.path.join(extract_to, "test.txt"))
+ assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt"))
+
+ # Check zip was removed
+ assert not os.path.exists(zip_path)
+
+ def test_extract_nested_zip(self):
+ """Nested zip files should be extracted recursively."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Create inner zip
+ inner_zip_content = b""
+ with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp:
+ with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf:
+ inner_zf.writestr("inner.txt", "Inner content")
+ with open(inner_tmp.name, "rb") as f:
+ inner_zip_content = f.read()
+ os.unlink(inner_tmp.name)
+
+ # Create outer zip containing inner zip
+ zip_path = os.path.join(temp_dir, "outer.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr("outer.txt", "Outer content")
+ zf.writestr("inner.zip", inner_zip_content)
+
+ extract_zip_recursive(zip_path, extract_to)
+
+ # Check outer file was extracted
+ assert os.path.exists(os.path.join(extract_to, "outer.txt"))
+
+ # Check inner zip was extracted
+ assert os.path.exists(os.path.join(extract_to, "inner.txt"))
+
+ # Check both zips were removed
+ assert not os.path.exists(zip_path)
+ assert not os.path.exists(os.path.join(extract_to, "inner.zip"))
+
+ def test_respects_max_depth(self):
+ """Extraction should stop at max recursion depth."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a chain of nested zips
+ current_content = b"Final content"
+ for i in range(7): # More than default max_depth of 5
+ inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
+ with zipfile.ZipFile(inner_tmp.name, "w") as zf:
+ if i == 0:
+ zf.writestr("content.txt", current_content.decode())
+ else:
+ zf.writestr("nested.zip", current_content)
+ with open(inner_tmp.name, "rb") as f:
+ current_content = f.read()
+ os.unlink(inner_tmp.name)
+
+ # Write the final outermost zip
+ zip_path = os.path.join(temp_dir, "outer.zip")
+ with open(zip_path, "wb") as f:
+ f.write(current_content)
+
+ # Extract with max_depth=2
+ extract_zip_recursive(zip_path, extract_to, max_depth=2)
+
+ # The deepest nested zips should remain unextracted
+ # (we can't easily verify the exact behavior, but the function should not crash)
+
+ def test_rejects_path_traversal(self):
+ """Zip with path traversal should be rejected and removed."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "malicious.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a malicious zip
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr("../../../tmp/malicious.txt", "malicious")
+
+ extract_zip_recursive(zip_path, extract_to)
+
+ # Zip should be removed
+ assert not os.path.exists(zip_path)
+
+ # Malicious file should NOT exist outside extract_to
+ assert not os.path.exists("/tmp/malicious.txt")
+
+ def test_handles_corrupted_zip_gracefully(self):
+ """Corrupted zip should be handled gracefully without crashing."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "corrupted.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a corrupted file
+ with open(zip_path, "wb") as f:
+ f.write(b"This is not a valid zip file")
+
+ # Should not raise, just log error
+ extract_zip_recursive(zip_path, extract_to)
+
+ # Function should complete without exception
+
+
+class TestZipBombProtection:
+ """Tests specifically for zip bomb protection."""
+
+ def test_detects_high_compression_ratio(self):
+ """Highly compressed data should trigger compression ratio check."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "bomb.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a file with highly compressible content (all zeros)
+ # This triggers the compression ratio check
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
+ # Create a large file with repetitive content - compresses extremely well
+ repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's
+ zf.writestr("repetitive.txt", repetitive_content)
+
+ # This should be rejected due to high compression ratio
+ with pytest.raises(ZipExtractionError) as exc_info:
+ _validate_zip_safety(zip_path, extract_to)
+ assert "compression ratio" in str(exc_info.value).lower()
+
+ def test_normal_compression_passes(self):
+ """Normal compression ratio should pass validation."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "normal.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a zip with random-ish content that doesn't compress well
+ import random
+ random.seed(42)
+ random_content = "".join(
+ random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240)
+ )
+
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
+ zf.writestr("random.txt", random_content)
+
+ # Should pass - random content doesn't compress well
+ _validate_zip_safety(zip_path, extract_to)
+
+ def test_size_limit_check(self):
+ """Files exceeding size limit should be rejected."""
+ # Note: We can't easily create a real zip bomb in tests
+ # This test verifies the validation logic works
+ with tempfile.TemporaryDirectory() as temp_dir:
+ zip_path = os.path.join(temp_dir, "test.zip")
+ extract_to = os.path.join(temp_dir, "extract")
+ os.makedirs(extract_to)
+
+ # Create a zip with a reasonable size (no compression to avoid ratio issues)
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
+ # 10 KB file
+ zf.writestr("normal.txt", "x" * 10240)
+
+ # Should pass
+ _validate_zip_safety(zip_path, extract_to)