* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes
This commit is contained in:
Alex
2025-12-24 15:05:35 +00:00
committed by GitHub
parent 83e7a928f1
commit 98e949d2fd
11 changed files with 929 additions and 27 deletions

View File

@@ -0,0 +1,197 @@
"""Tests for SSRF URL validation module."""
import pytest
from unittest.mock import patch
from application.core.url_validation import (
SSRFError,
validate_url,
validate_url_safe,
is_private_ip,
is_metadata_ip,
)
class TestIsPrivateIP:
"""Tests for is_private_ip function."""
def test_loopback_ipv4(self):
assert is_private_ip("127.0.0.1") is True
assert is_private_ip("127.255.255.255") is True
def test_private_class_a(self):
assert is_private_ip("10.0.0.1") is True
assert is_private_ip("10.255.255.255") is True
def test_private_class_b(self):
assert is_private_ip("172.16.0.1") is True
assert is_private_ip("172.31.255.255") is True
def test_private_class_c(self):
assert is_private_ip("192.168.0.1") is True
assert is_private_ip("192.168.255.255") is True
def test_link_local(self):
assert is_private_ip("169.254.0.1") is True
def test_public_ip(self):
assert is_private_ip("8.8.8.8") is False
assert is_private_ip("1.1.1.1") is False
assert is_private_ip("93.184.216.34") is False
def test_invalid_ip(self):
assert is_private_ip("not-an-ip") is False
assert is_private_ip("") is False
class TestIsMetadataIP:
"""Tests for is_metadata_ip function."""
def test_aws_metadata_ip(self):
assert is_metadata_ip("169.254.169.254") is True
def test_aws_ecs_metadata_ip(self):
assert is_metadata_ip("169.254.170.2") is True
def test_non_metadata_ip(self):
assert is_metadata_ip("8.8.8.8") is False
assert is_metadata_ip("10.0.0.1") is False
class TestValidateUrl:
"""Tests for validate_url function."""
def test_adds_scheme_if_missing(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34" # Public IP
result = validate_url("example.com")
assert result == "http://example.com"
def test_preserves_https_scheme(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
result = validate_url("https://example.com")
assert result == "https://example.com"
def test_blocks_localhost(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://localhost")
assert "localhost" in str(exc_info.value).lower()
def test_blocks_localhost_localdomain(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://localhost.localdomain")
assert "not allowed" in str(exc_info.value).lower()
def test_blocks_loopback_ip(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://127.0.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_a(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://10.0.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_b(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://172.16.0.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_private_ip_class_c(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://192.168.1.1")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_aws_metadata_ip(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://169.254.169.254")
assert "metadata" in str(exc_info.value).lower()
def test_blocks_aws_metadata_with_path(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://169.254.169.254/latest/meta-data/")
assert "metadata" in str(exc_info.value).lower()
def test_blocks_gcp_metadata_hostname(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://metadata.google.internal")
assert "not allowed" in str(exc_info.value).lower()
def test_blocks_ftp_scheme(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("ftp://example.com")
assert "scheme" in str(exc_info.value).lower()
def test_blocks_file_scheme(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("file:///etc/passwd")
assert "scheme" in str(exc_info.value).lower()
def test_blocks_hostname_resolving_to_private_ip(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "192.168.1.1"
with pytest.raises(SSRFError) as exc_info:
validate_url("http://internal.example.com")
assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower()
def test_blocks_hostname_resolving_to_metadata_ip(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "169.254.169.254"
with pytest.raises(SSRFError) as exc_info:
validate_url("http://evil.example.com")
assert "metadata" in str(exc_info.value).lower()
def test_allows_public_ip(self):
result = validate_url("http://8.8.8.8")
assert result == "http://8.8.8.8"
def test_allows_public_hostname(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
result = validate_url("https://example.com")
assert result == "https://example.com"
def test_raises_on_unresolvable_hostname(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = None
with pytest.raises(SSRFError) as exc_info:
validate_url("http://nonexistent.invalid")
assert "resolve" in str(exc_info.value).lower()
def test_raises_on_empty_hostname(self):
with pytest.raises(SSRFError) as exc_info:
validate_url("http://")
assert "hostname" in str(exc_info.value).lower()
def test_allow_localhost_flag(self):
# Should work with allow_localhost=True
result = validate_url("http://localhost", allow_localhost=True)
assert result == "http://localhost"
result = validate_url("http://127.0.0.1", allow_localhost=True)
assert result == "http://127.0.0.1"
class TestValidateUrlSafe:
"""Tests for validate_url_safe non-throwing function."""
def test_returns_tuple_on_success(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
is_valid, url, error = validate_url_safe("https://example.com")
assert is_valid is True
assert url == "https://example.com"
assert error is None
def test_returns_tuple_on_failure(self):
is_valid, url, error = validate_url_safe("http://localhost")
assert is_valid is False
assert url == "http://localhost"
assert error is not None
assert "localhost" in error.lower()
def test_returns_error_message_for_private_ip(self):
is_valid, url, error = validate_url_safe("http://192.168.1.1")
assert is_valid is False
assert "private" in error.lower() or "internal" in error.lower()

View File

@@ -13,8 +13,17 @@ class DummyResponse:
return None
def _mock_validate_url(url):
"""Mock validate_url that allows test URLs through."""
from urllib.parse import urlparse
if not urlparse(url).scheme:
url = "http://" + url
return url
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_crawls_same_domain_links(mock_requests_get):
def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
"http://example.com/about": DummyResponse("<html><body>About page</body></html>"),
}
def response_side_effect(url: str):
def response_side_effect(url: str, timeout=30):
if url not in responses:
raise AssertionError(f"Unexpected request for URL: {url}")
return responses[url]
@@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
assert loader_call_order == ["http://example.com", "http://example.com/about"]
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url):
mock_requests_get.return_value = DummyResponse("<html><body>No links here</body></html>")
doc = MagicMock(spec=LCDocument)
@@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
result = crawler.load_data(["example.com", "unused.com"])
mock_requests_get.assert_called_once_with("http://example.com")
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
crawler.loader.assert_called_once_with(["http://example.com"])
assert len(result) == 1
@@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
assert result[0].extra_info == {"source": "http://example.com"}
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_respects_limit(mock_requests_get):
def test_load_data_respects_limit(mock_requests_get, mock_validate_url):
responses = {
"http://example.com": DummyResponse(
"""
@@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get):
"http://example.com/about": DummyResponse("<html><body>About</body></html>"),
}
mock_requests_get.side_effect = lambda url: responses[url]
mock_requests_get.side_effect = lambda url, timeout=30: responses[url]
root_doc = MagicMock(spec=LCDocument)
root_doc.page_content = "Root content"
@@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get):
assert crawler.loader.call_count == 1
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
@patch("application.parser.remote.crawler_loader.logging")
@patch("application.parser.remote.crawler_loader.requests.get")
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging):
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url):
mock_requests_get.return_value = DummyResponse("<html><body>Error route</body></html>")
failing_loader_instance = MagicMock()
@@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
result = crawler.load_data("http://example.com")
assert result == []
mock_requests_get.assert_called_once_with("http://example.com")
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
failing_loader_instance.load.assert_called_once()
mock_logging.error.assert_called_once()
@@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
assert "Error processing URL http://example.com" in message
assert mock_logging.error.call_args.kwargs.get("exc_info") is True
@patch("application.parser.remote.crawler_loader.validate_url")
def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url):
"""Test that SSRF validation failure returns empty list."""
from application.core.url_validation import SSRFError
mock_validate_url.side_effect = SSRFError("Access to private IP not allowed")
crawler = CrawlerLoader()
result = crawler.load_data("http://192.168.1.1")
assert result == []
mock_validate_url.assert_called_once()

View File

@@ -1,5 +1,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from urllib.parse import urlparse
import pytest
import requests
@@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace:
return SimpleNamespace(domain=domain, suffix=suffix)
def _mock_validate_url(url):
"""Mock validate_url that allows test URLs through."""
if not urlparse(url).scheme:
url = "http://" + url
return url
@pytest.fixture(autouse=True)
def _patch_validate_url(monkeypatch):
monkeypatch.setattr(
"application.parser.remote.crawler_markdown.validate_url",
_mock_validate_url,
)
@pytest.fixture(autouse=True)
def _patch_tldextract(monkeypatch):
monkeypatch.setattr(
@@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify):
assert len(docs) == 2
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url):
root_html = """
<html><head><title>Home</title></head>
<body><a href="/about">About</a></body>
@@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
assert docs[0].text == "Home Markdown"
assert mock_print.called
def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch):
"""Test that SSRF validation failure returns empty list."""
from application.core.url_validation import SSRFError
def raise_ssrf_error(url):
raise SSRFError("Access to private IP not allowed")
monkeypatch.setattr(
"application.parser.remote.crawler_markdown.validate_url",
raise_ssrf_error,
)
loader = CrawlerLoader()
result = loader.load_data("http://192.168.1.1")
assert result == []

View File

@@ -0,0 +1,293 @@
"""Tests for zip extraction security measures."""
import os
import tempfile
import zipfile
import pytest
from application.worker import (
ZipExtractionError,
_is_path_safe,
_validate_zip_safety,
extract_zip_recursive,
MAX_FILE_COUNT,
)
class TestIsPathSafe:
"""Tests for _is_path_safe function."""
def test_safe_path_in_directory(self):
"""Normal file within directory should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True
def test_safe_path_in_subdirectory(self):
"""File in subdirectory should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True
def test_unsafe_path_parent_traversal(self):
"""Path traversal to parent directory should be unsafe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False
def test_unsafe_path_absolute(self):
"""Absolute path outside base should be unsafe."""
assert _is_path_safe("/tmp/extract", "/etc/passwd") is False
def test_unsafe_path_sibling(self):
"""Sibling directory should be unsafe."""
assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False
def test_base_path_itself(self):
"""Base path itself should be safe."""
assert _is_path_safe("/tmp/extract", "/tmp/extract") is True
class TestValidateZipSafety:
"""Tests for _validate_zip_safety function."""
def test_valid_small_zip(self):
"""Small valid zip file should pass validation."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a small valid zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "Hello, World!")
# Should not raise
_validate_zip_safety(zip_path, extract_to)
def test_zip_with_too_many_files(self):
"""Zip with too many files should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with many files (just over limit)
with zipfile.ZipFile(zip_path, "w") as zf:
for i in range(MAX_FILE_COUNT + 1):
zf.writestr(f"file_{i}.txt", "x")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "too many files" in str(exc_info.value).lower()
def test_zip_with_path_traversal(self):
"""Zip with path traversal attempt should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with path traversal
with zipfile.ZipFile(zip_path, "w") as zf:
# Add a normal file first
zf.writestr("normal.txt", "normal content")
# Add a file with path traversal
zf.writestr("../../../etc/passwd", "malicious content")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "path traversal" in str(exc_info.value).lower()
def test_corrupted_zip(self):
"""Corrupted zip file should be rejected."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a corrupted "zip" file
with open(zip_path, "wb") as f:
f.write(b"not a zip file content")
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower()
class TestExtractZipRecursive:
"""Tests for extract_zip_recursive function."""
def test_extract_valid_zip(self):
"""Valid zip file should be extracted successfully."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a valid zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "Hello, World!")
zf.writestr("subdir/nested.txt", "Nested content")
extract_zip_recursive(zip_path, extract_to)
# Check files were extracted
assert os.path.exists(os.path.join(extract_to, "test.txt"))
assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt"))
# Check zip was removed
assert not os.path.exists(zip_path)
def test_extract_nested_zip(self):
"""Nested zip files should be extracted recursively."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create inner zip
inner_zip_content = b""
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp:
with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf:
inner_zf.writestr("inner.txt", "Inner content")
with open(inner_tmp.name, "rb") as f:
inner_zip_content = f.read()
os.unlink(inner_tmp.name)
# Create outer zip containing inner zip
zip_path = os.path.join(temp_dir, "outer.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("outer.txt", "Outer content")
zf.writestr("inner.zip", inner_zip_content)
extract_zip_recursive(zip_path, extract_to)
# Check outer file was extracted
assert os.path.exists(os.path.join(extract_to, "outer.txt"))
# Check inner zip was extracted
assert os.path.exists(os.path.join(extract_to, "inner.txt"))
# Check both zips were removed
assert not os.path.exists(zip_path)
assert not os.path.exists(os.path.join(extract_to, "inner.zip"))
def test_respects_max_depth(self):
"""Extraction should stop at max recursion depth."""
with tempfile.TemporaryDirectory() as temp_dir:
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a chain of nested zips
current_content = b"Final content"
for i in range(7): # More than default max_depth of 5
inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
with zipfile.ZipFile(inner_tmp.name, "w") as zf:
if i == 0:
zf.writestr("content.txt", current_content.decode())
else:
zf.writestr("nested.zip", current_content)
with open(inner_tmp.name, "rb") as f:
current_content = f.read()
os.unlink(inner_tmp.name)
# Write the final outermost zip
zip_path = os.path.join(temp_dir, "outer.zip")
with open(zip_path, "wb") as f:
f.write(current_content)
# Extract with max_depth=2
extract_zip_recursive(zip_path, extract_to, max_depth=2)
# The deepest nested zips should remain unextracted
# (we can't easily verify the exact behavior, but the function should not crash)
def test_rejects_path_traversal(self):
"""Zip with path traversal should be rejected and removed."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "malicious.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a malicious zip
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("../../../tmp/malicious.txt", "malicious")
extract_zip_recursive(zip_path, extract_to)
# Zip should be removed
assert not os.path.exists(zip_path)
# Malicious file should NOT exist outside extract_to
assert not os.path.exists("/tmp/malicious.txt")
def test_handles_corrupted_zip_gracefully(self):
"""Corrupted zip should be handled gracefully without crashing."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "corrupted.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a corrupted file
with open(zip_path, "wb") as f:
f.write(b"This is not a valid zip file")
# Should not raise, just log error
extract_zip_recursive(zip_path, extract_to)
# Function should complete without exception
class TestZipBombProtection:
"""Tests specifically for zip bomb protection."""
def test_detects_high_compression_ratio(self):
"""Highly compressed data should trigger compression ratio check."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "bomb.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a file with highly compressible content (all zeros)
# This triggers the compression ratio check
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# Create a large file with repetitive content - compresses extremely well
repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's
zf.writestr("repetitive.txt", repetitive_content)
# This should be rejected due to high compression ratio
with pytest.raises(ZipExtractionError) as exc_info:
_validate_zip_safety(zip_path, extract_to)
assert "compression ratio" in str(exc_info.value).lower()
def test_normal_compression_passes(self):
"""Normal compression ratio should pass validation."""
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "normal.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with random-ish content that doesn't compress well
import random
random.seed(42)
random_content = "".join(
random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240)
)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("random.txt", random_content)
# Should pass - random content doesn't compress well
_validate_zip_safety(zip_path, extract_to)
def test_size_limit_check(self):
"""Files exceeding size limit should be rejected."""
# Note: We can't easily create a real zip bomb in tests
# This test verifies the validation logic works
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, "test.zip")
extract_to = os.path.join(temp_dir, "extract")
os.makedirs(extract_to)
# Create a zip with a reasonable size (no compression to avoid ratio issues)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
# 10 KB file
zf.writestr("normal.txt", "x" * 10240)
# Should pass
_validate_zip_safety(zip_path, extract_to)