diff --git a/application/agents/base.py b/application/agents/base.py
index f2cabdb7..77729fe6 100644
--- a/application/agents/base.py
+++ b/application/agents/base.py
@@ -140,28 +140,28 @@ class BaseAgent(ABC):
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
-
+
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
-
+
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
- "action_name": getattr(call, 'name', 'unknown'),
+ "action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
-
+
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
-
+
# Return error result
tool_call_data = {
"tool_name": "unknown",
@@ -173,7 +173,7 @@ class BaseAgent(ABC):
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
-
+
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
+ user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(
diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py
new file mode 100644
index 00000000..dc689367
--- /dev/null
+++ b/application/agents/tools/mcp_tool.py
@@ -0,0 +1,405 @@
+import json
+import time
+from typing import Any, Dict, List, Optional
+
+import requests
+
+from application.agents.tools.base import Tool
+from application.security.encryption import decrypt_credentials
+
+
+_mcp_session_cache = {}
+
+
+class MCPTool(Tool):
+ """
+ MCP Tool
+ Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
+ """
+
+ def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
+ """
+ Initialize the MCP Tool with configuration.
+
+ Args:
+ config: Dictionary containing MCP server configuration:
+ - server_url: URL of the remote MCP server
+ - auth_type: Type of authentication (api_key, bearer, basic, none)
+ - encrypted_credentials: Encrypted credentials (if available)
+ - timeout: Request timeout in seconds (default: 30)
+ user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
+ """
+ self.config = config
+ self.server_url = config.get("server_url", "")
+ self.auth_type = config.get("auth_type", "none")
+ self.timeout = config.get("timeout", 30)
+
+ self.auth_credentials = {}
+ if config.get("encrypted_credentials") and user_id:
+ self.auth_credentials = decrypt_credentials(
+ config["encrypted_credentials"], user_id
+ )
+ else:
+ self.auth_credentials = config.get("auth_credentials", {})
+ self.available_tools = []
+ self._session = requests.Session()
+ self._mcp_session_id = None
+ self._setup_authentication()
+ self._cache_key = self._generate_cache_key()
+
+ def _setup_authentication(self):
+ """Setup authentication for the MCP server connection."""
+ if self.auth_type == "api_key":
+ api_key = self.auth_credentials.get("api_key", "")
+ header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
+ if api_key:
+ self._session.headers.update({header_name: api_key})
+ elif self.auth_type == "bearer":
+ token = self.auth_credentials.get("bearer_token", "")
+ if token:
+ self._session.headers.update({"Authorization": f"Bearer {token}"})
+ elif self.auth_type == "basic":
+ username = self.auth_credentials.get("username", "")
+ password = self.auth_credentials.get("password", "")
+ if username and password:
+ self._session.auth = (username, password)
+
+ def _generate_cache_key(self) -> str:
+ """Generate a unique cache key for this MCP server configuration."""
+ auth_key = ""
+ if self.auth_type == "bearer":
+ token = self.auth_credentials.get("bearer_token", "")
+ auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
+ elif self.auth_type == "api_key":
+ api_key = self.auth_credentials.get("api_key", "")
+ auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
+ elif self.auth_type == "basic":
+ username = self.auth_credentials.get("username", "")
+ auth_key = f"basic:{username}"
+ else:
+ auth_key = "none"
+ return f"{self.server_url}#{auth_key}"
+
+ def _get_cached_session(self) -> Optional[str]:
+ """Get cached session ID if available and not expired."""
+ global _mcp_session_cache
+
+ if self._cache_key in _mcp_session_cache:
+ session_data = _mcp_session_cache[self._cache_key]
+ if time.time() - session_data["created_at"] < 1800:
+ return session_data["session_id"]
+ else:
+ del _mcp_session_cache[self._cache_key]
+ return None
+
+ def _cache_session(self, session_id: str):
+ """Cache the session ID for reuse."""
+ global _mcp_session_cache
+ _mcp_session_cache[self._cache_key] = {
+ "session_id": session_id,
+ "created_at": time.time(),
+ }
+
+ def _initialize_mcp_connection(self) -> Dict:
+ """
+ Initialize MCP connection with the server, using cached session if available.
+
+ Returns:
+ Server capabilities and information
+ """
+ cached_session = self._get_cached_session()
+ if cached_session:
+ self._mcp_session_id = cached_session
+ return {"cached": True}
+ try:
+ init_params = {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {"roots": {"listChanged": True}, "sampling": {}},
+ "clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
+ }
+ response = self._make_mcp_request("initialize", init_params)
+ self._make_mcp_request("notifications/initialized")
+
+ return response
+ except Exception as e:
+ return {"error": str(e), "fallback": True}
+
+ def _ensure_valid_session(self):
+ """Ensure we have a valid MCP session, reinitializing if needed."""
+ if not self._mcp_session_id:
+ self._initialize_mcp_connection()
+
+ def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
+ """
+ Make an MCP protocol request to the server with automatic session recovery.
+
+ Args:
+ method: MCP method name (e.g., "tools/list", "tools/call")
+ params: Parameters for the MCP method
+
+ Returns:
+ Response data as dictionary
+
+ Raises:
+ Exception: If request fails after retry
+ """
+ mcp_message = {"jsonrpc": "2.0", "method": method}
+
+ if not method.startswith("notifications/"):
+ mcp_message["id"] = 1
+ if params:
+ mcp_message["params"] = params
+ return self._execute_mcp_request(mcp_message, method)
+
+ def _execute_mcp_request(
+ self, mcp_message: Dict, method: str, is_retry: bool = False
+ ) -> Dict:
+ """Execute MCP request with optional retry on session failure."""
+ try:
+ final_headers = self._session.headers.copy()
+ final_headers.update(
+ {
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ }
+ )
+
+ if self._mcp_session_id:
+ final_headers["Mcp-Session-Id"] = self._mcp_session_id
+ response = self._session.post(
+ self.server_url.rstrip("/"),
+ json=mcp_message,
+ headers=final_headers,
+ timeout=self.timeout,
+ )
+
+ if "mcp-session-id" in response.headers:
+ self._mcp_session_id = response.headers["mcp-session-id"]
+ self._cache_session(self._mcp_session_id)
+ response.raise_for_status()
+
+ if method.startswith("notifications/"):
+ return {}
+ response_text = response.text.strip()
+ if response_text.startswith("event:") and "data:" in response_text:
+ lines = response_text.split("\n")
+ data_line = None
+ for line in lines:
+ if line.startswith("data:"):
+ data_line = line[5:].strip()
+ break
+ if data_line:
+ try:
+ result = json.loads(data_line)
+ except json.JSONDecodeError:
+ raise Exception(f"Invalid JSON in SSE data: {data_line}")
+ else:
+ raise Exception(f"No data found in SSE response: {response_text}")
+ else:
+ try:
+ result = response.json()
+ except json.JSONDecodeError:
+ raise Exception(f"Invalid JSON response: {response.text}")
+ if "error" in result:
+ error_msg = result["error"]
+ if isinstance(error_msg, dict):
+ error_msg = error_msg.get("message", str(error_msg))
+ raise Exception(f"MCP server error: {error_msg}")
+ return result.get("result", result)
+ except requests.exceptions.RequestException as e:
+ if not is_retry and self._should_retry_with_new_session(e):
+ self._invalidate_and_refresh_session()
+ return self._execute_mcp_request(mcp_message, method, is_retry=True)
+ raise Exception(f"MCP server request failed: {str(e)}")
+
+ def _should_retry_with_new_session(self, error: Exception) -> bool:
+ """Check if error indicates session invalidation and retry is warranted."""
+ error_str = str(error).lower()
+ return (
+ any(
+ indicator in error_str
+ for indicator in [
+ "invalid session",
+ "session expired",
+ "unauthorized",
+ "401",
+ "403",
+ ]
+ )
+ and self._mcp_session_id is not None
+ )
+
+ def _invalidate_and_refresh_session(self) -> None:
+ """Invalidate current session and create a new one."""
+ global _mcp_session_cache
+ if self._cache_key in _mcp_session_cache:
+ del _mcp_session_cache[self._cache_key]
+ self._mcp_session_id = None
+ self._initialize_mcp_connection()
+
+ def discover_tools(self) -> List[Dict]:
+ """
+ Discover available tools from the MCP server using MCP protocol.
+
+ Returns:
+ List of tool definitions from the server
+ """
+ try:
+ self._ensure_valid_session()
+
+ response = self._make_mcp_request("tools/list")
+
+ # Handle both formats: response with 'tools' key or response that IS the tools list
+
+ if isinstance(response, dict):
+ if "tools" in response:
+ self.available_tools = response["tools"]
+ elif (
+ "result" in response
+ and isinstance(response["result"], dict)
+ and "tools" in response["result"]
+ ):
+ self.available_tools = response["result"]["tools"]
+ else:
+ self.available_tools = [response] if response else []
+ elif isinstance(response, list):
+ self.available_tools = response
+ else:
+ self.available_tools = []
+ return self.available_tools
+ except Exception as e:
+ raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
+
+ def execute_action(self, action_name: str, **kwargs) -> Any:
+ """
+ Execute an action on the remote MCP server using MCP protocol.
+
+ Args:
+ action_name: Name of the action to execute
+ **kwargs: Parameters for the action
+
+ Returns:
+ Result from the MCP server
+ """
+ self._ensure_valid_session()
+
+ # Skipping empty/None values - letting the server use defaults
+
+ cleaned_kwargs = {}
+ for key, value in kwargs.items():
+ if value == "" or value is None:
+ continue
+ cleaned_kwargs[key] = value
+ call_params = {"name": action_name, "arguments": cleaned_kwargs}
+ try:
+ result = self._make_mcp_request("tools/call", call_params)
+ return result
+ except Exception as e:
+ raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
+
+ def get_actions_metadata(self) -> List[Dict]:
+ """
+ Get metadata for all available actions.
+
+ Returns:
+ List of action metadata dictionaries
+ """
+ actions = []
+ for tool in self.available_tools:
+ input_schema = (
+ tool.get("inputSchema")
+ or tool.get("input_schema")
+ or tool.get("schema")
+ or tool.get("parameters")
+ )
+
+ parameters_schema = {
+ "type": "object",
+ "properties": {},
+ "required": [],
+ }
+
+ if input_schema:
+ if isinstance(input_schema, dict):
+ if "properties" in input_schema:
+ parameters_schema = {
+ "type": input_schema.get("type", "object"),
+ "properties": input_schema.get("properties", {}),
+ "required": input_schema.get("required", []),
+ }
+
+ for key in ["additionalProperties", "description"]:
+ if key in input_schema:
+ parameters_schema[key] = input_schema[key]
+ else:
+ parameters_schema["properties"] = input_schema
+ action = {
+ "name": tool.get("name", ""),
+ "description": tool.get("description", ""),
+ "parameters": parameters_schema,
+ }
+ actions.append(action)
+ return actions
+
+ def test_connection(self) -> Dict:
+ """
+ Test the connection to the MCP server and validate functionality.
+
+ Returns:
+ Dictionary with connection test results including tool count
+ """
+ try:
+ self._mcp_session_id = None
+
+ init_result = self._initialize_mcp_connection()
+
+ tools = self.discover_tools()
+
+ message = f"Successfully connected to MCP server. Found {len(tools)} tools."
+ if init_result.get("cached"):
+ message += " (Using cached session)"
+ elif init_result.get("fallback"):
+ message += " (No formal initialization required)"
+ return {
+ "success": True,
+ "message": message,
+ "tools_count": len(tools),
+ "session_id": self._mcp_session_id,
+ "tools": [tool.get("name", "unknown") for tool in tools[:5]],
+ }
+ except Exception as e:
+ return {
+ "success": False,
+ "message": f"Connection failed: {str(e)}",
+ "tools_count": 0,
+ "error_type": type(e).__name__,
+ }
+
+ def get_config_requirements(self) -> Dict:
+ return {
+ "server_url": {
+ "type": "string",
+ "description": "URL of the remote MCP server (e.g., https://api.example.com)",
+ "required": True,
+ },
+ "auth_type": {
+ "type": "string",
+ "description": "Authentication type",
+ "enum": ["none", "api_key", "bearer", "basic"],
+ "default": "none",
+ "required": True,
+ },
+ "auth_credentials": {
+ "type": "object",
+ "description": "Authentication credentials (varies by auth_type)",
+ "required": False,
+ },
+ "timeout": {
+ "type": "integer",
+ "description": "Request timeout in seconds",
+ "default": 30,
+ "minimum": 1,
+ "maximum": 300,
+ "required": False,
+ },
+ }
diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py
index ad71db28..d602b762 100644
--- a/application/agents/tools/tool_manager.py
+++ b/application/agents/tools/tool_manager.py
@@ -23,16 +23,23 @@ class ToolManager:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
- def load_tool(self, tool_name, tool_config):
+ def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
- return obj(tool_config)
+ if tool_name == "mcp_tool" and user_id:
+ return obj(tool_config, user_id)
+ else:
+ return obj(tool_config)
- def execute_action(self, tool_name, action_name, **kwargs):
+ def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
+ if tool_name == "mcp_tool" and user_id:
+ tool_config = self.config.get(tool_name, {})
+ tool = self.load_tool(tool_name, tool_config, user_id)
+ return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index f508b7cf..f0493c7c 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -25,6 +25,8 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
+from application.agents.tools.mcp_tool import MCPTool
+
from application.agents.tools.tool_manager import ToolManager
from application.api import api
@@ -38,6 +40,7 @@ from application.api.user.tasks import (
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
+from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.storage_creator import StorageCreator
from application.tts.google_tts import GoogleTTS
from application.utils import (
@@ -491,6 +494,7 @@ class DeleteOldIndexes(Resource):
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
+
storage = StorageCreator.get_storage()
try:
@@ -507,6 +511,7 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
+
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
@@ -515,6 +520,7 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f)
else:
storage.delete_file(file_path)
+
except FileNotFoundError:
pass
except Exception as err:
@@ -522,6 +528,7 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
+
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -593,6 +600,7 @@ class UploadFile(Resource):
== temp_file_path
):
continue
+
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
@@ -617,6 +625,7 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
+
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
@@ -688,6 +697,7 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
+
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
operation = request.form.get("operation")
@@ -737,6 +747,7 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
+
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -793,6 +804,7 @@ class ManageSourceFiles(Resource):
),
200,
)
+
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
@@ -846,6 +858,7 @@ class ManageSourceFiles(Resource):
),
200,
)
+
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
@@ -871,6 +884,7 @@ class ManageSourceFiles(Resource):
),
400,
)
+
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
@@ -929,6 +943,7 @@ class ManageSourceFiles(Resource):
),
200,
)
+
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
@@ -940,6 +955,7 @@ class ManageSourceFiles(Resource):
elif operation == "add":
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
+
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
@@ -1616,6 +1632,7 @@ class CreateAgent(Resource):
),
400,
)
+
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
@@ -3625,7 +3642,60 @@ class UpdateTool(Resource):
),
400,
)
- update_data["config"] = data["config"]
+ tool_doc = user_tools_collection.find_one(
+ {"_id": ObjectId(data["id"]), "user": user}
+ )
+ if tool_doc and tool_doc.get("name") == "mcp_tool":
+ config = data["config"]
+ existing_config = tool_doc.get("config", {})
+ storage_config = existing_config.copy()
+
+ storage_config.update(config)
+ existing_credentials = {}
+ if "encrypted_credentials" in existing_config:
+ existing_credentials = decrypt_credentials(
+ existing_config["encrypted_credentials"], user
+ )
+ auth_credentials = existing_credentials.copy()
+ auth_type = storage_config.get("auth_type", "none")
+ if auth_type == "api_key":
+ if "api_key" in config and config["api_key"]:
+ auth_credentials["api_key"] = config["api_key"]
+ if "api_key_header" in config:
+ auth_credentials["api_key_header"] = config[
+ "api_key_header"
+ ]
+ elif auth_type == "bearer":
+ if "bearer_token" in config and config["bearer_token"]:
+ auth_credentials["bearer_token"] = config["bearer_token"]
+ elif "encrypted_token" in config and config["encrypted_token"]:
+ auth_credentials["bearer_token"] = config["encrypted_token"]
+ elif auth_type == "basic":
+ if "username" in config and config["username"]:
+ auth_credentials["username"] = config["username"]
+ if "password" in config and config["password"]:
+ auth_credentials["password"] = config["password"]
+ if auth_type != "none" and auth_credentials:
+ encrypted_credentials_string = encrypt_credentials(
+ auth_credentials, user
+ )
+ storage_config["encrypted_credentials"] = (
+ encrypted_credentials_string
+ )
+ elif auth_type == "none":
+ storage_config.pop("encrypted_credentials", None)
+ for field in [
+ "api_key",
+ "bearer_token",
+ "encrypted_token",
+ "username",
+ "password",
+ "api_key_header",
+ ]:
+ storage_config.pop(field, None)
+ update_data["config"] = storage_config
+ else:
+ update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -3837,6 +3907,7 @@ class GetChunks(Resource):
if not (text_match or title_match):
continue
filtered_chunks.append(chunk)
+
chunks = filtered_chunks
total_chunks = len(chunks)
@@ -4027,6 +4098,7 @@ class UpdateChunk(Resource):
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
+
return make_response(
jsonify(
{
@@ -4154,19 +4226,23 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
+
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
+
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
+
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
+
directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "")
@@ -4196,3 +4272,204 @@ class DirectoryStructure(Resource):
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)
+
+
+@user_ns.route("/api/mcp_server/test")
+class TestMCPServerConfig(Resource):
+ @api.expect(
+ api.model(
+ "MCPServerTestModel",
+ {
+ "config": fields.Raw(
+ required=True, description="MCP server configuration to test"
+ ),
+ },
+ )
+ )
+ @api.doc(description="Test MCP server connection with provided configuration")
+ def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
+ data = request.get_json()
+
+ required_fields = ["config"]
+ missing_fields = check_required_fields(data, required_fields)
+ if missing_fields:
+ return missing_fields
+ try:
+ config = data["config"]
+
+ auth_credentials = {}
+ auth_type = config.get("auth_type", "none")
+
+ if auth_type == "api_key" and "api_key" in config:
+ auth_credentials["api_key"] = config["api_key"]
+ if "api_key_header" in config:
+ auth_credentials["api_key_header"] = config["api_key_header"]
+ elif auth_type == "bearer" and "bearer_token" in config:
+ auth_credentials["bearer_token"] = config["bearer_token"]
+ elif auth_type == "basic":
+ if "username" in config:
+ auth_credentials["username"] = config["username"]
+ if "password" in config:
+ auth_credentials["password"] = config["password"]
+
+ test_config = config.copy()
+ test_config["auth_credentials"] = auth_credentials
+
+ mcp_tool = MCPTool(test_config, user)
+ result = mcp_tool.test_connection()
+
+ return make_response(jsonify(result), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
+ return make_response(
+ jsonify(
+ {"success": False, "error": f"Connection test failed: {str(e)}"}
+ ),
+ 500,
+ )
+
+
+@user_ns.route("/api/mcp_server/save")
+class MCPServerSave(Resource):
+ @api.expect(
+ api.model(
+ "MCPServerSaveModel",
+ {
+ "id": fields.String(
+ required=False, description="Tool ID for updates (optional)"
+ ),
+ "displayName": fields.String(
+ required=True, description="Display name for the MCP server"
+ ),
+ "config": fields.Raw(
+ required=True, description="MCP server configuration"
+ ),
+ "status": fields.Boolean(
+ required=False, default=True, description="Tool status"
+ ),
+ },
+ )
+ )
+ @api.doc(description="Create or update MCP server with automatic tool discovery")
+ def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
+ data = request.get_json()
+
+ required_fields = ["displayName", "config"]
+ missing_fields = check_required_fields(data, required_fields)
+ if missing_fields:
+ return missing_fields
+ try:
+ config = data["config"]
+
+ auth_credentials = {}
+ auth_type = config.get("auth_type", "none")
+ if auth_type == "api_key":
+ if "api_key" in config and config["api_key"]:
+ auth_credentials["api_key"] = config["api_key"]
+ if "api_key_header" in config:
+ auth_credentials["api_key_header"] = config["api_key_header"]
+ elif auth_type == "bearer":
+ if "bearer_token" in config and config["bearer_token"]:
+ auth_credentials["bearer_token"] = config["bearer_token"]
+ elif auth_type == "basic":
+ if "username" in config and config["username"]:
+ auth_credentials["username"] = config["username"]
+ if "password" in config and config["password"]:
+ auth_credentials["password"] = config["password"]
+ mcp_config = config.copy()
+ mcp_config["auth_credentials"] = auth_credentials
+
+ if auth_type == "none" or auth_credentials:
+ mcp_tool = MCPTool(mcp_config, user)
+ mcp_tool.discover_tools()
+ actions_metadata = mcp_tool.get_actions_metadata()
+ else:
+ raise Exception(
+ "No valid credentials provided for the selected authentication type"
+ )
+
+ storage_config = config.copy()
+ if auth_credentials:
+ encrypted_credentials_string = encrypt_credentials(
+ auth_credentials, user
+ )
+ storage_config["encrypted_credentials"] = encrypted_credentials_string
+
+ for field in [
+ "api_key",
+ "bearer_token",
+ "username",
+ "password",
+ "api_key_header",
+ ]:
+ storage_config.pop(field, None)
+ transformed_actions = []
+ for action in actions_metadata:
+ action["active"] = True
+ if "parameters" in action:
+ if "properties" in action["parameters"]:
+ for param_name, param_details in action["parameters"][
+ "properties"
+ ].items():
+ param_details["filled_by_llm"] = True
+ param_details["value"] = ""
+ transformed_actions.append(action)
+ tool_data = {
+ "name": "mcp_tool",
+ "displayName": data["displayName"],
+ "customName": data["displayName"],
+ "description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
+ "config": storage_config,
+ "actions": transformed_actions,
+ "status": data.get("status", True),
+ "user": user,
+ }
+
+ tool_id = data.get("id")
+ if tool_id:
+ result = user_tools_collection.update_one(
+ {"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
+ {"$set": {k: v for k, v in tool_data.items() if k != "user"}},
+ )
+ if result.matched_count == 0:
+ return make_response(
+ jsonify(
+ {
+ "success": False,
+ "error": "Tool not found or access denied",
+ }
+ ),
+ 404,
+ )
+ response_data = {
+ "success": True,
+ "id": tool_id,
+ "message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
+ "tools_count": len(transformed_actions),
+ }
+ else:
+ result = user_tools_collection.insert_one(tool_data)
+ tool_id = str(result.inserted_id)
+ response_data = {
+ "success": True,
+ "id": tool_id,
+ "message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
+ "tools_count": len(transformed_actions),
+ }
+ return make_response(jsonify(response_data), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
+ return make_response(
+ jsonify(
+ {"success": False, "error": f"Failed to save MCP server: {str(e)}"}
+ ),
+ 500,
+ )
diff --git a/application/core/settings.py b/application/core/settings.py
index cb7d75e3..7ede4e86 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
- "gemini-2.0-flash-exp": 1e6,
+ "gemini-2.5-flash": 1e6,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
@@ -96,7 +96,7 @@ class Settings(BaseSettings):
QDRANT_HOST: Optional[str] = None
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
-
+
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
JWT_SECRET_KEY: str = ""
+ # Encryption settings
+ ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
+
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py
index 91065b74..b88e1d9f 100644
--- a/application/llm/google_ai.py
+++ b/application/llm/google_ai.py
@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
+ """Convert OpenAI format messages to Google AI format."""
cleaned_messages = []
for message in messages:
role = message.get("role")
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
+ elif role == "tool":
+ role = "model"
parts = []
if role and content is not None:
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
else:
raise ValueError(f"Unexpected content type: {type(content)}")
- cleaned_messages.append(types.Content(role=role, parts=parts))
+ if parts:
+ cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
+ def _clean_schema(self, schema_obj):
+ """
+ Recursively remove unsupported fields from schema objects
+ and validate required properties.
+ """
+ if not isinstance(schema_obj, dict):
+ return schema_obj
+ allowed_fields = {
+ "type",
+ "description",
+ "items",
+ "properties",
+ "required",
+ "enum",
+ "pattern",
+ "minimum",
+ "maximum",
+ "nullable",
+ "default",
+ }
+
+ cleaned = {}
+ for key, value in schema_obj.items():
+ if key not in allowed_fields:
+ continue
+ elif key == "type" and isinstance(value, str):
+ cleaned[key] = value.upper()
+ elif isinstance(value, dict):
+ cleaned[key] = self._clean_schema(value)
+ elif isinstance(value, list):
+ cleaned[key] = [self._clean_schema(item) for item in value]
+ else:
+ cleaned[key] = value
+
+ # Validate that required properties actually exist in properties
+ if "required" in cleaned and "properties" in cleaned:
+ valid_required = []
+ properties_keys = set(cleaned["properties"].keys())
+ for required_prop in cleaned["required"]:
+ if required_prop in properties_keys:
+ valid_required.append(required_prop)
+ if valid_required:
+ cleaned["required"] = valid_required
+ else:
+ cleaned.pop("required", None)
+ elif "required" in cleaned and "properties" not in cleaned:
+ cleaned.pop("required", None)
+
+ return cleaned
+
def _clean_tools_format(self, tools_list):
+ """Convert OpenAI format tools to Google AI format."""
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
properties = parameters.get("properties", {})
if properties:
+ cleaned_properties = {}
+ for k, v in properties.items():
+ cleaned_properties[k] = self._clean_schema(v)
+
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
- "properties": {
- k: {
- **v,
- "type": v["type"].upper() if v["type"] else None,
- }
- for k, v in properties.items()
- },
+ "properties": cleaned_properties,
"required": (
parameters["required"]
if "required" in parameters
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
response_schema=None,
**kwargs,
):
+ """Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
response_schema=None,
**kwargs,
):
+ """Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
yield chunk.text
def _supports_tools(self):
+ """Return whether this LLM supports function calling."""
return True
def _supports_structured_output(self):
+ """Return whether this LLM supports structured JSON output."""
return True
def prepare_structured_output_format(self, json_schema):
+ """Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py
index 43205472..96ed4c00 100644
--- a/application/llm/handlers/base.py
+++ b/application/llm/handlers/base.py
@@ -205,7 +205,6 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
-
updated_messages.append(
{
"role": "assistant",
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
)
updated_messages.append(self.create_tool_message(call, tool_response))
-
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
- updated_messages.append(
- {
- "role": "tool",
- "content": f"Error executing tool: {str(e)}",
- "tool_call_id": call.id,
- }
+ error_call = ToolCall(
+ id=call.id, name=call.name, arguments=call.arguments
)
+ error_response = f"Error executing tool: {str(e)}"
+ error_message = self.create_tool_message(error_call, error_response)
+ updated_messages.append(error_message)
+ call_parts = call.name.split("_")
+ if len(call_parts) >= 2:
+ tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
+ action_name = "_".join(call_parts[:-1])
+ tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
+ full_action_name = f"{action_name}_{tool_id}"
+ else:
+ tool_name = "unknown_tool"
+ action_name = call.name
+ full_action_name = call.name
+ yield {
+ "type": "tool_call",
+ "data": {
+ "tool_name": tool_name,
+ "call_id": call.id,
+ "action_name": full_action_name,
+ "arguments": call.arguments,
+ "error": error_response,
+ "status": "error",
+ },
+ }
return updated_messages
def handle_non_streaming(
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
except StopIteration as e:
messages = e.value
break
-
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
-
return parsed.content
def handle_streaming(
diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py
index b43f2a16..7fa44cb6 100644
--- a/application/llm/handlers/google.py
+++ b/application/llm/handlers/google.py
@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="stop",
raw_response=response,
)
-
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
-
else:
tool_calls = []
if hasattr(response, "function_call"):
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
- from google.genai import types
return {
- "role": "tool",
+ "role": "model",
"content": [
- types.Part.from_function_response(
- name=tool_call.name, response={"result": result}
- ).to_json_dict()
+ {
+ "function_response": {
+ "name": tool_call.name,
+ "response": {"result": result},
+ }
+ }
],
}
diff --git a/application/requirements.txt b/application/requirements.txt
index b7076ed8..80564689 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -2,6 +2,7 @@ anthropic==0.49.0
boto3==1.38.18
beautifulsoup4==4.13.4
celery==5.4.0
+cryptography==42.0.8
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==7.5.2
diff --git a/application/security/__init__.py b/application/security/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/application/security/encryption.py b/application/security/encryption.py
new file mode 100644
index 00000000..4cb3a4d5
--- /dev/null
+++ b/application/security/encryption.py
@@ -0,0 +1,85 @@
+import base64
+import json
+import os
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
+from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+
+from application.core.settings import settings
+
+
+def _derive_key(user_id: str, salt: bytes) -> bytes:
+ app_secret = settings.ENCRYPTION_SECRET_KEY
+
+ password = f"{app_secret}#{user_id}".encode()
+
+ kdf = PBKDF2HMAC(
+ algorithm=hashes.SHA256(),
+ length=32,
+ salt=salt,
+ iterations=100000,
+ backend=default_backend(),
+ )
+
+ return kdf.derive(password)
+
+
+def encrypt_credentials(credentials: dict, user_id: str) -> str:
+ if not credentials:
+ return ""
+ try:
+ salt = os.urandom(16)
+ iv = os.urandom(16)
+ key = _derive_key(user_id, salt)
+
+ json_str = json.dumps(credentials)
+
+ cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
+ encryptor = cipher.encryptor()
+
+ padded_data = _pad_data(json_str.encode())
+ encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
+
+ result = salt + iv + encrypted_data
+ return base64.b64encode(result).decode()
+ except Exception as e:
+ print(f"Warning: Failed to encrypt credentials: {e}")
+ return ""
+
+
+def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
+ if not encrypted_data:
+ return {}
+ try:
+ data = base64.b64decode(encrypted_data.encode())
+
+ salt = data[:16]
+ iv = data[16:32]
+ encrypted_content = data[32:]
+
+ key = _derive_key(user_id, salt)
+
+ cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
+ decryptor = cipher.decryptor()
+
+ decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
+ decrypted_data = _unpad_data(decrypted_padded)
+
+ return json.loads(decrypted_data.decode())
+ except Exception as e:
+ print(f"Warning: Failed to decrypt credentials: {e}")
+ return {}
+
+
+def _pad_data(data: bytes) -> bytes:
+ block_size = 16
+ padding_len = block_size - (len(data) % block_size)
+ padding = bytes([padding_len]) * padding_len
+ return data + padding
+
+
+def _unpad_data(data: bytes) -> bytes:
+ padding_len = data[-1]
+ return data[:-padding_len]
diff --git a/frontend/public/toolIcons/tool_mcp_tool.svg b/frontend/public/toolIcons/tool_mcp_tool.svg
new file mode 100644
index 00000000..22c980e3
--- /dev/null
+++ b/frontend/public/toolIcons/tool_mcp_tool.svg
@@ -0,0 +1,4 @@
+
\ No newline at end of file
diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts
index 955f43ee..dad008da 100644
--- a/frontend/src/api/endpoints.ts
+++ b/frontend/src/api/endpoints.ts
@@ -57,6 +57,8 @@ const endpoints = {
DIRECTORY_STRUCTURE: (docId: string) =>
`/api/directory_structure?id=${docId}`,
MANAGE_SOURCE_FILES: '/api/manage_source_files',
+ MCP_TEST_CONNECTION: '/api/mcp_server/test',
+ MCP_SAVE_SERVER: '/api/mcp_server/save',
},
CONVERSATION: {
ANSWER: '/api/answer',
diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts
index 1cb4bbd6..5dda8ddf 100644
--- a/frontend/src/api/services/userService.ts
+++ b/frontend/src/api/services/userService.ts
@@ -108,6 +108,10 @@ const userService = {
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise
+ + {toolCall.error} + +
+ )} diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 4ccb04a1..d962e4bc 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -4,5 +4,6 @@ export type ToolCallsType = { call_id: string; arguments: Record{errors.api_key}
+ )} +{errors.bearer_token}
+ )} +{errors.username}
+ )} +{errors.password}
+ )} +{errors.name}
+ )} ++ {errors.server_url} +
+ )} +{errors.timeout}
+ )} +- {t('settings.tools.authentication')} + {tool.name === 'mcp_tool' + ? (tool.config as any)?.auth_type === 'bearer' + ? 'Bearer Token' + : (tool.config as any)?.auth_type === 'api_key' + ? 'API Key' + : (tool.config as any)?.auth_type === 'basic' + ? 'Password' + : t('settings.tools.authentication') + : t('settings.tools.authentication')}
)}