mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-12 09:11:05 +00:00
Patches (#2218)
* feat: implement URL validation to prevent SSRF * feat: add zip extraction security * ruff fixes
This commit is contained in:
@@ -4,6 +4,7 @@ from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -16,9 +17,12 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Check if the URL scheme is provided, if not, assume http
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
visited_urls = set()
|
||||
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
|
||||
@@ -30,7 +34,14 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
response = requests.get(current_url)
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
|
||||
@@ -2,6 +2,7 @@ import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
@@ -25,9 +26,12 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
# Keep track of visited URLs to avoid revisiting the same page
|
||||
visited_urls = set()
|
||||
@@ -78,9 +82,14 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
|
||||
@@ -3,6 +3,7 @@ import requests
|
||||
import re # Import regular expression library
|
||||
import xml.etree.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
|
||||
sitemap_url= inputs
|
||||
# Check if the input is a list and if it is, use the first element
|
||||
if isinstance(sitemap_url, list) and sitemap_url:
|
||||
url = sitemap_url[0]
|
||||
sitemap_url = sitemap_url[0]
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
sitemap_url = validate_url(sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
|
||||
urls = self._extract_urls(sitemap_url)
|
||||
if not urls:
|
||||
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
response = requests.get(sitemap_url)
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user