mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-19 10:51:01 +00:00
Patches (#2218)
* feat: implement URL validation to prevent SSRF * feat: add zip extraction security * ruff fixes
This commit is contained in:
@@ -13,8 +13,17 @@ class DummyResponse:
|
||||
return None
|
||||
|
||||
|
||||
def _mock_validate_url(url):
|
||||
"""Mock validate_url that allows test URLs through."""
|
||||
from urllib.parse import urlparse
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
return url
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
def test_load_data_crawls_same_domain_links(mock_requests_get, mock_validate_url):
|
||||
responses = {
|
||||
"http://example.com": DummyResponse(
|
||||
"""
|
||||
@@ -29,7 +38,7 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
"http://example.com/about": DummyResponse("<html><body>About page</body></html>"),
|
||||
}
|
||||
|
||||
def response_side_effect(url: str):
|
||||
def response_side_effect(url: str, timeout=30):
|
||||
if url not in responses:
|
||||
raise AssertionError(f"Unexpected request for URL: {url}")
|
||||
return responses[url]
|
||||
@@ -76,8 +85,9 @@ def test_load_data_crawls_same_domain_links(mock_requests_get):
|
||||
assert loader_call_order == ["http://example.com", "http://example.com/about"]
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get, mock_validate_url):
|
||||
mock_requests_get.return_value = DummyResponse("<html><body>No links here</body></html>")
|
||||
|
||||
doc = MagicMock(spec=LCDocument)
|
||||
@@ -92,7 +102,7 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
|
||||
result = crawler.load_data(["example.com", "unused.com"])
|
||||
|
||||
mock_requests_get.assert_called_once_with("http://example.com")
|
||||
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
|
||||
crawler.loader.assert_called_once_with(["http://example.com"])
|
||||
|
||||
assert len(result) == 1
|
||||
@@ -100,8 +110,9 @@ def test_load_data_accepts_list_input_and_adds_scheme(mock_requests_get):
|
||||
assert result[0].extra_info == {"source": "http://example.com"}
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_respects_limit(mock_requests_get):
|
||||
def test_load_data_respects_limit(mock_requests_get, mock_validate_url):
|
||||
responses = {
|
||||
"http://example.com": DummyResponse(
|
||||
"""
|
||||
@@ -115,7 +126,7 @@ def test_load_data_respects_limit(mock_requests_get):
|
||||
"http://example.com/about": DummyResponse("<html><body>About</body></html>"),
|
||||
}
|
||||
|
||||
mock_requests_get.side_effect = lambda url: responses[url]
|
||||
mock_requests_get.side_effect = lambda url, timeout=30: responses[url]
|
||||
|
||||
root_doc = MagicMock(spec=LCDocument)
|
||||
root_doc.page_content = "Root content"
|
||||
@@ -143,9 +154,10 @@ def test_load_data_respects_limit(mock_requests_get):
|
||||
assert crawler.loader.call_count == 1
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url", side_effect=_mock_validate_url)
|
||||
@patch("application.parser.remote.crawler_loader.logging")
|
||||
@patch("application.parser.remote.crawler_loader.requests.get")
|
||||
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging):
|
||||
def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_logging, mock_validate_url):
|
||||
mock_requests_get.return_value = DummyResponse("<html><body>Error route</body></html>")
|
||||
|
||||
failing_loader_instance = MagicMock()
|
||||
@@ -157,7 +169,7 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
|
||||
result = crawler.load_data("http://example.com")
|
||||
|
||||
assert result == []
|
||||
mock_requests_get.assert_called_once_with("http://example.com")
|
||||
mock_requests_get.assert_called_once_with("http://example.com", timeout=30)
|
||||
failing_loader_instance.load.assert_called_once()
|
||||
|
||||
mock_logging.error.assert_called_once()
|
||||
@@ -165,3 +177,16 @@ def test_load_data_logs_and_skips_on_loader_error(mock_requests_get, mock_loggin
|
||||
assert "Error processing URL http://example.com" in message
|
||||
assert mock_logging.error.call_args.kwargs.get("exc_info") is True
|
||||
|
||||
|
||||
@patch("application.parser.remote.crawler_loader.validate_url")
|
||||
def test_load_data_returns_empty_on_ssrf_validation_failure(mock_validate_url):
|
||||
"""Test that SSRF validation failure returns empty list."""
|
||||
from application.core.url_validation import SSRFError
|
||||
mock_validate_url.side_effect = SSRFError("Access to private IP not allowed")
|
||||
|
||||
crawler = CrawlerLoader()
|
||||
result = crawler.load_data("http://192.168.1.1")
|
||||
|
||||
assert result == []
|
||||
mock_validate_url.assert_called_once()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -29,6 +30,21 @@ def _fake_extract(value: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(domain=domain, suffix=suffix)
|
||||
|
||||
|
||||
def _mock_validate_url(url):
|
||||
"""Mock validate_url that allows test URLs through."""
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_validate_url(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.parser.remote.crawler_markdown.validate_url",
|
||||
_mock_validate_url,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_tldextract(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
@@ -112,7 +128,7 @@ def test_load_data_allows_subdomains(_patch_markdownify):
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
|
||||
def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify, _patch_validate_url):
|
||||
root_html = """
|
||||
<html><head><title>Home</title></head>
|
||||
<body><a href="/about">About</a></body>
|
||||
@@ -137,3 +153,21 @@ def test_load_data_handles_fetch_errors(monkeypatch, _patch_markdownify):
|
||||
assert docs[0].text == "Home Markdown"
|
||||
assert mock_print.called
|
||||
|
||||
|
||||
def test_load_data_returns_empty_on_ssrf_validation_failure(monkeypatch):
|
||||
"""Test that SSRF validation failure returns empty list."""
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
def raise_ssrf_error(url):
|
||||
raise SSRFError("Access to private IP not allowed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.parser.remote.crawler_markdown.validate_url",
|
||||
raise_ssrf_error,
|
||||
)
|
||||
|
||||
loader = CrawlerLoader()
|
||||
result = loader.load_data("http://192.168.1.1")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user