diff --git a/application/agents/tools/api_tool.py b/application/agents/tools/api_tool.py index 6bd2eb8d..e010b51b 100644 --- a/application/agents/tools/api_tool.py +++ b/application/agents/tools/api_tool.py @@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import ( RequestBodySerializer, ) from application.agents.tools.base import Tool +from application.core.url_validation import validate_url, SSRFError logger = logging.getLogger(__name__) @@ -73,6 +74,17 @@ class APITool(Tool): request_headers = headers.copy() if headers else {} response = None + # Validate URL to prevent SSRF attacks + try: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + try: path_params_used = set() if query_params: @@ -90,6 +102,18 @@ class APITool(Tool): query_string = urlencode(remaining_params) separator = "&" if "?" in request_url else "?" request_url = f"{request_url}{separator}{query_string}" + + # Re-validate URL after parameter substitution to prevent SSRF via path params + try: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed after parameter substitution: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + # Serialize body based on content type if body and body != {}: diff --git a/application/agents/tools/read_webpage.py b/application/agents/tools/read_webpage.py index e87c79e3..f0321a5a 100644 --- a/application/agents/tools/read_webpage.py +++ b/application/agents/tools/read_webpage.py @@ -1,7 +1,7 @@ import requests from markdownify import markdownify from application.agents.tools.base import Tool -from urllib.parse import urlparse +from application.core.url_validation import validate_url, SSRFError class ReadWebpageTool(Tool): """ @@ -31,11 +31,12 @@ class ReadWebpageTool(Tool): if not url: return "Error: URL parameter is missing." - # Ensure the URL has a scheme (if not, default to http) - parsed_url = urlparse(url) - if not parsed_url.scheme: - url = "http://" + url - + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + return f"Error: URL validation failed - {e}" + try: response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'}) response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) diff --git a/application/core/url_validation.py b/application/core/url_validation.py new file mode 100644 index 00000000..acd8a523 --- /dev/null +++ b/application/core/url_validation.py @@ -0,0 +1,181 @@ +""" +URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks. + +This module provides functions to validate URLs before making HTTP requests, +blocking access to internal networks, cloud metadata services, and other +potentially dangerous endpoints. +""" + +import ipaddress +import socket +from urllib.parse import urlparse +from typing import Optional, Set + + +class SSRFError(Exception): + """Raised when a URL fails SSRF validation.""" + pass + + +# Blocked hostnames that should never be accessed +BLOCKED_HOSTNAMES: Set[str] = { + "localhost", + "localhost.localdomain", + "metadata.google.internal", + "metadata", +} + +# Cloud metadata IP addresses (AWS, GCP, Azure, etc.) +METADATA_IPS: Set[str] = { + "169.254.169.254", # AWS, GCP, Azure metadata + "169.254.170.2", # AWS ECS task metadata + "fd00:ec2::254", # AWS IPv6 metadata +} + +# Allowed schemes for external requests +ALLOWED_SCHEMES: Set[str] = {"http", "https"} + + +def is_private_ip(ip_str: str) -> bool: + """ + Check if an IP address is private, loopback, or link-local. + + Args: + ip_str: IP address as a string + + Returns: + True if the IP is private/internal, False otherwise + """ + try: + ip = ipaddress.ip_address(ip_str) + return ( + ip.is_private or + ip.is_loopback or + ip.is_link_local or + ip.is_reserved or + ip.is_multicast or + ip.is_unspecified + ) + except ValueError: + # If we can't parse it as an IP, return False + return False + + +def is_metadata_ip(ip_str: str) -> bool: + """ + Check if an IP address is a cloud metadata service IP. + + Args: + ip_str: IP address as a string + + Returns: + True if the IP is a metadata service, False otherwise + """ + return ip_str in METADATA_IPS + + +def resolve_hostname(hostname: str) -> Optional[str]: + """ + Resolve a hostname to an IP address. + + Args: + hostname: The hostname to resolve + + Returns: + The resolved IP address, or None if resolution fails + """ + try: + return socket.gethostbyname(hostname) + except socket.gaierror: + return None + + +def validate_url(url: str, allow_localhost: bool = False) -> str: + """ + Validate a URL to prevent SSRF attacks. + + This function checks that: + 1. The URL has an allowed scheme (http or https) + 2. The hostname is not a blocked hostname + 3. The resolved IP is not a private/internal IP + 4. The resolved IP is not a cloud metadata service + + Args: + url: The URL to validate + allow_localhost: If True, allow localhost connections (for testing only) + + Returns: + The validated URL (with scheme added if missing) + + Raises: + SSRFError: If the URL fails validation + """ + # Ensure URL has a scheme + if not urlparse(url).scheme: + url = "http://" + url + + parsed = urlparse(url) + + # Check scheme + if parsed.scheme not in ALLOWED_SCHEMES: + raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.") + + hostname = parsed.hostname + if not hostname: + raise SSRFError("URL must have a valid hostname.") + + hostname_lower = hostname.lower() + + # Check blocked hostnames + if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost: + raise SSRFError(f"Access to '{hostname}' is not allowed.") + + # Check if hostname is an IP address directly + try: + ip = ipaddress.ip_address(hostname) + ip_str = str(ip) + + if is_metadata_ip(ip_str): + raise SSRFError("Access to cloud metadata services is not allowed.") + + if is_private_ip(ip_str) and not allow_localhost: + raise SSRFError("Access to private/internal IP addresses is not allowed.") + + return url + except ValueError: + # Not an IP address, it's a hostname - resolve it + pass + + # Resolve hostname and check the IP + resolved_ip = resolve_hostname(hostname) + if resolved_ip is None: + raise SSRFError(f"Unable to resolve hostname: {hostname}") + + if is_metadata_ip(resolved_ip): + raise SSRFError("Access to cloud metadata services is not allowed.") + + if is_private_ip(resolved_ip) and not allow_localhost: + raise SSRFError("Access to private/internal networks is not allowed.") + + return url + + +def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]: + """ + Validate a URL and return a tuple with validation result. + + This is a non-throwing version of validate_url for cases where + you want to handle validation failures gracefully. + + Args: + url: The URL to validate + allow_localhost: If True, allow localhost connections (for testing only) + + Returns: + Tuple of (is_valid, validated_url_or_original, error_message_or_none) + """ + try: + validated = validate_url(url, allow_localhost) + return (True, validated, None) + except SSRFError as e: + return (False, url, str(e)) diff --git a/application/parser/remote/crawler_loader.py b/application/parser/remote/crawler_loader.py index 2ff6cf6f..1bfd2276 100644 --- a/application/parser/remote/crawler_loader.py +++ b/application/parser/remote/crawler_loader.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin from bs4 import BeautifulSoup from application.parser.remote.base import BaseRemote from application.parser.schema.base import Document +from application.core.url_validation import validate_url, SSRFError from langchain_community.document_loaders import WebBaseLoader class CrawlerLoader(BaseRemote): @@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote): if isinstance(url, list) and url: url = url[0] - # Check if the URL scheme is provided, if not, assume http - if not urlparse(url).scheme: - url = "http://" + url + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + logging.error(f"URL validation failed: {e}") + return [] visited_urls = set() base_url = urlparse(url).scheme + "://" + urlparse(url).hostname @@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote): visited_urls.add(current_url) try: - response = requests.get(current_url) + # Validate each URL before making requests + try: + validate_url(current_url) + except SSRFError as e: + logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}") + continue + + response = requests.get(current_url, timeout=30) response.raise_for_status() loader = self.loader([current_url]) docs = loader.load() diff --git a/application/parser/remote/crawler_markdown.py b/application/parser/remote/crawler_markdown.py index 3d199332..8fc4c92c 100644 --- a/application/parser/remote/crawler_markdown.py +++ b/application/parser/remote/crawler_markdown.py @@ -2,6 +2,7 @@ import requests from urllib.parse import urlparse, urljoin from bs4 import BeautifulSoup from application.parser.remote.base import BaseRemote +from application.core.url_validation import validate_url, SSRFError import re from markdownify import markdownify from application.parser.schema.base import Document @@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote): if isinstance(url, list) and url: url = url[0] - # Ensure the URL has a scheme (if not, default to http) - if not urlparse(url).scheme: - url = "http://" + url + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + print(f"URL validation failed: {e}") + return [] # Keep track of visited URLs to avoid revisiting the same page visited_urls = set() @@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote): def _fetch_page(self, url): try: + # Validate URL before fetching to prevent SSRF + validate_url(url) response = self.session.get(url, timeout=10) response.raise_for_status() return response.text + except SSRFError as e: + print(f"URL validation failed for {url}: {e}") + return None except requests.exceptions.RequestException as e: print(f"Error fetching URL {url}: {e}") return None diff --git a/application/parser/remote/sitemap_loader.py b/application/parser/remote/sitemap_loader.py index 6d54ea9b..ff7c1ede 100644 --- a/application/parser/remote/sitemap_loader.py +++ b/application/parser/remote/sitemap_loader.py @@ -3,6 +3,7 @@ import requests import re # Import regular expression library import xml.etree.ElementTree as ET from application.parser.remote.base import BaseRemote +from application.core.url_validation import validate_url, SSRFError class SitemapLoader(BaseRemote): def __init__(self, limit=20): @@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote): sitemap_url= inputs # Check if the input is a list and if it is, use the first element if isinstance(sitemap_url, list) and sitemap_url: - url = sitemap_url[0] + sitemap_url = sitemap_url[0] + + # Validate URL to prevent SSRF attacks + try: + sitemap_url = validate_url(sitemap_url) + except SSRFError as e: + logging.error(f"URL validation failed: {e}") + return [] urls = self._extract_urls(sitemap_url) if not urls: @@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote): def _extract_urls(self, sitemap_url): try: - response = requests.get(sitemap_url) + # Validate URL before fetching to prevent SSRF + validate_url(sitemap_url) + response = requests.get(sitemap_url, timeout=30) response.raise_for_status() # Raise an exception for HTTP errors + except SSRFError as e: + print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}") + return [] except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e: print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}") return [] diff --git a/application/worker.py b/application/worker.py index f45e94a5..fa2b6cd7 100755 --- a/application/worker.py +++ b/application/worker.py @@ -63,10 +63,111 @@ current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +# Zip extraction security limits +MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size +MAX_FILE_COUNT = 10000 # Maximum number of files to extract +MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs) + + +class ZipExtractionError(Exception): + """Raised when zip extraction fails due to security constraints.""" + pass + + +def _is_path_safe(base_path: str, target_path: str) -> bool: + """ + Check if target_path is safely within base_path (prevents zip slip attacks). + + Args: + base_path: The base directory where extraction should occur. + target_path: The full path where a file would be extracted. + + Returns: + True if the path is safe, False otherwise. + """ + # Resolve to absolute paths and check containment + base_resolved = os.path.realpath(base_path) + target_resolved = os.path.realpath(target_path) + return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved + + +def _validate_zip_safety(zip_path: str, extract_to: str) -> None: + """ + Validate a zip file for security issues before extraction. + + Checks for: + - Zip bombs (excessive compression ratio or uncompressed size) + - Too many files + - Path traversal attacks (zip slip) + + Args: + zip_path: Path to the zip file. + extract_to: Destination directory. + + Raises: + ZipExtractionError: If the zip file fails security validation. + """ + try: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + # Get compressed size + compressed_size = os.path.getsize(zip_path) + + # Calculate total uncompressed size and file count + total_uncompressed = 0 + file_count = 0 + + for info in zip_ref.infolist(): + file_count += 1 + + # Check file count limit + if file_count > MAX_FILE_COUNT: + raise ZipExtractionError( + f"Zip file contains too many files (>{MAX_FILE_COUNT}). " + "This may be a zip bomb attack." + ) + + # Accumulate uncompressed size + total_uncompressed += info.file_size + + # Check total uncompressed size + if total_uncompressed > MAX_UNCOMPRESSED_SIZE: + raise ZipExtractionError( + f"Zip file uncompressed size exceeds limit " + f"({total_uncompressed / (1024*1024):.1f} MB > " + f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). " + "This may be a zip bomb attack." + ) + + # Check for path traversal (zip slip) + target_path = os.path.join(extract_to, info.filename) + if not _is_path_safe(extract_to, target_path): + raise ZipExtractionError( + f"Zip file contains path traversal attempt: {info.filename}" + ) + + # Check compression ratio (only if compressed size is meaningful) + if compressed_size > 0 and total_uncompressed > 0: + compression_ratio = total_uncompressed / compressed_size + if compression_ratio > MAX_COMPRESSION_RATIO: + raise ZipExtractionError( + f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > " + f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack." + ) + + except zipfile.BadZipFile as e: + raise ZipExtractionError(f"Invalid or corrupted zip file: {e}") + def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): """ - Recursively extract zip files with a limit on recursion depth. + Recursively extract zip files with security protections. + + Security measures: + - Limits recursion depth to prevent infinite loops + - Validates uncompressed size to prevent zip bombs + - Limits number of files to prevent resource exhaustion + - Checks compression ratio to detect zip bombs + - Validates paths to prevent zip slip attacks Args: zip_path (str): Path to the zip file to be extracted. @@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): if current_depth > max_depth: logging.warning(f"Reached maximum recursion depth of {max_depth}") return + try: + # Validate zip file safety before extraction + _validate_zip_safety(zip_path, extract_to) + + # Safe to extract with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) os.remove(zip_path) # Remove the zip file after extracting + + except ZipExtractionError as e: + logging.error(f"Zip security validation failed for {zip_path}: {e}") + # Remove the potentially malicious zip file + try: + os.remove(zip_path) + except OSError: + pass + return except Exception as e: logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True) return - # Check for nested zip files and extract them + # Check for nested zip files and extract them for root, dirs, files in os.walk(extract_to): for file in files: if file.endswith(".zip"): # If a nested zip file is found, extract it recursively - file_path = os.path.join(root, file) extract_zip_recursive(file_path, root, current_depth + 1, max_depth) diff --git a/tests/core/test_url_validation.py b/tests/core/test_url_validation.py new file mode 100644 index 00000000..924e5cde --- /dev/null +++ b/tests/core/test_url_validation.py @@ -0,0 +1,197 @@ +"""Tests for SSRF URL validation module.""" + +import pytest +from unittest.mock import patch + +from application.core.url_validation import ( + SSRFError, + validate_url, + validate_url_safe, + is_private_ip, + is_metadata_ip, +) + + +class TestIsPrivateIP: + """Tests for is_private_ip function.""" + + def test_loopback_ipv4(self): + assert is_private_ip("127.0.0.1") is True + assert is_private_ip("127.255.255.255") is True + + def test_private_class_a(self): + assert is_private_ip("10.0.0.1") is True + assert is_private_ip("10.255.255.255") is True + + def test_private_class_b(self): + assert is_private_ip("172.16.0.1") is True + assert is_private_ip("172.31.255.255") is True + + def test_private_class_c(self): + assert is_private_ip("192.168.0.1") is True + assert is_private_ip("192.168.255.255") is True + + def test_link_local(self): + assert is_private_ip("169.254.0.1") is True + + def test_public_ip(self): + assert is_private_ip("8.8.8.8") is False + assert is_private_ip("1.1.1.1") is False + assert is_private_ip("93.184.216.34") is False + + def test_invalid_ip(self): + assert is_private_ip("not-an-ip") is False + assert is_private_ip("") is False + + +class TestIsMetadataIP: + """Tests for is_metadata_ip function.""" + + def test_aws_metadata_ip(self): + assert is_metadata_ip("169.254.169.254") is True + + def test_aws_ecs_metadata_ip(self): + assert is_metadata_ip("169.254.170.2") is True + + def test_non_metadata_ip(self): + assert is_metadata_ip("8.8.8.8") is False + assert is_metadata_ip("10.0.0.1") is False + + +class TestValidateUrl: + """Tests for validate_url function.""" + + def test_adds_scheme_if_missing(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" # Public IP + result = validate_url("example.com") + assert result == "http://example.com" + + def test_preserves_https_scheme(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + result = validate_url("https://example.com") + assert result == "https://example.com" + + def test_blocks_localhost(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://localhost") + assert "localhost" in str(exc_info.value).lower() + + def test_blocks_localhost_localdomain(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://localhost.localdomain") + assert "not allowed" in str(exc_info.value).lower() + + def test_blocks_loopback_ip(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://127.0.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_a(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://10.0.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_b(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://172.16.0.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_private_ip_class_c(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://192.168.1.1") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_aws_metadata_ip(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://169.254.169.254") + assert "metadata" in str(exc_info.value).lower() + + def test_blocks_aws_metadata_with_path(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://169.254.169.254/latest/meta-data/") + assert "metadata" in str(exc_info.value).lower() + + def test_blocks_gcp_metadata_hostname(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://metadata.google.internal") + assert "not allowed" in str(exc_info.value).lower() + + def test_blocks_ftp_scheme(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("ftp://example.com") + assert "scheme" in str(exc_info.value).lower() + + def test_blocks_file_scheme(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("file:///etc/passwd") + assert "scheme" in str(exc_info.value).lower() + + def test_blocks_hostname_resolving_to_private_ip(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "192.168.1.1" + with pytest.raises(SSRFError) as exc_info: + validate_url("http://internal.example.com") + assert "private" in str(exc_info.value).lower() or "internal" in str(exc_info.value).lower() + + def test_blocks_hostname_resolving_to_metadata_ip(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "169.254.169.254" + with pytest.raises(SSRFError) as exc_info: + validate_url("http://evil.example.com") + assert "metadata" in str(exc_info.value).lower() + + def test_allows_public_ip(self): + result = validate_url("http://8.8.8.8") + assert result == "http://8.8.8.8" + + def test_allows_public_hostname(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + result = validate_url("https://example.com") + assert result == "https://example.com" + + def test_raises_on_unresolvable_hostname(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = None + with pytest.raises(SSRFError) as exc_info: + validate_url("http://nonexistent.invalid") + assert "resolve" in str(exc_info.value).lower() + + def test_raises_on_empty_hostname(self): + with pytest.raises(SSRFError) as exc_info: + validate_url("http://") + assert "hostname" in str(exc_info.value).lower() + + def test_allow_localhost_flag(self): + # Should work with allow_localhost=True + result = validate_url("http://localhost", allow_localhost=True) + assert result == "http://localhost" + + result = validate_url("http://127.0.0.1", allow_localhost=True) + assert result == "http://127.0.0.1" + + +class TestValidateUrlSafe: + """Tests for validate_url_safe non-throwing function.""" + + def test_returns_tuple_on_success(self): + with patch("application.core.url_validation.resolve_hostname") as mock_resolve: + mock_resolve.return_value = "93.184.216.34" + is_valid, url, error = validate_url_safe("https://example.com") + assert is_valid is True + assert url == "https://example.com" + assert error is None + + def test_returns_tuple_on_failure(self): + is_valid, url, error = validate_url_safe("http://localhost") + assert is_valid is False + assert url == "http://localhost" + assert error is not None + assert "localhost" in error.lower() + + def test_returns_error_message_for_private_ip(self): + is_valid, url, error = validate_url_safe("http://192.168.1.1") + assert is_valid is False + assert "private" in error.lower() or "internal" in error.lower() diff --git a/tests/parser/remote/test_crawler_loader.py b/tests/parser/remote/test_crawler_loader.py index 92ffdc84..06d27517 100644 --- a/tests/parser/remote/test_crawler_loader.py +++ b/tests/parser/remote/test_crawler_loader.py @@ -13,8 +13,17 @@ class DummyResponse: return None +def _mock_validate_url(url): + """Mock validate_url that allows test URLs through.""" + from urllib.parse import urlparse + if not urlparse(url).scheme: + url = "http://" + url + return url + + +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_crawls_same_domain_links(mock_requests_get): +def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url): responses = { "http://example.com": DummyResponse( """ @@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get): "http://example.com/about": DummyResponse("
About page"), } - def response_side_effect(url: str): + def response_side_effect(url: str, timeout=30): if url not in responses: raise AssertionError(f"Unexpected request for URL: {url}") return responses[url] @@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get): assert loader_call_order == ["http://example.com", "http://example.com/about"] +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): +def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url): mock_requests_get.return_value = DummyResponse("No links here") doc = MagicMock(spec=LCDocument) @@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): result = crawler.load_data(["example.com", "unused.com"]) - mock_requests_get.assert_called_once_with("http://example.com") + mock_requests_get.assert_called_once_with("http://example.com", timeout=30) crawler.loader.assert_called_once_with(["http://example.com"]) assert len(result) == 1 @@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get): assert result[0].extra_info == {"source": "http://example.com"} +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_respects_limit(mock_requests_get): +def test_load_data_respects_limit(mock_requests_get, mock_validate_url): responses = { "http://example.com": DummyResponse( """ @@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get): "http://example.com/about": DummyResponse("About"), } - mock_requests_get.side_effect = lambda url: responses[url] + mock_requests_get.side_effect = lambda url, timeout=30: responses[url] root_doc = MagicMock(spec=LCDocument) root_doc.page_content = "Root content" @@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get): assert crawler.loader.call_count == 1 +@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url) @patch("application.parser.remote.crawler_loader.logging") @patch("application.parser.remote.crawler_loader.requests.get") -def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging): +def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url): mock_requests_get.return_value = DummyResponse("Error route") failing_loader_instance = MagicMock() @@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin result = crawler.load_data("http://example.com") assert result == [] - mock_requests_get.assert_called_once_with("http://example.com") + mock_requests_get.assert_called_once_with("http://example.com", timeout=30) failing_loader_instance.load.assert_called_once() mock_logging.error.assert_called_once() @@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin assert "Error processing URL http://example.com" in message assert mock_logging.error.call_args.kwargs.get("exc_info") is True + +@patch("application.parser.remote.crawler_loader.validate_url") +def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url): + """Test that SSRF validation failure returns empty list.""" + from application.core.url_validation import SSRFError + mock_validate_url.side_effect = SSRFError("Access to private IP not allowed") + + crawler = CrawlerLoader() + result = crawler.load_data("http://192.168.1.1") + + assert result == [] + mock_validate_url.assert_called_once() + diff --git a/tests/parser/remote/test_crawler_markdown.py b/tests/parser/remote/test_crawler_markdown.py index ac27b3d0..b2b3f21c 100644 --- a/tests/parser/remote/test_crawler_markdown.py +++ b/tests/parser/remote/test_crawler_markdown.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock +from urllib.parse import urlparse import pytest import requests @@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace: return SimpleNamespace(domain=domain, suffix=suffix) +def _mock_validate_url(url): + """Mock validate_url that allows test URLs through.""" + if not urlparse(url).scheme: + url = "http://" + url + return url + + +@pytest.fixture(autouse=True) +def _patch_validate_url(monkeypatch): + monkeypatch.setattr( + "application.parser.remote.crawler_markdown.validate_url", + _mock_validate_url, + ) + + @pytest.fixture(autouse=True) def _patch_tldextract(monkeypatch): monkeypatch.setattr( @@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify): assert len(docs) == 2 -def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify): +def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url): root_html = """