mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-26 00:40:32 +00:00
Patches (#2218)
* feat: implement URL validation to prevent SSRF * feat: add zip extraction security * ruff fixes
This commit is contained in:
@@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,6 +74,17 @@ class APITool(Tool):
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
try:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
@@ -90,6 +102,18 @@ class APITool(Tool):
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from urllib.parse import urlparse
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
url = "http://" + url
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
181
application/core/url_validation.py
Normal file
181
application/core/url_validation.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
|
||||
|
||||
This module provides functions to validate URLs before making HTTP requests,
|
||||
blocking access to internal networks, cloud metadata services, and other
|
||||
potentially dangerous endpoints.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Set
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""Raised when a URL fails SSRF validation."""
|
||||
pass
|
||||
|
||||
|
||||
# Blocked hostnames that should never be accessed
|
||||
BLOCKED_HOSTNAMES: Set[str] = {
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
"metadata.google.internal",
|
||||
"metadata",
|
||||
}
|
||||
|
||||
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
|
||||
METADATA_IPS: Set[str] = {
|
||||
"169.254.169.254", # AWS, GCP, Azure metadata
|
||||
"169.254.170.2", # AWS ECS task metadata
|
||||
"fd00:ec2::254", # AWS IPv6 metadata
|
||||
}
|
||||
|
||||
# Allowed schemes for external requests
|
||||
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
|
||||
|
||||
|
||||
def is_private_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is private, loopback, or link-local.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string
|
||||
|
||||
Returns:
|
||||
True if the IP is private/internal, False otherwise
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
return (
|
||||
ip.is_private or
|
||||
ip.is_loopback or
|
||||
ip.is_link_local or
|
||||
ip.is_reserved or
|
||||
ip.is_multicast or
|
||||
ip.is_unspecified
|
||||
)
|
||||
except ValueError:
|
||||
# If we can't parse it as an IP, return False
|
||||
return False
|
||||
|
||||
|
||||
def is_metadata_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is a cloud metadata service IP.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string
|
||||
|
||||
Returns:
|
||||
True if the IP is a metadata service, False otherwise
|
||||
"""
|
||||
return ip_str in METADATA_IPS
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a hostname to an IP address.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to resolve
|
||||
|
||||
Returns:
|
||||
The resolved IP address, or None if resolution fails
|
||||
"""
|
||||
try:
|
||||
return socket.gethostbyname(hostname)
|
||||
except socket.gaierror:
|
||||
return None
|
||||
|
||||
|
||||
def validate_url(url: str, allow_localhost: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL to prevent SSRF attacks.
|
||||
|
||||
This function checks that:
|
||||
1. The URL has an allowed scheme (http or https)
|
||||
2. The hostname is not a blocked hostname
|
||||
3. The resolved IP is not a private/internal IP
|
||||
4. The resolved IP is not a cloud metadata service
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
allow_localhost: If True, allow localhost connections (for testing only)
|
||||
|
||||
Returns:
|
||||
The validated URL (with scheme added if missing)
|
||||
|
||||
Raises:
|
||||
SSRFError: If the URL fails validation
|
||||
"""
|
||||
# Ensure URL has a scheme
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
|
||||
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise SSRFError("URL must have a valid hostname.")
|
||||
|
||||
hostname_lower = hostname.lower()
|
||||
|
||||
# Check blocked hostnames
|
||||
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
|
||||
raise SSRFError(f"Access to '{hostname}' is not allowed.")
|
||||
|
||||
# Check if hostname is an IP address directly
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
ip_str = str(ip)
|
||||
|
||||
if is_metadata_ip(ip_str):
|
||||
raise SSRFError("Access to cloud metadata services is not allowed.")
|
||||
|
||||
if is_private_ip(ip_str) and not allow_localhost:
|
||||
raise SSRFError("Access to private/internal IP addresses is not allowed.")
|
||||
|
||||
return url
|
||||
except ValueError:
|
||||
# Not an IP address, it's a hostname - resolve it
|
||||
pass
|
||||
|
||||
# Resolve hostname and check the IP
|
||||
resolved_ip = resolve_hostname(hostname)
|
||||
if resolved_ip is None:
|
||||
raise SSRFError(f"Unable to resolve hostname: {hostname}")
|
||||
|
||||
if is_metadata_ip(resolved_ip):
|
||||
raise SSRFError("Access to cloud metadata services is not allowed.")
|
||||
|
||||
if is_private_ip(resolved_ip) and not allow_localhost:
|
||||
raise SSRFError("Access to private/internal networks is not allowed.")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate a URL and return a tuple with validation result.
|
||||
|
||||
This is a non-throwing version of validate_url for cases where
|
||||
you want to handle validation failures gracefully.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
allow_localhost: If True, allow localhost connections (for testing only)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
|
||||
"""
|
||||
try:
|
||||
validated = validate_url(url, allow_localhost)
|
||||
return (True, validated, None)
|
||||
except SSRFError as e:
|
||||
return (False, url, str(e))
|
||||
@@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Check if the URL scheme is provided, if not, assume http
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
visited_urls = set()
|
||||
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
|
||||
@@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
response = requests.get(current_url)
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
|
||||
@@ -2,6 +2,7 @@ import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
@@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
# Keep track of visited URLs to avoid revisiting the same page
|
||||
visited_urls = set()
|
||||
@@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
|
||||
@@ -3,6 +3,7 @@ import requests
|
||||
import re # Import regular expression library
|
||||
import xml.etree.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
|
||||
sitemap_url= inputs
|
||||
# Check if the input is a list and if it is, use the first element
|
||||
if isinstance(sitemap_url, list) and sitemap_url:
|
||||
url = sitemap_url[0]
|
||||
sitemap_url = sitemap_url[0]
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
sitemap_url = validate_url(sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
urls = self._extract_urls(sitemap_url)
|
||||
if not urls:
|
||||
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
response = requests.get(sitemap_url)
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
@@ -63,10 +63,111 @@ current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
# Zip extraction security limits
|
||||
MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size
|
||||
MAX_FILE_COUNT = 10000 # Maximum number of files to extract
|
||||
MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs)
|
||||
|
||||
|
||||
class ZipExtractionError(Exception):
|
||||
"""Raised when zip extraction fails due to security constraints."""
|
||||
pass
|
||||
|
||||
|
||||
def _is_path_safe(base_path: str, target_path: str) -> bool:
|
||||
"""
|
||||
Check if target_path is safely within base_path (prevents zip slip attacks).
|
||||
|
||||
Args:
|
||||
base_path: The base directory where extraction should occur.
|
||||
target_path: The full path where a file would be extracted.
|
||||
|
||||
Returns:
|
||||
True if the path is safe, False otherwise.
|
||||
"""
|
||||
# Resolve to absolute paths and check containment
|
||||
base_resolved = os.path.realpath(base_path)
|
||||
target_resolved = os.path.realpath(target_path)
|
||||
return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved
|
||||
|
||||
|
||||
def _validate_zip_safety(zip_path: str, extract_to: str) -> None:
|
||||
"""
|
||||
Validate a zip file for security issues before extraction.
|
||||
|
||||
Checks for:
|
||||
- Zip bombs (excessive compression ratio or uncompressed size)
|
||||
- Too many files
|
||||
- Path traversal attacks (zip slip)
|
||||
|
||||
Args:
|
||||
zip_path: Path to the zip file.
|
||||
extract_to: Destination directory.
|
||||
|
||||
Raises:
|
||||
ZipExtractionError: If the zip file fails security validation.
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
# Get compressed size
|
||||
compressed_size = os.path.getsize(zip_path)
|
||||
|
||||
# Calculate total uncompressed size and file count
|
||||
total_uncompressed = 0
|
||||
file_count = 0
|
||||
|
||||
for info in zip_ref.infolist():
|
||||
file_count += 1
|
||||
|
||||
# Check file count limit
|
||||
if file_count > MAX_FILE_COUNT:
|
||||
raise ZipExtractionError(
|
||||
f"Zip file contains too many files (>{MAX_FILE_COUNT}). "
|
||||
"This may be a zip bomb attack."
|
||||
)
|
||||
|
||||
# Accumulate uncompressed size
|
||||
total_uncompressed += info.file_size
|
||||
|
||||
# Check total uncompressed size
|
||||
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
|
||||
raise ZipExtractionError(
|
||||
f"Zip file uncompressed size exceeds limit "
|
||||
f"({total_uncompressed / (1024*1024):.1f} MB > "
|
||||
f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). "
|
||||
"This may be a zip bomb attack."
|
||||
)
|
||||
|
||||
# Check for path traversal (zip slip)
|
||||
target_path = os.path.join(extract_to, info.filename)
|
||||
if not _is_path_safe(extract_to, target_path):
|
||||
raise ZipExtractionError(
|
||||
f"Zip file contains path traversal attempt: {info.filename}"
|
||||
)
|
||||
|
||||
# Check compression ratio (only if compressed size is meaningful)
|
||||
if compressed_size > 0 and total_uncompressed > 0:
|
||||
compression_ratio = total_uncompressed / compressed_size
|
||||
if compression_ratio > MAX_COMPRESSION_RATIO:
|
||||
raise ZipExtractionError(
|
||||
f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > "
|
||||
f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack."
|
||||
)
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
raise ZipExtractionError(f"Invalid or corrupted zip file: {e}")
|
||||
|
||||
|
||||
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
||||
"""
|
||||
Recursively extract zip files with a limit on recursion depth.
|
||||
Recursively extract zip files with security protections.
|
||||
|
||||
Security measures:
|
||||
- Limits recursion depth to prevent infinite loops
|
||||
- Validates uncompressed size to prevent zip bombs
|
||||
- Limits number of files to prevent resource exhaustion
|
||||
- Checks compression ratio to detect zip bombs
|
||||
- Validates paths to prevent zip slip attacks
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zip file to be extracted.
|
||||
@@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
||||
if current_depth > max_depth:
|
||||
logging.warning(f"Reached maximum recursion depth of {max_depth}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Validate zip file safety before extraction
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
|
||||
# Safe to extract
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_to)
|
||||
os.remove(zip_path) # Remove the zip file after extracting
|
||||
|
||||
except ZipExtractionError as e:
|
||||
logging.error(f"Zip security validation failed for {zip_path}: {e}")
|
||||
# Remove the potentially malicious zip file
|
||||
try:
|
||||
os.remove(zip_path)
|
||||
except OSError:
|
||||
pass
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
|
||||
return
|
||||
# Check for nested zip files and extract them
|
||||
|
||||
# Check for nested zip files and extract them
|
||||
for root, dirs, files in os.walk(extract_to):
|
||||
for file in files:
|
||||
if file.endswith(".zip"):
|
||||
# If a nested zip file is found, extract it recursively
|
||||
|
||||
file_path = os.path.join(root, file)
|
||||
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
|
||||
|
||||
|
||||
197
tests/core/test_url_validation.py
Normal file
197
tests/core/test_url_validation.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Tests for SSRF URL validation module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.core.url_validation import (
|
||||
SSRFError,
|
||||
validate_url,
|
||||
validate_url_safe,
|
||||
is_private_ip,
|
||||
is_metadata_ip,
|
||||
)
|
||||
|
||||
|
||||
class TestIsPrivateIP:
|
||||
"""Tests for is_private_ip function."""
|
||||
|
||||
def test_loopback_ipv4(self):
|
||||
assert is_private_ip("127.0.0.1") is True
|
||||
assert is_private_ip("127.255.255.255") is True
|
||||
|
||||
def test_private_class_a(self):
|
||||
assert is_private_ip("10.0.0.1") is True
|
||||
assert is_private_ip("10.255.255.255") is True
|
||||
|
||||
def test_private_class_b(self):
|
||||
assert is_private_ip("172.16.0.1") is True
|
||||
assert is_private_ip("172.31.255.255") is True
|
||||
|
||||
def test_private_class_c(self):
|
||||
assert is_private_ip("192.168.0.1") is True
|
||||
assert is_private_ip("192.168.255.255") is True
|
||||
|
||||
def test_link_local(self):
|
||||
assert is_private_ip("169.254.0.1") is True
|
||||
|
||||
def test_public_ip(self):
|
||||
assert is_private_ip("8.8.8.8") is False
|
||||
assert is_private_ip("1.1.1.1") is False
|
||||
assert is_private_ip("93.184.216.34") is False
|
||||
|
||||
def test_invalid_ip(self):
|
||||
assert is_private_ip("not-an-ip") is False
|
||||
assert is_private_ip("") is False
|
||||
|
||||
|
||||
class TestIsMetadataIP:
|
||||
"""Tests for is_metadata_ip function."""
|
||||
|
||||
def test_aws_metadata_ip(self):
|
||||
assert is_metadata_ip("169.254.169.254") is True
|
||||
|
||||
def test_aws_ecs_metadata_ip(self):
|
||||
assert is_metadata_ip("169.254.170.2") is True
|
||||
|
||||
def test_non_metadata_ip(self):
|
||||
assert is_metadata_ip("8.8.8.8") is False
|
||||
assert is_metadata_ip("10.0.0.1") is False
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url function."""
|
||||
|
||||
def test_adds_scheme_if_missing(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "93.184.216.34" # Public IP
|
||||
result = validate_url("example.com")
|
||||
assert result == "http://example.com"
|
||||
|
||||
def test_preserves_https_scheme(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "93.184.216.34"
|
||||
result = validate_url("https://example.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_blocks_localhost(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://localhost")
|
||||
assert "localhost" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_localhost_localdomain(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://localhost.localdomain")
|
||||
assert "not allowed" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_loopback_ip(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://127.0.0.1")
|
||||
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_private_ip_class_a(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://10.0.0.1")
|
||||
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_private_ip_class_b(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://172.16.0.1")
|
||||
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_private_ip_class_c(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://192.168.1.1")
|
||||
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_aws_metadata_ip(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://169.254.169.254")
|
||||
assert "metadata" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_aws_metadata_with_path(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://169.254.169.254/latest/meta-data/")
|
||||
assert "metadata" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_gcp_metadata_hostname(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://metadata.google.internal")
|
||||
assert "not allowed" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_ftp_scheme(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("ftp://example.com")
|
||||
assert "scheme" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_file_scheme(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("file:///etc/passwd")
|
||||
assert "scheme" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_hostname_resolving_to_private_ip(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "192.168.1.1"
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://internal.example.com")
|
||||
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
|
||||
|
||||
def test_blocks_hostname_resolving_to_metadata_ip(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "169.254.169.254"
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://evil.example.com")
|
||||
assert "metadata" in str(exc_info.value).lower()
|
||||
|
||||
def test_allows_public_ip(self):
|
||||
result = validate_url("http://8.8.8.8")
|
||||
assert result == "http://8.8.8.8"
|
||||
|
||||
def test_allows_public_hostname(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "93.184.216.34"
|
||||
result = validate_url("https://example.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_raises_on_unresolvable_hostname(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = None
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://nonexistent.invalid")
|
||||
assert "resolve" in str(exc_info.value).lower()
|
||||
|
||||
def test_raises_on_empty_hostname(self):
|
||||
with pytest.raises(SSRFError) as exc_info:
|
||||
validate_url("http://")
|
||||
assert "hostname" in str(exc_info.value).lower()
|
||||
|
||||
def test_allow_localhost_flag(self):
|
||||
# Should work with allow_localhost=True
|
||||
result = validate_url("http://localhost", allow_localhost=True)
|
||||
assert result == "http://localhost"
|
||||
|
||||
result = validate_url("http://127.0.0.1", allow_localhost=True)
|
||||
assert result == "http://127.0.0.1"
|
||||
|
||||
|
||||
class TestValidateUrlSafe:
|
||||
"""Tests for validate_url_safe non-throwing function."""
|
||||
|
||||
def test_returns_tuple_on_success(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "93.184.216.34"
|
||||
is_valid, url, error = validate_url_safe("https://example.com")
|
||||
assert is_valid is True
|
||||
assert url == "https://example.com"
|
||||
assert error is None
|
||||
|
||||
def test_returns_tuple_on_failure(self):
|
||||
is_valid, url, error = validate_url_safe("http://localhost")
|
||||
assert is_valid is False
|
||||
assert url == "http://localhost"
|
||||
assert error is not None
|
||||
assert "localhost" in error.lower()
|
||||
|
||||
def test_returns_error_message_for_private_ip(self):
|
||||
is_valid, url, error = validate_url_safe("http://192.168.1.1")
|
||||
assert is_valid is False
|
||||
assert "private" in error.lower() or "internal" in error.lower()
|
||||
@@ -13,8 +13,17 @@ class DummyResponse:
|
||||
return None
|
||||
|
||||
|
||||
def _mock_validate_url(url):
|
||||
"""Mock validate_url that allows test URLs through."""
|
||||
from urllib.parse import urlparse
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
return url
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url):
|
||||
responses = {
|
||||
"http://example.com": DummyResponse(
|
||||
"""
|
||||
@@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
"http://example.com/about": DummyResponse("<html><body>About page</body></html>"),
|
||||
}
|
||||
|
||||
def response_side_effect(url: str):
|
||||
def response_side_effect(url: str, timeout=30):
|
||||
if url not in responses:
|
||||
raise AssertionError(f"Unexpected request for URL: {url}")
|
||||
return responses[url]
|
||||
@@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
assert loader_call_order == ["http://example.com", "http://example.com/about"]
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url):
|
||||
mock_requests_get.return_value = DummyResponse("<html><body>No links here</body></html>")
|
||||
|
||||
doc = MagicMock(spec=LCDocument)
|
||||
@@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
|
||||
result = crawler.load_data(["example.com", "unused.com"])
|
||||
|
||||
mock_requests_get.assert_called_once_with("http://example.com")
|
||||
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
|
||||
crawler.loader.assert_called_once_with(["http://example.com"])
|
||||
|
||||
assert len(result) == 1
|
||||
@@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
assert result[0].extra_info == {"source": "http://example.com"}
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_respects_limit(mock_requests_get):
|
||||
def test_load_data_respects_limit(mock_requests_get, mock_validate_url):
|
||||
responses = {
|
||||
"http://example.com": DummyResponse(
|
||||
"""
|
||||
@@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get):
|
||||
"http://example.com/about": DummyResponse("<html><body>About</body></html>"),
|
||||
}
|
||||
|
||||
mock_requests_get.side_effect = lambda url: responses[url]
|
||||
mock_requests_get.side_effect = lambda url, timeout=30: responses[url]
|
||||
|
||||
root_doc = MagicMock(spec=LCDocument)
|
||||
root_doc.page_content = "Root content"
|
||||
@@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get):
|
||||
assert crawler.loader.call_count == 1
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.logging")
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging):
|
||||
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url):
|
||||
mock_requests_get.return_value = DummyResponse("<html><body>Error route</body></html>")
|
||||
|
||||
failing_loader_instance = MagicMock()
|
||||
@@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
|
||||
result = crawler.load_data("http://example.com")
|
||||
|
||||
assert result == []
|
||||
mock_requests_get.assert_called_once_with("http://example.com")
|
||||
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
|
||||
failing_loader_instance.load.assert_called_once()
|
||||
|
||||
mock_logging.error.assert_called_once()
|
||||
@@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
|
||||
assert "Error processing URL http://example.com" in message
|
||||
assert mock_logging.error.call_args.kwargs.get("exc_info") is True
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url")
|
||||
def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url):
|
||||
"""Test that SSRF validation failure returns empty list."""
|
||||
from application.core.url_validation import SSRFError
|
||||
mock_validate_url.side_effect = SSRFError("Access to private IP not allowed")
|
||||
|
||||
crawler = CrawlerLoader()
|
||||
result = crawler.load_data("http://192.168.1.1")
|
||||
|
||||
assert result == []
|
||||
mock_validate_url.assert_called_once()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(domain=domain, suffix=suffix)
|
||||
|
||||
|
||||
def _mock_validate_url(url):
|
||||
"""Mock validate_url that allows test URLs through."""
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_validate_url(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.parser.remote.crawler_markdown.validate_url",
|
||||
_mock_validate_url,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_tldextract(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
@@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify):
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
|
||||
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url):
|
||||
root_html = """
|
||||
<html><head><title>Home</title></head>
|
||||
<body><a href="/about">About</a></body>
|
||||
@@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
|
||||
assert docs[0].text == "Home Markdown"
|
||||
assert mock_print.called
|
||||
|
||||
|
||||
def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch):
|
||||
"""Test that SSRF validation failure returns empty list."""
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
def raise_ssrf_error(url):
|
||||
raise SSRFError("Access to private IP not allowed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.parser.remote.crawler_markdown.validate_url",
|
||||
raise_ssrf_error,
|
||||
)
|
||||
|
||||
loader = CrawlerLoader()
|
||||
result = loader.load_data("http://192.168.1.1")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
293
tests/test_zip_extraction_security.py
Normal file
293
tests/test_zip_extraction_security.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for zip extraction security measures."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import pytest
|
||||
|
||||
from application.worker import (
|
||||
ZipExtractionError,
|
||||
_is_path_safe,
|
||||
_validate_zip_safety,
|
||||
extract_zip_recursive,
|
||||
MAX_FILE_COUNT,
|
||||
)
|
||||
|
||||
|
||||
class TestIsPathSafe:
|
||||
"""Tests for _is_path_safe function."""
|
||||
|
||||
def test_safe_path_in_directory(self):
|
||||
"""Normal file within directory should be safe."""
|
||||
assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True
|
||||
|
||||
def test_safe_path_in_subdirectory(self):
|
||||
"""File in subdirectory should be safe."""
|
||||
assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True
|
||||
|
||||
def test_unsafe_path_parent_traversal(self):
|
||||
"""Path traversal to parent directory should be unsafe."""
|
||||
assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False
|
||||
|
||||
def test_unsafe_path_absolute(self):
|
||||
"""Absolute path outside base should be unsafe."""
|
||||
assert _is_path_safe("/tmp/extract", "/etc/passwd") is False
|
||||
|
||||
def test_unsafe_path_sibling(self):
|
||||
"""Sibling directory should be unsafe."""
|
||||
assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False
|
||||
|
||||
def test_base_path_itself(self):
|
||||
"""Base path itself should be safe."""
|
||||
assert _is_path_safe("/tmp/extract", "/tmp/extract") is True
|
||||
|
||||
|
||||
class TestValidateZipSafety:
|
||||
"""Tests for _validate_zip_safety function."""
|
||||
|
||||
def test_valid_small_zip(self):
|
||||
"""Small valid zip file should pass validation."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a small valid zip
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "Hello, World!")
|
||||
|
||||
# Should not raise
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
|
||||
def test_zip_with_too_many_files(self):
|
||||
"""Zip with too many files should be rejected."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a zip with many files (just over limit)
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for i in range(MAX_FILE_COUNT + 1):
|
||||
zf.writestr(f"file_{i}.txt", "x")
|
||||
|
||||
with pytest.raises(ZipExtractionError) as exc_info:
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
assert "too many files" in str(exc_info.value).lower()
|
||||
|
||||
def test_zip_with_path_traversal(self):
|
||||
"""Zip with path traversal attempt should be rejected."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a zip with path traversal
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
# Add a normal file first
|
||||
zf.writestr("normal.txt", "normal content")
|
||||
# Add a file with path traversal
|
||||
zf.writestr("../../../etc/passwd", "malicious content")
|
||||
|
||||
with pytest.raises(ZipExtractionError) as exc_info:
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
assert "path traversal" in str(exc_info.value).lower()
|
||||
|
||||
def test_corrupted_zip(self):
|
||||
"""Corrupted zip file should be rejected."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a corrupted "zip" file
|
||||
with open(zip_path, "wb") as f:
|
||||
f.write(b"not a zip file content")
|
||||
|
||||
with pytest.raises(ZipExtractionError) as exc_info:
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestExtractZipRecursive:
|
||||
"""Tests for extract_zip_recursive function."""
|
||||
|
||||
def test_extract_valid_zip(self):
|
||||
"""Valid zip file should be extracted successfully."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a valid zip
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "Hello, World!")
|
||||
zf.writestr("subdir/nested.txt", "Nested content")
|
||||
|
||||
extract_zip_recursive(zip_path, extract_to)
|
||||
|
||||
# Check files were extracted
|
||||
assert os.path.exists(os.path.join(extract_to, "test.txt"))
|
||||
assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt"))
|
||||
|
||||
# Check zip was removed
|
||||
assert not os.path.exists(zip_path)
|
||||
|
||||
def test_extract_nested_zip(self):
|
||||
"""Nested zip files should be extracted recursively."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create inner zip
|
||||
inner_zip_content = b""
|
||||
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp:
|
||||
with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf:
|
||||
inner_zf.writestr("inner.txt", "Inner content")
|
||||
with open(inner_tmp.name, "rb") as f:
|
||||
inner_zip_content = f.read()
|
||||
os.unlink(inner_tmp.name)
|
||||
|
||||
# Create outer zip containing inner zip
|
||||
zip_path = os.path.join(temp_dir, "outer.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("outer.txt", "Outer content")
|
||||
zf.writestr("inner.zip", inner_zip_content)
|
||||
|
||||
extract_zip_recursive(zip_path, extract_to)
|
||||
|
||||
# Check outer file was extracted
|
||||
assert os.path.exists(os.path.join(extract_to, "outer.txt"))
|
||||
|
||||
# Check inner zip was extracted
|
||||
assert os.path.exists(os.path.join(extract_to, "inner.txt"))
|
||||
|
||||
# Check both zips were removed
|
||||
assert not os.path.exists(zip_path)
|
||||
assert not os.path.exists(os.path.join(extract_to, "inner.zip"))
|
||||
|
||||
def test_respects_max_depth(self):
|
||||
"""Extraction should stop at max recursion depth."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a chain of nested zips
|
||||
current_content = b"Final content"
|
||||
for i in range(7): # More than default max_depth of 5
|
||||
inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
|
||||
with zipfile.ZipFile(inner_tmp.name, "w") as zf:
|
||||
if i == 0:
|
||||
zf.writestr("content.txt", current_content.decode())
|
||||
else:
|
||||
zf.writestr("nested.zip", current_content)
|
||||
with open(inner_tmp.name, "rb") as f:
|
||||
current_content = f.read()
|
||||
os.unlink(inner_tmp.name)
|
||||
|
||||
# Write the final outermost zip
|
||||
zip_path = os.path.join(temp_dir, "outer.zip")
|
||||
with open(zip_path, "wb") as f:
|
||||
f.write(current_content)
|
||||
|
||||
# Extract with max_depth=2
|
||||
extract_zip_recursive(zip_path, extract_to, max_depth=2)
|
||||
|
||||
# The deepest nested zips should remain unextracted
|
||||
# (we can't easily verify the exact behavior, but the function should not crash)
|
||||
|
||||
def test_rejects_path_traversal(self):
|
||||
"""Zip with path traversal should be rejected and removed."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "malicious.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a malicious zip
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("../../../tmp/malicious.txt", "malicious")
|
||||
|
||||
extract_zip_recursive(zip_path, extract_to)
|
||||
|
||||
# Zip should be removed
|
||||
assert not os.path.exists(zip_path)
|
||||
|
||||
# Malicious file should NOT exist outside extract_to
|
||||
assert not os.path.exists("/tmp/malicious.txt")
|
||||
|
||||
def test_handles_corrupted_zip_gracefully(self):
|
||||
"""Corrupted zip should be handled gracefully without crashing."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "corrupted.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a corrupted file
|
||||
with open(zip_path, "wb") as f:
|
||||
f.write(b"This is not a valid zip file")
|
||||
|
||||
# Should not raise, just log error
|
||||
extract_zip_recursive(zip_path, extract_to)
|
||||
|
||||
# Function should complete without exception
|
||||
|
||||
|
||||
class TestZipBombProtection:
|
||||
"""Tests specifically for zip bomb protection."""
|
||||
|
||||
def test_detects_high_compression_ratio(self):
|
||||
"""Highly compressed data should trigger compression ratio check."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "bomb.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a file with highly compressible content (all zeros)
|
||||
# This triggers the compression ratio check
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# Create a large file with repetitive content - compresses extremely well
|
||||
repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's
|
||||
zf.writestr("repetitive.txt", repetitive_content)
|
||||
|
||||
# This should be rejected due to high compression ratio
|
||||
with pytest.raises(ZipExtractionError) as exc_info:
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
assert "compression ratio" in str(exc_info.value).lower()
|
||||
|
||||
def test_normal_compression_passes(self):
|
||||
"""Normal compression ratio should pass validation."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "normal.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a zip with random-ish content that doesn't compress well
|
||||
import random
|
||||
random.seed(42)
|
||||
random_content = "".join(
|
||||
random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240)
|
||||
)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("random.txt", random_content)
|
||||
|
||||
# Should pass - random content doesn't compress well
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
|
||||
def test_size_limit_check(self):
|
||||
"""Files exceeding size limit should be rejected."""
|
||||
# Note: We can't easily create a real zip bomb in tests
|
||||
# This test verifies the validation logic works
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, "test.zip")
|
||||
extract_to = os.path.join(temp_dir, "extract")
|
||||
os.makedirs(extract_to)
|
||||
|
||||
# Create a zip with a reasonable size (no compression to avoid ratio issues)
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
|
||||
# 10 KB file
|
||||
zf.writestr("normal.txt", "x" * 10240)
|
||||
|
||||
# Should pass
|
||||
_validate_zip_safety(zip_path, extract_to)
|
||||
Reference in New Issue
Block a user