diff --git a/application/agents/base.py b/application/agents/base.py index f2cabdb7..77729fe6 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -140,28 +140,28 @@ class BaseAgent(ABC): tool_id, action_name, call_args = parser.parse_args(call) call_id = getattr(call, "id", None) or str(uuid.uuid4()) - + # Check if parsing failed if tool_id is None or action_name is None: error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}" logger.error(error_message) - + tool_call_data = { "tool_name": "unknown", "call_id": call_id, - "action_name": getattr(call, 'name', 'unknown'), + "action_name": getattr(call, "name", "unknown"), "arguments": call_args or {}, "result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}", } yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) return "Failed to parse tool call.", call_id - + # Check if tool_id exists in available tools if tool_id not in tools_dict: error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}" logger.error(error_message) - + # Return error result tool_call_data = { "tool_name": "unknown", @@ -173,7 +173,7 @@ class BaseAgent(ABC): yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) return f"Tool with ID {tool_id} not found.", call_id - + tool_call_data = { "tool_name": tools_dict[tool_id]["name"], "call_id": call_id, @@ -225,6 +225,7 @@ class BaseAgent(ABC): if tool_data["name"] == "api_tool" else tool_data["config"] ), + user_id=self.user, # Pass user ID for MCP tools credential decryption ) if tool_data["name"] == "api_tool": print( diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py new file mode 100644 index 00000000..dc689367 --- /dev/null +++ b/application/agents/tools/mcp_tool.py @@ -0,0 +1,405 @@ +import json +import time +from typing import Any, Dict, List, Optional + +import requests + +from application.agents.tools.base import Tool +from application.security.encryption import decrypt_credentials + + +_mcp_session_cache = {} + + +class MCPTool(Tool): + """ + MCP Tool + Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol. + """ + + def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None): + """ + Initialize the MCP Tool with configuration. + + Args: + config: Dictionary containing MCP server configuration: + - server_url: URL of the remote MCP server + - auth_type: Type of authentication (api_key, bearer, basic, none) + - encrypted_credentials: Encrypted credentials (if available) + - timeout: Request timeout in seconds (default: 30) + user_id: User ID for decrypting credentials (required if encrypted_credentials exist) + """ + self.config = config + self.server_url = config.get("server_url", "") + self.auth_type = config.get("auth_type", "none") + self.timeout = config.get("timeout", 30) + + self.auth_credentials = {} + if config.get("encrypted_credentials") and user_id: + self.auth_credentials = decrypt_credentials( + config["encrypted_credentials"], user_id + ) + else: + self.auth_credentials = config.get("auth_credentials", {}) + self.available_tools = [] + self._session = requests.Session() + self._mcp_session_id = None + self._setup_authentication() + self._cache_key = self._generate_cache_key() + + def _setup_authentication(self): + """Setup authentication for the MCP server connection.""" + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + self._session.headers.update({header_name: api_key}) + elif self.auth_type == "bearer": + token = self.auth_credentials.get("bearer_token", "") + if token: + self._session.headers.update({"Authorization": f"Bearer {token}"}) + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + self._session.auth = (username, password) + + def _generate_cache_key(self) -> str: + """Generate a unique cache key for this MCP server configuration.""" + auth_key = "" + if self.auth_type == "bearer": + token = self.auth_credentials.get("bearer_token", "") + auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" + elif self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none" + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + auth_key = f"basic:{username}" + else: + auth_key = "none" + return f"{self.server_url}#{auth_key}" + + def _get_cached_session(self) -> Optional[str]: + """Get cached session ID if available and not expired.""" + global _mcp_session_cache + + if self._cache_key in _mcp_session_cache: + session_data = _mcp_session_cache[self._cache_key] + if time.time() - session_data["created_at"] < 1800: + return session_data["session_id"] + else: + del _mcp_session_cache[self._cache_key] + return None + + def _cache_session(self, session_id: str): + """Cache the session ID for reuse.""" + global _mcp_session_cache + _mcp_session_cache[self._cache_key] = { + "session_id": session_id, + "created_at": time.time(), + } + + def _initialize_mcp_connection(self) -> Dict: + """ + Initialize MCP connection with the server, using cached session if available. + + Returns: + Server capabilities and information + """ + cached_session = self._get_cached_session() + if cached_session: + self._mcp_session_id = cached_session + return {"cached": True} + try: + init_params = { + "protocolVersion": "2024-11-05", + "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, + "clientInfo": {"name": "DocsGPT", "version": "1.0.0"}, + } + response = self._make_mcp_request("initialize", init_params) + self._make_mcp_request("notifications/initialized") + + return response + except Exception as e: + return {"error": str(e), "fallback": True} + + def _ensure_valid_session(self): + """Ensure we have a valid MCP session, reinitializing if needed.""" + if not self._mcp_session_id: + self._initialize_mcp_connection() + + def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict: + """ + Make an MCP protocol request to the server with automatic session recovery. + + Args: + method: MCP method name (e.g., "tools/list", "tools/call") + params: Parameters for the MCP method + + Returns: + Response data as dictionary + + Raises: + Exception: If request fails after retry + """ + mcp_message = {"jsonrpc": "2.0", "method": method} + + if not method.startswith("notifications/"): + mcp_message["id"] = 1 + if params: + mcp_message["params"] = params + return self._execute_mcp_request(mcp_message, method) + + def _execute_mcp_request( + self, mcp_message: Dict, method: str, is_retry: bool = False + ) -> Dict: + """Execute MCP request with optional retry on session failure.""" + try: + final_headers = self._session.headers.copy() + final_headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + ) + + if self._mcp_session_id: + final_headers["Mcp-Session-Id"] = self._mcp_session_id + response = self._session.post( + self.server_url.rstrip("/"), + json=mcp_message, + headers=final_headers, + timeout=self.timeout, + ) + + if "mcp-session-id" in response.headers: + self._mcp_session_id = response.headers["mcp-session-id"] + self._cache_session(self._mcp_session_id) + response.raise_for_status() + + if method.startswith("notifications/"): + return {} + response_text = response.text.strip() + if response_text.startswith("event:") and "data:" in response_text: + lines = response_text.split("\n") + data_line = None + for line in lines: + if line.startswith("data:"): + data_line = line[5:].strip() + break + if data_line: + try: + result = json.loads(data_line) + except json.JSONDecodeError: + raise Exception(f"Invalid JSON in SSE data: {data_line}") + else: + raise Exception(f"No data found in SSE response: {response_text}") + else: + try: + result = response.json() + except json.JSONDecodeError: + raise Exception(f"Invalid JSON response: {response.text}") + if "error" in result: + error_msg = result["error"] + if isinstance(error_msg, dict): + error_msg = error_msg.get("message", str(error_msg)) + raise Exception(f"MCP server error: {error_msg}") + return result.get("result", result) + except requests.exceptions.RequestException as e: + if not is_retry and self._should_retry_with_new_session(e): + self._invalidate_and_refresh_session() + return self._execute_mcp_request(mcp_message, method, is_retry=True) + raise Exception(f"MCP server request failed: {str(e)}") + + def _should_retry_with_new_session(self, error: Exception) -> bool: + """Check if error indicates session invalidation and retry is warranted.""" + error_str = str(error).lower() + return ( + any( + indicator in error_str + for indicator in [ + "invalid session", + "session expired", + "unauthorized", + "401", + "403", + ] + ) + and self._mcp_session_id is not None + ) + + def _invalidate_and_refresh_session(self) -> None: + """Invalidate current session and create a new one.""" + global _mcp_session_cache + if self._cache_key in _mcp_session_cache: + del _mcp_session_cache[self._cache_key] + self._mcp_session_id = None + self._initialize_mcp_connection() + + def discover_tools(self) -> List[Dict]: + """ + Discover available tools from the MCP server using MCP protocol. + + Returns: + List of tool definitions from the server + """ + try: + self._ensure_valid_session() + + response = self._make_mcp_request("tools/list") + + # Handle both formats: response with 'tools' key or response that IS the tools list + + if isinstance(response, dict): + if "tools" in response: + self.available_tools = response["tools"] + elif ( + "result" in response + and isinstance(response["result"], dict) + and "tools" in response["result"] + ): + self.available_tools = response["result"]["tools"] + else: + self.available_tools = [response] if response else [] + elif isinstance(response, list): + self.available_tools = response + else: + self.available_tools = [] + return self.available_tools + except Exception as e: + raise Exception(f"Failed to discover tools from MCP server: {str(e)}") + + def execute_action(self, action_name: str, **kwargs) -> Any: + """ + Execute an action on the remote MCP server using MCP protocol. + + Args: + action_name: Name of the action to execute + **kwargs: Parameters for the action + + Returns: + Result from the MCP server + """ + self._ensure_valid_session() + + # Skipping empty/None values - letting the server use defaults + + cleaned_kwargs = {} + for key, value in kwargs.items(): + if value == "" or value is None: + continue + cleaned_kwargs[key] = value + call_params = {"name": action_name, "arguments": cleaned_kwargs} + try: + result = self._make_mcp_request("tools/call", call_params) + return result + except Exception as e: + raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + + def get_actions_metadata(self) -> List[Dict]: + """ + Get metadata for all available actions. + + Returns: + List of action metadata dictionaries + """ + actions = [] + for tool in self.available_tools: + input_schema = ( + tool.get("inputSchema") + or tool.get("input_schema") + or tool.get("schema") + or tool.get("parameters") + ) + + parameters_schema = { + "type": "object", + "properties": {}, + "required": [], + } + + if input_schema: + if isinstance(input_schema, dict): + if "properties" in input_schema: + parameters_schema = { + "type": input_schema.get("type", "object"), + "properties": input_schema.get("properties", {}), + "required": input_schema.get("required", []), + } + + for key in ["additionalProperties", "description"]: + if key in input_schema: + parameters_schema[key] = input_schema[key] + else: + parameters_schema["properties"] = input_schema + action = { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": parameters_schema, + } + actions.append(action) + return actions + + def test_connection(self) -> Dict: + """ + Test the connection to the MCP server and validate functionality. + + Returns: + Dictionary with connection test results including tool count + """ + try: + self._mcp_session_id = None + + init_result = self._initialize_mcp_connection() + + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if init_result.get("cached"): + message += " (Using cached session)" + elif init_result.get("fallback"): + message += " (No formal initialization required)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "session_id": self._mcp_session_id, + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + except Exception as e: + return { + "success": False, + "message": f"Connection failed: {str(e)}", + "tools_count": 0, + "error_type": type(e).__name__, + } + + def get_config_requirements(self) -> Dict: + return { + "server_url": { + "type": "string", + "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "required": True, + }, + "auth_type": { + "type": "string", + "description": "Authentication type", + "enum": ["none", "api_key", "bearer", "basic"], + "default": "none", + "required": True, + }, + "auth_credentials": { + "type": "object", + "description": "Authentication credentials (varies by auth_type)", + "required": False, + }, + "timeout": { + "type": "integer", + "description": "Request timeout in seconds", + "default": 30, + "minimum": 1, + "maximum": 300, + "required": False, + }, + } diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index ad71db28..d602b762 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -23,16 +23,23 @@ class ToolManager: tool_config = self.config.get(name, {}) self.tools[name] = obj(tool_config) - def load_tool(self, tool_name, tool_config): + def load_tool(self, tool_name, tool_config, user_id=None): self.config[tool_name] = tool_config module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: - return obj(tool_config) + if tool_name == "mcp_tool" and user_id: + return obj(tool_config, user_id) + else: + return obj(tool_config) - def execute_action(self, tool_name, action_name, **kwargs): + def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") + if tool_name == "mcp_tool" and user_id: + tool_config = self.config.get(tool_name, {}) + tool = self.load_tool(tool_name, tool_config, user_id) + return tool.execute_action(action_name, **kwargs) return self.tools[tool_name].execute_action(action_name, **kwargs) def get_all_actions_metadata(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index f508b7cf..f0493c7c 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -25,6 +25,8 @@ from flask_restx import fields, inputs, Namespace, Resource from pymongo import ReturnDocument from werkzeug.utils import secure_filename +from application.agents.tools.mcp_tool import MCPTool + from application.agents.tools.tool_manager import ToolManager from application.api import api @@ -38,6 +40,7 @@ from application.api.user.tasks import ( from application.core.mongo_db import MongoDB from application.core.settings import settings from application.parser.connectors.connector_creator import ConnectorCreator +from application.security.encryption import decrypt_credentials, encrypt_credentials from application.storage.storage_creator import StorageCreator from application.tts.google_tts import GoogleTTS from application.utils import ( @@ -491,6 +494,7 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) + storage = StorageCreator.get_storage() try: @@ -507,6 +511,7 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() + if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -515,6 +520,7 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) + except FileNotFoundError: pass except Exception as err: @@ -522,6 +528,7 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) + sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -593,6 +600,7 @@ class UploadFile(Resource): == temp_file_path ): continue + rel_path = os.path.relpath( os.path.join(root, extracted_file), temp_dir ) @@ -617,6 +625,7 @@ class UploadFile(Resource): file_path = f"{base_path}/{safe_file}" with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) + task = ingest.delay( settings.UPLOAD_FOLDER, [ @@ -688,6 +697,7 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Unauthorized"}), 401 ) + user = decoded_token.get("sub") source_id = request.form.get("source_id") operation = request.form.get("operation") @@ -737,6 +747,7 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Database error"}), 500 ) + try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") @@ -793,6 +804,7 @@ class ManageSourceFiles(Resource): ), 200, ) + elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: @@ -846,6 +858,7 @@ class ManageSourceFiles(Resource): ), 200, ) + elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: @@ -871,6 +884,7 @@ class ManageSourceFiles(Resource): ), 400, ) + full_directory_path = ( f"{source_file_path}/{directory_path}" if directory_path @@ -929,6 +943,7 @@ class ManageSourceFiles(Resource): ), 200, ) + except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" if operation == "remove_directory": @@ -940,6 +955,7 @@ class ManageSourceFiles(Resource): elif operation == "add": parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" + current_app.logger.error( f"Error managing source files: {err} ({error_context})", exc_info=True ) @@ -1616,6 +1632,7 @@ class CreateAgent(Resource): ), 400, ) + # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: @@ -3625,7 +3642,60 @@ class UpdateTool(Resource): ), 400, ) - update_data["config"] = data["config"] + tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(data["id"]), "user": user} + ) + if tool_doc and tool_doc.get("name") == "mcp_tool": + config = data["config"] + existing_config = tool_doc.get("config", {}) + storage_config = existing_config.copy() + + storage_config.update(config) + existing_credentials = {} + if "encrypted_credentials" in existing_config: + existing_credentials = decrypt_credentials( + existing_config["encrypted_credentials"], user + ) + auth_credentials = existing_credentials.copy() + auth_type = storage_config.get("auth_type", "none") + if auth_type == "api_key": + if "api_key" in config and config["api_key"]: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config[ + "api_key_header" + ] + elif auth_type == "bearer": + if "bearer_token" in config and config["bearer_token"]: + auth_credentials["bearer_token"] = config["bearer_token"] + elif "encrypted_token" in config and config["encrypted_token"]: + auth_credentials["bearer_token"] = config["encrypted_token"] + elif auth_type == "basic": + if "username" in config and config["username"]: + auth_credentials["username"] = config["username"] + if "password" in config and config["password"]: + auth_credentials["password"] = config["password"] + if auth_type != "none" and auth_credentials: + encrypted_credentials_string = encrypt_credentials( + auth_credentials, user + ) + storage_config["encrypted_credentials"] = ( + encrypted_credentials_string + ) + elif auth_type == "none": + storage_config.pop("encrypted_credentials", None) + for field in [ + "api_key", + "bearer_token", + "encrypted_token", + "username", + "password", + "api_key_header", + ]: + storage_config.pop(field, None) + update_data["config"] = storage_config + else: + update_data["config"] = data["config"] if "status" in data: update_data["status"] = data["status"] user_tools_collection.update_one( @@ -3837,6 +3907,7 @@ class GetChunks(Resource): if not (text_match or title_match): continue filtered_chunks.append(chunk) + chunks = filtered_chunks total_chunks = len(chunks) @@ -4027,6 +4098,7 @@ class UpdateChunk(Resource): current_app.logger.warning( f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" ) + return make_response( jsonify( { @@ -4154,19 +4226,23 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") doc_id = request.args.get("id") if not doc_id: return make_response(jsonify({"error": "Document ID is required"}), 400) + if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) + try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) + directory_structure = doc.get("directory_structure", {}) base_path = doc.get("file_path", "") @@ -4196,3 +4272,204 @@ class DirectoryStructure(Resource): f"Error retrieving directory structure: {e}", exc_info=True ) return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@user_ns.route("/api/mcp_server/test") +class TestMCPServerConfig(Resource): + @api.expect( + api.model( + "MCPServerTestModel", + { + "config": fields.Raw( + required=True, description="MCP server configuration to test" + ), + }, + ) + ) + @api.doc(description="Test MCP server connection with provided configuration") + def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + data = request.get_json() + + required_fields = ["config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + try: + config = data["config"] + + auth_credentials = {} + auth_type = config.get("auth_type", "none") + + if auth_type == "api_key" and "api_key" in config: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config["api_key_header"] + elif auth_type == "bearer" and "bearer_token" in config: + auth_credentials["bearer_token"] = config["bearer_token"] + elif auth_type == "basic": + if "username" in config: + auth_credentials["username"] = config["username"] + if "password" in config: + auth_credentials["password"] = config["password"] + + test_config = config.copy() + test_config["auth_credentials"] = auth_credentials + + mcp_tool = MCPTool(test_config, user) + result = mcp_tool.test_connection() + + return make_response(jsonify(result), 200) + except Exception as e: + current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True) + return make_response( + jsonify( + {"success": False, "error": f"Connection test failed: {str(e)}"} + ), + 500, + ) + + +@user_ns.route("/api/mcp_server/save") +class MCPServerSave(Resource): + @api.expect( + api.model( + "MCPServerSaveModel", + { + "id": fields.String( + required=False, description="Tool ID for updates (optional)" + ), + "displayName": fields.String( + required=True, description="Display name for the MCP server" + ), + "config": fields.Raw( + required=True, description="MCP server configuration" + ), + "status": fields.Boolean( + required=False, default=True, description="Tool status" + ), + }, + ) + ) + @api.doc(description="Create or update MCP server with automatic tool discovery") + def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + data = request.get_json() + + required_fields = ["displayName", "config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + try: + config = data["config"] + + auth_credentials = {} + auth_type = config.get("auth_type", "none") + if auth_type == "api_key": + if "api_key" in config and config["api_key"]: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config["api_key_header"] + elif auth_type == "bearer": + if "bearer_token" in config and config["bearer_token"]: + auth_credentials["bearer_token"] = config["bearer_token"] + elif auth_type == "basic": + if "username" in config and config["username"]: + auth_credentials["username"] = config["username"] + if "password" in config and config["password"]: + auth_credentials["password"] = config["password"] + mcp_config = config.copy() + mcp_config["auth_credentials"] = auth_credentials + + if auth_type == "none" or auth_credentials: + mcp_tool = MCPTool(mcp_config, user) + mcp_tool.discover_tools() + actions_metadata = mcp_tool.get_actions_metadata() + else: + raise Exception( + "No valid credentials provided for the selected authentication type" + ) + + storage_config = config.copy() + if auth_credentials: + encrypted_credentials_string = encrypt_credentials( + auth_credentials, user + ) + storage_config["encrypted_credentials"] = encrypted_credentials_string + + for field in [ + "api_key", + "bearer_token", + "username", + "password", + "api_key_header", + ]: + storage_config.pop(field, None) + transformed_actions = [] + for action in actions_metadata: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) + tool_data = { + "name": "mcp_tool", + "displayName": data["displayName"], + "customName": data["displayName"], + "description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}", + "config": storage_config, + "actions": transformed_actions, + "status": data.get("status", True), + "user": user, + } + + tool_id = data.get("id") + if tool_id: + result = user_tools_collection.update_one( + {"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}, + {"$set": {k: v for k, v in tool_data.items() if k != "user"}}, + ) + if result.matched_count == 0: + return make_response( + jsonify( + { + "success": False, + "error": "Tool not found or access denied", + } + ), + 404, + ) + response_data = { + "success": True, + "id": tool_id, + "message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.", + "tools_count": len(transformed_actions), + } + else: + result = user_tools_collection.insert_one(tool_data) + tool_id = str(result.inserted_id) + response_data = { + "success": True, + "id": tool_id, + "message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.", + "tools_count": len(transformed_actions), + } + return make_response(jsonify(response_data), 200) + except Exception as e: + current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True) + return make_response( + jsonify( + {"success": False, "error": f"Failed to save MCP server: {str(e)}"} + ), + 500, + ) diff --git a/application/core/settings.py b/application/core/settings.py index cb7d75e3..7ede4e86 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -26,7 +26,7 @@ class Settings(BaseSettings): "gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5, - "gemini-2.0-flash-exp": 1e6, + "gemini-2.5-flash": 1e6, } UPLOAD_FOLDER: str = "inputs" PARSE_PDF_AS_IMAGE: bool = False @@ -96,7 +96,7 @@ class Settings(BaseSettings): QDRANT_HOST: Optional[str] = None QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" - + # PGVector vectorstore config PGVECTOR_CONNECTION_STRING: Optional[str] = None # Milvus vectorstore config @@ -116,6 +116,9 @@ class Settings(BaseSettings): JWT_SECRET_KEY: str = "" + # Encryption settings + ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 91065b74..b88e1d9f 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM): raise def _clean_messages_google(self, messages): + """Convert OpenAI format messages to Google AI format.""" cleaned_messages = [] for message in messages: role = message.get("role") @@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM): if role == "assistant": role = "model" + elif role == "tool": + role = "model" parts = [] if role and content is not None: @@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM): else: raise ValueError(f"Unexpected content type: {type(content)}") - cleaned_messages.append(types.Content(role=role, parts=parts)) + if parts: + cleaned_messages.append(types.Content(role=role, parts=parts)) return cleaned_messages + def _clean_schema(self, schema_obj): + """ + Recursively remove unsupported fields from schema objects + and validate required properties. + """ + if not isinstance(schema_obj, dict): + return schema_obj + allowed_fields = { + "type", + "description", + "items", + "properties", + "required", + "enum", + "pattern", + "minimum", + "maximum", + "nullable", + "default", + } + + cleaned = {} + for key, value in schema_obj.items(): + if key not in allowed_fields: + continue + elif key == "type" and isinstance(value, str): + cleaned[key] = value.upper() + elif isinstance(value, dict): + cleaned[key] = self._clean_schema(value) + elif isinstance(value, list): + cleaned[key] = [self._clean_schema(item) for item in value] + else: + cleaned[key] = value + + # Validate that required properties actually exist in properties + if "required" in cleaned and "properties" in cleaned: + valid_required = [] + properties_keys = set(cleaned["properties"].keys()) + for required_prop in cleaned["required"]: + if required_prop in properties_keys: + valid_required.append(required_prop) + if valid_required: + cleaned["required"] = valid_required + else: + cleaned.pop("required", None) + elif "required" in cleaned and "properties" not in cleaned: + cleaned.pop("required", None) + + return cleaned + def _clean_tools_format(self, tools_list): + """Convert OpenAI format tools to Google AI format.""" genai_tools = [] for tool_data in tools_list: if tool_data["type"] == "function": @@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM): properties = parameters.get("properties", {}) if properties: + cleaned_properties = {} + for k, v in properties.items(): + cleaned_properties[k] = self._clean_schema(v) + genai_function = dict( name=function["name"], description=function["description"], parameters={ "type": "OBJECT", - "properties": { - k: { - **v, - "type": v["type"].upper() if v["type"] else None, - } - for k, v in properties.items() - }, + "properties": cleaned_properties, "required": ( parameters["required"] if "required" in parameters @@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API without streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API with streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM): yield chunk.text def _supports_tools(self): + """Return whether this LLM supports function calling.""" return True def _supports_structured_output(self): + """Return whether this LLM supports structured JSON output.""" return True def prepare_structured_output_format(self, json_schema): + """Convert JSON schema to Google AI structured output format.""" if not json_schema: return None diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 43205472..96ed4c00 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -205,7 +205,6 @@ class LLMHandler(ABC): except StopIteration as e: tool_response, call_id = e.value break - updated_messages.append( { "role": "assistant", @@ -222,17 +221,36 @@ class LLMHandler(ABC): ) updated_messages.append(self.create_tool_message(call, tool_response)) - except Exception as e: logger.error(f"Error executing tool: {str(e)}", exc_info=True) - updated_messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call.id, - } + error_call = ToolCall( + id=call.id, name=call.name, arguments=call.arguments ) + error_response = f"Error executing tool: {str(e)}" + error_message = self.create_tool_message(error_call, error_response) + updated_messages.append(error_message) + call_parts = call.name.split("_") + if len(call_parts) >= 2: + tool_id = call_parts[-1] # Last part is tool ID (e.g., "1") + action_name = "_".join(call_parts[:-1]) + tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool") + full_action_name = f"{action_name}_{tool_id}" + else: + tool_name = "unknown_tool" + action_name = call.name + full_action_name = call.name + yield { + "type": "tool_call", + "data": { + "tool_name": tool_name, + "call_id": call.id, + "action_name": full_action_name, + "arguments": call.arguments, + "error": error_response, + "status": "error", + }, + } return updated_messages def handle_non_streaming( @@ -263,13 +281,11 @@ class LLMHandler(ABC): except StopIteration as e: messages = e.value break - response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) parsed = self.parse_response(response) self.llm_calls.append(build_stack_data(agent.llm)) - return parsed.content def handle_streaming( diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py index b43f2a16..7fa44cb6 100644 --- a/application/llm/handlers/google.py +++ b/application/llm/handlers/google.py @@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler): finish_reason="stop", raw_response=response, ) - if hasattr(response, "candidates"): parts = response.candidates[0].content.parts if response.candidates else [] tool_calls = [ @@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler): finish_reason="tool_calls" if tool_calls else "stop", raw_response=response, ) - else: tool_calls = [] if hasattr(response, "function_call"): @@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler): def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: """Create Google-style tool message.""" - from google.genai import types return { - "role": "tool", + "role": "model", "content": [ - types.Part.from_function_response( - name=tool_call.name, response={"result": result} - ).to_json_dict() + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + } + } ], } diff --git a/application/requirements.txt b/application/requirements.txt index b7076ed8..80564689 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -2,6 +2,7 @@ anthropic==0.49.0 boto3==1.38.18 beautifulsoup4==4.13.4 celery==5.4.0 +cryptography==42.0.8 dataclasses-json==0.6.7 docx2txt==0.8 duckduckgo-search==7.5.2 diff --git a/application/security/__init__.py b/application/security/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/security/encryption.py b/application/security/encryption.py new file mode 100644 index 00000000..4cb3a4d5 --- /dev/null +++ b/application/security/encryption.py @@ -0,0 +1,85 @@ +import base64 +import json +import os + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from application.core.settings import settings + + +def _derive_key(user_id: str, salt: bytes) -> bytes: + app_secret = settings.ENCRYPTION_SECRET_KEY + + password = f"{app_secret}#{user_id}".encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + backend=default_backend(), + ) + + return kdf.derive(password) + + +def encrypt_credentials(credentials: dict, user_id: str) -> str: + if not credentials: + return "" + try: + salt = os.urandom(16) + iv = os.urandom(16) + key = _derive_key(user_id, salt) + + json_str = json.dumps(credentials) + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + padded_data = _pad_data(json_str.encode()) + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + + result = salt + iv + encrypted_data + return base64.b64encode(result).decode() + except Exception as e: + print(f"Warning: Failed to encrypt credentials: {e}") + return "" + + +def decrypt_credentials(encrypted_data: str, user_id: str) -> dict: + if not encrypted_data: + return {} + try: + data = base64.b64decode(encrypted_data.encode()) + + salt = data[:16] + iv = data[16:32] + encrypted_content = data[32:] + + key = _derive_key(user_id, salt) + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + + decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize() + decrypted_data = _unpad_data(decrypted_padded) + + return json.loads(decrypted_data.decode()) + except Exception as e: + print(f"Warning: Failed to decrypt credentials: {e}") + return {} + + +def _pad_data(data: bytes) -> bytes: + block_size = 16 + padding_len = block_size - (len(data) % block_size) + padding = bytes([padding_len]) * padding_len + return data + padding + + +def _unpad_data(data: bytes) -> bytes: + padding_len = data[-1] + return data[:-padding_len] diff --git a/frontend/public/toolIcons/tool_mcp_tool.svg b/frontend/public/toolIcons/tool_mcp_tool.svg new file mode 100644 index 00000000..22c980e3 --- /dev/null +++ b/frontend/public/toolIcons/tool_mcp_tool.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 955f43ee..dad008da 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -57,6 +57,8 @@ const endpoints = { DIRECTORY_STRUCTURE: (docId: string) => `/api/directory_structure?id=${docId}`, MANAGE_SOURCE_FILES: '/api/manage_source_files', + MCP_TEST_CONNECTION: '/api/mcp_server/test', + MCP_SAVE_SERVER: '/api/mcp_server/save', }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 1cb4bbd6..5dda8ddf 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -108,6 +108,10 @@ const userService = { apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token), manageSourceFiles: (data: FormData, token: string | null): Promise => apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token), + testMCPConnection: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token), + saveMCPServer: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token), syncConnector: ( docId: string, provider: string, diff --git a/frontend/src/assets/server.svg b/frontend/src/assets/server.svg new file mode 100644 index 00000000..e69de29b diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 3be40df7..bbdf5e00 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -1,6 +1,6 @@ import 'katex/dist/katex.min.css'; -import { forwardRef, Fragment, useRef, useState, useEffect } from 'react'; +import { forwardRef, Fragment, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import ReactMarkdown from 'react-markdown'; import { useSelector } from 'react-redux'; @@ -12,12 +12,13 @@ import { import rehypeKatex from 'rehype-katex'; import remarkGfm from 'remark-gfm'; import remarkMath from 'remark-math'; -import DocumentationDark from '../assets/documentation-dark.svg'; + import ChevronDown from '../assets/chevron-down.svg'; import Cloud from '../assets/cloud.svg'; import DocsGPT3 from '../assets/cute_docsgpt3.svg'; import Dislike from '../assets/dislike.svg?react'; import Document from '../assets/document.svg'; +import DocumentationDark from '../assets/documentation-dark.svg'; import Edit from '../assets/edit.svg'; import Like from '../assets/like.svg?react'; import Link from '../assets/link.svg'; @@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { Response {' '}

{toolCall.status === 'pending' && ( @@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {

)} + {toolCall.status === 'error' && ( +

+ + {toolCall.error} + +

+ )} diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 4ccb04a1..d962e4bc 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -4,5 +4,6 @@ export type ToolCallsType = { call_id: string; arguments: Record; result?: Record; - status?: 'pending' | 'completed'; + error?: string; + status?: 'pending' | 'completed' | 'error'; }; diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index 004e4bb1..17568c59 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -18,7 +18,10 @@ export default function useDefaultDocument() { const fetchDocs = () => { getDocs(token).then((data) => { dispatch(setSourceDocs(data)); - if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0)) + if ( + !selectedDoc || + (Array.isArray(selectedDoc) && selectedDoc.length === 0) + ) Array.isArray(data) && data?.forEach((doc: Doc) => { if (doc.model && doc.name === 'default') { diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 39e2bee7..3992f043 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -184,7 +184,39 @@ "cancel": "Cancel", "addNew": "Add New", "name": "Name", - "type": "Type" + "type": "Type", + "mcp": { + "addServer": "Add MCP Server", + "editServer": "Edit Server", + "serverName": "Server Name", + "serverUrl": "Server URL", + "headerName": "Header Name", + "timeout": "Timeout (seconds)", + "testConnection": "Test Connection", + "testing": "Testing...", + "saving": "Saving...", + "save": "Save", + "cancel": "Cancel", + "noAuth": "No Authentication", + "placeholders": { + "serverUrl": "https://api.example.com", + "apiKey": "Your secret API key", + "bearerToken": "Your secret token", + "username": "Your username", + "password": "Your password" + }, + "errors": { + "nameRequired": "Server name is required", + "urlRequired": "Server URL is required", + "invalidUrl": "Please enter a valid URL", + "apiKeyRequired": "API key is required", + "tokenRequired": "Bearer token is required", + "usernameRequired": "Username is required", + "passwordRequired": "Password is required", + "testFailed": "Connection test failed", + "saveFailed": "Failed to save MCP server" + } + } } }, "modals": { diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx index 885ef467..28766bd1 100644 --- a/frontend/src/modals/AddToolModal.tsx +++ b/frontend/src/modals/AddToolModal.tsx @@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks'; import { ActiveState } from '../models/misc'; import { selectToken } from '../preferences/preferenceSlice'; import ConfigToolModal from './ConfigToolModal'; +import MCPServerModal from './MCPServerModal'; import { AvailableToolType } from './types'; import WrapperComponent from './WrapperModal'; @@ -34,6 +35,8 @@ export default function AddToolModal({ React.useState(null); const [configModalState, setConfigModalState] = React.useState('INACTIVE'); + const [mcpModalState, setMcpModalState] = + React.useState('INACTIVE'); const [loading, setLoading] = React.useState(false); useOutsideAlerter(modalRef, () => { @@ -86,6 +89,9 @@ export default function AddToolModal({ .catch((error) => { console.error('Failed to create tool:', error); }); + } else if (tool.name === 'mcp_tool') { + setModalState('INACTIVE'); + setMcpModalState('ACTIVE'); } else { setModalState('INACTIVE'); setConfigModalState('ACTIVE'); @@ -95,6 +101,12 @@ export default function AddToolModal({ React.useEffect(() => { if (modalState === 'ACTIVE') getAvailableTools(); }, [modalState]); + + const handleMcpServerAdded = () => { + getUserTools(); + setMcpModalState('INACTIVE'); + }; + return ( <> {modalState === 'ACTIVE' && ( @@ -166,6 +178,11 @@ export default function AddToolModal({ tool={selectedTool} getUserTools={getUserTools} /> + ); } diff --git a/frontend/src/modals/MCPServerModal.tsx b/frontend/src/modals/MCPServerModal.tsx new file mode 100644 index 00000000..5e916210 --- /dev/null +++ b/frontend/src/modals/MCPServerModal.tsx @@ -0,0 +1,482 @@ +import { useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; + +import userService from '../api/services/userService'; +import Dropdown from '../components/Dropdown'; +import Input from '../components/Input'; +import Spinner from '../components/Spinner'; +import { useOutsideAlerter } from '../hooks'; +import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; +import WrapperComponent from './WrapperModal'; + +interface MCPServerModalProps { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + server?: any; + onServerSaved: () => void; +} + +const authTypes = [ + { label: 'No Authentication', value: 'none' }, + { label: 'API Key', value: 'api_key' }, + { label: 'Bearer Token', value: 'bearer' }, + // { label: 'Basic Authentication', value: 'basic' }, +]; + +export default function MCPServerModal({ + modalState, + setModalState, + server, + onServerSaved, +}: MCPServerModalProps) { + const { t } = useTranslation(); + const token = useSelector(selectToken); + const modalRef = useRef(null); + + const [formData, setFormData] = useState({ + name: server?.displayName || 'My MCP Server', + server_url: server?.server_url || '', + auth_type: server?.auth_type || 'none', + api_key: '', + header_name: 'X-API-Key', + bearer_token: '', + username: '', + password: '', + timeout: server?.timeout || 30, + }); + + const [loading, setLoading] = useState(false); + const [testing, setTesting] = useState(false); + const [testResult, setTestResult] = useState<{ + success: boolean; + message: string; + } | null>(null); + const [errors, setErrors] = useState<{ [key: string]: string }>({}); + + useOutsideAlerter(modalRef, () => { + if (modalState === 'ACTIVE') { + setModalState('INACTIVE'); + resetForm(); + } + }, [modalState]); + + const resetForm = () => { + setFormData({ + name: 'My MCP Server', + server_url: '', + auth_type: 'none', + api_key: '', + header_name: 'X-API-Key', + bearer_token: '', + username: '', + password: '', + timeout: 30, + }); + setErrors({}); + setTestResult(null); + }; + + const validateForm = () => { + const requiredFields: { [key: string]: boolean } = { + name: !formData.name.trim(), + server_url: !formData.server_url.trim(), + }; + + const authFieldChecks: { [key: string]: () => void } = { + api_key: () => { + if (!formData.api_key.trim()) + newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired'); + }, + bearer: () => { + if (!formData.bearer_token.trim()) + newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired'); + }, + basic: () => { + if (!formData.username.trim()) + newErrors.username = t('settings.tools.mcp.errors.usernameRequired'); + if (!formData.password.trim()) + newErrors.password = t('settings.tools.mcp.errors.passwordRequired'); + }, + }; + + const newErrors: { [key: string]: string } = {}; + Object.entries(requiredFields).forEach(([field, isEmpty]) => { + if (isEmpty) + newErrors[field] = t( + `settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`, + ); + }); + + if (formData.server_url.trim()) { + try { + new URL(formData.server_url); + } catch { + newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl'); + } + } + + const timeoutValue = formData.timeout === '' ? 30 : formData.timeout; + if ( + typeof timeoutValue === 'number' && + (timeoutValue < 1 || timeoutValue > 300) + ) + newErrors.timeout = 'Timeout must be between 1 and 300 seconds'; + + if (authFieldChecks[formData.auth_type]) + authFieldChecks[formData.auth_type](); + + setErrors(newErrors); + return Object.keys(newErrors).length === 0; + }; + + const handleInputChange = (name: string, value: string | number) => { + setFormData((prev) => ({ ...prev, [name]: value })); + if (errors[name]) { + setErrors((prev) => ({ ...prev, [name]: '' })); + } + setTestResult(null); + }; + + const buildToolConfig = () => { + const config: any = { + server_url: formData.server_url.trim(), + auth_type: formData.auth_type, + timeout: formData.timeout === '' ? 30 : formData.timeout, + }; + + if (formData.auth_type === 'api_key') { + config.api_key = formData.api_key.trim(); + config.api_key_header = formData.header_name.trim() || 'X-API-Key'; + } else if (formData.auth_type === 'bearer') { + config.bearer_token = formData.bearer_token.trim(); + } else if (formData.auth_type === 'basic') { + config.username = formData.username.trim(); + config.password = formData.password.trim(); + } + return config; + }; + + const testConnection = async () => { + if (!validateForm()) return; + setTesting(true); + setTestResult(null); + try { + const config = buildToolConfig(); + const response = await userService.testMCPConnection({ config }, token); + const result = await response.json(); + + setTestResult(result); + } catch (error) { + setTestResult({ + success: false, + message: t('settings.tools.mcp.errors.testFailed'), + }); + } finally { + setTesting(false); + } + }; + + const handleSave = async () => { + if (!validateForm()) return; + setLoading(true); + try { + const config = buildToolConfig(); + const serverData = { + displayName: formData.name, + config, + status: true, + ...(server?.id && { id: server.id }), + }; + + const response = await userService.saveMCPServer(serverData, token); + const result = await response.json(); + + if (response.ok && result.success) { + setTestResult({ + success: true, + message: result.message, + }); + onServerSaved(); + setModalState('INACTIVE'); + resetForm(); + } else { + setErrors({ + general: result.error || t('settings.tools.mcp.errors.saveFailed'), + }); + } + } catch (error) { + console.error('Error saving MCP server:', error); + setErrors({ general: t('settings.tools.mcp.errors.saveFailed') }); + } finally { + setLoading(false); + } + }; + + const renderAuthFields = () => { + switch (formData.auth_type) { + case 'api_key': + return ( +
+
+ handleInputChange('api_key', e.target.value)} + placeholder={t('settings.tools.mcp.placeholders.apiKey')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.api_key && ( +

{errors.api_key}

+ )} +
+
+ + handleInputChange('header_name', e.target.value) + } + placeholder={t('settings.tools.mcp.headerName')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> +
+
+ ); + case 'bearer': + return ( +
+ + handleInputChange('bearer_token', e.target.value) + } + placeholder={t('settings.tools.mcp.placeholders.bearerToken')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.bearer_token && ( +

{errors.bearer_token}

+ )} +
+ ); + case 'basic': + return ( +
+
+ handleInputChange('username', e.target.value)} + placeholder={t('settings.tools.mcp.username')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.username && ( +

{errors.username}

+ )} +
+
+ handleInputChange('password', e.target.value)} + placeholder={t('settings.tools.mcp.password')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.password && ( +

{errors.password}

+ )} +
+
+ ); + default: + return null; + } + }; + + return ( + modalState === 'ACTIVE' && ( + { + setModalState('INACTIVE'); + resetForm(); + }} + className="max-w-[600px] md:w-[80vw] lg:w-[60vw]" + > +
+
+

+ {server + ? t('settings.tools.mcp.editServer') + : t('settings.tools.mcp.addServer')} +

+
+
+
+
+ handleInputChange('name', e.target.value)} + borderVariant="thin" + placeholder={t('settings.tools.mcp.serverName')} + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.name && ( +

{errors.name}

+ )} +
+ +
+ + handleInputChange('server_url', e.target.value) + } + placeholder={t('settings.tools.mcp.serverUrl')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.server_url && ( +

+ {errors.server_url} +

+ )} +
+ + type.value === formData.auth_type) + ?.label || null + } + onSelect={(selection: { label: string; value: string }) => { + handleInputChange('auth_type', selection.value); + }} + options={authTypes} + size="w-full" + rounded="3xl" + border="border" + /> + + {renderAuthFields()} + +
+ { + const value = e.target.value; + if (value === '') { + handleInputChange('timeout', ''); + } else { + const numValue = parseInt(value); + if (!isNaN(numValue) && numValue >= 1) { + handleInputChange('timeout', numValue); + } + } + }} + placeholder={t('settings.tools.mcp.timeout')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.timeout && ( +

{errors.timeout}

+ )} +
+ + {testResult && ( +
+ {testResult.message} +
+ )} + {errors.general && ( +
+ {errors.general} +
+ )} +
+
+ +
+
+ + +
+ + +
+
+
+
+
+ ) + ); +} diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 40dc4bcc..4e5b5d00 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -92,7 +92,7 @@ export function getLocalApiKey(): string | null { export function getLocalRecentDocs(): Doc[] | null { const docs = localStorage.getItem('DocsGPTRecentDocs'); - return docs ? JSON.parse(docs) as Doc[] : null; + return docs ? (JSON.parse(docs) as Doc[]) : null; } export function getLocalPrompt(): string | null { diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx index 61a1d850..bca5c6ce 100644 --- a/frontend/src/settings/ToolConfig.tsx +++ b/frontend/src/settings/ToolConfig.tsx @@ -30,9 +30,22 @@ export default function ToolConfig({ handleGoBack: () => void; }) { const token = useSelector(selectToken); - const [authKey, setAuthKey] = React.useState( - 'token' in tool.config ? tool.config.token : '', - ); + const [authKey, setAuthKey] = React.useState(() => { + if (tool.name === 'mcp_tool') { + const config = tool.config as any; + if (config.auth_type === 'api_key') { + return config.api_key || ''; + } else if (config.auth_type === 'bearer') { + return config.encrypted_token || ''; + } else if (config.auth_type === 'basic') { + return config.password || ''; + } + return ''; + } else if ('token' in tool.config) { + return tool.config.token; + } + return ''; + }); const [customName, setCustomName] = React.useState( tool.customName || '', ); @@ -97,6 +110,26 @@ export default function ToolConfig({ }; const handleSaveChanges = () => { + let configToSave; + if (tool.name === 'api_tool') { + configToSave = tool.config; + } else if (tool.name === 'mcp_tool') { + configToSave = { ...tool.config } as any; + const mcpConfig = tool.config as any; + + if (authKey.trim()) { + if (mcpConfig.auth_type === 'api_key') { + configToSave.api_key = authKey; + } else if (mcpConfig.auth_type === 'bearer') { + configToSave.encrypted_token = authKey; + } else if (mcpConfig.auth_type === 'basic') { + configToSave.password = authKey; + } + } + } else { + configToSave = { token: authKey }; + } + userService .updateTool( { @@ -105,7 +138,7 @@ export default function ToolConfig({ displayName: tool.displayName, customName: customName, description: tool.description, - config: tool.name === 'api_tool' ? tool.config : { token: authKey }, + config: configToSave, actions: 'actions' in tool ? tool.actions : [], status: tool.status, }, @@ -196,7 +229,15 @@ export default function ToolConfig({
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (

- {t('settings.tools.authentication')} + {tool.name === 'mcp_tool' + ? (tool.config as any)?.auth_type === 'bearer' + ? 'Bearer Token' + : (tool.config as any)?.auth_type === 'api_key' + ? 'API Key' + : (tool.config as any)?.auth_type === 'basic' + ? 'Password' + : t('settings.tools.authentication') + : t('settings.tools.authentication')}

)}
@@ -208,7 +249,17 @@ export default function ToolConfig({ value={authKey} onChange={(e) => setAuthKey(e.target.value)} borderVariant="thin" - placeholder={t('modals.configTool.apiKeyPlaceholder')} + placeholder={ + tool.name === 'mcp_tool' + ? (tool.config as any)?.auth_type === 'bearer' + ? 'Bearer Token' + : (tool.config as any)?.auth_type === 'api_key' + ? 'API Key' + : (tool.config as any)?.auth_type === 'basic' + ? 'Password' + : t('modals.configTool.apiKeyPlaceholder') + : t('modals.configTool.apiKeyPlaceholder') + } />
)} @@ -450,6 +501,26 @@ export default function ToolConfig({ setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')} submitLabel={t('settings.tools.saveAndLeave')} handleSubmit={() => { + let configToSave; + if (tool.name === 'api_tool') { + configToSave = tool.config; + } else if (tool.name === 'mcp_tool') { + configToSave = { ...tool.config } as any; + const mcpConfig = tool.config as any; + + if (authKey.trim()) { + if (mcpConfig.auth_type === 'api_key') { + configToSave.api_key = authKey; + } else if (mcpConfig.auth_type === 'bearer') { + configToSave.encrypted_token = authKey; + } else if (mcpConfig.auth_type === 'basic') { + configToSave.password = authKey; + } + } + } else { + configToSave = { token: authKey }; + } + userService .updateTool( { @@ -458,10 +529,7 @@ export default function ToolConfig({ displayName: tool.displayName, customName: customName, description: tool.description, - config: - tool.name === 'api_tool' - ? tool.config - : { token: authKey }, + config: configToSave, actions: 'actions' in tool ? tool.actions : [], status: tool.status, },