mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 05:15:08 +00:00
Compare commits
1 Commits
feat/defau
...
hardening-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a5ea8fe00 |
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
import requests
|
||||
|
||||
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,18 +70,16 @@ class APITool(Tool):
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
if method.upper() not in _VALID_METHODS:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
@@ -91,8 +89,9 @@ class APITool(Tool):
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
safe_value = quote(str(query_params[param_name]), safe="")
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
f"{{{param_name}}}", safe_value
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
@@ -103,19 +102,6 @@ class APITool(Tool):
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
@@ -141,49 +127,13 @@ class APITool(Tool):
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response = pinned_request(
|
||||
method,
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
@@ -193,6 +143,13 @@ class APITool(Tool):
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except UnsafeUserUrlError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class NtfyTool(Tool):
|
||||
"""
|
||||
@@ -71,7 +71,12 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
try:
|
||||
response = pinned_request(
|
||||
"POST", url, data=data, headers=headers, timeout=100,
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
return {"status_code": None, "message": f"URL validation error: {e}"}
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,28 +30,24 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
response = pinned_request(
|
||||
"GET",
|
||||
url,
|
||||
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
html_content = response.text
|
||||
#soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
|
||||
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
|
||||
|
||||
|
||||
return markdown_content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
except UnsafeUserUrlError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
except Exception as e:
|
||||
return f"Error processing URL {url}: {e}"
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -35,14 +35,7 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response = pinned_request("GET", current_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
@@ -20,7 +20,6 @@ class CrawlerLoader(BaseRemote):
|
||||
"""
|
||||
self.limit = limit
|
||||
self.allow_subdomains = allow_subdomains
|
||||
self.session = requests.Session()
|
||||
|
||||
def load_data(self, inputs):
|
||||
url = inputs
|
||||
@@ -91,15 +90,13 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response = pinned_request("GET", url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import requests
|
||||
import re # Import regular expression library
|
||||
import re
|
||||
import defusedxml.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -53,14 +53,12 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
response = pinned_request("GET", sitemap_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
@@ -97,13 +95,6 @@ class SitemapLoader(BaseRemote):
|
||||
nested_sitemap_url = sitemap.text
|
||||
if not nested_sitemap_url:
|
||||
continue
|
||||
try:
|
||||
nested_sitemap_url = validate_url(nested_sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(
|
||||
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
|
||||
)
|
||||
continue
|
||||
urls.extend(self._extract_urls(nested_sitemap_url))
|
||||
|
||||
return urls
|
||||
|
||||
@@ -291,6 +291,55 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
|
||||
return str(ip)
|
||||
|
||||
|
||||
def pinned_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
data: Any = None,
|
||||
json: Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 90.0,
|
||||
allow_redirects: bool = False,
|
||||
) -> requests.Response:
|
||||
"""Send an HTTP request with the connection pinned to a validated IP,
|
||||
closing the DNS-rebinding TOCTOU window left by the naive
|
||||
validate-then-``requests`` pattern.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If the URL fails the SSRF guard.
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.request(
|
||||
method=method.upper(),
|
||||
url=pinned_url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def pinned_post(
|
||||
url: str,
|
||||
*,
|
||||
@@ -328,33 +377,15 @@ def pinned_post(
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
return pinned_request(
|
||||
"POST",
|
||||
url,
|
||||
json=json,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.post(
|
||||
pinned_url,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _PinnedHTTPSTransport(httpx.HTTPTransport):
|
||||
"""``httpx`` transport pinned to a single validated IP literal.
|
||||
|
||||
@@ -45,15 +45,14 @@ class TestAPIToolInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any_action")
|
||||
|
||||
@@ -61,54 +60,50 @@ class TestMakeApiCall:
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -116,15 +111,14 @@ class TestMakeApiCall:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] == 404
|
||||
assert "HTTP Error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(
|
||||
config={"url": "https://example.com", "method": "CUSTOM"}
|
||||
)
|
||||
@@ -132,69 +126,64 @@ class TestMakeApiCall:
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
@@ -202,9 +191,8 @@ class TestMakeApiCall:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
@@ -217,11 +205,11 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
|
||||
@@ -81,104 +81,103 @@ class TestAPIToolInit:
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
assert mock_pinned.call_args[0][0] == "GET"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
assert result["status_code"] == 201
|
||||
assert mock_pinned.call_args[0][0] == "POST"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PUT"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
assert mock_pinned.call_args[0][0] == "DELETE"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PATCH"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "HEAD"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "OPTIONS"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
@@ -193,19 +192,18 @@ class TestMakeApiCall:
|
||||
@pytest.mark.unit
|
||||
class TestSSRFValidation:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, get_tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked_with_path_params(self, mock_pinned):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/{host}/data",
|
||||
@@ -213,14 +211,7 @@ class TestSSRFValidation:
|
||||
"query_params": {"host": "169.254.169.254"},
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def side_effect(url):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
raise SSRFError("blocked after substitution")
|
||||
|
||||
mock_validate.side_effect = side_effect
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
@@ -234,40 +225,36 @@ class TestSSRFValidation:
|
||||
@pytest.mark.unit
|
||||
class TestErrorHandling:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_with_json(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 422
|
||||
mock_resp.json.return_value = {"error": "invalid_field"}
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 422
|
||||
assert result["data"] == {"error": "invalid_field"}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_non_json_body(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -275,29 +262,26 @@ class TestErrorHandling:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 404
|
||||
assert result["data"] == "Not Found"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_request_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("something")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_request_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.RequestException("something")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "API call failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = RuntimeError("unexpected")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_unexpected_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = RuntimeError("unexpected")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "Unexpected error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_body_serialization_error(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_body_serialization_error(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -320,9 +304,8 @@ class TestErrorHandling:
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
@@ -333,17 +316,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_remaining_query_params_appended(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_remaining_query_params_appended(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "GET",
|
||||
@@ -354,19 +336,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "page=2" in called_url
|
||||
assert "limit=10" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_query_params_append_with_existing_query_string(
|
||||
self, mock_get, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_query_params_append_with_existing_query_string(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items?existing=true",
|
||||
"method": "GET",
|
||||
@@ -377,27 +356,65 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "&page=1" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_empty_body_no_serialization(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_empty_body_no_serialization(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "POST"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("create")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_are_url_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/profile",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "../../admin"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "../../admin" not in called_url
|
||||
assert "%2F" in called_url or "%2f" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_query_injection_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items/{item_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"item_id": "x?admin=true"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "x?admin=true" not in called_url
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Parse Response
|
||||
@@ -494,11 +511,8 @@ class TestAPIToolMetadata:
|
||||
def test_config_requirements_empty(self, get_tool):
|
||||
assert get_tool.get_config_requirements() == {}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_content_type_set_for_post_with_no_headers(
|
||||
self, mock_post, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -509,8 +523,8 @@ class TestAPIToolMetadata:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("create")
|
||||
call_headers = mock_post.call_args[1]["headers"]
|
||||
call_headers = mock_pinned.call_args.kwargs["headers"]
|
||||
assert "Content-Type" in call_headers
|
||||
|
||||
Reference in New Issue
Block a user