Compare commits

...

1 Commits

Author SHA1 Message Date
Pavel
9a5ea8fe00 Harden protection with pinned requests and path-param encoding 2026-05-21 00:31:52 +04:00
9 changed files with 263 additions and 292 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
from urllib.parse import quote, urlencode
import requests
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
logger = logging.getLogger(__name__)
@@ -70,18 +70,16 @@ class APITool(Tool):
Returns:
Dict with status_code, data, and message
"""
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
if method.upper() not in _VALID_METHODS:
return {
"status_code": None,
"message": f"URL validation error: {e}",
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
@@ -91,8 +89,9 @@ class APITool(Tool):
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
safe_value = quote(str(query_params[param_name]), safe="")
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
f"{{{param_name}}}", safe_value
)
path_params_used.add(param_name)
remaining_params = {
@@ -103,19 +102,6 @@ class APITool(Tool):
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
@@ -141,49 +127,13 @@ class APITool(Tool):
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response = pinned_request(
method,
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
response.raise_for_status()
data = self._parse_response(response)
@@ -193,6 +143,13 @@ class APITool(Tool):
"data": data,
"message": "API call successful.",
}
except UnsafeUserUrlError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {

View File

@@ -1,5 +1,5 @@
import requests
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class NtfyTool(Tool):
"""
@@ -71,7 +71,12 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data, timeout=100)
try:
response = pinned_request(
"POST", url, data=data, headers=headers, timeout=100,
)
except UnsafeUserUrlError as e:
return {"status_code": None, "message": f"URL validation error: {e}"}
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):

View File

@@ -1,7 +1,6 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class ReadWebpageTool(Tool):
"""
@@ -31,28 +30,24 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
response = pinned_request(
"GET",
url,
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
timeout=10,
)
response.raise_for_status()
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
html_content = response.text
#soup = BeautifulSoup(html_content, 'html.parser')
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
return markdown_content
except requests.exceptions.RequestException as e:
return f"Error fetching URL {url}: {e}"
except UnsafeUserUrlError as e:
return f"Error: URL validation failed - {e}"
except Exception as e:
return f"Error processing URL {url}: {e}"
return f"Error fetching URL {url}: {e}"
def get_actions_metadata(self):
"""

View File

@@ -1,11 +1,11 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -35,14 +35,7 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response = pinned_request("GET", current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -1,8 +1,8 @@
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -20,7 +20,6 @@ class CrawlerLoader(BaseRemote):
"""
self.limit = limit
self.allow_subdomains = allow_subdomains
self.session = requests.Session()
def load_data(self, inputs):
url = inputs
@@ -91,15 +90,13 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response = pinned_request("GET", url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
except UnsafeUserUrlError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
except Exception as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -1,9 +1,9 @@
import logging
import requests
import re # Import regular expression library
import re
import defusedxml.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -53,14 +53,12 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
response = pinned_request("GET", sitemap_url, timeout=30)
response.raise_for_status()
except UnsafeUserUrlError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
except Exception as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
@@ -97,13 +95,6 @@ class SitemapLoader(BaseRemote):
nested_sitemap_url = sitemap.text
if not nested_sitemap_url:
continue
try:
nested_sitemap_url = validate_url(nested_sitemap_url)
except SSRFError as e:
logging.error(
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
)
continue
urls.extend(self._extract_urls(nested_sitemap_url))
return urls

View File

@@ -291,6 +291,55 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
return str(ip)
def pinned_request(
method: str,
url: str,
*,
data: Any = None,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 90.0,
allow_redirects: bool = False,
) -> requests.Response:
"""Send an HTTP request with the connection pinned to a validated IP,
closing the DNS-rebinding TOCTOU window left by the naive
validate-then-``requests`` pattern.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.request(
method=method.upper(),
url=pinned_url,
data=data,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
def pinned_post(
url: str,
*,
@@ -328,33 +377,15 @@ def pinned_post(
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
return pinned_request(
"POST",
url,
json=json,
headers=headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.

View File

@@ -45,15 +45,14 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("any_action")
@@ -61,54 +60,50 @@ class TestMakeApiCall:
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, tool):
from application.security.safe_url import UnsafeUserUrlError
mock_validate.side_effect = SSRFError("blocked")
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
result = tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error(self, mock_pinned, tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -116,15 +111,14 @@ class TestMakeApiCall:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("any")
assert result["status_code"] == 404
assert "HTTP Error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
def test_unsupported_method(self):
tool = APITool(
config={"url": "https://example.com", "method": "CUSTOM"}
)
@@ -132,69 +126,64 @@ class TestMakeApiCall:
assert result["status_code"] is None
assert "Unsupported" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@@ -202,9 +191,8 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
tool = APITool(
config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
@@ -217,11 +205,11 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url

View File

@@ -81,104 +81,103 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
assert mock_pinned.call_args[0][0] == "GET"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
assert mock_pinned.call_args[0][0] == "POST"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PUT"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
assert mock_pinned.call_args[0][0] == "DELETE"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PATCH"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "HEAD"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "OPTIONS"
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
def test_unsupported_method(self):
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
result = tool.execute_action("any")
assert result["status_code"] is None
@@ -193,19 +192,18 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestSSRFValidation:
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, get_tool):
from application.security.safe_url import UnsafeUserUrlError
mock_validate.side_effect = SSRFError("blocked")
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked_with_path_params(self, mock_pinned):
from application.security.safe_url import UnsafeUserUrlError
tool = APITool(config={
"url": "https://api.example.com/{host}/data",
@@ -213,14 +211,7 @@ class TestSSRFValidation:
"query_params": {"host": "169.254.169.254"},
})
call_count = [0]
def side_effect(url):
call_count[0] += 1
if call_count[0] == 2:
raise SSRFError("blocked after substitution")
mock_validate.side_effect = side_effect
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@@ -234,40 +225,36 @@ class TestSSRFValidation:
@pytest.mark.unit
class TestErrorHandling:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_with_json(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 422
mock_resp.json.return_value = {"error": "invalid_field"}
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 422
assert result["data"] == {"error": "invalid_field"}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_non_json_body(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -275,29 +262,26 @@ class TestErrorHandling:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 404
assert result["data"] == "Not Found"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_request_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.RequestException("something")
@patch("application.agents.tools.api_tool.pinned_request")
def test_request_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.RequestException("something")
result = get_tool.execute_action("any")
assert "API call failed" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = RuntimeError("unexpected")
@patch("application.agents.tools.api_tool.pinned_request")
def test_unexpected_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = RuntimeError("unexpected")
result = get_tool.execute_action("any")
assert "Unexpected error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_body_serialization_error(self, mock_post, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_body_serialization_error(self, mock_pinned):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -320,9 +304,8 @@ class TestErrorHandling:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
@@ -333,17 +316,16 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_remaining_query_params_appended(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_remaining_query_params_appended(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items",
"method": "GET",
@@ -354,19 +336,16 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "page=2" in called_url
assert "limit=10" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_query_params_append_with_existing_query_string(
self, mock_get, mock_validate
):
@patch("application.agents.tools.api_tool.pinned_request")
def test_query_params_append_with_existing_query_string(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items?existing=true",
"method": "GET",
@@ -377,27 +356,65 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "&page=1" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_empty_body_no_serialization(self, mock_post, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_empty_body_no_serialization(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "POST"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("create")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_are_url_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/profile",
"method": "GET",
"query_params": {"user_id": "../../admin"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "../../admin" not in called_url
assert "%2F" in called_url or "%2f" in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_query_injection_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items/{item_id}",
"method": "GET",
"query_params": {"item_id": "x?admin=true"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "x?admin=true" not in called_url
# =====================================================================
# Parse Response
@@ -494,11 +511,8 @@ class TestAPIToolMetadata:
def test_config_requirements_empty(self, get_tool):
assert get_tool.get_config_requirements() == {}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_content_type_set_for_post_with_no_headers(
self, mock_post, mock_validate
):
@patch("application.agents.tools.api_tool.pinned_request")
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -509,8 +523,8 @@ class TestAPIToolMetadata:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("create")
call_headers = mock_post.call_args[1]["headers"]
call_headers = mock_pinned.call_args.kwargs["headers"]
assert "Content-Type" in call_headers