mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
* feat: implement URL validation to prevent SSRF * feat: add zip extraction security * ruff fixes
294 lines
12 KiB
Python
294 lines
12 KiB
Python
"""Tests for zip extraction security measures."""
|
|
|
|
import os
|
|
import tempfile
|
|
import zipfile
|
|
|
|
import pytest
|
|
|
|
from application.worker import (
|
|
ZipExtractionError,
|
|
_is_path_safe,
|
|
_validate_zip_safety,
|
|
extract_zip_recursive,
|
|
MAX_FILE_COUNT,
|
|
)
|
|
|
|
|
|
class TestIsPathSafe:
|
|
"""Tests for _is_path_safe function."""
|
|
|
|
def test_safe_path_in_directory(self):
|
|
"""Normal file within directory should be safe."""
|
|
assert _is_path_safe("/tmp/extract", "/tmp/extract/file.txt") is True
|
|
|
|
def test_safe_path_in_subdirectory(self):
|
|
"""File in subdirectory should be safe."""
|
|
assert _is_path_safe("/tmp/extract", "/tmp/extract/subdir/file.txt") is True
|
|
|
|
def test_unsafe_path_parent_traversal(self):
|
|
"""Path traversal to parent directory should be unsafe."""
|
|
assert _is_path_safe("/tmp/extract", "/tmp/extract/../etc/passwd") is False
|
|
|
|
def test_unsafe_path_absolute(self):
|
|
"""Absolute path outside base should be unsafe."""
|
|
assert _is_path_safe("/tmp/extract", "/etc/passwd") is False
|
|
|
|
def test_unsafe_path_sibling(self):
|
|
"""Sibling directory should be unsafe."""
|
|
assert _is_path_safe("/tmp/extract", "/tmp/other/file.txt") is False
|
|
|
|
def test_base_path_itself(self):
|
|
"""Base path itself should be safe."""
|
|
assert _is_path_safe("/tmp/extract", "/tmp/extract") is True
|
|
|
|
|
|
class TestValidateZipSafety:
|
|
"""Tests for _validate_zip_safety function."""
|
|
|
|
def test_valid_small_zip(self):
|
|
"""Small valid zip file should pass validation."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a small valid zip
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
zf.writestr("test.txt", "Hello, World!")
|
|
|
|
# Should not raise
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
|
|
def test_zip_with_too_many_files(self):
|
|
"""Zip with too many files should be rejected."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a zip with many files (just over limit)
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
for i in range(MAX_FILE_COUNT + 1):
|
|
zf.writestr(f"file_{i}.txt", "x")
|
|
|
|
with pytest.raises(ZipExtractionError) as exc_info:
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
assert "too many files" in str(exc_info.value).lower()
|
|
|
|
def test_zip_with_path_traversal(self):
|
|
"""Zip with path traversal attempt should be rejected."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a zip with path traversal
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
# Add a normal file first
|
|
zf.writestr("normal.txt", "normal content")
|
|
# Add a file with path traversal
|
|
zf.writestr("../../../etc/passwd", "malicious content")
|
|
|
|
with pytest.raises(ZipExtractionError) as exc_info:
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
assert "path traversal" in str(exc_info.value).lower()
|
|
|
|
def test_corrupted_zip(self):
|
|
"""Corrupted zip file should be rejected."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a corrupted "zip" file
|
|
with open(zip_path, "wb") as f:
|
|
f.write(b"not a zip file content")
|
|
|
|
with pytest.raises(ZipExtractionError) as exc_info:
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
assert "invalid" in str(exc_info.value).lower() or "corrupted" in str(exc_info.value).lower()
|
|
|
|
|
|
class TestExtractZipRecursive:
|
|
"""Tests for extract_zip_recursive function."""
|
|
|
|
def test_extract_valid_zip(self):
|
|
"""Valid zip file should be extracted successfully."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a valid zip
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
zf.writestr("test.txt", "Hello, World!")
|
|
zf.writestr("subdir/nested.txt", "Nested content")
|
|
|
|
extract_zip_recursive(zip_path, extract_to)
|
|
|
|
# Check files were extracted
|
|
assert os.path.exists(os.path.join(extract_to, "test.txt"))
|
|
assert os.path.exists(os.path.join(extract_to, "subdir", "nested.txt"))
|
|
|
|
# Check zip was removed
|
|
assert not os.path.exists(zip_path)
|
|
|
|
def test_extract_nested_zip(self):
|
|
"""Nested zip files should be extracted recursively."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create inner zip
|
|
inner_zip_content = b""
|
|
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as inner_tmp:
|
|
with zipfile.ZipFile(inner_tmp.name, "w") as inner_zf:
|
|
inner_zf.writestr("inner.txt", "Inner content")
|
|
with open(inner_tmp.name, "rb") as f:
|
|
inner_zip_content = f.read()
|
|
os.unlink(inner_tmp.name)
|
|
|
|
# Create outer zip containing inner zip
|
|
zip_path = os.path.join(temp_dir, "outer.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
zf.writestr("outer.txt", "Outer content")
|
|
zf.writestr("inner.zip", inner_zip_content)
|
|
|
|
extract_zip_recursive(zip_path, extract_to)
|
|
|
|
# Check outer file was extracted
|
|
assert os.path.exists(os.path.join(extract_to, "outer.txt"))
|
|
|
|
# Check inner zip was extracted
|
|
assert os.path.exists(os.path.join(extract_to, "inner.txt"))
|
|
|
|
# Check both zips were removed
|
|
assert not os.path.exists(zip_path)
|
|
assert not os.path.exists(os.path.join(extract_to, "inner.zip"))
|
|
|
|
def test_respects_max_depth(self):
|
|
"""Extraction should stop at max recursion depth."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a chain of nested zips
|
|
current_content = b"Final content"
|
|
for i in range(7): # More than default max_depth of 5
|
|
inner_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
|
|
with zipfile.ZipFile(inner_tmp.name, "w") as zf:
|
|
if i == 0:
|
|
zf.writestr("content.txt", current_content.decode())
|
|
else:
|
|
zf.writestr("nested.zip", current_content)
|
|
with open(inner_tmp.name, "rb") as f:
|
|
current_content = f.read()
|
|
os.unlink(inner_tmp.name)
|
|
|
|
# Write the final outermost zip
|
|
zip_path = os.path.join(temp_dir, "outer.zip")
|
|
with open(zip_path, "wb") as f:
|
|
f.write(current_content)
|
|
|
|
# Extract with max_depth=2
|
|
extract_zip_recursive(zip_path, extract_to, max_depth=2)
|
|
|
|
# The deepest nested zips should remain unextracted
|
|
# (we can't easily verify the exact behavior, but the function should not crash)
|
|
|
|
def test_rejects_path_traversal(self):
|
|
"""Zip with path traversal should be rejected and removed."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "malicious.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a malicious zip
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|
zf.writestr("../../../tmp/malicious.txt", "malicious")
|
|
|
|
extract_zip_recursive(zip_path, extract_to)
|
|
|
|
# Zip should be removed
|
|
assert not os.path.exists(zip_path)
|
|
|
|
# Malicious file should NOT exist outside extract_to
|
|
assert not os.path.exists("/tmp/malicious.txt")
|
|
|
|
def test_handles_corrupted_zip_gracefully(self):
|
|
"""Corrupted zip should be handled gracefully without crashing."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "corrupted.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a corrupted file
|
|
with open(zip_path, "wb") as f:
|
|
f.write(b"This is not a valid zip file")
|
|
|
|
# Should not raise, just log error
|
|
extract_zip_recursive(zip_path, extract_to)
|
|
|
|
# Function should complete without exception
|
|
|
|
|
|
class TestZipBombProtection:
|
|
"""Tests specifically for zip bomb protection."""
|
|
|
|
def test_detects_high_compression_ratio(self):
|
|
"""Highly compressed data should trigger compression ratio check."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "bomb.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a file with highly compressible content (all zeros)
|
|
# This triggers the compression ratio check
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
# Create a large file with repetitive content - compresses extremely well
|
|
repetitive_content = "A" * (1024 * 1024) # 1 MB of 'A's
|
|
zf.writestr("repetitive.txt", repetitive_content)
|
|
|
|
# This should be rejected due to high compression ratio
|
|
with pytest.raises(ZipExtractionError) as exc_info:
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
assert "compression ratio" in str(exc_info.value).lower()
|
|
|
|
def test_normal_compression_passes(self):
|
|
"""Normal compression ratio should pass validation."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "normal.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a zip with random-ish content that doesn't compress well
|
|
import random
|
|
random.seed(42)
|
|
random_content = "".join(
|
|
random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=10240)
|
|
)
|
|
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
zf.writestr("random.txt", random_content)
|
|
|
|
# Should pass - random content doesn't compress well
|
|
_validate_zip_safety(zip_path, extract_to)
|
|
|
|
def test_size_limit_check(self):
|
|
"""Files exceeding size limit should be rejected."""
|
|
# Note: We can't easily create a real zip bomb in tests
|
|
# This test verifies the validation logic works
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
zip_path = os.path.join(temp_dir, "test.zip")
|
|
extract_to = os.path.join(temp_dir, "extract")
|
|
os.makedirs(extract_to)
|
|
|
|
# Create a zip with a reasonable size (no compression to avoid ratio issues)
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
|
|
# 10 KB file
|
|
zf.writestr("normal.txt", "x" * 10240)
|
|
|
|
# Should pass
|
|
_validate_zip_safety(zip_path, extract_to)
|