* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes
This commit is contained in:
Alex
2025-12-24 15:05:35 +00:00
committed by GitHub
parent 83e7a928f1
commit 98e949d2fd
11 changed files with 929 additions and 27 deletions

View File

@@ -11,6 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
@@ -73,6 +74,17 @@ class APITool(Tool):
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
@@ -90,6 +102,18 @@ class APITool(Tool):
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -0,0 +1,181 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Check if the URL scheme is provided, if not, assume http
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
visited_urls = set()
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
@@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
response = requests.get(current_url)
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -2,6 +2,7 @@ import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Ensure the URL has a scheme (if not, default to http)
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
print(f"URL validation failed: {e}")
return []
# Keep track of visited URLs to avoid revisiting the same page
visited_urls = set()
@@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -3,6 +3,7 @@ import requests
import re # Import regular expression library
import xml.etree.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
url = sitemap_url[0]
sitemap_url = sitemap_url[0]
# Validate URL to prevent SSRF attacks
try:
sitemap_url = validate_url(sitemap_url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
urls = self._extract_urls(sitemap_url)
if not urls:
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
response = requests.get(sitemap_url)
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []

View File

@@ -63,10 +63,111 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
# Zip extraction security limits
MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size
MAX_FILE_COUNT = 10000 # Maximum number of files to extract
MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs)
class ZipExtractionError(Exception):
"""Raised when zip extraction fails due to security constraints."""
pass
def _is_path_safe(base_path: str, target_path: str) -> bool:
"""
Check if target_path is safely within base_path (prevents zip slip attacks).
Args:
base_path: The base directory where extraction should occur.
target_path: The full path where a file would be extracted.
Returns:
True if the path is safe, False otherwise.
"""
# Resolve to absolute paths and check containment
base_resolved = os.path.realpath(base_path)
target_resolved = os.path.realpath(target_path)
return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved
def _validate_zip_safety(zip_path: str, extract_to: str) -> None:
"""
Validate a zip file for security issues before extraction.
Checks for:
- Zip bombs (excessive compression ratio or uncompressed size)
- Too many files
- Path traversal attacks (zip slip)
Args:
zip_path: Path to the zip file.
extract_to: Destination directory.
Raises:
ZipExtractionError: If the zip file fails security validation.
"""
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# Get compressed size
compressed_size = os.path.getsize(zip_path)
# Calculate total uncompressed size and file count
total_uncompressed = 0
file_count = 0
for info in zip_ref.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILE_COUNT:
raise ZipExtractionError(
f"Zip file contains too many files (>{MAX_FILE_COUNT}). "
"This may be a zip bomb attack."
)
# Accumulate uncompressed size
total_uncompressed += info.file_size
# Check total uncompressed size
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
raise ZipExtractionError(
f"Zip file uncompressed size exceeds limit "
f"({total_uncompressed / (1024*1024):.1f} MB > "
f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). "
"This may be a zip bomb attack."
)
# Check for path traversal (zip slip)
target_path = os.path.join(extract_to, info.filename)
if not _is_path_safe(extract_to, target_path):
raise ZipExtractionError(
f"Zip file contains path traversal attempt: {info.filename}"
)
# Check compression ratio (only if compressed size is meaningful)
if compressed_size > 0 and total_uncompressed > 0:
compression_ratio = total_uncompressed / compressed_size
if compression_ratio > MAX_COMPRESSION_RATIO:
raise ZipExtractionError(
f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > "
f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack."
)
except zipfile.BadZipFile as e:
raise ZipExtractionError(f"Invalid or corrupted zip file: {e}")
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
Recursively extract zip files with a limit on recursion depth.
Recursively extract zip files with security protections.
Security measures:
- Limits recursion depth to prevent infinite loops
- Validates uncompressed size to prevent zip bombs
- Limits number of files to prevent resource exhaustion
- Checks compression ratio to detect zip bombs
- Validates paths to prevent zip slip attacks
Args:
zip_path (str): Path to the zip file to be extracted.
@@ -77,20 +178,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
if current_depth > max_depth:
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
try:
# Validate zip file safety before extraction
_validate_zip_safety(zip_path, extract_to)
# Safe to extract
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
except ZipExtractionError as e:
logging.error(f"Zip security validation failed for {zip_path}: {e}")
# Remove the potentially malicious zip file
try:
os.remove(zip_path)
except OSError:
pass
return
except Exception as e:
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
return
# Check for nested zip files and extract them
# Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)