Merge pull request #1947 from siiddhantt/feat/remote-mcp

feat: remote mcp
This commit is contained in:
Alex
2025-09-15 09:31:36 +01:00
committed by GitHub
23 changed files with 1535 additions and 54 deletions

View File

@@ -140,28 +140,28 @@ class BaseAgent(ABC):
tool_id, action_name, call_args = parser.parse_args(call) tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4()) call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed # Check if parsing failed
if tool_id is None or action_name is None: if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}" error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message) logger.error(error_message)
tool_call_data = { tool_call_data = {
"tool_name": "unknown", "tool_name": "unknown",
"call_id": call_id, "call_id": call_id,
"action_name": getattr(call, 'name', 'unknown'), "action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {}, "arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}", "result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
} }
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data) self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools # Check if tool_id exists in available tools
if tool_id not in tools_dict: if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}" error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message) logger.error(error_message)
# Return error result # Return error result
tool_call_data = { tool_call_data = {
"tool_name": "unknown", "tool_name": "unknown",
@@ -173,7 +173,7 @@ class BaseAgent(ABC):
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data) self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = { tool_call_data = {
"tool_name": tools_dict[tool_id]["name"], "tool_name": tools_dict[tool_id]["name"],
"call_id": call_id, "call_id": call_id,
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
if tool_data["name"] == "api_tool" if tool_data["name"] == "api_tool"
else tool_data["config"] else tool_data["config"]
), ),
user_id=self.user, # Pass user ID for MCP tools credential decryption
) )
if tool_data["name"] == "api_tool": if tool_data["name"] == "api_tool":
print( print(

View File

@@ -0,0 +1,405 @@
import json
import time
from typing import Any, Dict, List, Optional
import requests
from application.agents.tools.base import Tool
from application.security.encryption import decrypt_credentials
_mcp_session_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- auth_type: Type of authentication (api_key, bearer, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.server_url = config.get("server_url", "")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
self.auth_credentials = config.get("auth_credentials", {})
self.available_tools = []
self._session = requests.Session()
self._mcp_session_id = None
self._setup_authentication()
self._cache_key = self._generate_cache_key()
def _setup_authentication(self):
"""Setup authentication for the MCP server connection."""
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
self._session.headers.update({header_name: api_key})
elif self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
self._session.auth = (username, password)
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{auth_key}"
def _get_cached_session(self) -> Optional[str]:
"""Get cached session ID if available and not expired."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
session_data = _mcp_session_cache[self._cache_key]
if time.time() - session_data["created_at"] < 1800:
return session_data["session_id"]
else:
del _mcp_session_cache[self._cache_key]
return None
def _cache_session(self, session_id: str):
"""Cache the session ID for reuse."""
global _mcp_session_cache
_mcp_session_cache[self._cache_key] = {
"session_id": session_id,
"created_at": time.time(),
}
def _initialize_mcp_connection(self) -> Dict:
"""
Initialize MCP connection with the server, using cached session if available.
Returns:
Server capabilities and information
"""
cached_session = self._get_cached_session()
if cached_session:
self._mcp_session_id = cached_session
return {"cached": True}
try:
init_params = {
"protocolVersion": "2024-11-05",
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
}
response = self._make_mcp_request("initialize", init_params)
self._make_mcp_request("notifications/initialized")
return response
except Exception as e:
return {"error": str(e), "fallback": True}
def _ensure_valid_session(self):
"""Ensure we have a valid MCP session, reinitializing if needed."""
if not self._mcp_session_id:
self._initialize_mcp_connection()
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
"""
Make an MCP protocol request to the server with automatic session recovery.
Args:
method: MCP method name (e.g., "tools/list", "tools/call")
params: Parameters for the MCP method
Returns:
Response data as dictionary
Raises:
Exception: If request fails after retry
"""
mcp_message = {"jsonrpc": "2.0", "method": method}
if not method.startswith("notifications/"):
mcp_message["id"] = 1
if params:
mcp_message["params"] = params
return self._execute_mcp_request(mcp_message, method)
def _execute_mcp_request(
self, mcp_message: Dict, method: str, is_retry: bool = False
) -> Dict:
"""Execute MCP request with optional retry on session failure."""
try:
final_headers = self._session.headers.copy()
final_headers.update(
{
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
)
if self._mcp_session_id:
final_headers["Mcp-Session-Id"] = self._mcp_session_id
response = self._session.post(
self.server_url.rstrip("/"),
json=mcp_message,
headers=final_headers,
timeout=self.timeout,
)
if "mcp-session-id" in response.headers:
self._mcp_session_id = response.headers["mcp-session-id"]
self._cache_session(self._mcp_session_id)
response.raise_for_status()
if method.startswith("notifications/"):
return {}
response_text = response.text.strip()
if response_text.startswith("event:") and "data:" in response_text:
lines = response_text.split("\n")
data_line = None
for line in lines:
if line.startswith("data:"):
data_line = line[5:].strip()
break
if data_line:
try:
result = json.loads(data_line)
except json.JSONDecodeError:
raise Exception(f"Invalid JSON in SSE data: {data_line}")
else:
raise Exception(f"No data found in SSE response: {response_text}")
else:
try:
result = response.json()
except json.JSONDecodeError:
raise Exception(f"Invalid JSON response: {response.text}")
if "error" in result:
error_msg = result["error"]
if isinstance(error_msg, dict):
error_msg = error_msg.get("message", str(error_msg))
raise Exception(f"MCP server error: {error_msg}")
return result.get("result", result)
except requests.exceptions.RequestException as e:
if not is_retry and self._should_retry_with_new_session(e):
self._invalidate_and_refresh_session()
return self._execute_mcp_request(mcp_message, method, is_retry=True)
raise Exception(f"MCP server request failed: {str(e)}")
def _should_retry_with_new_session(self, error: Exception) -> bool:
"""Check if error indicates session invalidation and retry is warranted."""
error_str = str(error).lower()
return (
any(
indicator in error_str
for indicator in [
"invalid session",
"session expired",
"unauthorized",
"401",
"403",
]
)
and self._mcp_session_id is not None
)
def _invalidate_and_refresh_session(self) -> None:
"""Invalidate current session and create a new one."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
del _mcp_session_cache[self._cache_key]
self._mcp_session_id = None
self._initialize_mcp_connection()
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using MCP protocol.
Returns:
List of tool definitions from the server
"""
try:
self._ensure_valid_session()
response = self._make_mcp_request("tools/list")
# Handle both formats: response with 'tools' key or response that IS the tools list
if isinstance(response, dict):
if "tools" in response:
self.available_tools = response["tools"]
elif (
"result" in response
and isinstance(response["result"], dict)
and "tools" in response["result"]
):
self.available_tools = response["result"]["tools"]
else:
self.available_tools = [response] if response else []
elif isinstance(response, list):
self.available_tools = response
else:
self.available_tools = []
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using MCP protocol.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
self._ensure_valid_session()
# Skipping empty/None values - letting the server use defaults
cleaned_kwargs = {}
for key, value in kwargs.items():
if value == "" or value is None:
continue
cleaned_kwargs[key] = value
call_params = {"name": action_name, "arguments": cleaned_kwargs}
try:
result = self._make_mcp_request("tools/call", call_params)
return result
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
if input_schema:
if isinstance(input_schema, dict):
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
try:
self._mcp_session_id = None
init_result = self._initialize_mcp_connection()
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if init_result.get("cached"):
message += " (Using cached session)"
elif init_result.get("fallback"):
message += " (No formal initialization required)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"session_id": self._mcp_session_id,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"error_type": type(e).__name__,
}
def get_config_requirements(self) -> Dict:
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
"required": True,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "api_key", "bearer", "basic"],
"default": "none",
"required": True,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
}

View File

@@ -23,16 +23,23 @@ class ToolManager:
tool_config = self.config.get(name, {}) tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config) self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config): def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}") module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass): for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool: if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config) if tool_name == "mcp_tool" and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs): def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools: if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded") raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name == "mcp_tool" and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs) return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self): def get_all_actions_metadata(self):

View File

@@ -25,6 +25,8 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument from pymongo import ReturnDocument
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from application.agents.tools.mcp_tool import MCPTool
from application.agents.tools.tool_manager import ToolManager from application.agents.tools.tool_manager import ToolManager
from application.api import api from application.api import api
@@ -38,6 +40,7 @@ from application.api.user.tasks import (
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator from application.parser.connectors.connector_creator import ConnectorCreator
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.storage_creator import StorageCreator from application.storage.storage_creator import StorageCreator
from application.tts.google_tts import GoogleTTS from application.tts.google_tts import GoogleTTS
from application.utils import ( from application.utils import (
@@ -491,6 +494,7 @@ class DeleteOldIndexes(Resource):
) )
if not doc: if not doc:
return make_response(jsonify({"status": "not found"}), 404) return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
try: try:
@@ -507,6 +511,7 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"]) settings.VECTOR_STORE, source_id=str(doc["_id"])
) )
vectorstore.delete_index() vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]: if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"] file_path = doc["file_path"]
if storage.is_directory(file_path): if storage.is_directory(file_path):
@@ -515,6 +520,7 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f) storage.delete_file(f)
else: else:
storage.delete_file(file_path) storage.delete_file(file_path)
except FileNotFoundError: except FileNotFoundError:
pass pass
except Exception as err: except Exception as err:
@@ -522,6 +528,7 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True f"Error deleting files and indexes: {err}", exc_info=True
) )
return make_response(jsonify({"success": False}), 400) return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)}) sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200) return make_response(jsonify({"success": True}), 200)
@@ -593,6 +600,7 @@ class UploadFile(Resource):
== temp_file_path == temp_file_path
): ):
continue continue
rel_path = os.path.relpath( rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir os.path.join(root, extracted_file), temp_dir
) )
@@ -617,6 +625,7 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}" file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f: with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path) storage.save_file(f, file_path)
task = ingest.delay( task = ingest.delay(
settings.UPLOAD_FOLDER, settings.UPLOAD_FOLDER,
[ [
@@ -688,6 +697,7 @@ class ManageSourceFiles(Resource):
return make_response( return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401 jsonify({"success": False, "message": "Unauthorized"}), 401
) )
user = decoded_token.get("sub") user = decoded_token.get("sub")
source_id = request.form.get("source_id") source_id = request.form.get("source_id")
operation = request.form.get("operation") operation = request.form.get("operation")
@@ -737,6 +747,7 @@ class ManageSourceFiles(Resource):
return make_response( return make_response(
jsonify({"success": False, "message": "Database error"}), 500 jsonify({"success": False, "message": "Database error"}), 500
) )
try: try:
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "") source_file_path = source.get("file_path", "")
@@ -793,6 +804,7 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
elif operation == "remove": elif operation == "remove":
file_paths_str = request.form.get("file_paths") file_paths_str = request.form.get("file_paths")
if not file_paths_str: if not file_paths_str:
@@ -846,6 +858,7 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
elif operation == "remove_directory": elif operation == "remove_directory":
directory_path = request.form.get("directory_path") directory_path = request.form.get("directory_path")
if not directory_path: if not directory_path:
@@ -871,6 +884,7 @@ class ManageSourceFiles(Resource):
), ),
400, 400,
) )
full_directory_path = ( full_directory_path = (
f"{source_file_path}/{directory_path}" f"{source_file_path}/{directory_path}"
if directory_path if directory_path
@@ -929,6 +943,7 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
except Exception as err: except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}" error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory": if operation == "remove_directory":
@@ -940,6 +955,7 @@ class ManageSourceFiles(Resource):
elif operation == "add": elif operation == "add":
parent_dir = request.form.get("parent_dir", "") parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}" error_context += f", parent_dir={parent_dir}"
current_app.logger.error( current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True f"Error managing source files: {err} ({error_context})", exc_info=True
) )
@@ -1616,6 +1632,7 @@ class CreateAgent(Resource):
), ),
400, 400,
) )
# Validate that it has either a 'schema' property or is itself a schema # Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema: if "schema" not in json_schema and "type" not in json_schema:
@@ -3625,7 +3642,60 @@ class UpdateTool(Resource):
), ),
400, 400,
) )
update_data["config"] = data["config"] tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data: if "status" in data:
update_data["status"] = data["status"] update_data["status"] = data["status"]
user_tools_collection.update_one( user_tools_collection.update_one(
@@ -3837,6 +3907,7 @@ class GetChunks(Resource):
if not (text_match or title_match): if not (text_match or title_match):
continue continue
filtered_chunks.append(chunk) filtered_chunks.append(chunk)
chunks = filtered_chunks chunks = filtered_chunks
total_chunks = len(chunks) total_chunks = len(chunks)
@@ -4027,6 +4098,7 @@ class UpdateChunk(Resource):
current_app.logger.warning( current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
) )
return make_response( return make_response(
jsonify( jsonify(
{ {
@@ -4154,19 +4226,23 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token decoded_token = request.decoded_token
if not decoded_token: if not decoded_token:
return make_response(jsonify({"success": False}), 401) return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub") user = decoded_token.get("sub")
doc_id = request.args.get("id") doc_id = request.args.get("id")
if not doc_id: if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400) return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id): if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400) return make_response(jsonify({"error": "Invalid document ID"}), 400)
try: try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc: if not doc:
return make_response( return make_response(
jsonify({"error": "Document not found or access denied"}), 404 jsonify({"error": "Document not found or access denied"}), 404
) )
directory_structure = doc.get("directory_structure", {}) directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "") base_path = doc.get("file_path", "")
@@ -4196,3 +4272,204 @@ class DirectoryStructure(Resource):
f"Error retrieving directory structure: {e}", exc_info=True f"Error retrieving directory structure: {e}", exc_info=True
) )
return make_response(jsonify({"success": False, "error": str(e)}), 500) return make_response(jsonify({"success": False, "error": str(e)}), 500)
@user_ns.route("/api/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
api.model(
"MCPServerTestModel",
{
"config": fields.Raw(
required=True, description="MCP server configuration to test"
),
},
)
)
@api.doc(description="Test MCP server connection with provided configuration")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(test_config, user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
500,
)
@user_ns.route("/api/mcp_server/save")
class MCPServerSave(Resource):
@api.expect(
api.model(
"MCPServerSaveModel",
{
"id": fields.String(
required=False, description="Tool ID for updates (optional)"
),
"displayName": fields.String(
required=True, description="Display name for the MCP server"
),
"config": fields.Raw(
required=True, description="MCP server configuration"
),
"status": fields.Boolean(
required=False, default=True, description="Tool status"
),
},
)
)
@api.doc(description="Create or update MCP server with automatic tool discovery")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["displayName", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(mcp_config, user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
raise Exception(
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
"customName": data["displayName"],
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
"config": storage_config,
"actions": transformed_actions,
"status": data.get("status", True),
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
result = user_tools_collection.insert_one(tool_data)
tool_id = str(result.inserted_id)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
500,
)

View File

@@ -26,7 +26,7 @@ class Settings(BaseSettings):
"gpt-4o-mini": 128000, "gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
"claude-2": 1e5, "claude-2": 1e5,
"gemini-2.0-flash-exp": 1e6, "gemini-2.5-flash": 1e6,
} }
UPLOAD_FOLDER: str = "inputs" UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False PARSE_PDF_AS_IMAGE: bool = False
@@ -96,7 +96,7 @@ class Settings(BaseSettings):
QDRANT_HOST: Optional[str] = None QDRANT_HOST: Optional[str] = None
QDRANT_PATH: Optional[str] = None QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine" QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config # PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config # Milvus vectorstore config
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
JWT_SECRET_KEY: str = "" JWT_SECRET_KEY: str = ""
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
path = Path(__file__).parent.parent.absolute() path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
raise raise
def _clean_messages_google(self, messages): def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
cleaned_messages = [] cleaned_messages = []
for message in messages: for message in messages:
role = message.get("role") role = message.get("role")
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
if role == "assistant": if role == "assistant":
role = "model" role = "model"
elif role == "tool":
role = "model"
parts = [] parts = []
if role and content is not None: if role and content is not None:
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
else: else:
raise ValueError(f"Unexpected content type: {type(content)}") raise ValueError(f"Unexpected content type: {type(content)}")
cleaned_messages.append(types.Content(role=role, parts=parts)) if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages return cleaned_messages
def _clean_schema(self, schema_obj):
"""
Recursively remove unsupported fields from schema objects
and validate required properties.
"""
if not isinstance(schema_obj, dict):
return schema_obj
allowed_fields = {
"type",
"description",
"items",
"properties",
"required",
"enum",
"pattern",
"minimum",
"maximum",
"nullable",
"default",
}
cleaned = {}
for key, value in schema_obj.items():
if key not in allowed_fields:
continue
elif key == "type" and isinstance(value, str):
cleaned[key] = value.upper()
elif isinstance(value, dict):
cleaned[key] = self._clean_schema(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
for required_prop in cleaned["required"]:
if required_prop in properties_keys:
valid_required.append(required_prop)
if valid_required:
cleaned["required"] = valid_required
else:
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list): def _clean_tools_format(self, tools_list):
"""Convert OpenAI format tools to Google AI format."""
genai_tools = [] genai_tools = []
for tool_data in tools_list: for tool_data in tools_list:
if tool_data["type"] == "function": if tool_data["type"] == "function":
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
properties = parameters.get("properties", {}) properties = parameters.get("properties", {})
if properties: if properties:
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict( genai_function = dict(
name=function["name"], name=function["name"],
description=function["description"], description=function["description"],
parameters={ parameters={
"type": "OBJECT", "type": "OBJECT",
"properties": { "properties": cleaned_properties,
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in properties.items()
},
"required": ( "required": (
parameters["required"] parameters["required"]
if "required" in parameters if "required" in parameters
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
response_schema=None, response_schema=None,
**kwargs, **kwargs,
): ):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key) client = genai.Client(api_key=self.api_key)
if formatting == "openai": if formatting == "openai":
messages = self._clean_messages_google(messages) messages = self._clean_messages_google(messages)
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
response_schema=None, response_schema=None,
**kwargs, **kwargs,
): ):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key) client = genai.Client(api_key=self.api_key)
if formatting == "openai": if formatting == "openai":
messages = self._clean_messages_google(messages) messages = self._clean_messages_google(messages)
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
yield chunk.text yield chunk.text
def _supports_tools(self): def _supports_tools(self):
"""Return whether this LLM supports function calling."""
return True return True
def _supports_structured_output(self): def _supports_structured_output(self):
"""Return whether this LLM supports structured JSON output."""
return True return True
def prepare_structured_output_format(self, json_schema): def prepare_structured_output_format(self, json_schema):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema: if not json_schema:
return None return None

View File

@@ -205,7 +205,6 @@ class LLMHandler(ABC):
except StopIteration as e: except StopIteration as e:
tool_response, call_id = e.value tool_response, call_id = e.value
break break
updated_messages.append( updated_messages.append(
{ {
"role": "assistant", "role": "assistant",
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
) )
updated_messages.append(self.create_tool_message(call, tool_response)) updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e: except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True) logger.error(f"Error executing tool: {str(e)}", exc_info=True)
updated_messages.append( error_call = ToolCall(
{ id=call.id, name=call.name, arguments=call.arguments
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
) )
error_response = f"Error executing tool: {str(e)}"
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
"tool_name": tool_name,
"call_id": call.id,
"action_name": full_action_name,
"arguments": call.arguments,
"error": error_response,
"status": "error",
},
}
return updated_messages return updated_messages
def handle_non_streaming( def handle_non_streaming(
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
except StopIteration as e: except StopIteration as e:
messages = e.value messages = e.value
break break
response = agent.llm.gen( response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.gpt_model, messages=messages, tools=agent.tools
) )
parsed = self.parse_response(response) parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm)) self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content return parsed.content
def handle_streaming( def handle_streaming(

View File

@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="stop", finish_reason="stop",
raw_response=response, raw_response=response,
) )
if hasattr(response, "candidates"): if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else [] parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [ tool_calls = [
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="tool_calls" if tool_calls else "stop", finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response, raw_response=response,
) )
else: else:
tool_calls = [] tool_calls = []
if hasattr(response, "function_call"): if hasattr(response, "function_call"):
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message.""" """Create Google-style tool message."""
from google.genai import types
return { return {
"role": "tool", "role": "model",
"content": [ "content": [
types.Part.from_function_response( {
name=tool_call.name, response={"result": result} "function_response": {
).to_json_dict() "name": tool_call.name,
"response": {"result": result},
}
}
], ],
} }

View File

@@ -2,6 +2,7 @@ anthropic==0.49.0
boto3==1.38.18 boto3==1.38.18
beautifulsoup4==4.13.4 beautifulsoup4==4.13.4
celery==5.4.0 celery==5.4.0
cryptography==42.0.8
dataclasses-json==0.6.7 dataclasses-json==0.6.7
docx2txt==0.8 docx2txt==0.8
duckduckgo-search==7.5.2 duckduckgo-search==7.5.2

View File

View File

@@ -0,0 +1,85 @@
import base64
import json
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from application.core.settings import settings
def _derive_key(user_id: str, salt: bytes) -> bytes:
app_secret = settings.ENCRYPTION_SECRET_KEY
password = f"{app_secret}#{user_id}".encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
backend=default_backend(),
)
return kdf.derive(password)
def encrypt_credentials(credentials: dict, user_id: str) -> str:
if not credentials:
return ""
try:
salt = os.urandom(16)
iv = os.urandom(16)
key = _derive_key(user_id, salt)
json_str = json.dumps(credentials)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
padded_data = _pad_data(json_str.encode())
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
result = salt + iv + encrypted_data
return base64.b64encode(result).decode()
except Exception as e:
print(f"Warning: Failed to encrypt credentials: {e}")
return ""
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
if not encrypted_data:
return {}
try:
data = base64.b64decode(encrypted_data.encode())
salt = data[:16]
iv = data[16:32]
encrypted_content = data[32:]
key = _derive_key(user_id, salt)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
decrypted_data = _unpad_data(decrypted_padded)
return json.loads(decrypted_data.decode())
except Exception as e:
print(f"Warning: Failed to decrypt credentials: {e}")
return {}
def _pad_data(data: bytes) -> bytes:
block_size = 16
padding_len = block_size - (len(data) % block_size)
padding = bytes([padding_len]) * padding_len
return data + padding
def _unpad_data(data: bytes) -> bytes:
padding_len = data[-1]
return data[:-padding_len]

View File

@@ -0,0 +1,4 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
</svg>

After

Width:  |  Height:  |  Size: 831 B

View File

@@ -57,6 +57,8 @@ const endpoints = {
DIRECTORY_STRUCTURE: (docId: string) => DIRECTORY_STRUCTURE: (docId: string) =>
`/api/directory_structure?id=${docId}`, `/api/directory_structure?id=${docId}`,
MANAGE_SOURCE_FILES: '/api/manage_source_files', MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
}, },
CONVERSATION: { CONVERSATION: {
ANSWER: '/api/answer', ANSWER: '/api/answer',

View File

@@ -108,6 +108,10 @@ const userService = {
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token), apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> => manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token), apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
syncConnector: ( syncConnector: (
docId: string, docId: string,
provider: string, provider: string,

View File

View File

@@ -1,6 +1,6 @@
import 'katex/dist/katex.min.css'; import 'katex/dist/katex.min.css';
import { forwardRef, Fragment, useRef, useState, useEffect } from 'react'; import { forwardRef, Fragment, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown'; import ReactMarkdown from 'react-markdown';
import { useSelector } from 'react-redux'; import { useSelector } from 'react-redux';
@@ -12,12 +12,13 @@ import {
import rehypeKatex from 'rehype-katex'; import rehypeKatex from 'rehype-katex';
import remarkGfm from 'remark-gfm'; import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math'; import remarkMath from 'remark-math';
import DocumentationDark from '../assets/documentation-dark.svg';
import ChevronDown from '../assets/chevron-down.svg'; import ChevronDown from '../assets/chevron-down.svg';
import Cloud from '../assets/cloud.svg'; import Cloud from '../assets/cloud.svg';
import DocsGPT3 from '../assets/cute_docsgpt3.svg'; import DocsGPT3 from '../assets/cute_docsgpt3.svg';
import Dislike from '../assets/dislike.svg?react'; import Dislike from '../assets/dislike.svg?react';
import Document from '../assets/document.svg'; import Document from '../assets/document.svg';
import DocumentationDark from '../assets/documentation-dark.svg';
import Edit from '../assets/edit.svg'; import Edit from '../assets/edit.svg';
import Like from '../assets/like.svg?react'; import Like from '../assets/like.svg?react';
import Link from '../assets/link.svg'; import Link from '../assets/link.svg';
@@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
Response Response
</span>{' '} </span>{' '}
<CopyButton <CopyButton
textToCopy={JSON.stringify(toolCall.result, null, 2)} textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/> />
</p> </p>
{toolCall.status === 'pending' && ( {toolCall.status === 'pending' && (
@@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
</span> </span>
</p> </p>
)} )}
{toolCall.status === 'error' && (
<p className="dark:bg-raisin-black rounded-b-2xl p-2 font-mono text-sm break-words">
<span
className="leading-[23px] text-red-500 dark:text-red-400"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
</div> </div>
</div> </div>
</Accordion> </Accordion>

View File

@@ -4,5 +4,6 @@ export type ToolCallsType = {
call_id: string; call_id: string;
arguments: Record<string, any>; arguments: Record<string, any>;
result?: Record<string, any>; result?: Record<string, any>;
status?: 'pending' | 'completed'; error?: string;
status?: 'pending' | 'completed' | 'error';
}; };

View File

@@ -18,7 +18,10 @@ export default function useDefaultDocument() {
const fetchDocs = () => { const fetchDocs = () => {
getDocs(token).then((data) => { getDocs(token).then((data) => {
dispatch(setSourceDocs(data)); dispatch(setSourceDocs(data));
if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0)) if (
!selectedDoc ||
(Array.isArray(selectedDoc) && selectedDoc.length === 0)
)
Array.isArray(data) && Array.isArray(data) &&
data?.forEach((doc: Doc) => { data?.forEach((doc: Doc) => {
if (doc.model && doc.name === 'default') { if (doc.model && doc.name === 'default') {

View File

@@ -184,7 +184,39 @@
"cancel": "Cancel", "cancel": "Cancel",
"addNew": "Add New", "addNew": "Add New",
"name": "Name", "name": "Name",
"type": "Type" "type": "Type",
"mcp": {
"addServer": "Add MCP Server",
"editServer": "Edit Server",
"serverName": "Server Name",
"serverUrl": "Server URL",
"headerName": "Header Name",
"timeout": "Timeout (seconds)",
"testConnection": "Test Connection",
"testing": "Testing...",
"saving": "Saving...",
"save": "Save",
"cancel": "Cancel",
"noAuth": "No Authentication",
"placeholders": {
"serverUrl": "https://api.example.com",
"apiKey": "Your secret API key",
"bearerToken": "Your secret token",
"username": "Your username",
"password": "Your password"
},
"errors": {
"nameRequired": "Server name is required",
"urlRequired": "Server URL is required",
"invalidUrl": "Please enter a valid URL",
"apiKeyRequired": "API key is required",
"tokenRequired": "Bearer token is required",
"usernameRequired": "Username is required",
"passwordRequired": "Password is required",
"testFailed": "Connection test failed",
"saveFailed": "Failed to save MCP server"
}
}
} }
}, },
"modals": { "modals": {

View File

@@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc'; import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice'; import { selectToken } from '../preferences/preferenceSlice';
import ConfigToolModal from './ConfigToolModal'; import ConfigToolModal from './ConfigToolModal';
import MCPServerModal from './MCPServerModal';
import { AvailableToolType } from './types'; import { AvailableToolType } from './types';
import WrapperComponent from './WrapperModal'; import WrapperComponent from './WrapperModal';
@@ -34,6 +35,8 @@ export default function AddToolModal({
React.useState<AvailableToolType | null>(null); React.useState<AvailableToolType | null>(null);
const [configModalState, setConfigModalState] = const [configModalState, setConfigModalState] =
React.useState<ActiveState>('INACTIVE'); React.useState<ActiveState>('INACTIVE');
const [mcpModalState, setMcpModalState] =
React.useState<ActiveState>('INACTIVE');
const [loading, setLoading] = React.useState(false); const [loading, setLoading] = React.useState(false);
useOutsideAlerter(modalRef, () => { useOutsideAlerter(modalRef, () => {
@@ -86,6 +89,9 @@ export default function AddToolModal({
.catch((error) => { .catch((error) => {
console.error('Failed to create tool:', error); console.error('Failed to create tool:', error);
}); });
} else if (tool.name === 'mcp_tool') {
setModalState('INACTIVE');
setMcpModalState('ACTIVE');
} else { } else {
setModalState('INACTIVE'); setModalState('INACTIVE');
setConfigModalState('ACTIVE'); setConfigModalState('ACTIVE');
@@ -95,6 +101,12 @@ export default function AddToolModal({
React.useEffect(() => { React.useEffect(() => {
if (modalState === 'ACTIVE') getAvailableTools(); if (modalState === 'ACTIVE') getAvailableTools();
}, [modalState]); }, [modalState]);
const handleMcpServerAdded = () => {
getUserTools();
setMcpModalState('INACTIVE');
};
return ( return (
<> <>
{modalState === 'ACTIVE' && ( {modalState === 'ACTIVE' && (
@@ -166,6 +178,11 @@ export default function AddToolModal({
tool={selectedTool} tool={selectedTool}
getUserTools={getUserTools} getUserTools={getUserTools}
/> />
<MCPServerModal
modalState={mcpModalState}
setModalState={setMcpModalState}
onServerSaved={handleMcpServerAdded}
/>
</> </>
); );
} }

View File

@@ -0,0 +1,482 @@
import { useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
import Spinner from '../components/Spinner';
import { useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
interface MCPServerModalProps {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
server?: any;
onServerSaved: () => void;
}
const authTypes = [
{ label: 'No Authentication', value: 'none' },
{ label: 'API Key', value: 'api_key' },
{ label: 'Bearer Token', value: 'bearer' },
// { label: 'Basic Authentication', value: 'basic' },
];
export default function MCPServerModal({
modalState,
setModalState,
server,
onServerSaved,
}: MCPServerModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const modalRef = useRef<HTMLDivElement>(null);
const [formData, setFormData] = useState({
name: server?.displayName || 'My MCP Server',
server_url: server?.server_url || '',
auth_type: server?.auth_type || 'none',
api_key: '',
header_name: 'X-API-Key',
bearer_token: '',
username: '',
password: '',
timeout: server?.timeout || 30,
});
const [loading, setLoading] = useState(false);
const [testing, setTesting] = useState(false);
const [testResult, setTestResult] = useState<{
success: boolean;
message: string;
} | null>(null);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
setModalState('INACTIVE');
resetForm();
}
}, [modalState]);
const resetForm = () => {
setFormData({
name: 'My MCP Server',
server_url: '',
auth_type: 'none',
api_key: '',
header_name: 'X-API-Key',
bearer_token: '',
username: '',
password: '',
timeout: 30,
});
setErrors({});
setTestResult(null);
};
const validateForm = () => {
const requiredFields: { [key: string]: boolean } = {
name: !formData.name.trim(),
server_url: !formData.server_url.trim(),
};
const authFieldChecks: { [key: string]: () => void } = {
api_key: () => {
if (!formData.api_key.trim())
newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired');
},
bearer: () => {
if (!formData.bearer_token.trim())
newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired');
},
basic: () => {
if (!formData.username.trim())
newErrors.username = t('settings.tools.mcp.errors.usernameRequired');
if (!formData.password.trim())
newErrors.password = t('settings.tools.mcp.errors.passwordRequired');
},
};
const newErrors: { [key: string]: string } = {};
Object.entries(requiredFields).forEach(([field, isEmpty]) => {
if (isEmpty)
newErrors[field] = t(
`settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`,
);
});
if (formData.server_url.trim()) {
try {
new URL(formData.server_url);
} catch {
newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl');
}
}
const timeoutValue = formData.timeout === '' ? 30 : formData.timeout;
if (
typeof timeoutValue === 'number' &&
(timeoutValue < 1 || timeoutValue > 300)
)
newErrors.timeout = 'Timeout must be between 1 and 300 seconds';
if (authFieldChecks[formData.auth_type])
authFieldChecks[formData.auth_type]();
setErrors(newErrors);
return Object.keys(newErrors).length === 0;
};
const handleInputChange = (name: string, value: string | number) => {
setFormData((prev) => ({ ...prev, [name]: value }));
if (errors[name]) {
setErrors((prev) => ({ ...prev, [name]: '' }));
}
setTestResult(null);
};
const buildToolConfig = () => {
const config: any = {
server_url: formData.server_url.trim(),
auth_type: formData.auth_type,
timeout: formData.timeout === '' ? 30 : formData.timeout,
};
if (formData.auth_type === 'api_key') {
config.api_key = formData.api_key.trim();
config.api_key_header = formData.header_name.trim() || 'X-API-Key';
} else if (formData.auth_type === 'bearer') {
config.bearer_token = formData.bearer_token.trim();
} else if (formData.auth_type === 'basic') {
config.username = formData.username.trim();
config.password = formData.password.trim();
}
return config;
};
const testConnection = async () => {
if (!validateForm()) return;
setTesting(true);
setTestResult(null);
try {
const config = buildToolConfig();
const response = await userService.testMCPConnection({ config }, token);
const result = await response.json();
setTestResult(result);
} catch (error) {
setTestResult({
success: false,
message: t('settings.tools.mcp.errors.testFailed'),
});
} finally {
setTesting(false);
}
};
const handleSave = async () => {
if (!validateForm()) return;
setLoading(true);
try {
const config = buildToolConfig();
const serverData = {
displayName: formData.name,
config,
status: true,
...(server?.id && { id: server.id }),
};
const response = await userService.saveMCPServer(serverData, token);
const result = await response.json();
if (response.ok && result.success) {
setTestResult({
success: true,
message: result.message,
});
onServerSaved();
setModalState('INACTIVE');
resetForm();
} else {
setErrors({
general: result.error || t('settings.tools.mcp.errors.saveFailed'),
});
}
} catch (error) {
console.error('Error saving MCP server:', error);
setErrors({ general: t('settings.tools.mcp.errors.saveFailed') });
} finally {
setLoading(false);
}
};
const renderAuthFields = () => {
switch (formData.auth_type) {
case 'api_key':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="api_key"
type="text"
className="rounded-md"
value={formData.api_key}
onChange={(e) => handleInputChange('api_key', e.target.value)}
placeholder={t('settings.tools.mcp.placeholders.apiKey')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.api_key && (
<p className="mt-1 text-sm text-red-600">{errors.api_key}</p>
)}
</div>
<div className="mt-5">
<Input
name="header_name"
type="text"
className="rounded-md"
value={formData.header_name}
onChange={(e) =>
handleInputChange('header_name', e.target.value)
}
placeholder={t('settings.tools.mcp.headerName')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
</div>
</div>
);
case 'bearer':
return (
<div className="mb-10">
<Input
name="bearer_token"
type="text"
className="rounded-md"
value={formData.bearer_token}
onChange={(e) =>
handleInputChange('bearer_token', e.target.value)
}
placeholder={t('settings.tools.mcp.placeholders.bearerToken')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.bearer_token && (
<p className="mt-1 text-sm text-red-600">{errors.bearer_token}</p>
)}
</div>
);
case 'basic':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="username"
type="text"
className="rounded-md"
value={formData.username}
onChange={(e) => handleInputChange('username', e.target.value)}
placeholder={t('settings.tools.mcp.username')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.username && (
<p className="mt-1 text-sm text-red-600">{errors.username}</p>
)}
</div>
<div className="mt-5">
<Input
name="password"
type="text"
className="rounded-md"
value={formData.password}
onChange={(e) => handleInputChange('password', e.target.value)}
placeholder={t('settings.tools.mcp.password')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.password && (
<p className="mt-1 text-sm text-red-600">{errors.password}</p>
)}
</div>
</div>
);
default:
return null;
}
};
return (
modalState === 'ACTIVE' && (
<WrapperComponent
close={() => {
setModalState('INACTIVE');
resetForm();
}}
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
>
<div className="flex h-full flex-col">
<div className="px-6 py-4">
<h2 className="text-jet dark:text-bright-gray text-xl font-semibold">
{server
? t('settings.tools.mcp.editServer')
: t('settings.tools.mcp.addServer')}
</h2>
</div>
<div className="flex-1 px-6">
<div className="space-y-6 py-6">
<div>
<Input
name="name"
type="text"
className="rounded-md"
value={formData.name}
onChange={(e) => handleInputChange('name', e.target.value)}
borderVariant="thin"
placeholder={t('settings.tools.mcp.serverName')}
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.name && (
<p className="mt-1 text-sm text-red-600">{errors.name}</p>
)}
</div>
<div>
<Input
name="server_url"
type="text"
className="rounded-md"
value={formData.server_url}
onChange={(e) =>
handleInputChange('server_url', e.target.value)
}
placeholder={t('settings.tools.mcp.serverUrl')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.server_url && (
<p className="mt-1 text-sm text-red-600">
{errors.server_url}
</p>
)}
</div>
<Dropdown
placeholder={t('settings.tools.mcp.authType')}
selectedValue={
authTypes.find((type) => type.value === formData.auth_type)
?.label || null
}
onSelect={(selection: { label: string; value: string }) => {
handleInputChange('auth_type', selection.value);
}}
options={authTypes}
size="w-full"
rounded="3xl"
border="border"
/>
{renderAuthFields()}
<div>
<Input
name="timeout"
type="number"
className="rounded-md"
value={formData.timeout}
onChange={(e) => {
const value = e.target.value;
if (value === '') {
handleInputChange('timeout', '');
} else {
const numValue = parseInt(value);
if (!isNaN(numValue) && numValue >= 1) {
handleInputChange('timeout', numValue);
}
}
}}
placeholder={t('settings.tools.mcp.timeout')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.timeout && (
<p className="mt-2 text-sm text-red-600">{errors.timeout}</p>
)}
</div>
{testResult && (
<div
className={`rounded-md p-5 ${
testResult.success
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
}`}
>
{testResult.message}
</div>
)}
{errors.general && (
<div className="rounded-2xl bg-red-50 p-5 text-red-700 dark:bg-red-900 dark:text-red-300">
{errors.general}
</div>
)}
</div>
</div>
<div className="px-6 py-2">
<div className="flex flex-col gap-4 sm:flex-row sm:justify-between">
<button
onClick={testConnection}
disabled={testing}
className="border-silver dark:border-dim-gray dark:text-light-gray w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all hover:bg-gray-100 disabled:opacity-50 sm:w-auto dark:hover:bg-[#767183]/50"
>
{testing ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.tools.mcp.testing')}
</span>
</div>
) : (
t('settings.tools.mcp.testConnection')
)}
</button>
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
<button
onClick={() => {
setModalState('INACTIVE');
resetForm();
}}
className="dark:text-light-gray w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium hover:bg-gray-100 sm:w-auto dark:bg-transparent dark:hover:bg-[#767183]/50"
>
{t('settings.tools.mcp.cancel')}
</button>
<button
onClick={handleSave}
disabled={loading}
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{loading ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.tools.mcp.saving')}
</span>
</div>
) : (
t('settings.tools.mcp.save')
)}
</button>
</div>
</div>
</div>
</div>
</WrapperComponent>
)
);
}

View File

@@ -92,7 +92,7 @@ export function getLocalApiKey(): string | null {
export function getLocalRecentDocs(): Doc[] | null { export function getLocalRecentDocs(): Doc[] | null {
const docs = localStorage.getItem('DocsGPTRecentDocs'); const docs = localStorage.getItem('DocsGPTRecentDocs');
return docs ? JSON.parse(docs) as Doc[] : null; return docs ? (JSON.parse(docs) as Doc[]) : null;
} }
export function getLocalPrompt(): string | null { export function getLocalPrompt(): string | null {

View File

@@ -30,9 +30,22 @@ export default function ToolConfig({
handleGoBack: () => void; handleGoBack: () => void;
}) { }) {
const token = useSelector(selectToken); const token = useSelector(selectToken);
const [authKey, setAuthKey] = React.useState<string>( const [authKey, setAuthKey] = React.useState<string>(() => {
'token' in tool.config ? tool.config.token : '', if (tool.name === 'mcp_tool') {
); const config = tool.config as any;
if (config.auth_type === 'api_key') {
return config.api_key || '';
} else if (config.auth_type === 'bearer') {
return config.encrypted_token || '';
} else if (config.auth_type === 'basic') {
return config.password || '';
}
return '';
} else if ('token' in tool.config) {
return tool.config.token;
}
return '';
});
const [customName, setCustomName] = React.useState<string>( const [customName, setCustomName] = React.useState<string>(
tool.customName || '', tool.customName || '',
); );
@@ -97,6 +110,26 @@ export default function ToolConfig({
}; };
const handleSaveChanges = () => { const handleSaveChanges = () => {
let configToSave;
if (tool.name === 'api_tool') {
configToSave = tool.config;
} else if (tool.name === 'mcp_tool') {
configToSave = { ...tool.config } as any;
const mcpConfig = tool.config as any;
if (authKey.trim()) {
if (mcpConfig.auth_type === 'api_key') {
configToSave.api_key = authKey;
} else if (mcpConfig.auth_type === 'bearer') {
configToSave.encrypted_token = authKey;
} else if (mcpConfig.auth_type === 'basic') {
configToSave.password = authKey;
}
}
} else {
configToSave = { token: authKey };
}
userService userService
.updateTool( .updateTool(
{ {
@@ -105,7 +138,7 @@ export default function ToolConfig({
displayName: tool.displayName, displayName: tool.displayName,
customName: customName, customName: customName,
description: tool.description, description: tool.description,
config: tool.name === 'api_tool' ? tool.config : { token: authKey }, config: configToSave,
actions: 'actions' in tool ? tool.actions : [], actions: 'actions' in tool ? tool.actions : [],
status: tool.status, status: tool.status,
}, },
@@ -196,7 +229,15 @@ export default function ToolConfig({
<div className="mt-1"> <div className="mt-1">
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && ( {Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold"> <p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
{t('settings.tools.authentication')} {tool.name === 'mcp_tool'
? (tool.config as any)?.auth_type === 'bearer'
? 'Bearer Token'
: (tool.config as any)?.auth_type === 'api_key'
? 'API Key'
: (tool.config as any)?.auth_type === 'basic'
? 'Password'
: t('settings.tools.authentication')
: t('settings.tools.authentication')}
</p> </p>
)} )}
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center"> <div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
@@ -208,7 +249,17 @@ export default function ToolConfig({
value={authKey} value={authKey}
onChange={(e) => setAuthKey(e.target.value)} onChange={(e) => setAuthKey(e.target.value)}
borderVariant="thin" borderVariant="thin"
placeholder={t('modals.configTool.apiKeyPlaceholder')} placeholder={
tool.name === 'mcp_tool'
? (tool.config as any)?.auth_type === 'bearer'
? 'Bearer Token'
: (tool.config as any)?.auth_type === 'api_key'
? 'API Key'
: (tool.config as any)?.auth_type === 'basic'
? 'Password'
: t('modals.configTool.apiKeyPlaceholder')
: t('modals.configTool.apiKeyPlaceholder')
}
/> />
</div> </div>
)} )}
@@ -450,6 +501,26 @@ export default function ToolConfig({
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')} setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
submitLabel={t('settings.tools.saveAndLeave')} submitLabel={t('settings.tools.saveAndLeave')}
handleSubmit={() => { handleSubmit={() => {
let configToSave;
if (tool.name === 'api_tool') {
configToSave = tool.config;
} else if (tool.name === 'mcp_tool') {
configToSave = { ...tool.config } as any;
const mcpConfig = tool.config as any;
if (authKey.trim()) {
if (mcpConfig.auth_type === 'api_key') {
configToSave.api_key = authKey;
} else if (mcpConfig.auth_type === 'bearer') {
configToSave.encrypted_token = authKey;
} else if (mcpConfig.auth_type === 'basic') {
configToSave.password = authKey;
}
}
} else {
configToSave = { token: authKey };
}
userService userService
.updateTool( .updateTool(
{ {
@@ -458,10 +529,7 @@ export default function ToolConfig({
displayName: tool.displayName, displayName: tool.displayName,
customName: customName, customName: customName,
description: tool.description, description: tool.description,
config: config: configToSave,
tool.name === 'api_tool'
? tool.config
: { token: authKey },
actions: 'actions' in tool ? tool.actions : [], actions: 'actions' in tool ? tool.actions : [],
status: tool.status, status: tool.status,
}, },