* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes
This commit is contained in:
Alex
2025-12-24 15:05:35 +00:00
committed by GitHub
parent 83e7a928f1
commit 98e949d2fd
11 changed files with 929 additions and 27 deletions

View File

@@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
@@ -73,6 +74,17 @@ class APITool(Tool):
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
@@ -90,6 +102,18 @@ class APITool(Tool):
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -0,0 +1,181 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Check if the URL scheme is provided, if not, assume http
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
visited_urls = set()
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
@@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
response = requests.get(current_url)
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -2,6 +2,7 @@ import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Ensure the URL has a scheme (if not, default to http)
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
print(f"URL validation failed: {e}")
return []
# Keep track of visited URLs to avoid revisiting the same page
visited_urls = set()
@@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -3,6 +3,7 @@ import requests
import re # Import regular expression library
import xml.etree.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
url = sitemap_url[0]
sitemap_url = sitemap_url[0]
# Validate URL to prevent SSRF attacks
try:
sitemap_url = validate_url(sitemap_url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
urls = self._extract_urls(sitemap_url)
if not urls:
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
response = requests.get(sitemap_url)
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []

View File

@@ -63,10 +63,111 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
# Zip extraction security limits
MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size
MAX_FILE_COUNT = 10000 # Maximum number of files to extract
MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs)
class ZipExtractionError(Exception):
"""Raised when zip extraction fails due to security constraints."""
pass
def _is_path_safe(base_path: str, target_path: str) -> bool:
"""
Check if target_path is safely within base_path (prevents zip slip attacks).
Args:
base_path: The base directory where extraction should occur.
target_path: The full path where a file would be extracted.
Returns:
True if the path is safe, False otherwise.
"""
# Resolve to absolute paths and check containment
base_resolved = os.path.realpath(base_path)
target_resolved = os.path.realpath(target_path)
return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved
def _validate_zip_safety(zip_path: str, extract_to: str) -> None:
"""
Validate a zip file for security issues before extraction.
Checks for:
- Zip bombs (excessive compression ratio or uncompressed size)
- Too many files
- Path traversal attacks (zip slip)
Args:
zip_path: Path to the zip file.
extract_to: Destination directory.
Raises:
ZipExtractionError: If the zip file fails security validation.
"""
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# Get compressed size
compressed_size = os.path.getsize(zip_path)
# Calculate total uncompressed size and file count
total_uncompressed = 0
file_count = 0
for info in zip_ref.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILE_COUNT:
raise ZipExtractionError(
f"Zip file contains too many files (>{MAX_FILE_COUNT}). "
"This may be a zip bomb attack."
)
# Accumulate uncompressed size
total_uncompressed += info.file_size
# Check total uncompressed size
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
raise ZipExtractionError(
f"Zip file uncompressed size exceeds limit "
f"({total_uncompressed / (1024*1024):.1f} MB > "
f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). "
"This may be a zip bomb attack."
)
# Check for path traversal (zip slip)
target_path = os.path.join(extract_to, info.filename)
if not _is_path_safe(extract_to, target_path):
raise ZipExtractionError(
f"Zip file contains path traversal attempt: {info.filename}"
)
# Check compression ratio (only if compressed size is meaningful)
if compressed_size > 0 and total_uncompressed > 0:
compression_ratio = total_uncompressed / compressed_size
if compression_ratio > MAX_COMPRESSION_RATIO:
raise ZipExtractionError(
f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > "
f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack."
)
except zipfile.BadZipFile as e:
raise ZipExtractionError(f"Invalid or corrupted zip file: {e}")
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
Recursively extract zip files with a limit on recursion depth.
Recursively extract zip files with security protections.
Security measures:
- Limits recursion depth to prevent infinite loops
- Validates uncompressed size to prevent zip bombs
- Limits number of files to prevent resource exhaustion
- Checks compression ratio to detect zip bombs
- Validates paths to prevent zip slip attacks
Args:
zip_path (str): Path to the zip file to be extracted.
@@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
if current_depth > max_depth:
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
try:
# Validate zip file safety before extraction
_validate_zip_safety(zip_path, extract_to)
# Safe to extract
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
except ZipExtractionError as e:
logging.error(f"Zip security validation failed for {zip_path}: {e}")
# Remove the potentially malicious zip file
try:
os.remove(zip_path)
except OSError:
pass
return
except Exception as e:
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
return
# Check for nested zip files and extract them
# Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)

View File

@@ -0,0 +1,197 @@
"""Tests for SSRF URL validation module."""
import pytest
from unittest.mock import patch
from application.core.url_validation import (
SSRFError,
validate_url,
validate_url_safe,
is_private_ip,
is_metadata_ip,
)
class TestIsPrivateIP:
"""Tests for is_private_ip function."""
def test_loopback_ipv4(self):
assert is_private_ip("127.0.0.1") is True
assert is_private_ip("127.255.255.255") is True
def test_private_class_a(self):
assert is_private_ip("10.0.0.1") is True
assert is_private_ip("10.255.255.255") is True
def test_private_class_b(self):
assert is_private_ip("172.16.0.1") is True
assert is_private_ip("172.31.255.255") is True
def test_private_class_c(self):
assert is_private_ip("192.168.0.1") is True
assert is_private_ip("192.168.255.255") is True
def test_link_local(self):
assert is_private_ip("169.254.0.1") is True
def test_public_ip(self):
assert is_private_ip("8.8.8.8") is False
assert is_private_ip("1.1.1.1") is False
assert is_private_ip("93.184.216.34") is False
def test_invalid_ip(self):
assert is_private_ip("not-an-ip") is False
assert is_private_ip("") is False
class TestIsMetadataIP:
"""Tests for is_metadata_ip function."""
def test_aws_metadata_ip(self):
assert is_metadata_ip("169.254.169.254") is True
def test_aws_ecs_metadata_ip(self):
assert is_metadata_ip("169.254.170.2") is True
def test_non_metadata_ip(self):
assert is_metadata_ip("8.8.8.8") is False
assert is_metadata_ip("10.0.0.1") is False
class TestValidateUrl:
"""Tests for validate_url function."""
def test_adds_scheme_if_missing(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34" # Public IP
result = validate_url("example.com")
assert result == "http://example.com"
def test_preserves_https_scheme(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
result = validate_url("https://example.com")
assert result == "https://example.com"
def test_blocks_localhost(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://localhost")
assert "localhost" in str(exc_info.value).lower()
def test_blocks_localhost_localdomain(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://localhost.localdomain")
assert "not allowed" in str(exc_info.value).lower()
def test_blocks_loopback_ip(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://127.0.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_a(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://10.0.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_b(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://172.16.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_c(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://192.168.1.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_aws_metadata_ip(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://169.254.169.254")
assert "metadata" in str(exc_info.value).lower()
def test_blocks_aws_metadata_with_path(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://169.254.169.254/latest/meta-data/")
assert "metadata" in str(exc_info.value).lower()
def test_blocks_gcp_metadata_hostname(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://metadata.google.internal")
assert "not allowed" in str(exc_info.value).lower()
def test_blocks_ftp_scheme(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("ftp://example.com")
assert "scheme" in str(exc_info.value).lower()
def test_blocks_file_scheme(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("file:///etc/passwd")
assert "scheme" in str(exc_info.value).lower()
def test_blocks_hostname_resolving_to_private_ip(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "192.168.1.1"
with pytest.raises(SSRFError) as exc_info:
validate_url("http://internal.example.com")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_hostname_resolving_to_metadata_ip(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "169.254.169.254"
with pytest.raises(SSRFError) as exc_info:
validate_url("http://evil.example.com")
assert "metadata" in str(exc_info.value).lower()
def test_allows_public_ip(self):
result = validate_url("http://8.8.8.8")
assert result == "http://8.8.8.8"
def test_allows_public_hostname(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
result = validate_url("https://example.com")
assert result == "https://example.com"
def test_raises_on_unresolvable_hostname(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = None
with pytest.raises(SSRFError) as exc_info:
validate_url("http://nonexistent.invalid")
assert "resolve" in str(exc_info.value).lower()
def test_raises_on_empty_hostname(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://")
assert "hostname" in str(exc_info.value).lower()
def test_allow_localhost_flag(self):
# Should work with allow_localhost=True
result = validate_url("http://localhost", allow_localhost=True)
assert result == "http://localhost"
result = validate_url("http://127.0.0.1", allow_localhost=True)
assert result == "http://127.0.0.1"
class TestValidateUrlSafe:
"""Tests for validate_url_safe non-throwing function."""
def test_returns_tuple_on_success(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
is_valid, url, error = validate_url_safe("https://example.com")
assert is_valid is True
assert url == "https://example.com"
assert error is None
def test_returns_tuple_on_failure(self):
is_valid, url, error = validate_url_safe("http://localhost")
assert is_valid is False
assert url == "http://localhost"
assert error is not None
assert "localhost" in error.lower()
def test_returns_error_message_for_private_ip(self):
is_valid, url, error = validate_url_safe("http://192.168.1.1")
assert is_valid is False
assert "private" in error.lower() or "internal" in error.lower()

View File

@@ -13,8 +13,17 @@ class DummyResponse:
return None
def _mock_validate_url(url):
"""Mock validate_url that allows test URLs through."""
from urllib.parse import urlparse
if not urlparse(url).scheme:
url = "http://" + url
return url
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_crawls_same_domain_links(mock_requests_get):
def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
"http://example.com/about": DummyResponse("<html><body>About page</body></html>"),
}
def response_side_effect(url: str):
def response_side_effect(url: str, timeout=30):
if url not in responses:
raise AssertionError(f"Unexpected request for URL: {url}")
return responses[url]
@@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
assert loader_call_order == ["http://example.com", "http://example.com/about"]
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url):
mock_requests_get.return_value = DummyResponse("<html><body>No links here</body></html>")
doc = MagicMock(spec=LCDocument)
@@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
result = crawler.load_data(["example.com", "unused.com"])
mock_requests_get.assert_called_once_with("http://example.com")
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
crawler.loader.assert_called_once_with(["http://example.com"])
assert len(result) == 1
@@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
assert result[0].extra_info == {"source": "http://example.com"}
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_respects_limit(mock_requests_get):
def test_load_data_respects_limit(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get):
"http://example.com/about": DummyResponse("<html><body>About</body></html>"),
}
mock_requests_get.side_effect = lambda url: responses[url]
mock_requests_get.side_effect = lambda url, timeout=30: responses[url]
root_doc = MagicMock(spec=LCDocument)
root_doc.page_content = "Root content"
@@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get):
assert crawler.loader.call_count == 1
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.logging")
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging):
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url):
mock_requests_get.return_value = DummyResponse("<html><body>Error route</body></html>")
failing_loader_instance = MagicMock()
@@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
result = crawler.load_data("http://example.com")
assert result == []
mock_requests_get.assert_called_once_with("http://example.com")
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
failing_loader_instance.load.assert_called_once()
mock_logging.error.assert_called_once()
@@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
assert "Error processing URL http://example.com" in message
assert mock_logging.error.call_args.kwargs.get("exc_info") is True
@patch("application.parser.remote.crawler_loader.validate_url")
def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url):
"""Test that SSRF validation failure returns empty list."""
from application.core.url_validation import SSRFError
mock_validate_url.side_effect = SSRFError("Access to private IP not allowed")
crawler = CrawlerLoader()
result = crawler.load_data("http://192.168.1.1")
assert result == []
mock_validate_url.assert_called_once()

View File

@@ -1,5 +1,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from urllib.parse import urlparse
import pytest
import requests
@@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace:
return SimpleNamespace(domain=domain, suffix=suffix)
def _mock_validate_url(url):
"""Mock validate_url that allows test URLs through."""
if not urlparse(url).scheme:
url = "http://" + url
return url
@pytest.fixture(autouse=True)
def _patch_validate_url(monkeypatch):
monkeypatch.setattr(
"application.parser.remote.crawler_markdown.validate_url",
_mock_validate_url,
)
@pytest.fixture(autouse=True)
def _patch_tldextract(monkeypatch):
monkeypatch.setattr(
@@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify):
assert len(docs) == 2
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url):
root_html = """
<html><head><title>Home</title></head>
<body><a href="/about">About</a></body>
@@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
assert docs[0].text == "Home Markdown"
assert mock_print.called
def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch):
"""Test that SSRF validation failure returns empty list."""
from application.core.url_validation import SSRFError
def raise_ssrf_error(url):
raise SSRFError("Access to private IP not allowed")
monkeypatch.setattr(
"application.parser.remote.crawler_markdown.validate_url",
raise_ssrf_error,
)
loader = CrawlerLoader()
result = loader.load_data("http://192.168.1.1")
assert result == []

View File

@@ -0,0 +1,293 @@
"""Tests for zip extraction security measures."""
import os
import tempfile
import zipfile
import pytest
from application.worker import (
ZipExtractionError,
_is_path_safe,
_validate_zip_safety,
extract_zip_recursive,
MAX_FILE_COUNT,
)
class TestIsPathSafe:
"""Tests for _is_path_safe function."""
def test_safe_path_in_directory(self):
"""Normal file within directory should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True
def test_safe_path_in_subdirectory(self):
"""File in subdirectory should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True
def test_unsafe_path_parent_traversal(self):
"""Path traversal to parent directory should be unsafe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False
def test_unsafe_path_absolute(self):
"""Absolute path outside base should be unsafe."""
assert _is_path_safe("/tmp/extract", "/etc/passwd") is False
def test_unsafe_path_sibling(self):
"""Sibling directory should be unsafe."""
assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False
def test_base_path_itself(self):
"""Base path itself should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract") is True
class TestValidateZipSafety:
"""Tests for _validate_zip_safety function."""
def test_valid_small_zip(self):
"""Small valid zip file should pass validation."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a small valid zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "Hello, World!")
# Should not raise
_validate_zip_safety(zip_path, extract_to)
def test_zip_with_too_many_files(self):
"""Zip with too many files should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with many files (just over limit)
with zipfile.ZipFile(zip_path, "w") as zf:
for i in range(MAX_FILE_COUNT + 1):
zf.writestr(f"file_{i}.txt", "x")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "too many files" in str(exc_info.value).lower()
def test_zip_with_path_traversal(self):
"""Zip with path traversal attempt should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with path traversal
with zipfile.ZipFile(zip_path, "w") as zf:
# Add a normal file first
zf.writestr("normal.txt", "normal content")
# Add a file with path traversal
zf.writestr("../../../etc/passwd", "malicious content")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "path traversal" in str(exc_info.value).lower()
def test_corrupted_zip(self):
"""Corrupted zip file should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a corrupted "zip" file
with open(zip_path, "wb") as f:
f.write(b"not a zip file content")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower()
class TestExtractZipRecursive:
"""Tests for extract_zip_recursive function."""
def test_extract_valid_zip(self):
"""Valid zip file should be extracted successfully."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a valid zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "Hello, World!")
zf.writestr("subdir/nested.txt", "Nested content")
extract_zip_recursive(zip_path, extract_to)
# Check files were extracted
assert os.path.exists(os.path.join(extract_to, "test.txt"))
assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt"))
# Check zip was removed
assert not os.path.exists(zip_path)
def test_extract_nested_zip(self):
"""Nested zip files should be extracted recursively."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create inner zip
inner_zip_content = b""
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp:
with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf:
inner_zf.writestr("inner.txt", "Inner content")
with open(inner_tmp.name, "rb") as f:
inner_zip_content = f.read()
os.unlink(inner_tmp.name)
# Create outer zip containing inner zip
zip_path = os.path.join(temp_dir, "outer.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("outer.txt", "Outer content")
zf.writestr("inner.zip", inner_zip_content)
extract_zip_recursive(zip_path, extract_to)
# Check outer file was extracted
assert os.path.exists(os.path.join(extract_to, "outer.txt"))
# Check inner zip was extracted
assert os.path.exists(os.path.join(extract_to, "inner.txt"))
# Check both zips were removed
assert not os.path.exists(zip_path)
assert not os.path.exists(os.path.join(extract_to, "inner.zip"))
def test_respects_max_depth(self):
"""Extraction should stop at max recursion depth."""
with tempfile.TemporaryDirectory() as temp_dir:
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a chain of nested zips
current_content = b"Final content"
for i in range(7): # More than default max_depth of 5
inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
with zipfile.ZipFile(inner_tmp.name, "w") as zf:
if i == 0:
zf.writestr("content.txt", current_content.decode())
else:
zf.writestr("nested.zip", current_content)
with open(inner_tmp.name, "rb") as f:
current_content = f.read()
os.unlink(inner_tmp.name)
# Write the final outermost zip
zip_path = os.path.join(temp_dir, "outer.zip")
with open(zip_path, "wb") as f:
f.write(current_content)
# Extract with max_depth=2
extract_zip_recursive(zip_path, extract_to, max_depth=2)
# The deepest nested zips should remain unextracted
# (we can't easily verify the exact behavior, but the function should not crash)
def test_rejects_path_traversal(self):
"""Zip with path traversal should be rejected and removed."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "malicious.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a malicious zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("../../../tmp/malicious.txt", "malicious")
extract_zip_recursive(zip_path, extract_to)
# Zip should be removed
assert not os.path.exists(zip_path)
# Malicious file should NOT exist outside extract_to
assert not os.path.exists("/tmp/malicious.txt")
def test_handles_corrupted_zip_gracefully(self):
"""Corrupted zip should be handled gracefully without crashing."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "corrupted.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a corrupted file
with open(zip_path, "wb") as f:
f.write(b"This is not a valid zip file")
# Should not raise, just log error
extract_zip_recursive(zip_path, extract_to)
# Function should complete without exception
class TestZipBombProtection:
"""Tests specifically for zip bomb protection."""
def test_detects_high_compression_ratio(self):
"""Highly compressed data should trigger compression ratio check."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "bomb.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a file with highly compressible content (all zeros)
# This triggers the compression ratio check
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# Create a large file with repetitive content - compresses extremely well
repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's
zf.writestr("repetitive.txt", repetitive_content)
# This should be rejected due to high compression ratio
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "compression ratio" in str(exc_info.value).lower()
def test_normal_compression_passes(self):
"""Normal compression ratio should pass validation."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "normal.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with random-ish content that doesn't compress well
import random
random.seed(42)
random_content = "".join(
random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240)
)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("random.txt", random_content)
# Should pass - random content doesn't compress well
_validate_zip_safety(zip_path, extract_to)
def test_size_limit_check(self):
"""Files exceeding size limit should be rejected."""
# Note: We can't easily create a real zip bomb in tests
# This test verifies the validation logic works
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with a reasonable size (no compression to avoid ratio issues)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
# 10 KB file
zf.writestr("normal.txt", "x" * 10240)
# Should pass
_validate_zip_safety(zip_path, extract_to)