mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-08 23:30:42 +00:00
feat: enhance API tool with body serialization and content type handling (#2192)
* feat: enhance API tool with body serialization and content type handling * feat: enhance ToolConfig with import functionality and user action management - Added ImportSpecModal to allow importing actions into the tool configuration. - Implemented search functionality for user actions with expandable action details. - Introduced method colors for better visual distinction of HTTP methods. - Updated APIActionType and ParameterGroupType to include optional 'required' field. - Refactored action rendering to improve usability and maintainability. * feat: add base URL input to ImportSpecModal for action URL customization * feat: update TestBaseAgentTools to include 'required' field for parameters * feat: standardize API call timeout to DEFAULT_TIMEOUT constant * feat: add import specification functionality and related translations for multiple languages --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -1,72 +1,256 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 90 # seconds
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.headers = config.get("headers", {})
|
||||
self.query_params = config.get("query_params", {})
|
||||
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
"""Execute an API action with the given arguments."""
|
||||
return self._make_api_call(
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
self.url,
|
||||
self.method,
|
||||
self.headers,
|
||||
self.query_params,
|
||||
kwargs,
|
||||
self.body_content_type,
|
||||
self.body_encoding_rules,
|
||||
)
|
||||
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
# if isinstance(body, dict):
|
||||
# body = json.dumps(body)
|
||||
def _make_api_call(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API call with proper body serialization and error handling.
|
||||
|
||||
Args:
|
||||
url: API endpoint URL
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||
headers: Request headers dict
|
||||
query_params: URL query parameters
|
||||
body: Request body as dict
|
||||
content_type: Content-Type for serialization
|
||||
encoding_rules: OpenAPI encoding rules
|
||||
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
try:
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
k: v for k, v in query_params.items() if k not in path_params_used
|
||||
}
|
||||
if remaining_params:
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
body, content_type, encoding_rules
|
||||
)
|
||||
request_headers.update(body_headers)
|
||||
except ValueError as e:
|
||||
logger.error(f"Body serialization failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
"status_code": None,
|
||||
"message": f"Body serialization error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
serialized_body = None
|
||||
if "Content-Type" not in request_headers and method not in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"DELETE",
|
||||
]:
|
||||
request_headers["Content-Type"] = ContentType.JSON
|
||||
logger.debug(
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||
try:
|
||||
error_data = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
error_data = response.text
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"HTTP Error {response.status_code}",
|
||||
"data": error_data,
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unexpected error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> Any:
|
||||
"""
|
||||
Parse response based on Content-Type header.
|
||||
|
||||
Supports: JSON, XML, plain text, binary data.
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
|
||||
if not response.content:
|
||||
return None
|
||||
# JSON response
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||
return response.text
|
||||
# XML response
|
||||
|
||||
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||
return response.text
|
||||
# Plain text response
|
||||
|
||||
elif "text/plain" in content_type or "text/html" in content_type:
|
||||
return response.text
|
||||
# Binary/unknown response
|
||||
|
||||
else:
|
||||
# Try to decode as text first, fall back to base64
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
import base64
|
||||
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""Return configuration requirements for the tool."""
|
||||
return {}
|
||||
|
||||
Reference in New Issue
Block a user