mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1947 from siiddhantt/feat/remote-mcp
feat: remote mcp
This commit is contained in:
@@ -140,28 +140,28 @@ class BaseAgent(ABC):
|
|||||||
tool_id, action_name, call_args = parser.parse_args(call)
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
|
||||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
# Check if parsing failed
|
# Check if parsing failed
|
||||||
if tool_id is None or action_name is None:
|
if tool_id is None or action_name is None:
|
||||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
"action_name": getattr(call, 'name', 'unknown'),
|
"action_name": getattr(call, "name", "unknown"),
|
||||||
"arguments": call_args or {},
|
"arguments": call_args or {},
|
||||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||||
}
|
}
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
return "Failed to parse tool call.", call_id
|
return "Failed to parse tool call.", call_id
|
||||||
|
|
||||||
# Check if tool_id exists in available tools
|
# Check if tool_id exists in available tools
|
||||||
if tool_id not in tools_dict:
|
if tool_id not in tools_dict:
|
||||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
# Return error result
|
# Return error result
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
@@ -173,7 +173,7 @@ class BaseAgent(ABC):
|
|||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
return f"Tool with ID {tool_id} not found.", call_id
|
return f"Tool with ID {tool_id} not found.", call_id
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": tools_dict[tool_id]["name"],
|
"tool_name": tools_dict[tool_id]["name"],
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
|
|||||||
if tool_data["name"] == "api_tool"
|
if tool_data["name"] == "api_tool"
|
||||||
else tool_data["config"]
|
else tool_data["config"]
|
||||||
),
|
),
|
||||||
|
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||||
)
|
)
|
||||||
if tool_data["name"] == "api_tool":
|
if tool_data["name"] == "api_tool":
|
||||||
print(
|
print(
|
||||||
|
|||||||
405
application/agents/tools/mcp_tool.py
Normal file
405
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.agents.tools.base import Tool
|
||||||
|
from application.security.encryption import decrypt_credentials
|
||||||
|
|
||||||
|
|
||||||
|
_mcp_session_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTool(Tool):
|
||||||
|
"""
|
||||||
|
MCP Tool
|
||||||
|
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the MCP Tool with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dictionary containing MCP server configuration:
|
||||||
|
- server_url: URL of the remote MCP server
|
||||||
|
- auth_type: Type of authentication (api_key, bearer, basic, none)
|
||||||
|
- encrypted_credentials: Encrypted credentials (if available)
|
||||||
|
- timeout: Request timeout in seconds (default: 30)
|
||||||
|
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.server_url = config.get("server_url", "")
|
||||||
|
self.auth_type = config.get("auth_type", "none")
|
||||||
|
self.timeout = config.get("timeout", 30)
|
||||||
|
|
||||||
|
self.auth_credentials = {}
|
||||||
|
if config.get("encrypted_credentials") and user_id:
|
||||||
|
self.auth_credentials = decrypt_credentials(
|
||||||
|
config["encrypted_credentials"], user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.auth_credentials = config.get("auth_credentials", {})
|
||||||
|
self.available_tools = []
|
||||||
|
self._session = requests.Session()
|
||||||
|
self._mcp_session_id = None
|
||||||
|
self._setup_authentication()
|
||||||
|
self._cache_key = self._generate_cache_key()
|
||||||
|
|
||||||
|
def _setup_authentication(self):
|
||||||
|
"""Setup authentication for the MCP server connection."""
|
||||||
|
if self.auth_type == "api_key":
|
||||||
|
api_key = self.auth_credentials.get("api_key", "")
|
||||||
|
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||||
|
if api_key:
|
||||||
|
self._session.headers.update({header_name: api_key})
|
||||||
|
elif self.auth_type == "bearer":
|
||||||
|
token = self.auth_credentials.get("bearer_token", "")
|
||||||
|
if token:
|
||||||
|
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||||
|
elif self.auth_type == "basic":
|
||||||
|
username = self.auth_credentials.get("username", "")
|
||||||
|
password = self.auth_credentials.get("password", "")
|
||||||
|
if username and password:
|
||||||
|
self._session.auth = (username, password)
|
||||||
|
|
||||||
|
def _generate_cache_key(self) -> str:
|
||||||
|
"""Generate a unique cache key for this MCP server configuration."""
|
||||||
|
auth_key = ""
|
||||||
|
if self.auth_type == "bearer":
|
||||||
|
token = self.auth_credentials.get("bearer_token", "")
|
||||||
|
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||||
|
elif self.auth_type == "api_key":
|
||||||
|
api_key = self.auth_credentials.get("api_key", "")
|
||||||
|
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||||
|
elif self.auth_type == "basic":
|
||||||
|
username = self.auth_credentials.get("username", "")
|
||||||
|
auth_key = f"basic:{username}"
|
||||||
|
else:
|
||||||
|
auth_key = "none"
|
||||||
|
return f"{self.server_url}#{auth_key}"
|
||||||
|
|
||||||
|
def _get_cached_session(self) -> Optional[str]:
|
||||||
|
"""Get cached session ID if available and not expired."""
|
||||||
|
global _mcp_session_cache
|
||||||
|
|
||||||
|
if self._cache_key in _mcp_session_cache:
|
||||||
|
session_data = _mcp_session_cache[self._cache_key]
|
||||||
|
if time.time() - session_data["created_at"] < 1800:
|
||||||
|
return session_data["session_id"]
|
||||||
|
else:
|
||||||
|
del _mcp_session_cache[self._cache_key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cache_session(self, session_id: str):
|
||||||
|
"""Cache the session ID for reuse."""
|
||||||
|
global _mcp_session_cache
|
||||||
|
_mcp_session_cache[self._cache_key] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _initialize_mcp_connection(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Initialize MCP connection with the server, using cached session if available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server capabilities and information
|
||||||
|
"""
|
||||||
|
cached_session = self._get_cached_session()
|
||||||
|
if cached_session:
|
||||||
|
self._mcp_session_id = cached_session
|
||||||
|
return {"cached": True}
|
||||||
|
try:
|
||||||
|
init_params = {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
|
||||||
|
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
response = self._make_mcp_request("initialize", init_params)
|
||||||
|
self._make_mcp_request("notifications/initialized")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e), "fallback": True}
|
||||||
|
|
||||||
|
def _ensure_valid_session(self):
|
||||||
|
"""Ensure we have a valid MCP session, reinitializing if needed."""
|
||||||
|
if not self._mcp_session_id:
|
||||||
|
self._initialize_mcp_connection()
|
||||||
|
|
||||||
|
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Make an MCP protocol request to the server with automatic session recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: MCP method name (e.g., "tools/list", "tools/call")
|
||||||
|
params: Parameters for the MCP method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data as dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If request fails after retry
|
||||||
|
"""
|
||||||
|
mcp_message = {"jsonrpc": "2.0", "method": method}
|
||||||
|
|
||||||
|
if not method.startswith("notifications/"):
|
||||||
|
mcp_message["id"] = 1
|
||||||
|
if params:
|
||||||
|
mcp_message["params"] = params
|
||||||
|
return self._execute_mcp_request(mcp_message, method)
|
||||||
|
|
||||||
|
def _execute_mcp_request(
|
||||||
|
self, mcp_message: Dict, method: str, is_retry: bool = False
|
||||||
|
) -> Dict:
|
||||||
|
"""Execute MCP request with optional retry on session failure."""
|
||||||
|
try:
|
||||||
|
final_headers = self._session.headers.copy()
|
||||||
|
final_headers.update(
|
||||||
|
{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._mcp_session_id:
|
||||||
|
final_headers["Mcp-Session-Id"] = self._mcp_session_id
|
||||||
|
response = self._session.post(
|
||||||
|
self.server_url.rstrip("/"),
|
||||||
|
json=mcp_message,
|
||||||
|
headers=final_headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "mcp-session-id" in response.headers:
|
||||||
|
self._mcp_session_id = response.headers["mcp-session-id"]
|
||||||
|
self._cache_session(self._mcp_session_id)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if method.startswith("notifications/"):
|
||||||
|
return {}
|
||||||
|
response_text = response.text.strip()
|
||||||
|
if response_text.startswith("event:") and "data:" in response_text:
|
||||||
|
lines = response_text.split("\n")
|
||||||
|
data_line = None
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data_line = line[5:].strip()
|
||||||
|
break
|
||||||
|
if data_line:
|
||||||
|
try:
|
||||||
|
result = json.loads(data_line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise Exception(f"Invalid JSON in SSE data: {data_line}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"No data found in SSE response: {response_text}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = response.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise Exception(f"Invalid JSON response: {response.text}")
|
||||||
|
if "error" in result:
|
||||||
|
error_msg = result["error"]
|
||||||
|
if isinstance(error_msg, dict):
|
||||||
|
error_msg = error_msg.get("message", str(error_msg))
|
||||||
|
raise Exception(f"MCP server error: {error_msg}")
|
||||||
|
return result.get("result", result)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if not is_retry and self._should_retry_with_new_session(e):
|
||||||
|
self._invalidate_and_refresh_session()
|
||||||
|
return self._execute_mcp_request(mcp_message, method, is_retry=True)
|
||||||
|
raise Exception(f"MCP server request failed: {str(e)}")
|
||||||
|
|
||||||
|
def _should_retry_with_new_session(self, error: Exception) -> bool:
|
||||||
|
"""Check if error indicates session invalidation and retry is warranted."""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
return (
|
||||||
|
any(
|
||||||
|
indicator in error_str
|
||||||
|
for indicator in [
|
||||||
|
"invalid session",
|
||||||
|
"session expired",
|
||||||
|
"unauthorized",
|
||||||
|
"401",
|
||||||
|
"403",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
and self._mcp_session_id is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _invalidate_and_refresh_session(self) -> None:
|
||||||
|
"""Invalidate current session and create a new one."""
|
||||||
|
global _mcp_session_cache
|
||||||
|
if self._cache_key in _mcp_session_cache:
|
||||||
|
del _mcp_session_cache[self._cache_key]
|
||||||
|
self._mcp_session_id = None
|
||||||
|
self._initialize_mcp_connection()
|
||||||
|
|
||||||
|
def discover_tools(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Discover available tools from the MCP server using MCP protocol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions from the server
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._ensure_valid_session()
|
||||||
|
|
||||||
|
response = self._make_mcp_request("tools/list")
|
||||||
|
|
||||||
|
# Handle both formats: response with 'tools' key or response that IS the tools list
|
||||||
|
|
||||||
|
if isinstance(response, dict):
|
||||||
|
if "tools" in response:
|
||||||
|
self.available_tools = response["tools"]
|
||||||
|
elif (
|
||||||
|
"result" in response
|
||||||
|
and isinstance(response["result"], dict)
|
||||||
|
and "tools" in response["result"]
|
||||||
|
):
|
||||||
|
self.available_tools = response["result"]["tools"]
|
||||||
|
else:
|
||||||
|
self.available_tools = [response] if response else []
|
||||||
|
elif isinstance(response, list):
|
||||||
|
self.available_tools = response
|
||||||
|
else:
|
||||||
|
self.available_tools = []
|
||||||
|
return self.available_tools
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||||
|
|
||||||
|
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute an action on the remote MCP server using MCP protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: Name of the action to execute
|
||||||
|
**kwargs: Parameters for the action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from the MCP server
|
||||||
|
"""
|
||||||
|
self._ensure_valid_session()
|
||||||
|
|
||||||
|
# Skipping empty/None values - letting the server use defaults
|
||||||
|
|
||||||
|
cleaned_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value == "" or value is None:
|
||||||
|
continue
|
||||||
|
cleaned_kwargs[key] = value
|
||||||
|
call_params = {"name": action_name, "arguments": cleaned_kwargs}
|
||||||
|
try:
|
||||||
|
result = self._make_mcp_request("tools/call", call_params)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||||
|
|
||||||
|
def get_actions_metadata(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get metadata for all available actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of action metadata dictionaries
|
||||||
|
"""
|
||||||
|
actions = []
|
||||||
|
for tool in self.available_tools:
|
||||||
|
input_schema = (
|
||||||
|
tool.get("inputSchema")
|
||||||
|
or tool.get("input_schema")
|
||||||
|
or tool.get("schema")
|
||||||
|
or tool.get("parameters")
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if input_schema:
|
||||||
|
if isinstance(input_schema, dict):
|
||||||
|
if "properties" in input_schema:
|
||||||
|
parameters_schema = {
|
||||||
|
"type": input_schema.get("type", "object"),
|
||||||
|
"properties": input_schema.get("properties", {}),
|
||||||
|
"required": input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in ["additionalProperties", "description"]:
|
||||||
|
if key in input_schema:
|
||||||
|
parameters_schema[key] = input_schema[key]
|
||||||
|
else:
|
||||||
|
parameters_schema["properties"] = input_schema
|
||||||
|
action = {
|
||||||
|
"name": tool.get("name", ""),
|
||||||
|
"description": tool.get("description", ""),
|
||||||
|
"parameters": parameters_schema,
|
||||||
|
}
|
||||||
|
actions.append(action)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def test_connection(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Test the connection to the MCP server and validate functionality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with connection test results including tool count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._mcp_session_id = None
|
||||||
|
|
||||||
|
init_result = self._initialize_mcp_connection()
|
||||||
|
|
||||||
|
tools = self.discover_tools()
|
||||||
|
|
||||||
|
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||||
|
if init_result.get("cached"):
|
||||||
|
message += " (Using cached session)"
|
||||||
|
elif init_result.get("fallback"):
|
||||||
|
message += " (No formal initialization required)"
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": message,
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"session_id": self._mcp_session_id,
|
||||||
|
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Connection failed: {str(e)}",
|
||||||
|
"tools_count": 0,
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_config_requirements(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"server_url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"auth_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Authentication type",
|
||||||
|
"enum": ["none", "api_key", "bearer", "basic"],
|
||||||
|
"default": "none",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"auth_credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Authentication credentials (varies by auth_type)",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Request timeout in seconds",
|
||||||
|
"default": 30,
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 300,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -23,16 +23,23 @@ class ToolManager:
|
|||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
self.tools[name] = obj(tool_config)
|
self.tools[name] = obj(tool_config)
|
||||||
|
|
||||||
def load_tool(self, tool_name, tool_config):
|
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||||
self.config[tool_name] = tool_config
|
self.config[tool_name] = tool_config
|
||||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool:
|
||||||
return obj(tool_config)
|
if tool_name == "mcp_tool" and user_id:
|
||||||
|
return obj(tool_config, user_id)
|
||||||
|
else:
|
||||||
|
return obj(tool_config)
|
||||||
|
|
||||||
def execute_action(self, tool_name, action_name, **kwargs):
|
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||||
|
if tool_name == "mcp_tool" and user_id:
|
||||||
|
tool_config = self.config.get(tool_name, {})
|
||||||
|
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||||
|
return tool.execute_action(action_name, **kwargs)
|
||||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||||
|
|
||||||
def get_all_actions_metadata(self):
|
def get_all_actions_metadata(self):
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from flask_restx import fields, inputs, Namespace, Resource
|
|||||||
from pymongo import ReturnDocument
|
from pymongo import ReturnDocument
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
|
from application.agents.tools.mcp_tool import MCPTool
|
||||||
|
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.api import api
|
from application.api import api
|
||||||
|
|
||||||
@@ -38,6 +40,7 @@ from application.api.user.tasks import (
|
|||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.tts.google_tts import GoogleTTS
|
from application.tts.google_tts import GoogleTTS
|
||||||
from application.utils import (
|
from application.utils import (
|
||||||
@@ -491,6 +494,7 @@ class DeleteOldIndexes(Resource):
|
|||||||
)
|
)
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(jsonify({"status": "not found"}), 404)
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
|
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -507,6 +511,7 @@ class DeleteOldIndexes(Resource):
|
|||||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||||
)
|
)
|
||||||
vectorstore.delete_index()
|
vectorstore.delete_index()
|
||||||
|
|
||||||
if "file_path" in doc and doc["file_path"]:
|
if "file_path" in doc and doc["file_path"]:
|
||||||
file_path = doc["file_path"]
|
file_path = doc["file_path"]
|
||||||
if storage.is_directory(file_path):
|
if storage.is_directory(file_path):
|
||||||
@@ -515,6 +520,7 @@ class DeleteOldIndexes(Resource):
|
|||||||
storage.delete_file(f)
|
storage.delete_file(f)
|
||||||
else:
|
else:
|
||||||
storage.delete_file(file_path)
|
storage.delete_file(file_path)
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -522,6 +528,7 @@ class DeleteOldIndexes(Resource):
|
|||||||
f"Error deleting files and indexes: {err}", exc_info=True
|
f"Error deleting files and indexes: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
@@ -593,6 +600,7 @@ class UploadFile(Resource):
|
|||||||
== temp_file_path
|
== temp_file_path
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
rel_path = os.path.relpath(
|
rel_path = os.path.relpath(
|
||||||
os.path.join(root, extracted_file), temp_dir
|
os.path.join(root, extracted_file), temp_dir
|
||||||
)
|
)
|
||||||
@@ -617,6 +625,7 @@ class UploadFile(Resource):
|
|||||||
file_path = f"{base_path}/{safe_file}"
|
file_path = f"{base_path}/{safe_file}"
|
||||||
with open(temp_file_path, "rb") as f:
|
with open(temp_file_path, "rb") as f:
|
||||||
storage.save_file(f, file_path)
|
storage.save_file(f, file_path)
|
||||||
|
|
||||||
task = ingest.delay(
|
task = ingest.delay(
|
||||||
settings.UPLOAD_FOLDER,
|
settings.UPLOAD_FOLDER,
|
||||||
[
|
[
|
||||||
@@ -688,6 +697,7 @@ class ManageSourceFiles(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||||
)
|
)
|
||||||
|
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
source_id = request.form.get("source_id")
|
source_id = request.form.get("source_id")
|
||||||
operation = request.form.get("operation")
|
operation = request.form.get("operation")
|
||||||
@@ -737,6 +747,7 @@ class ManageSourceFiles(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Database error"}), 500
|
jsonify({"success": False, "message": "Database error"}), 500
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
source_file_path = source.get("file_path", "")
|
source_file_path = source.get("file_path", "")
|
||||||
@@ -793,6 +804,7 @@ class ManageSourceFiles(Resource):
|
|||||||
),
|
),
|
||||||
200,
|
200,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif operation == "remove":
|
elif operation == "remove":
|
||||||
file_paths_str = request.form.get("file_paths")
|
file_paths_str = request.form.get("file_paths")
|
||||||
if not file_paths_str:
|
if not file_paths_str:
|
||||||
@@ -846,6 +858,7 @@ class ManageSourceFiles(Resource):
|
|||||||
),
|
),
|
||||||
200,
|
200,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif operation == "remove_directory":
|
elif operation == "remove_directory":
|
||||||
directory_path = request.form.get("directory_path")
|
directory_path = request.form.get("directory_path")
|
||||||
if not directory_path:
|
if not directory_path:
|
||||||
@@ -871,6 +884,7 @@ class ManageSourceFiles(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_directory_path = (
|
full_directory_path = (
|
||||||
f"{source_file_path}/{directory_path}"
|
f"{source_file_path}/{directory_path}"
|
||||||
if directory_path
|
if directory_path
|
||||||
@@ -929,6 +943,7 @@ class ManageSourceFiles(Resource):
|
|||||||
),
|
),
|
||||||
200,
|
200,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||||
if operation == "remove_directory":
|
if operation == "remove_directory":
|
||||||
@@ -940,6 +955,7 @@ class ManageSourceFiles(Resource):
|
|||||||
elif operation == "add":
|
elif operation == "add":
|
||||||
parent_dir = request.form.get("parent_dir", "")
|
parent_dir = request.form.get("parent_dir", "")
|
||||||
error_context += f", parent_dir={parent_dir}"
|
error_context += f", parent_dir={parent_dir}"
|
||||||
|
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||||
)
|
)
|
||||||
@@ -1616,6 +1632,7 @@ class CreateAgent(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that it has either a 'schema' property or is itself a schema
|
# Validate that it has either a 'schema' property or is itself a schema
|
||||||
|
|
||||||
if "schema" not in json_schema and "type" not in json_schema:
|
if "schema" not in json_schema and "type" not in json_schema:
|
||||||
@@ -3625,7 +3642,60 @@ class UpdateTool(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
update_data["config"] = data["config"]
|
tool_doc = user_tools_collection.find_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user}
|
||||||
|
)
|
||||||
|
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
||||||
|
config = data["config"]
|
||||||
|
existing_config = tool_doc.get("config", {})
|
||||||
|
storage_config = existing_config.copy()
|
||||||
|
|
||||||
|
storage_config.update(config)
|
||||||
|
existing_credentials = {}
|
||||||
|
if "encrypted_credentials" in existing_config:
|
||||||
|
existing_credentials = decrypt_credentials(
|
||||||
|
existing_config["encrypted_credentials"], user
|
||||||
|
)
|
||||||
|
auth_credentials = existing_credentials.copy()
|
||||||
|
auth_type = storage_config.get("auth_type", "none")
|
||||||
|
if auth_type == "api_key":
|
||||||
|
if "api_key" in config and config["api_key"]:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config[
|
||||||
|
"api_key_header"
|
||||||
|
]
|
||||||
|
elif auth_type == "bearer":
|
||||||
|
if "bearer_token" in config and config["bearer_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif "encrypted_token" in config and config["encrypted_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["encrypted_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config and config["username"]:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config and config["password"]:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
if auth_type != "none" and auth_credentials:
|
||||||
|
encrypted_credentials_string = encrypt_credentials(
|
||||||
|
auth_credentials, user
|
||||||
|
)
|
||||||
|
storage_config["encrypted_credentials"] = (
|
||||||
|
encrypted_credentials_string
|
||||||
|
)
|
||||||
|
elif auth_type == "none":
|
||||||
|
storage_config.pop("encrypted_credentials", None)
|
||||||
|
for field in [
|
||||||
|
"api_key",
|
||||||
|
"bearer_token",
|
||||||
|
"encrypted_token",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"api_key_header",
|
||||||
|
]:
|
||||||
|
storage_config.pop(field, None)
|
||||||
|
update_data["config"] = storage_config
|
||||||
|
else:
|
||||||
|
update_data["config"] = data["config"]
|
||||||
if "status" in data:
|
if "status" in data:
|
||||||
update_data["status"] = data["status"]
|
update_data["status"] = data["status"]
|
||||||
user_tools_collection.update_one(
|
user_tools_collection.update_one(
|
||||||
@@ -3837,6 +3907,7 @@ class GetChunks(Resource):
|
|||||||
if not (text_match or title_match):
|
if not (text_match or title_match):
|
||||||
continue
|
continue
|
||||||
filtered_chunks.append(chunk)
|
filtered_chunks.append(chunk)
|
||||||
|
|
||||||
chunks = filtered_chunks
|
chunks = filtered_chunks
|
||||||
|
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
@@ -4027,6 +4098,7 @@ class UpdateChunk(Resource):
|
|||||||
current_app.logger.warning(
|
current_app.logger.warning(
|
||||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||||
)
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -4154,19 +4226,23 @@ class DirectoryStructure(Resource):
|
|||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
doc_id = request.args.get("id")
|
doc_id = request.args.get("id")
|
||||||
|
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
if not ObjectId.is_valid(doc_id):
|
||||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
|
|
||||||
directory_structure = doc.get("directory_structure", {})
|
directory_structure = doc.get("directory_structure", {})
|
||||||
base_path = doc.get("file_path", "")
|
base_path = doc.get("file_path", "")
|
||||||
|
|
||||||
@@ -4196,3 +4272,204 @@ class DirectoryStructure(Resource):
|
|||||||
f"Error retrieving directory structure: {e}", exc_info=True
|
f"Error retrieving directory structure: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@user_ns.route("/api/mcp_server/test")
|
||||||
|
class TestMCPServerConfig(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MCPServerTestModel",
|
||||||
|
{
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="MCP server configuration to test"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Test MCP server connection with provided configuration")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
required_fields = ["config"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
config = data["config"]
|
||||||
|
|
||||||
|
auth_credentials = {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "api_key" and "api_key" in config:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||||
|
elif auth_type == "bearer" and "bearer_token" in config:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
|
||||||
|
test_config = config.copy()
|
||||||
|
test_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
|
mcp_tool = MCPTool(test_config, user)
|
||||||
|
result = mcp_tool.test_connection()
|
||||||
|
|
||||||
|
return make_response(jsonify(result), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@user_ns.route("/api/mcp_server/save")
|
||||||
|
class MCPServerSave(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MCPServerSaveModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(
|
||||||
|
required=False, description="Tool ID for updates (optional)"
|
||||||
|
),
|
||||||
|
"displayName": fields.String(
|
||||||
|
required=True, description="Display name for the MCP server"
|
||||||
|
),
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="MCP server configuration"
|
||||||
|
),
|
||||||
|
"status": fields.Boolean(
|
||||||
|
required=False, default=True, description="Tool status"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
required_fields = ["displayName", "config"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
config = data["config"]
|
||||||
|
|
||||||
|
auth_credentials = {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
if auth_type == "api_key":
|
||||||
|
if "api_key" in config and config["api_key"]:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||||
|
elif auth_type == "bearer":
|
||||||
|
if "bearer_token" in config and config["bearer_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config and config["username"]:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config and config["password"]:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
mcp_config = config.copy()
|
||||||
|
mcp_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
|
if auth_type == "none" or auth_credentials:
|
||||||
|
mcp_tool = MCPTool(mcp_config, user)
|
||||||
|
mcp_tool.discover_tools()
|
||||||
|
actions_metadata = mcp_tool.get_actions_metadata()
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"No valid credentials provided for the selected authentication type"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_config = config.copy()
|
||||||
|
if auth_credentials:
|
||||||
|
encrypted_credentials_string = encrypt_credentials(
|
||||||
|
auth_credentials, user
|
||||||
|
)
|
||||||
|
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||||
|
|
||||||
|
for field in [
|
||||||
|
"api_key",
|
||||||
|
"bearer_token",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"api_key_header",
|
||||||
|
]:
|
||||||
|
storage_config.pop(field, None)
|
||||||
|
transformed_actions = []
|
||||||
|
for action in actions_metadata:
|
||||||
|
action["active"] = True
|
||||||
|
if "parameters" in action:
|
||||||
|
if "properties" in action["parameters"]:
|
||||||
|
for param_name, param_details in action["parameters"][
|
||||||
|
"properties"
|
||||||
|
].items():
|
||||||
|
param_details["filled_by_llm"] = True
|
||||||
|
param_details["value"] = ""
|
||||||
|
transformed_actions.append(action)
|
||||||
|
tool_data = {
|
||||||
|
"name": "mcp_tool",
|
||||||
|
"displayName": data["displayName"],
|
||||||
|
"customName": data["displayName"],
|
||||||
|
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||||
|
"config": storage_config,
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"status": data.get("status", True),
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_id = data.get("id")
|
||||||
|
if tool_id:
|
||||||
|
result = user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||||
|
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||||
|
)
|
||||||
|
if result.matched_count == 0:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Tool not found or access denied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": tool_id,
|
||||||
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
result = user_tools_collection.insert_one(tool_data)
|
||||||
|
tool_id = str(result.inserted_id)
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": tool_id,
|
||||||
|
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
return make_response(jsonify(response_data), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
|
|||||||
"gpt-4o-mini": 128000,
|
"gpt-4o-mini": 128000,
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"claude-2": 1e5,
|
"claude-2": 1e5,
|
||||||
"gemini-2.0-flash-exp": 1e6,
|
"gemini-2.5-flash": 1e6,
|
||||||
}
|
}
|
||||||
UPLOAD_FOLDER: str = "inputs"
|
UPLOAD_FOLDER: str = "inputs"
|
||||||
PARSE_PDF_AS_IMAGE: bool = False
|
PARSE_PDF_AS_IMAGE: bool = False
|
||||||
@@ -96,7 +96,7 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_HOST: Optional[str] = None
|
QDRANT_HOST: Optional[str] = None
|
||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
# PGVector vectorstore config
|
# PGVector vectorstore config
|
||||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
JWT_SECRET_KEY: str = ""
|
JWT_SECRET_KEY: str = ""
|
||||||
|
|
||||||
|
# Encryption settings
|
||||||
|
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||||
|
|
||||||
|
|
||||||
path = Path(__file__).parent.parent.absolute()
|
path = Path(__file__).parent.parent.absolute()
|
||||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
|
"""Convert OpenAI format messages to Google AI format."""
|
||||||
cleaned_messages = []
|
cleaned_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
role = "model"
|
role = "model"
|
||||||
|
elif role == "tool":
|
||||||
|
role = "model"
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
if parts:
|
||||||
|
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||||
|
|
||||||
return cleaned_messages
|
return cleaned_messages
|
||||||
|
|
||||||
|
def _clean_schema(self, schema_obj):
|
||||||
|
"""
|
||||||
|
Recursively remove unsupported fields from schema objects
|
||||||
|
and validate required properties.
|
||||||
|
"""
|
||||||
|
if not isinstance(schema_obj, dict):
|
||||||
|
return schema_obj
|
||||||
|
allowed_fields = {
|
||||||
|
"type",
|
||||||
|
"description",
|
||||||
|
"items",
|
||||||
|
"properties",
|
||||||
|
"required",
|
||||||
|
"enum",
|
||||||
|
"pattern",
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"nullable",
|
||||||
|
"default",
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned = {}
|
||||||
|
for key, value in schema_obj.items():
|
||||||
|
if key not in allowed_fields:
|
||||||
|
continue
|
||||||
|
elif key == "type" and isinstance(value, str):
|
||||||
|
cleaned[key] = value.upper()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
cleaned[key] = self._clean_schema(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||||
|
else:
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
# Validate that required properties actually exist in properties
|
||||||
|
if "required" in cleaned and "properties" in cleaned:
|
||||||
|
valid_required = []
|
||||||
|
properties_keys = set(cleaned["properties"].keys())
|
||||||
|
for required_prop in cleaned["required"]:
|
||||||
|
if required_prop in properties_keys:
|
||||||
|
valid_required.append(required_prop)
|
||||||
|
if valid_required:
|
||||||
|
cleaned["required"] = valid_required
|
||||||
|
else:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
elif "required" in cleaned and "properties" not in cleaned:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
def _clean_tools_format(self, tools_list):
|
def _clean_tools_format(self, tools_list):
|
||||||
|
"""Convert OpenAI format tools to Google AI format."""
|
||||||
genai_tools = []
|
genai_tools = []
|
||||||
for tool_data in tools_list:
|
for tool_data in tools_list:
|
||||||
if tool_data["type"] == "function":
|
if tool_data["type"] == "function":
|
||||||
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
|
|||||||
properties = parameters.get("properties", {})
|
properties = parameters.get("properties", {})
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
|
cleaned_properties = {}
|
||||||
|
for k, v in properties.items():
|
||||||
|
cleaned_properties[k] = self._clean_schema(v)
|
||||||
|
|
||||||
genai_function = dict(
|
genai_function = dict(
|
||||||
name=function["name"],
|
name=function["name"],
|
||||||
description=function["description"],
|
description=function["description"],
|
||||||
parameters={
|
parameters={
|
||||||
"type": "OBJECT",
|
"type": "OBJECT",
|
||||||
"properties": {
|
"properties": cleaned_properties,
|
||||||
k: {
|
|
||||||
**v,
|
|
||||||
"type": v["type"].upper() if v["type"] else None,
|
|
||||||
}
|
|
||||||
for k, v in properties.items()
|
|
||||||
},
|
|
||||||
"required": (
|
"required": (
|
||||||
parameters["required"]
|
parameters["required"]
|
||||||
if "required" in parameters
|
if "required" in parameters
|
||||||
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API without streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API with streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
|
|||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
"""Return whether this LLM supports function calling."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _supports_structured_output(self):
|
def _supports_structured_output(self):
|
||||||
|
"""Return whether this LLM supports structured JSON output."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def prepare_structured_output_format(self, json_schema):
|
def prepare_structured_output_format(self, json_schema):
|
||||||
|
"""Convert JSON schema to Google AI structured output format."""
|
||||||
if not json_schema:
|
if not json_schema:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -205,7 +205,6 @@ class LLMHandler(ABC):
|
|||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
tool_response, call_id = e.value
|
tool_response, call_id = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
updated_messages.append(
|
updated_messages.append(
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||||
updated_messages.append(
|
error_call = ToolCall(
|
||||||
{
|
id=call.id, name=call.name, arguments=call.arguments
|
||||||
"role": "tool",
|
|
||||||
"content": f"Error executing tool: {str(e)}",
|
|
||||||
"tool_call_id": call.id,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
error_response = f"Error executing tool: {str(e)}"
|
||||||
|
error_message = self.create_tool_message(error_call, error_response)
|
||||||
|
updated_messages.append(error_message)
|
||||||
|
|
||||||
|
call_parts = call.name.split("_")
|
||||||
|
if len(call_parts) >= 2:
|
||||||
|
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||||
|
action_name = "_".join(call_parts[:-1])
|
||||||
|
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||||
|
full_action_name = f"{action_name}_{tool_id}"
|
||||||
|
else:
|
||||||
|
tool_name = "unknown_tool"
|
||||||
|
action_name = call.name
|
||||||
|
full_action_name = call.name
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"call_id": call.id,
|
||||||
|
"action_name": full_action_name,
|
||||||
|
"arguments": call.arguments,
|
||||||
|
"error": error_response,
|
||||||
|
"status": "error",
|
||||||
|
},
|
||||||
|
}
|
||||||
return updated_messages
|
return updated_messages
|
||||||
|
|
||||||
def handle_non_streaming(
|
def handle_non_streaming(
|
||||||
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
|
|||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
messages = e.value
|
messages = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
response = agent.llm.gen(
|
response = agent.llm.gen(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
parsed = self.parse_response(response)
|
parsed = self.parse_response(response)
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
return parsed.content
|
return parsed.content
|
||||||
|
|
||||||
def handle_streaming(
|
def handle_streaming(
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(response, "candidates"):
|
if hasattr(response, "candidates"):
|
||||||
parts = response.candidates[0].content.parts if response.candidates else []
|
parts = response.candidates[0].content.parts if response.candidates else []
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
finish_reason="tool_calls" if tool_calls else "stop",
|
finish_reason="tool_calls" if tool_calls else "stop",
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if hasattr(response, "function_call"):
|
if hasattr(response, "function_call"):
|
||||||
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
|
|
||||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
"""Create Google-style tool message."""
|
"""Create Google-style tool message."""
|
||||||
from google.genai import types
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"role": "tool",
|
"role": "model",
|
||||||
"content": [
|
"content": [
|
||||||
types.Part.from_function_response(
|
{
|
||||||
name=tool_call.name, response={"result": result}
|
"function_response": {
|
||||||
).to_json_dict()
|
"name": tool_call.name,
|
||||||
|
"response": {"result": result},
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
|||||||
boto3==1.38.18
|
boto3==1.38.18
|
||||||
beautifulsoup4==4.13.4
|
beautifulsoup4==4.13.4
|
||||||
celery==5.4.0
|
celery==5.4.0
|
||||||
|
cryptography==42.0.8
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
duckduckgo-search==7.5.2
|
duckduckgo-search==7.5.2
|
||||||
|
|||||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
85
application/security/encryption.py
Normal file
85
application/security/encryption.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||||
|
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||||
|
|
||||||
|
password = f"{app_secret}#{user_id}".encode()
|
||||||
|
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
salt=salt,
|
||||||
|
iterations=100000,
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return kdf.derive(password)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||||
|
if not credentials:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
salt = os.urandom(16)
|
||||||
|
iv = os.urandom(16)
|
||||||
|
key = _derive_key(user_id, salt)
|
||||||
|
|
||||||
|
json_str = json.dumps(credentials)
|
||||||
|
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
|
||||||
|
padded_data = _pad_data(json_str.encode())
|
||||||
|
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
|
||||||
|
result = salt + iv + encrypted_data
|
||||||
|
return base64.b64encode(result).decode()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||||
|
if not encrypted_data:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
data = base64.b64decode(encrypted_data.encode())
|
||||||
|
|
||||||
|
salt = data[:16]
|
||||||
|
iv = data[16:32]
|
||||||
|
encrypted_content = data[32:]
|
||||||
|
|
||||||
|
key = _derive_key(user_id, salt)
|
||||||
|
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
|
||||||
|
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||||
|
decrypted_data = _unpad_data(decrypted_padded)
|
||||||
|
|
||||||
|
return json.loads(decrypted_data.decode())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_data(data: bytes) -> bytes:
|
||||||
|
block_size = 16
|
||||||
|
padding_len = block_size - (len(data) % block_size)
|
||||||
|
padding = bytes([padding_len]) * padding_len
|
||||||
|
return data + padding
|
||||||
|
|
||||||
|
|
||||||
|
def _unpad_data(data: bytes) -> bytes:
|
||||||
|
padding_len = data[-1]
|
||||||
|
return data[:-padding_len]
|
||||||
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
|
||||||
|
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||||
|
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 831 B |
@@ -57,6 +57,8 @@ const endpoints = {
|
|||||||
DIRECTORY_STRUCTURE: (docId: string) =>
|
DIRECTORY_STRUCTURE: (docId: string) =>
|
||||||
`/api/directory_structure?id=${docId}`,
|
`/api/directory_structure?id=${docId}`,
|
||||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||||
|
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||||
|
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||||
},
|
},
|
||||||
CONVERSATION: {
|
CONVERSATION: {
|
||||||
ANSWER: '/api/answer',
|
ANSWER: '/api/answer',
|
||||||
|
|||||||
@@ -108,6 +108,10 @@ const userService = {
|
|||||||
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
||||||
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
||||||
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
||||||
|
testMCPConnection: (data: any, token: string | null): Promise<any> =>
|
||||||
|
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||||
|
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||||
|
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||||
syncConnector: (
|
syncConnector: (
|
||||||
docId: string,
|
docId: string,
|
||||||
provider: string,
|
provider: string,
|
||||||
|
|||||||
0
frontend/src/assets/server.svg
Normal file
0
frontend/src/assets/server.svg
Normal file
@@ -1,6 +1,6 @@
|
|||||||
import 'katex/dist/katex.min.css';
|
import 'katex/dist/katex.min.css';
|
||||||
|
|
||||||
import { forwardRef, Fragment, useRef, useState, useEffect } from 'react';
|
import { forwardRef, Fragment, useEffect, useRef, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import ReactMarkdown from 'react-markdown';
|
import ReactMarkdown from 'react-markdown';
|
||||||
import { useSelector } from 'react-redux';
|
import { useSelector } from 'react-redux';
|
||||||
@@ -12,12 +12,13 @@ import {
|
|||||||
import rehypeKatex from 'rehype-katex';
|
import rehypeKatex from 'rehype-katex';
|
||||||
import remarkGfm from 'remark-gfm';
|
import remarkGfm from 'remark-gfm';
|
||||||
import remarkMath from 'remark-math';
|
import remarkMath from 'remark-math';
|
||||||
import DocumentationDark from '../assets/documentation-dark.svg';
|
|
||||||
import ChevronDown from '../assets/chevron-down.svg';
|
import ChevronDown from '../assets/chevron-down.svg';
|
||||||
import Cloud from '../assets/cloud.svg';
|
import Cloud from '../assets/cloud.svg';
|
||||||
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
||||||
import Dislike from '../assets/dislike.svg?react';
|
import Dislike from '../assets/dislike.svg?react';
|
||||||
import Document from '../assets/document.svg';
|
import Document from '../assets/document.svg';
|
||||||
|
import DocumentationDark from '../assets/documentation-dark.svg';
|
||||||
import Edit from '../assets/edit.svg';
|
import Edit from '../assets/edit.svg';
|
||||||
import Like from '../assets/like.svg?react';
|
import Like from '../assets/like.svg?react';
|
||||||
import Link from '../assets/link.svg';
|
import Link from '../assets/link.svg';
|
||||||
@@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
|||||||
Response
|
Response
|
||||||
</span>{' '}
|
</span>{' '}
|
||||||
<CopyButton
|
<CopyButton
|
||||||
textToCopy={JSON.stringify(toolCall.result, null, 2)}
|
textToCopy={
|
||||||
|
toolCall.status === 'error'
|
||||||
|
? toolCall.error || 'Unknown error'
|
||||||
|
: JSON.stringify(toolCall.result, null, 2)
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</p>
|
</p>
|
||||||
{toolCall.status === 'pending' && (
|
{toolCall.status === 'pending' && (
|
||||||
@@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
|||||||
</span>
|
</span>
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
|
{toolCall.status === 'error' && (
|
||||||
|
<p className="dark:bg-raisin-black rounded-b-2xl p-2 font-mono text-sm break-words">
|
||||||
|
<span
|
||||||
|
className="leading-[23px] text-red-500 dark:text-red-400"
|
||||||
|
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||||
|
>
|
||||||
|
{toolCall.error}
|
||||||
|
</span>
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|||||||
@@ -4,5 +4,6 @@ export type ToolCallsType = {
|
|||||||
call_id: string;
|
call_id: string;
|
||||||
arguments: Record<string, any>;
|
arguments: Record<string, any>;
|
||||||
result?: Record<string, any>;
|
result?: Record<string, any>;
|
||||||
status?: 'pending' | 'completed';
|
error?: string;
|
||||||
|
status?: 'pending' | 'completed' | 'error';
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ export default function useDefaultDocument() {
|
|||||||
const fetchDocs = () => {
|
const fetchDocs = () => {
|
||||||
getDocs(token).then((data) => {
|
getDocs(token).then((data) => {
|
||||||
dispatch(setSourceDocs(data));
|
dispatch(setSourceDocs(data));
|
||||||
if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0))
|
if (
|
||||||
|
!selectedDoc ||
|
||||||
|
(Array.isArray(selectedDoc) && selectedDoc.length === 0)
|
||||||
|
)
|
||||||
Array.isArray(data) &&
|
Array.isArray(data) &&
|
||||||
data?.forEach((doc: Doc) => {
|
data?.forEach((doc: Doc) => {
|
||||||
if (doc.model && doc.name === 'default') {
|
if (doc.model && doc.name === 'default') {
|
||||||
|
|||||||
@@ -184,7 +184,39 @@
|
|||||||
"cancel": "Cancel",
|
"cancel": "Cancel",
|
||||||
"addNew": "Add New",
|
"addNew": "Add New",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"type": "Type"
|
"type": "Type",
|
||||||
|
"mcp": {
|
||||||
|
"addServer": "Add MCP Server",
|
||||||
|
"editServer": "Edit Server",
|
||||||
|
"serverName": "Server Name",
|
||||||
|
"serverUrl": "Server URL",
|
||||||
|
"headerName": "Header Name",
|
||||||
|
"timeout": "Timeout (seconds)",
|
||||||
|
"testConnection": "Test Connection",
|
||||||
|
"testing": "Testing...",
|
||||||
|
"saving": "Saving...",
|
||||||
|
"save": "Save",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"noAuth": "No Authentication",
|
||||||
|
"placeholders": {
|
||||||
|
"serverUrl": "https://api.example.com",
|
||||||
|
"apiKey": "Your secret API key",
|
||||||
|
"bearerToken": "Your secret token",
|
||||||
|
"username": "Your username",
|
||||||
|
"password": "Your password"
|
||||||
|
},
|
||||||
|
"errors": {
|
||||||
|
"nameRequired": "Server name is required",
|
||||||
|
"urlRequired": "Server URL is required",
|
||||||
|
"invalidUrl": "Please enter a valid URL",
|
||||||
|
"apiKeyRequired": "API key is required",
|
||||||
|
"tokenRequired": "Bearer token is required",
|
||||||
|
"usernameRequired": "Username is required",
|
||||||
|
"passwordRequired": "Password is required",
|
||||||
|
"testFailed": "Connection test failed",
|
||||||
|
"saveFailed": "Failed to save MCP server"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"modals": {
|
"modals": {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks';
|
|||||||
import { ActiveState } from '../models/misc';
|
import { ActiveState } from '../models/misc';
|
||||||
import { selectToken } from '../preferences/preferenceSlice';
|
import { selectToken } from '../preferences/preferenceSlice';
|
||||||
import ConfigToolModal from './ConfigToolModal';
|
import ConfigToolModal from './ConfigToolModal';
|
||||||
|
import MCPServerModal from './MCPServerModal';
|
||||||
import { AvailableToolType } from './types';
|
import { AvailableToolType } from './types';
|
||||||
import WrapperComponent from './WrapperModal';
|
import WrapperComponent from './WrapperModal';
|
||||||
|
|
||||||
@@ -34,6 +35,8 @@ export default function AddToolModal({
|
|||||||
React.useState<AvailableToolType | null>(null);
|
React.useState<AvailableToolType | null>(null);
|
||||||
const [configModalState, setConfigModalState] =
|
const [configModalState, setConfigModalState] =
|
||||||
React.useState<ActiveState>('INACTIVE');
|
React.useState<ActiveState>('INACTIVE');
|
||||||
|
const [mcpModalState, setMcpModalState] =
|
||||||
|
React.useState<ActiveState>('INACTIVE');
|
||||||
const [loading, setLoading] = React.useState(false);
|
const [loading, setLoading] = React.useState(false);
|
||||||
|
|
||||||
useOutsideAlerter(modalRef, () => {
|
useOutsideAlerter(modalRef, () => {
|
||||||
@@ -86,6 +89,9 @@ export default function AddToolModal({
|
|||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Failed to create tool:', error);
|
console.error('Failed to create tool:', error);
|
||||||
});
|
});
|
||||||
|
} else if (tool.name === 'mcp_tool') {
|
||||||
|
setModalState('INACTIVE');
|
||||||
|
setMcpModalState('ACTIVE');
|
||||||
} else {
|
} else {
|
||||||
setModalState('INACTIVE');
|
setModalState('INACTIVE');
|
||||||
setConfigModalState('ACTIVE');
|
setConfigModalState('ACTIVE');
|
||||||
@@ -95,6 +101,12 @@ export default function AddToolModal({
|
|||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
if (modalState === 'ACTIVE') getAvailableTools();
|
if (modalState === 'ACTIVE') getAvailableTools();
|
||||||
}, [modalState]);
|
}, [modalState]);
|
||||||
|
|
||||||
|
const handleMcpServerAdded = () => {
|
||||||
|
getUserTools();
|
||||||
|
setMcpModalState('INACTIVE');
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{modalState === 'ACTIVE' && (
|
{modalState === 'ACTIVE' && (
|
||||||
@@ -166,6 +178,11 @@ export default function AddToolModal({
|
|||||||
tool={selectedTool}
|
tool={selectedTool}
|
||||||
getUserTools={getUserTools}
|
getUserTools={getUserTools}
|
||||||
/>
|
/>
|
||||||
|
<MCPServerModal
|
||||||
|
modalState={mcpModalState}
|
||||||
|
setModalState={setMcpModalState}
|
||||||
|
onServerSaved={handleMcpServerAdded}
|
||||||
|
/>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
482
frontend/src/modals/MCPServerModal.tsx
Normal file
482
frontend/src/modals/MCPServerModal.tsx
Normal file
@@ -0,0 +1,482 @@
|
|||||||
|
import { useRef, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useSelector } from 'react-redux';
|
||||||
|
|
||||||
|
import userService from '../api/services/userService';
|
||||||
|
import Dropdown from '../components/Dropdown';
|
||||||
|
import Input from '../components/Input';
|
||||||
|
import Spinner from '../components/Spinner';
|
||||||
|
import { useOutsideAlerter } from '../hooks';
|
||||||
|
import { ActiveState } from '../models/misc';
|
||||||
|
import { selectToken } from '../preferences/preferenceSlice';
|
||||||
|
import WrapperComponent from './WrapperModal';
|
||||||
|
|
||||||
|
interface MCPServerModalProps {
|
||||||
|
modalState: ActiveState;
|
||||||
|
setModalState: (state: ActiveState) => void;
|
||||||
|
server?: any;
|
||||||
|
onServerSaved: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const authTypes = [
|
||||||
|
{ label: 'No Authentication', value: 'none' },
|
||||||
|
{ label: 'API Key', value: 'api_key' },
|
||||||
|
{ label: 'Bearer Token', value: 'bearer' },
|
||||||
|
// { label: 'Basic Authentication', value: 'basic' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function MCPServerModal({
|
||||||
|
modalState,
|
||||||
|
setModalState,
|
||||||
|
server,
|
||||||
|
onServerSaved,
|
||||||
|
}: MCPServerModalProps) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const token = useSelector(selectToken);
|
||||||
|
const modalRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
const [formData, setFormData] = useState({
|
||||||
|
name: server?.displayName || 'My MCP Server',
|
||||||
|
server_url: server?.server_url || '',
|
||||||
|
auth_type: server?.auth_type || 'none',
|
||||||
|
api_key: '',
|
||||||
|
header_name: 'X-API-Key',
|
||||||
|
bearer_token: '',
|
||||||
|
username: '',
|
||||||
|
password: '',
|
||||||
|
timeout: server?.timeout || 30,
|
||||||
|
});
|
||||||
|
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [testing, setTesting] = useState(false);
|
||||||
|
const [testResult, setTestResult] = useState<{
|
||||||
|
success: boolean;
|
||||||
|
message: string;
|
||||||
|
} | null>(null);
|
||||||
|
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||||
|
|
||||||
|
useOutsideAlerter(modalRef, () => {
|
||||||
|
if (modalState === 'ACTIVE') {
|
||||||
|
setModalState('INACTIVE');
|
||||||
|
resetForm();
|
||||||
|
}
|
||||||
|
}, [modalState]);
|
||||||
|
|
||||||
|
const resetForm = () => {
|
||||||
|
setFormData({
|
||||||
|
name: 'My MCP Server',
|
||||||
|
server_url: '',
|
||||||
|
auth_type: 'none',
|
||||||
|
api_key: '',
|
||||||
|
header_name: 'X-API-Key',
|
||||||
|
bearer_token: '',
|
||||||
|
username: '',
|
||||||
|
password: '',
|
||||||
|
timeout: 30,
|
||||||
|
});
|
||||||
|
setErrors({});
|
||||||
|
setTestResult(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const validateForm = () => {
|
||||||
|
const requiredFields: { [key: string]: boolean } = {
|
||||||
|
name: !formData.name.trim(),
|
||||||
|
server_url: !formData.server_url.trim(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const authFieldChecks: { [key: string]: () => void } = {
|
||||||
|
api_key: () => {
|
||||||
|
if (!formData.api_key.trim())
|
||||||
|
newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired');
|
||||||
|
},
|
||||||
|
bearer: () => {
|
||||||
|
if (!formData.bearer_token.trim())
|
||||||
|
newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired');
|
||||||
|
},
|
||||||
|
basic: () => {
|
||||||
|
if (!formData.username.trim())
|
||||||
|
newErrors.username = t('settings.tools.mcp.errors.usernameRequired');
|
||||||
|
if (!formData.password.trim())
|
||||||
|
newErrors.password = t('settings.tools.mcp.errors.passwordRequired');
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const newErrors: { [key: string]: string } = {};
|
||||||
|
Object.entries(requiredFields).forEach(([field, isEmpty]) => {
|
||||||
|
if (isEmpty)
|
||||||
|
newErrors[field] = t(
|
||||||
|
`settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (formData.server_url.trim()) {
|
||||||
|
try {
|
||||||
|
new URL(formData.server_url);
|
||||||
|
} catch {
|
||||||
|
newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const timeoutValue = formData.timeout === '' ? 30 : formData.timeout;
|
||||||
|
if (
|
||||||
|
typeof timeoutValue === 'number' &&
|
||||||
|
(timeoutValue < 1 || timeoutValue > 300)
|
||||||
|
)
|
||||||
|
newErrors.timeout = 'Timeout must be between 1 and 300 seconds';
|
||||||
|
|
||||||
|
if (authFieldChecks[formData.auth_type])
|
||||||
|
authFieldChecks[formData.auth_type]();
|
||||||
|
|
||||||
|
setErrors(newErrors);
|
||||||
|
return Object.keys(newErrors).length === 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleInputChange = (name: string, value: string | number) => {
|
||||||
|
setFormData((prev) => ({ ...prev, [name]: value }));
|
||||||
|
if (errors[name]) {
|
||||||
|
setErrors((prev) => ({ ...prev, [name]: '' }));
|
||||||
|
}
|
||||||
|
setTestResult(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildToolConfig = () => {
|
||||||
|
const config: any = {
|
||||||
|
server_url: formData.server_url.trim(),
|
||||||
|
auth_type: formData.auth_type,
|
||||||
|
timeout: formData.timeout === '' ? 30 : formData.timeout,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (formData.auth_type === 'api_key') {
|
||||||
|
config.api_key = formData.api_key.trim();
|
||||||
|
config.api_key_header = formData.header_name.trim() || 'X-API-Key';
|
||||||
|
} else if (formData.auth_type === 'bearer') {
|
||||||
|
config.bearer_token = formData.bearer_token.trim();
|
||||||
|
} else if (formData.auth_type === 'basic') {
|
||||||
|
config.username = formData.username.trim();
|
||||||
|
config.password = formData.password.trim();
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
};
|
||||||
|
|
||||||
|
const testConnection = async () => {
|
||||||
|
if (!validateForm()) return;
|
||||||
|
setTesting(true);
|
||||||
|
setTestResult(null);
|
||||||
|
try {
|
||||||
|
const config = buildToolConfig();
|
||||||
|
const response = await userService.testMCPConnection({ config }, token);
|
||||||
|
const result = await response.json();
|
||||||
|
|
||||||
|
setTestResult(result);
|
||||||
|
} catch (error) {
|
||||||
|
setTestResult({
|
||||||
|
success: false,
|
||||||
|
message: t('settings.tools.mcp.errors.testFailed'),
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
setTesting(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSave = async () => {
|
||||||
|
if (!validateForm()) return;
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const config = buildToolConfig();
|
||||||
|
const serverData = {
|
||||||
|
displayName: formData.name,
|
||||||
|
config,
|
||||||
|
status: true,
|
||||||
|
...(server?.id && { id: server.id }),
|
||||||
|
};
|
||||||
|
|
||||||
|
const response = await userService.saveMCPServer(serverData, token);
|
||||||
|
const result = await response.json();
|
||||||
|
|
||||||
|
if (response.ok && result.success) {
|
||||||
|
setTestResult({
|
||||||
|
success: true,
|
||||||
|
message: result.message,
|
||||||
|
});
|
||||||
|
onServerSaved();
|
||||||
|
setModalState('INACTIVE');
|
||||||
|
resetForm();
|
||||||
|
} else {
|
||||||
|
setErrors({
|
||||||
|
general: result.error || t('settings.tools.mcp.errors.saveFailed'),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error saving MCP server:', error);
|
||||||
|
setErrors({ general: t('settings.tools.mcp.errors.saveFailed') });
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderAuthFields = () => {
|
||||||
|
switch (formData.auth_type) {
|
||||||
|
case 'api_key':
|
||||||
|
return (
|
||||||
|
<div className="mb-10">
|
||||||
|
<div className="mt-6">
|
||||||
|
<Input
|
||||||
|
name="api_key"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.api_key}
|
||||||
|
onChange={(e) => handleInputChange('api_key', e.target.value)}
|
||||||
|
placeholder={t('settings.tools.mcp.placeholders.apiKey')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.api_key && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">{errors.api_key}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="mt-5">
|
||||||
|
<Input
|
||||||
|
name="header_name"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.header_name}
|
||||||
|
onChange={(e) =>
|
||||||
|
handleInputChange('header_name', e.target.value)
|
||||||
|
}
|
||||||
|
placeholder={t('settings.tools.mcp.headerName')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
case 'bearer':
|
||||||
|
return (
|
||||||
|
<div className="mb-10">
|
||||||
|
<Input
|
||||||
|
name="bearer_token"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.bearer_token}
|
||||||
|
onChange={(e) =>
|
||||||
|
handleInputChange('bearer_token', e.target.value)
|
||||||
|
}
|
||||||
|
placeholder={t('settings.tools.mcp.placeholders.bearerToken')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.bearer_token && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">{errors.bearer_token}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
case 'basic':
|
||||||
|
return (
|
||||||
|
<div className="mb-10">
|
||||||
|
<div className="mt-6">
|
||||||
|
<Input
|
||||||
|
name="username"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.username}
|
||||||
|
onChange={(e) => handleInputChange('username', e.target.value)}
|
||||||
|
placeholder={t('settings.tools.mcp.username')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.username && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">{errors.username}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="mt-5">
|
||||||
|
<Input
|
||||||
|
name="password"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.password}
|
||||||
|
onChange={(e) => handleInputChange('password', e.target.value)}
|
||||||
|
placeholder={t('settings.tools.mcp.password')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.password && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">{errors.password}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
default:
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
modalState === 'ACTIVE' && (
|
||||||
|
<WrapperComponent
|
||||||
|
close={() => {
|
||||||
|
setModalState('INACTIVE');
|
||||||
|
resetForm();
|
||||||
|
}}
|
||||||
|
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
|
||||||
|
>
|
||||||
|
<div className="flex h-full flex-col">
|
||||||
|
<div className="px-6 py-4">
|
||||||
|
<h2 className="text-jet dark:text-bright-gray text-xl font-semibold">
|
||||||
|
{server
|
||||||
|
? t('settings.tools.mcp.editServer')
|
||||||
|
: t('settings.tools.mcp.addServer')}
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<div className="flex-1 px-6">
|
||||||
|
<div className="space-y-6 py-6">
|
||||||
|
<div>
|
||||||
|
<Input
|
||||||
|
name="name"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.name}
|
||||||
|
onChange={(e) => handleInputChange('name', e.target.value)}
|
||||||
|
borderVariant="thin"
|
||||||
|
placeholder={t('settings.tools.mcp.serverName')}
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.name && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">{errors.name}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Input
|
||||||
|
name="server_url"
|
||||||
|
type="text"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.server_url}
|
||||||
|
onChange={(e) =>
|
||||||
|
handleInputChange('server_url', e.target.value)
|
||||||
|
}
|
||||||
|
placeholder={t('settings.tools.mcp.serverUrl')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.server_url && (
|
||||||
|
<p className="mt-1 text-sm text-red-600">
|
||||||
|
{errors.server_url}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Dropdown
|
||||||
|
placeholder={t('settings.tools.mcp.authType')}
|
||||||
|
selectedValue={
|
||||||
|
authTypes.find((type) => type.value === formData.auth_type)
|
||||||
|
?.label || null
|
||||||
|
}
|
||||||
|
onSelect={(selection: { label: string; value: string }) => {
|
||||||
|
handleInputChange('auth_type', selection.value);
|
||||||
|
}}
|
||||||
|
options={authTypes}
|
||||||
|
size="w-full"
|
||||||
|
rounded="3xl"
|
||||||
|
border="border"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{renderAuthFields()}
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Input
|
||||||
|
name="timeout"
|
||||||
|
type="number"
|
||||||
|
className="rounded-md"
|
||||||
|
value={formData.timeout}
|
||||||
|
onChange={(e) => {
|
||||||
|
const value = e.target.value;
|
||||||
|
if (value === '') {
|
||||||
|
handleInputChange('timeout', '');
|
||||||
|
} else {
|
||||||
|
const numValue = parseInt(value);
|
||||||
|
if (!isNaN(numValue) && numValue >= 1) {
|
||||||
|
handleInputChange('timeout', numValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
placeholder={t('settings.tools.mcp.timeout')}
|
||||||
|
borderVariant="thin"
|
||||||
|
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||||
|
/>
|
||||||
|
{errors.timeout && (
|
||||||
|
<p className="mt-2 text-sm text-red-600">{errors.timeout}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{testResult && (
|
||||||
|
<div
|
||||||
|
className={`rounded-md p-5 ${
|
||||||
|
testResult.success
|
||||||
|
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||||
|
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{testResult.message}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{errors.general && (
|
||||||
|
<div className="rounded-2xl bg-red-50 p-5 text-red-700 dark:bg-red-900 dark:text-red-300">
|
||||||
|
{errors.general}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="px-6 py-2">
|
||||||
|
<div className="flex flex-col gap-4 sm:flex-row sm:justify-between">
|
||||||
|
<button
|
||||||
|
onClick={testConnection}
|
||||||
|
disabled={testing}
|
||||||
|
className="border-silver dark:border-dim-gray dark:text-light-gray w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all hover:bg-gray-100 disabled:opacity-50 sm:w-auto dark:hover:bg-[#767183]/50"
|
||||||
|
>
|
||||||
|
{testing ? (
|
||||||
|
<div className="flex items-center justify-center">
|
||||||
|
<Spinner size="small" />
|
||||||
|
<span className="ml-2">
|
||||||
|
{t('settings.tools.mcp.testing')}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
t('settings.tools.mcp.testConnection')
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
setModalState('INACTIVE');
|
||||||
|
resetForm();
|
||||||
|
}}
|
||||||
|
className="dark:text-light-gray w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium hover:bg-gray-100 sm:w-auto dark:bg-transparent dark:hover:bg-[#767183]/50"
|
||||||
|
>
|
||||||
|
{t('settings.tools.mcp.cancel')}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={handleSave}
|
||||||
|
disabled={loading}
|
||||||
|
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
||||||
|
>
|
||||||
|
{loading ? (
|
||||||
|
<div className="flex items-center justify-center">
|
||||||
|
<Spinner size="small" />
|
||||||
|
<span className="ml-2">
|
||||||
|
{t('settings.tools.mcp.saving')}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
t('settings.tools.mcp.save')
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</WrapperComponent>
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -92,7 +92,7 @@ export function getLocalApiKey(): string | null {
|
|||||||
|
|
||||||
export function getLocalRecentDocs(): Doc[] | null {
|
export function getLocalRecentDocs(): Doc[] | null {
|
||||||
const docs = localStorage.getItem('DocsGPTRecentDocs');
|
const docs = localStorage.getItem('DocsGPTRecentDocs');
|
||||||
return docs ? JSON.parse(docs) as Doc[] : null;
|
return docs ? (JSON.parse(docs) as Doc[]) : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getLocalPrompt(): string | null {
|
export function getLocalPrompt(): string | null {
|
||||||
|
|||||||
@@ -30,9 +30,22 @@ export default function ToolConfig({
|
|||||||
handleGoBack: () => void;
|
handleGoBack: () => void;
|
||||||
}) {
|
}) {
|
||||||
const token = useSelector(selectToken);
|
const token = useSelector(selectToken);
|
||||||
const [authKey, setAuthKey] = React.useState<string>(
|
const [authKey, setAuthKey] = React.useState<string>(() => {
|
||||||
'token' in tool.config ? tool.config.token : '',
|
if (tool.name === 'mcp_tool') {
|
||||||
);
|
const config = tool.config as any;
|
||||||
|
if (config.auth_type === 'api_key') {
|
||||||
|
return config.api_key || '';
|
||||||
|
} else if (config.auth_type === 'bearer') {
|
||||||
|
return config.encrypted_token || '';
|
||||||
|
} else if (config.auth_type === 'basic') {
|
||||||
|
return config.password || '';
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
} else if ('token' in tool.config) {
|
||||||
|
return tool.config.token;
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
});
|
||||||
const [customName, setCustomName] = React.useState<string>(
|
const [customName, setCustomName] = React.useState<string>(
|
||||||
tool.customName || '',
|
tool.customName || '',
|
||||||
);
|
);
|
||||||
@@ -97,6 +110,26 @@ export default function ToolConfig({
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleSaveChanges = () => {
|
const handleSaveChanges = () => {
|
||||||
|
let configToSave;
|
||||||
|
if (tool.name === 'api_tool') {
|
||||||
|
configToSave = tool.config;
|
||||||
|
} else if (tool.name === 'mcp_tool') {
|
||||||
|
configToSave = { ...tool.config } as any;
|
||||||
|
const mcpConfig = tool.config as any;
|
||||||
|
|
||||||
|
if (authKey.trim()) {
|
||||||
|
if (mcpConfig.auth_type === 'api_key') {
|
||||||
|
configToSave.api_key = authKey;
|
||||||
|
} else if (mcpConfig.auth_type === 'bearer') {
|
||||||
|
configToSave.encrypted_token = authKey;
|
||||||
|
} else if (mcpConfig.auth_type === 'basic') {
|
||||||
|
configToSave.password = authKey;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
configToSave = { token: authKey };
|
||||||
|
}
|
||||||
|
|
||||||
userService
|
userService
|
||||||
.updateTool(
|
.updateTool(
|
||||||
{
|
{
|
||||||
@@ -105,7 +138,7 @@ export default function ToolConfig({
|
|||||||
displayName: tool.displayName,
|
displayName: tool.displayName,
|
||||||
customName: customName,
|
customName: customName,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
config: tool.name === 'api_tool' ? tool.config : { token: authKey },
|
config: configToSave,
|
||||||
actions: 'actions' in tool ? tool.actions : [],
|
actions: 'actions' in tool ? tool.actions : [],
|
||||||
status: tool.status,
|
status: tool.status,
|
||||||
},
|
},
|
||||||
@@ -196,7 +229,15 @@ export default function ToolConfig({
|
|||||||
<div className="mt-1">
|
<div className="mt-1">
|
||||||
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
|
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
|
||||||
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
|
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
|
||||||
{t('settings.tools.authentication')}
|
{tool.name === 'mcp_tool'
|
||||||
|
? (tool.config as any)?.auth_type === 'bearer'
|
||||||
|
? 'Bearer Token'
|
||||||
|
: (tool.config as any)?.auth_type === 'api_key'
|
||||||
|
? 'API Key'
|
||||||
|
: (tool.config as any)?.auth_type === 'basic'
|
||||||
|
? 'Password'
|
||||||
|
: t('settings.tools.authentication')
|
||||||
|
: t('settings.tools.authentication')}
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
|
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
|
||||||
@@ -208,7 +249,17 @@ export default function ToolConfig({
|
|||||||
value={authKey}
|
value={authKey}
|
||||||
onChange={(e) => setAuthKey(e.target.value)}
|
onChange={(e) => setAuthKey(e.target.value)}
|
||||||
borderVariant="thin"
|
borderVariant="thin"
|
||||||
placeholder={t('modals.configTool.apiKeyPlaceholder')}
|
placeholder={
|
||||||
|
tool.name === 'mcp_tool'
|
||||||
|
? (tool.config as any)?.auth_type === 'bearer'
|
||||||
|
? 'Bearer Token'
|
||||||
|
: (tool.config as any)?.auth_type === 'api_key'
|
||||||
|
? 'API Key'
|
||||||
|
: (tool.config as any)?.auth_type === 'basic'
|
||||||
|
? 'Password'
|
||||||
|
: t('modals.configTool.apiKeyPlaceholder')
|
||||||
|
: t('modals.configTool.apiKeyPlaceholder')
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -450,6 +501,26 @@ export default function ToolConfig({
|
|||||||
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
|
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
|
||||||
submitLabel={t('settings.tools.saveAndLeave')}
|
submitLabel={t('settings.tools.saveAndLeave')}
|
||||||
handleSubmit={() => {
|
handleSubmit={() => {
|
||||||
|
let configToSave;
|
||||||
|
if (tool.name === 'api_tool') {
|
||||||
|
configToSave = tool.config;
|
||||||
|
} else if (tool.name === 'mcp_tool') {
|
||||||
|
configToSave = { ...tool.config } as any;
|
||||||
|
const mcpConfig = tool.config as any;
|
||||||
|
|
||||||
|
if (authKey.trim()) {
|
||||||
|
if (mcpConfig.auth_type === 'api_key') {
|
||||||
|
configToSave.api_key = authKey;
|
||||||
|
} else if (mcpConfig.auth_type === 'bearer') {
|
||||||
|
configToSave.encrypted_token = authKey;
|
||||||
|
} else if (mcpConfig.auth_type === 'basic') {
|
||||||
|
configToSave.password = authKey;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
configToSave = { token: authKey };
|
||||||
|
}
|
||||||
|
|
||||||
userService
|
userService
|
||||||
.updateTool(
|
.updateTool(
|
||||||
{
|
{
|
||||||
@@ -458,10 +529,7 @@ export default function ToolConfig({
|
|||||||
displayName: tool.displayName,
|
displayName: tool.displayName,
|
||||||
customName: customName,
|
customName: customName,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
config:
|
config: configToSave,
|
||||||
tool.name === 'api_tool'
|
|
||||||
? tool.config
|
|
||||||
: { token: authKey },
|
|
||||||
actions: 'actions' in tool ? tool.actions : [],
|
actions: 'actions' in tool ? tool.actions : [],
|
||||||
status: tool.status,
|
status: tool.status,
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user